abhaynb92 commited on
Commit
6a5f60d
·
verified ·
1 Parent(s): 7cfcb41

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -8
model.py CHANGED
@@ -5,17 +5,20 @@ from transformers import AutoModel
5
  class EmotionClassifier(nn.Module):
6
  def __init__(self, model_name="microsoft/deberta-v3-base"):
7
  super().__init__()
8
- self.backbone = AutoModel.from_pretrained(model_name)
9
- hidden_size = self.backbone.config.hidden_size
10
- # Must be named `out` because your trained weights use out.weight & out.bias
11
- self.out = nn.Linear(hidden_size, 5)
 
 
 
12
 
13
  def forward(self, input_ids, attention_mask):
14
- outputs = self.backbone(
15
  input_ids=input_ids,
16
  attention_mask=attention_mask
17
  )
18
- last_hidden_state = outputs.last_hidden_state
19
- cls_token = last_hidden_state[:, 0, :]
20
- logits = self.out(cls_token)
21
  return logits
 
5
  class EmotionClassifier(nn.Module):
6
  def __init__(self, model_name="microsoft/deberta-v3-base"):
7
  super().__init__()
8
+ # IMPORTANT: use the SAME NAME you used during training
9
+ self.transformer = AutoModel.from_pretrained(model_name)
10
+
11
+ hidden = self.transformer.config.hidden_size
12
+
13
+ # IMPORTANT: your saved checkpoint uses out.weight & out.bias
14
+ self.out = nn.Linear(hidden, 5)
15
 
16
  def forward(self, input_ids, attention_mask):
17
+ outputs = self.transformer(
18
  input_ids=input_ids,
19
  attention_mask=attention_mask
20
  )
21
+
22
+ cls_rep = outputs.last_hidden_state[:, 0, :]
23
+ logits = self.out(cls_rep)
24
  return logits