taegyeonglee commited on
Commit
c62ec7b
ยท
verified ยท
1 Parent(s): 8643c84

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -0
README.md CHANGED
@@ -81,6 +81,106 @@ base_model:
81
  - ๋งค์ˆ˜/๋งค๋„/๋ณด์œ /ํšŒํ”ผ/์งˆ๋ฌธ/์ •๋ณด.
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ---
85
 
86
  ### ์˜ˆ์‹œ
 
81
  - ๋งค์ˆ˜/๋งค๋„/๋ณด์œ /ํšŒํ”ผ/์งˆ๋ฌธ/์ •๋ณด.
82
 
83
 
84
+ ---
85
+ ## How to use the model
86
+ ```
87
+ import torch
88
+ import torch.nn as nn
89
+ import numpy as np
90
+ from transformers import AutoTokenizer, AutoModel
91
+ from huggingface_hub import hf_hub_download
92
+
93
+ # ---- ์ƒ์ˆ˜ ์ •์˜ ----
94
+ REPO_ID = "langquant/LQ-Kbert-base"
95
+ CKPT_RELPATH = "model/lq-kbert-base.pt"
96
+
97
+ SENTI_MAP = {'strong_pos':0,'weak_pos':1,'neutral':2,'weak_neg':3,'strong_neg':4}
98
+ ACT_MAP = {'buy':0,'hold':1,'sell':2,'avoid':3,'info_only':4,'ask_info':5}
99
+ EMO_LIST = ['greed','fear','confidence','doubt','anger','hope','sarcasm']
100
+ IDX2SENTI = {v:k for k,v in SENTI_MAP.items()}
101
+ IDX2ACT = {v:k for k,v in ACT_MAP.items()}
102
+
103
+ def sigmoid(x): return 1/(1+np.exp(-x))
104
+
105
+ # ---- ๋ชจ๋ธ ์ •์˜ ----
106
+ class KbertMTL(nn.Module):
107
+ def __init__(self, base_model, hidden=768):
108
+ super().__init__()
109
+ self.bert = base_model
110
+ self.head_senti = nn.Linear(hidden, 5)
111
+ self.head_act = nn.Linear(hidden, 6)
112
+ self.head_emo = nn.Linear(hidden, 7)
113
+ self.head_reg = nn.Linear(hidden, 3)
114
+ self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
115
+
116
+ def forward(self, input_ids, attention_mask, token_type_ids=None):
117
+ kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
118
+ if self.has_token_type and token_type_ids is not None:
119
+ kwargs["token_type_ids"] = token_type_ids
120
+ out = self.bert(**kwargs)
121
+ h = out.last_hidden_state[:, 0] # [CLS]
122
+ return {
123
+ "logits_senti": self.head_senti(h),
124
+ "logits_act": self.head_act(h),
125
+ "logits_emo": self.head_emo(h),
126
+ "pred_reg": self.head_reg(h)
127
+ }
128
+
129
+ # ---- ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ ----
130
+ def load_ckpt_from_hub():
131
+ path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_RELPATH)
132
+ obj = torch.load(path, map_location="cpu")
133
+ return obj
134
+
135
+ # ---- ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๊ตฌ์„ฑ ----
136
+ def build_model_and_tokenizer(ckpt_obj, hidden=768):
137
+ model_name = ckpt_obj.get("model_name", "klue/bert-base")
138
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
139
+ base = AutoModel.from_pretrained(model_name)
140
+ model = KbertMTL(base_model=base, hidden=hidden)
141
+ state_dict = ckpt_obj["state_dict"] if "state_dict" in ckpt_obj else ckpt_obj
142
+ model.load_state_dict(state_dict, strict=False)
143
+ emo_thr = float(ckpt_obj.get("emo_threshold", 0.5))
144
+ return model, tokenizer, emo_thr
145
+
146
+ # ---- ์ถ”๋ก  ----
147
+ @torch.no_grad()
148
+ def predict(text, model, tokenizer, device="cpu", max_len=200, emo_threshold=0.5):
149
+ model.to(device).eval()
150
+ enc = tokenizer([text], padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
151
+ out = model(**enc)
152
+
153
+ senti = out["logits_senti"].argmax(-1).item()
154
+ act = out["logits_act"].argmax(-1).item()
155
+ emo_p = sigmoid(out["logits_emo"].cpu().numpy())[0]
156
+ reg = out["pred_reg"].cpu().numpy()[0]
157
+
158
+ emos = [EMO_LIST[i] for i,p in enumerate(emo_p) if p >= emo_threshold]
159
+
160
+ return {
161
+ "text": text,
162
+ "pred_sentiment_strength": IDX2SENTI[senti],
163
+ "pred_action_signal": IDX2ACT[act],
164
+ "pred_emotions": emos,
165
+ "pred_certainty": float(np.clip(reg[0], 0, 1)),
166
+ "pred_relevance": float(np.clip(reg[1], 0, 1)),
167
+ "pred_toxicity": float(np.clip(reg[2], 0, 1)),
168
+ }
169
+
170
+ # ---- ๋ฉ”์ธ ----
171
+ if __name__ == "__main__":
172
+ text = input("๋ถ„์„ํ•  ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”: ").strip()
173
+ print("[๋ชจ๋ธ ๋กœ๋“œ ์ค‘...]")
174
+ ckpt = load_ckpt_from_hub()
175
+ model, tokenizer, emo_thr = build_model_and_tokenizer(ckpt)
176
+
177
+ print("[์ถ”๋ก  ์ค‘...]")
178
+ result = predict(text, model, tokenizer, device="cuda" if torch.cuda.is_available() else "cpu", emo_threshold=emo_thr)
179
+
180
+ print("\n=== ๊ฒฐ๊ณผ ===")
181
+ for k,v in result.items():
182
+ print(f"{k}: {v}")
183
+ ```
184
  ---
185
 
186
  ### ์˜ˆ์‹œ