Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +53 -0
- main.py +41 -0
- model.py +371 -0
- switch_transformer.pt +3 -0
app.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from main import tokenizer, model, device
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def qa_pipeline(text,question):
|
| 6 |
+
inputs = tokenizer(question, text, return_tensors="pt")
|
| 7 |
+
input_ids = inputs['input_ids'].to(device)
|
| 8 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 9 |
+
batch = {
|
| 10 |
+
"input_ids": input_ids,
|
| 11 |
+
"attention_mask": attention_mask
|
| 12 |
+
}
|
| 13 |
+
outputs = model(batch)
|
| 14 |
+
|
| 15 |
+
start_index = torch.argmax(outputs.start_logits, dim=-1).item()
|
| 16 |
+
end_index = torch.argmax(outputs.end_logits, dim=-1).item()
|
| 17 |
+
|
| 18 |
+
predict_answer_tokens = inputs.input_ids[0, start_index : end_index + 1]
|
| 19 |
+
return tokenizer.decode(predict_answer_tokens)
|
| 20 |
+
|
| 21 |
+
def answer_question(context, question):
|
| 22 |
+
result = qa_pipeline(context, question)
|
| 23 |
+
return result
|
| 24 |
+
|
| 25 |
+
example_contexts = [
|
| 26 |
+
"Қазақстанның ұлттық құрамы алуан түрлі. Халықтың басым бөлігін тұрғылықты қазақ халқы құрайды, пайыздық үлесі — 70,18%[10], орыстар — 18,42%, өзбектер — 3,29%, украиндар — 1,36%, ұйғырлар — 1,48%, татарлар — 1,06%, басқа халықтар 5,38%.[11] Халықтың 75% астамын мұсылмандар құрайды, православты христиандар — 21%, қалғаны басқа да дін өкілдері.[12]",
|
| 27 |
+
"Қазақстан бес мемлекетпен шекаралас, соның ішінде әлемдегі құрлықтағы ең ұзын шекара, солтүстігінде және батысында Ресеймен — 7591 км құрайды. Оңтүстігінде: Түрікменстан — 426 км, Өзбекстан — 2354 км және Қырғызстан — 1241 км, ал шығысында: Қытаймен — 1782 км шектеседі. Жалпы құрлық шекарасының ұзындығы — 13394 км. Батыста Каспий көлімен (2000 км), оңтүстік батыста Арал теңізімен шайылады.[9] 2024 жылдың 1 наурыздағы елдегі тұрғындар саны — 20 075 271[4], бұл әлем бойынша 64-орын. Жер көлемі жағынан әлем елдерінің ішінде 9-орын алады (2 724 902 км²).",
|
| 28 |
+
"Қазақстан — 1995 жылғы 30 тамыздағы республикалық референдумда қабылданған Конституция бойынша — өзін демократиялы, зайырлы, құқықты және әлеуметті мемлекет ретінде орнықтырды. Қазақстан Республикасы – президенттік басқару формасындағы біртұтас мемлекет. Республиканың ең жоғарғы өкілді органы — Парламент. Ол республиканың заң шығару құзіретін жүзеге асырады."
|
| 29 |
+
]
|
| 30 |
+
example_questions = [
|
| 31 |
+
"Қазақстанның халқы неше пайызды қазақтар құрайды?",
|
| 32 |
+
"Қазақстан нешеу мемлекетпен шекаралас?",
|
| 33 |
+
"Қазақстандағы басқару формасы қандай?",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
examples = [[context, question] for context, question in zip(example_contexts, example_questions)]
|
| 38 |
+
|
| 39 |
+
# Создаем интерфейс
|
| 40 |
+
iface = gr.Interface(
|
| 41 |
+
fn=answer_question,
|
| 42 |
+
inputs=[
|
| 43 |
+
gr.Textbox(lines=10, label="Context"),
|
| 44 |
+
gr.Textbox(lines=2, label="Question")
|
| 45 |
+
],
|
| 46 |
+
outputs="text",
|
| 47 |
+
title="Question Answering Model",
|
| 48 |
+
description="Введите контекст и задайте вопрос, чтобы получить ответ.",
|
| 49 |
+
examples=examples
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Запускаем интерфейс
|
| 53 |
+
iface.launch()
|
main.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model import (
|
| 4 |
+
SwitchTransformer,
|
| 5 |
+
SwitchTransformerLayer,
|
| 6 |
+
MultiHeadAttention,
|
| 7 |
+
SwitchFeedForward,
|
| 8 |
+
FeedForward,
|
| 9 |
+
)
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
device = 'cpu'
|
| 13 |
+
|
| 14 |
+
ff = FeedForward(768, 768*4)
|
| 15 |
+
attn = MultiHeadAttention(8, 768, 0.2)
|
| 16 |
+
st_ff = SwitchFeedForward(
|
| 17 |
+
capacity_factor=1.25,
|
| 18 |
+
drop_tokens=False,
|
| 19 |
+
n_experts=4,
|
| 20 |
+
expert=ff,
|
| 21 |
+
d_model=768,
|
| 22 |
+
is_scale_prob=True,
|
| 23 |
+
)
|
| 24 |
+
st_layer = SwitchTransformerLayer(
|
| 25 |
+
d_model=768,
|
| 26 |
+
attn=attn,
|
| 27 |
+
feed_forward=st_ff,
|
| 28 |
+
dropout_prob=0.2
|
| 29 |
+
)
|
| 30 |
+
model = SwitchTransformer(
|
| 31 |
+
layer=st_layer,
|
| 32 |
+
n_layers=4,
|
| 33 |
+
n_experts=4,
|
| 34 |
+
device=device,
|
| 35 |
+
load_balancing_loss_ceof=0.05,
|
| 36 |
+
).to(device)
|
| 37 |
+
|
| 38 |
+
model.load_state_dict(torch.load("switch_transformer.pt"))
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")
|
| 40 |
+
|
| 41 |
+
|
model.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from turtle import forward
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
import math
|
| 8 |
+
from transformers import DistilBertForQuestionAnswering, DistilBertConfig
|
| 9 |
+
from transformers import AutoModelForQuestionAnswering
|
| 10 |
+
|
| 11 |
+
class MultiHeadAttention(nn.Module):
|
| 12 |
+
def __init__(self, n_heads, dim, dropout_prob):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
# self.n_heads = config.n_heads
|
| 16 |
+
# self.dim = config.dim
|
| 17 |
+
# self.dropout = nn.Dropout(p=config.attention_dropout)
|
| 18 |
+
|
| 19 |
+
self.n_heads = n_heads
|
| 20 |
+
self.dim = dim
|
| 21 |
+
self.dropout = nn.Dropout(p=dropout_prob)
|
| 22 |
+
|
| 23 |
+
assert self.dim % self.n_heads == 0
|
| 24 |
+
self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
|
| 25 |
+
self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
|
| 26 |
+
self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
|
| 27 |
+
self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
|
| 28 |
+
|
| 29 |
+
def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
|
| 30 |
+
"""
|
| 31 |
+
Parameters:
|
| 32 |
+
query: torch.tensor(bs, seq_length, dim)
|
| 33 |
+
key: torch.tensor(bs, seq_length, dim)
|
| 34 |
+
value: torch.tensor(bs, seq_length, dim)
|
| 35 |
+
mask: torch.tensor(bs, seq_length)
|
| 36 |
+
Returns:
|
| 37 |
+
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
| 38 |
+
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
| 39 |
+
"""
|
| 40 |
+
bs, q_length, dim = query.size()
|
| 41 |
+
k_length = key.size(1)
|
| 42 |
+
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
|
| 43 |
+
# assert key.size() == value.size()
|
| 44 |
+
|
| 45 |
+
dim_per_head = self.dim // self.n_heads
|
| 46 |
+
|
| 47 |
+
mask_reshp = (bs, 1, 1, k_length)
|
| 48 |
+
|
| 49 |
+
def shape(x):
|
| 50 |
+
"""separate heads"""
|
| 51 |
+
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
| 52 |
+
|
| 53 |
+
def unshape(x):
|
| 54 |
+
"""group heads"""
|
| 55 |
+
return (
|
| 56 |
+
x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
| 60 |
+
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
| 61 |
+
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
| 62 |
+
|
| 63 |
+
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
| 64 |
+
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
| 65 |
+
mask = (
|
| 66 |
+
(mask == 0).view(mask_reshp).expand_as(scores)
|
| 67 |
+
) # (bs, n_heads, q_length, k_length)
|
| 68 |
+
scores = scores.masked_fill(
|
| 69 |
+
mask, -float("inf")
|
| 70 |
+
) # (bs, n_heads, q_length, k_length)
|
| 71 |
+
|
| 72 |
+
weights = nn.functional.softmax(
|
| 73 |
+
scores, dim=-1
|
| 74 |
+
) # (bs, n_heads, q_length, k_length)
|
| 75 |
+
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
| 76 |
+
|
| 77 |
+
# Mask heads if we want to
|
| 78 |
+
if head_mask is not None:
|
| 79 |
+
weights = weights * head_mask
|
| 80 |
+
|
| 81 |
+
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
|
| 82 |
+
context = unshape(context) # (bs, q_length, dim)
|
| 83 |
+
context = self.out_lin(context) # (bs, q_length, dim)
|
| 84 |
+
|
| 85 |
+
if output_attentions:
|
| 86 |
+
return (context, weights)
|
| 87 |
+
else:
|
| 88 |
+
return context
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FeedForward(nn.Module):
|
| 92 |
+
def __init__(self, dim_input: int = 768, dim_feedforward: int = 4 * 768):
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
self.linear1 = nn.Linear(dim_input, dim_feedforward)
|
| 96 |
+
self.relu = nn.ReLU()
|
| 97 |
+
self.linear2 = nn.Linear(dim_feedforward, dim_input)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
return self.linear2(self.relu(self.linear1(x)))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SwitchFeedForward(nn.Module):
|
| 104 |
+
"""
|
| 105 |
+
## Routing among multiple FFNs
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
*,
|
| 111 |
+
capacity_factor: float,
|
| 112 |
+
drop_tokens: bool,
|
| 113 |
+
is_scale_prob: bool,
|
| 114 |
+
n_experts: int,
|
| 115 |
+
expert: FeedForward,
|
| 116 |
+
d_model: int
|
| 117 |
+
):
|
| 118 |
+
"""
|
| 119 |
+
* `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
|
| 120 |
+
* `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
|
| 121 |
+
* `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
|
| 122 |
+
* `n_experts` is the number of experts
|
| 123 |
+
* `expert` is the expert layer, a [FFN module](../feed_forward.html)
|
| 124 |
+
* `d_model` is the number of features in a token embedding
|
| 125 |
+
* `d_ff` is the number of features in the hidden layer of the FFN
|
| 126 |
+
* `dropout` is dropout probability in the FFN
|
| 127 |
+
"""
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
self.capacity_factor = capacity_factor
|
| 131 |
+
self.is_scale_prob = is_scale_prob
|
| 132 |
+
self.n_experts = n_experts
|
| 133 |
+
self.drop_tokens = drop_tokens
|
| 134 |
+
|
| 135 |
+
# make copies of the FFNs
|
| 136 |
+
self.experts = nn.ModuleList([copy.deepcopy(expert) for _ in range(n_experts)])
|
| 137 |
+
# Routing layer and softmax
|
| 138 |
+
self.switch = nn.Linear(d_model, n_experts)
|
| 139 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor):
|
| 142 |
+
"""
|
| 143 |
+
* `x` is the input to the switching module with shape `[seq_len, batch_size, d_model]`
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
# Capture the shape to change shapes later
|
| 147 |
+
seq_len, batch_size, d_model = x.shape
|
| 148 |
+
# Flatten the sequence and batch dimensions
|
| 149 |
+
x = x.view(-1, d_model)
|
| 150 |
+
|
| 151 |
+
# Get routing probabilities for each of the tokens.
|
| 152 |
+
# $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
|
| 153 |
+
# where $N$ is the number of experts `n_experts` and
|
| 154 |
+
# $h(\cdot)$ is the linear transformation of token embeddings.
|
| 155 |
+
route_prob = self.softmax(self.switch(x))
|
| 156 |
+
|
| 157 |
+
# Get the maximum routing probabilities and the routes.
|
| 158 |
+
# We route to the expert with highest probability
|
| 159 |
+
route_prob_max, routes = torch.max(route_prob, dim=-1)
|
| 160 |
+
|
| 161 |
+
# Get indexes of tokens going to each expert
|
| 162 |
+
indexes_list = [
|
| 163 |
+
torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
# Initialize an empty tensor to store outputs
|
| 167 |
+
final_output = x.new_zeros(x.shape)
|
| 168 |
+
|
| 169 |
+
# Capacity of each expert.
|
| 170 |
+
# $$\mathrm{expert\;capacity} =
|
| 171 |
+
# \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
|
| 172 |
+
# \times \mathrm{capacity\;factor}$$
|
| 173 |
+
capacity = int(self.capacity_factor * len(x) / self.n_experts)
|
| 174 |
+
# Number of tokens routed to each expert.
|
| 175 |
+
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
|
| 176 |
+
|
| 177 |
+
# Initialize an empty list of dropped tokens
|
| 178 |
+
dropped = []
|
| 179 |
+
# Only drop tokens if `drop_tokens` is `True`.
|
| 180 |
+
if self.drop_tokens:
|
| 181 |
+
# Drop tokens in each of the experts
|
| 182 |
+
for i in range(self.n_experts):
|
| 183 |
+
# Ignore if the expert is not over capacity
|
| 184 |
+
if len(indexes_list[i]) <= capacity:
|
| 185 |
+
continue
|
| 186 |
+
# Shuffle indexes before dropping
|
| 187 |
+
indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
|
| 188 |
+
# Collect the tokens over capacity as dropped tokens
|
| 189 |
+
dropped.append(indexes_list[i][capacity:])
|
| 190 |
+
# Keep only the tokens upto the capacity of the expert
|
| 191 |
+
indexes_list[i] = indexes_list[i][:capacity]
|
| 192 |
+
|
| 193 |
+
# Get outputs of the expert FFNs
|
| 194 |
+
expert_output = [
|
| 195 |
+
self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
# Assign to final output
|
| 199 |
+
for i in range(self.n_experts):
|
| 200 |
+
final_output[indexes_list[i], :] = expert_output[i]
|
| 201 |
+
|
| 202 |
+
# Pass through the dropped tokens
|
| 203 |
+
if dropped:
|
| 204 |
+
dropped = torch.cat(dropped)
|
| 205 |
+
final_output[dropped, :] = x[dropped, :]
|
| 206 |
+
|
| 207 |
+
if self.is_scale_prob:
|
| 208 |
+
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
|
| 209 |
+
final_output = final_output * route_prob_max.view(-1, 1)
|
| 210 |
+
else:
|
| 211 |
+
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
|
| 212 |
+
# (this is something we experimented with).
|
| 213 |
+
final_output = final_output * (
|
| 214 |
+
route_prob_max / route_prob_max.detach()
|
| 215 |
+
).view(-1, 1)
|
| 216 |
+
|
| 217 |
+
# Change the shape of the final output back to `[seq_len, batch_size, d_model]`
|
| 218 |
+
final_output = final_output.view(seq_len, batch_size, d_model)
|
| 219 |
+
|
| 220 |
+
# Return
|
| 221 |
+
#
|
| 222 |
+
# * the final output
|
| 223 |
+
# * number of tokens routed to each expert
|
| 224 |
+
# * sum of probabilities for each expert
|
| 225 |
+
# * number of tokens dropped.
|
| 226 |
+
# * routing probabilities of the selected experts
|
| 227 |
+
#
|
| 228 |
+
# These are used for the load balancing loss and logging
|
| 229 |
+
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class SwitchTransformerLayer(nn.Module):
|
| 233 |
+
"""
|
| 234 |
+
# Switch Transformer Block
|
| 235 |
+
This is the same as [normal transformer block](../models.html#TransformerLayer)
|
| 236 |
+
with handling extra outputs of switch feedforward module.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
*,
|
| 242 |
+
d_model: int,
|
| 243 |
+
attn: MultiHeadAttention,
|
| 244 |
+
feed_forward: SwitchFeedForward,
|
| 245 |
+
dropout_prob: float
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
* `d_model` is the token embedding size
|
| 249 |
+
* `attn` is the attention module
|
| 250 |
+
* `feed_forward` is the feed forward module (which is the switching module in this case)
|
| 251 |
+
* `dropout_prob` is the probability of dropping out after self attention and FFN
|
| 252 |
+
"""
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.size = d_model
|
| 255 |
+
self.attn = attn
|
| 256 |
+
self.feed_forward = feed_forward
|
| 257 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 258 |
+
self.norm_self_attn = nn.LayerNorm([d_model])
|
| 259 |
+
self.norm_ff = nn.LayerNorm([d_model])
|
| 260 |
+
|
| 261 |
+
def forward(self, *, x: torch.Tensor, mask: torch.Tensor):
|
| 262 |
+
# Normalize the vectors before doing self attention
|
| 263 |
+
z = self.norm_self_attn(x)
|
| 264 |
+
# Run through self attention, i.e. keys and values are from self
|
| 265 |
+
self_attn = self.attn(query=z, key=z, value=z, mask=mask)
|
| 266 |
+
# Add the self attention results
|
| 267 |
+
x = x + self.dropout(self_attn)
|
| 268 |
+
|
| 269 |
+
# Normalize for feed-forward
|
| 270 |
+
z = self.norm_ff(x)
|
| 271 |
+
# Pass through the switching feed-forward network
|
| 272 |
+
ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
|
| 273 |
+
# Add the feed-forward results back
|
| 274 |
+
x = x + self.dropout(ff)
|
| 275 |
+
|
| 276 |
+
return x, counts, route_prob, n_dropped, route_prob_max
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class SwitchTransformer(nn.Module):
|
| 280 |
+
"""
|
| 281 |
+
## Switch Transformer
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, layer, n_layers, n_experts, device, load_balancing_loss_ceof):
|
| 285 |
+
super().__init__()
|
| 286 |
+
# Make copies of the transformer layer
|
| 287 |
+
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
|
| 288 |
+
# Final normalization layer
|
| 289 |
+
self.norm = nn.LayerNorm([layer.size])
|
| 290 |
+
self.qa_outputs = nn.Linear(768, 2)
|
| 291 |
+
model = AutoModelForQuestionAnswering.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz").to(device)
|
| 292 |
+
self.base_model = model
|
| 293 |
+
self.device = device
|
| 294 |
+
self.load_balancing_loss_ceof = load_balancing_loss_ceof
|
| 295 |
+
self.n_experts = n_experts # used to calculate lb loss
|
| 296 |
+
|
| 297 |
+
def freeze_base_model(self):
|
| 298 |
+
for param in self.base_model.parameters():
|
| 299 |
+
param.requires_grad = False
|
| 300 |
+
|
| 301 |
+
def freeze_experts(self):
|
| 302 |
+
# TODO: find how to freeze the experts in the SwitchTransformer
|
| 303 |
+
pass
|
| 304 |
+
|
| 305 |
+
# def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 306 |
+
def forward(self, batch):
|
| 307 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 308 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 309 |
+
start_positions = (
|
| 310 |
+
batch["start_positions"].to(self.device)
|
| 311 |
+
if "start_positions" in batch.keys()
|
| 312 |
+
else None
|
| 313 |
+
)
|
| 314 |
+
end_positions = (
|
| 315 |
+
batch["end_positions"].to(self.device)
|
| 316 |
+
if "end_positions" in batch.keys()
|
| 317 |
+
else None
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
outputs = self.base_model(
|
| 321 |
+
input_ids,
|
| 322 |
+
attention_mask=attention_mask,
|
| 323 |
+
start_positions=None,
|
| 324 |
+
end_positions=None,
|
| 325 |
+
output_hidden_states=True,
|
| 326 |
+
)
|
| 327 |
+
x = outputs.hidden_states[-1]
|
| 328 |
+
# Run through each transformer layer
|
| 329 |
+
counts, route_prob, n_dropped, route_prob_max = [], [], [], []
|
| 330 |
+
for layer in self.layers:
|
| 331 |
+
x, f, p, n_d, p_max = layer(x=x, mask=attention_mask)
|
| 332 |
+
counts.append(f)
|
| 333 |
+
route_prob.append(p)
|
| 334 |
+
n_dropped.append(n_d)
|
| 335 |
+
route_prob_max.append(p_max)
|
| 336 |
+
# Finally, normalize the vectors
|
| 337 |
+
output = self.norm(x)
|
| 338 |
+
|
| 339 |
+
logits = self.qa_outputs(output)
|
| 340 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 341 |
+
start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
| 342 |
+
end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
| 343 |
+
|
| 344 |
+
loss = None
|
| 345 |
+
if start_positions is not None and end_positions is not None:
|
| 346 |
+
if len(start_positions.size()) > 1:
|
| 347 |
+
start_positions = start_positions.squeeze(-1)
|
| 348 |
+
if len(end_positions.size()) > 1:
|
| 349 |
+
end_positions = end_positions.squeeze(-1)
|
| 350 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 351 |
+
ignored_index = start_logits.size(1)
|
| 352 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 353 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 354 |
+
|
| 355 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
| 356 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 357 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 358 |
+
loss = (start_loss + end_loss) / 2
|
| 359 |
+
counts = torch.stack(counts)
|
| 360 |
+
route_prob = torch.stack(route_prob)
|
| 361 |
+
route_prob_max = torch.stack(route_prob_max)
|
| 362 |
+
total = counts.sum(dim=-1, keepdims=True)
|
| 363 |
+
route_frac = counts / total
|
| 364 |
+
route_prob = route_prob / total
|
| 365 |
+
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
|
| 366 |
+
loss = (
|
| 367 |
+
load_balancing_loss
|
| 368 |
+
if loss is None
|
| 369 |
+
else loss + self.load_balancing_loss_ceof * load_balancing_loss
|
| 370 |
+
)
|
| 371 |
+
return start_logits, end_logits, loss
|
switch_transformer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18db93cbc33e8aab35f5583010b67d2ca0c44cd93445e0bfd5d886382708d9ba
|
| 3 |
+
size 671685785
|