abhaynb92 commited on
Commit
7cfcb41
·
verified ·
1 Parent(s): 7ddef43
Files changed (1) hide show
  1. model.py +21 -19
model.py CHANGED
@@ -1,19 +1,21 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import AutoModel
4
-
5
- class EmotionClassifier(nn.Module):
6
- def __init__(self, model_name, num_labels=5):
7
- super().__init__()
8
- self.transformer = AutoModel.from_pretrained(model_name)
9
- hidden_size = self.transformer.config.hidden_size
10
- self.classifier = nn.Linear(hidden_size, num_labels)
11
-
12
- def forward(self, input_ids, attention_mask):
13
- outputs = self.transformer(
14
- input_ids=input_ids,
15
- attention_mask=attention_mask
16
- )
17
- cls_embeddings = outputs.last_hidden_state[:, 0, :]
18
- logits = self.classifier(cls_embeddings)
19
- return logits
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class EmotionClassifier(nn.Module):
6
+ def __init__(self, model_name="microsoft/deberta-v3-base"):
7
+ super().__init__()
8
+ self.backbone = AutoModel.from_pretrained(model_name)
9
+ hidden_size = self.backbone.config.hidden_size
10
+ # Must be named `out` because your trained weights use out.weight & out.bias
11
+ self.out = nn.Linear(hidden_size, 5)
12
+
13
+ def forward(self, input_ids, attention_mask):
14
+ outputs = self.backbone(
15
+ input_ids=input_ids,
16
+ attention_mask=attention_mask
17
+ )
18
+ last_hidden_state = outputs.last_hidden_state
19
+ cls_token = last_hidden_state[:, 0, :]
20
+ logits = self.out(cls_token)
21
+ return logits