minstradamus commited on
Commit
6e3d1dc
·
verified ·
1 Parent(s): 1142191

Update advice.py

Browse files
Files changed (1) hide show
  1. advice.py +13 -10
advice.py CHANGED
@@ -6,7 +6,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  from common import read_json_stdin, write_json_stdout, current_month_snapshot, clean_ru
8
 
9
- ALLOWED_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
10
 
11
  os.environ.setdefault("OMP_NUM_THREADS", "1")
12
  os.environ.setdefault("MKL_NUM_THREADS", "1")
@@ -28,13 +28,11 @@ def _load():
28
 
29
  _tokenizer = AutoTokenizer.from_pretrained(
30
  ALLOWED_MODEL_ID,
31
- trust_remote_code=True,
32
  )
33
  _model = AutoModelForCausalLM.from_pretrained(
34
  ALLOWED_MODEL_ID,
35
  torch_dtype=torch.float32,
36
  low_cpu_mem_usage=True,
37
- trust_remote_code=True,
38
  ).to(_DEVICE).eval()
39
 
40
  if _tokenizer.pad_token_id is None:
@@ -45,9 +43,13 @@ def _load():
45
 
46
 
47
  def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
 
48
  txt = tok.apply_chat_template(
49
- messages, tokenize=False, add_generation_prompt=True
 
 
50
  )
 
51
  inputs = tok(
52
  txt,
53
  return_tensors="pt",
@@ -76,11 +78,11 @@ def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
76
  **inputs,
77
  do_sample=True,
78
  temperature=0.8,
79
- top_p=0.9,
80
- top_k=50,
81
  **common,
82
  )
83
- return tok.decode(out[0], skip_special_tokens=True)
 
 
84
 
85
 
86
  _BULLET_KILL = re.compile(
@@ -94,7 +96,7 @@ def _to_bullets(text: str) -> str:
94
  return ""
95
  m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text)
96
  if m:
97
- text = text[m.start() :]
98
 
99
  text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M)
100
  text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M)
@@ -124,7 +126,7 @@ def main():
124
 
125
  tx = req.get("transactions") or []
126
  question = (req.get("question") or "").strip()
127
-
128
  df = pd.DataFrame(tx) if tx else None
129
  snap = current_month_snapshot(df) if df is not None and not df.empty else {}
130
 
@@ -145,7 +147,8 @@ def main():
145
 
146
  system_msg = (
147
  "Ты финансовый помощник. Отвечай по-русски. "
148
- "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами (лимиты, проценты, частота). "
 
149
  "Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
150
  )
151
  messages = [
 
6
 
7
  from common import read_json_stdin, write_json_stdout, current_month_snapshot, clean_ru
8
 
9
+ ALLOWED_MODEL_ID = "google/gemma-3-1b-it"
10
 
11
  os.environ.setdefault("OMP_NUM_THREADS", "1")
12
  os.environ.setdefault("MKL_NUM_THREADS", "1")
 
28
 
29
  _tokenizer = AutoTokenizer.from_pretrained(
30
  ALLOWED_MODEL_ID,
 
31
  )
32
  _model = AutoModelForCausalLM.from_pretrained(
33
  ALLOWED_MODEL_ID,
34
  torch_dtype=torch.float32,
35
  low_cpu_mem_usage=True,
 
36
  ).to(_DEVICE).eval()
37
 
38
  if _tokenizer.pad_token_id is None:
 
43
 
44
 
45
  def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
46
+
47
  txt = tok.apply_chat_template(
48
+ messages,
49
+ tokenize=False,
50
+ add_generation_prompt=True,
51
  )
52
+
53
  inputs = tok(
54
  txt,
55
  return_tensors="pt",
 
78
  **inputs,
79
  do_sample=True,
80
  temperature=0.8,
 
 
81
  **common,
82
  )
83
+
84
+ gen_ids = out[0, inputs["input_ids"].shape[-1]:]
85
+ return tok.decode(gen_ids, skip_special_tokens=True)
86
 
87
 
88
  _BULLET_KILL = re.compile(
 
96
  return ""
97
  m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text)
98
  if m:
99
+ text = text[m.start():]
100
 
101
  text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M)
102
  text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M)
 
126
 
127
  tx = req.get("transactions") or []
128
  question = (req.get("question") or "").strip()
129
+
130
  df = pd.DataFrame(tx) if tx else None
131
  snap = current_month_snapshot(df) if df is not None and not df.empty else {}
132
 
 
147
 
148
  system_msg = (
149
  "Ты финансовый помощник. Отвечай по-русски. "
150
+ "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами "
151
+ "(лимиты, проценты, частота). "
152
  "Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
153
  )
154
  messages = [