finwise-ml / advice.py
minstradamus's picture
Update advice.py
f55e107 verified
import os
import re
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from common import (
read_json_stdin,
write_json_stdout,
current_month_snapshot,
clean_ru,
)
ALLOWED_MODEL_ID = "google/gemma-3-1b-it"
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
try:
torch.set_num_threads(1)
except Exception:
pass
_DEVICE = torch.device("cpu")
_tokenizer = None
_model = None
_loaded = False
def _get_hf_token() -> str:
for name in ("HF_TOKEN", "token", "HF_HUB_TOKEN"):
val = os.getenv(name)
if val:
return val
raise RuntimeError(
"HF token not found in env. "
"Set HF_TOKEN (или token / HF_HUB_TOKEN) в секретах Space."
)
def _load():
global _tokenizer, _model, _loaded
if _loaded and _tokenizer is not None and _model is not None:
return _tokenizer, _model
hf_token = _get_hf_token()
_tokenizer = AutoTokenizer.from_pretrained(
ALLOWED_MODEL_ID,
token=hf_token,
trust_remote_code=True,
)
_model = AutoModelForCausalLM.from_pretrained(
ALLOWED_MODEL_ID,
token=hf_token,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to(_DEVICE).eval()
if _tokenizer.pad_token_id is None:
_tokenizer.pad_token_id = _tokenizer.eos_token_id
_loaded = True
return _tokenizer, _model
def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(_DEVICE)
with torch.no_grad():
common = dict(
max_new_tokens=max_new_tokens,
repetition_penalty=1.08 if det else 1.12,
no_repeat_ngram_size=5 if det else 6,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
if det:
out = mdl.generate(
**inputs,
do_sample=False,
num_beams=4,
**common,
)
else:
out = mdl.generate(
**inputs,
do_sample=True,
temperature=0.8,
**common,
)
gen_ids = out[0, inputs["input_ids"].shape[-1] :]
return tok.decode(gen_ids, skip_special_tokens=True)
_BULLET_KILL = re.compile(
r"(?i)(учитывай данные|данные пользователя|месяц:|доход:|расход:|нетто:|топ стат|вопрос:|assistant)"
)
_ONLY_PUNCT = re.compile(r"^[-•\s\.\,\;\:\!\?]+$")
def _to_bullets(text: str) -> str:
if not text:
return ""
m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text)
if m:
text = text[m.start() :]
text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M)
text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M)
uniq, seen = [], set()
for ln in text.split("\n"):
s = ln.strip()
if not s or not s.startswith("- "):
continue
if _BULLET_KILL.search(s) or _ONLY_PUNCT.match(s):
continue
s = re.sub(r"\s{2,}", " ", s)
s = re.sub(r"\.\s*\.+$", ".", s)
key = s.lower()
if key in seen:
continue
seen.add(key)
uniq.append(s)
if len(uniq) >= 7:
break
return "\n".join(s.replace("- ", "• ", 1) for s in uniq)
def main():
req = read_json_stdin()
tx = req.get("transactions") or []
question = (req.get("question") or "").strip()
df = pd.DataFrame(tx) if tx else None
snap = current_month_snapshot(df) if df is not None and not df.empty else {}
if snap:
ctx_lines = [
f"Месяц: {snap['month']}",
f"Доход: {snap['income_total']:.0f}",
f"Расход: {abs(snap['expense_total']):.0f}",
f"Нетто: {snap['net']:.0f}",
]
if snap.get("top_expense_categories"):
ctx_lines.append("Топ статей расходов:")
for cat, val in snap["top_expense_categories"]:
ctx_lines.append(f"- {cat}: {abs(val):.0f}")
context = "\n".join(ctx_lines)
else:
context = "Данных за текущий месяц нет."
system_msg = (
"Ты финансовый помощник. Отвечай по-русски. "
"Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами "
"(лимиты, проценты, частота). "
"Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
)
messages = [
{"role": "system", "content": system_msg},
{
"role": "user",
"content": (
f"Мои данные за текущий месяц:\n{context}\n\nВопрос: {question}\n"
"Начни ответ сразу со строки, которая начинается с \"- \". Верни только список."
),
},
]
tok, mdl = _load()
raw = _gen(messages, tok, mdl, det=True)
text = _to_bullets(clean_ru(raw))
if text.count("\n") + 1 < 3:
raw2 = _gen(messages, tok, mdl, det=False)
text2 = _to_bullets(clean_ru(raw2))
if text2:
text = text2
write_json_stdout({"advice": text})
if __name__ == "__main__":
main()