Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from common import ( | |
| read_json_stdin, | |
| write_json_stdout, | |
| current_month_snapshot, | |
| clean_ru, | |
| ) | |
| ALLOWED_MODEL_ID = "google/gemma-3-1b-it" | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("MKL_NUM_THREADS", "1") | |
| try: | |
| torch.set_num_threads(1) | |
| except Exception: | |
| pass | |
| _DEVICE = torch.device("cpu") | |
| _tokenizer = None | |
| _model = None | |
| _loaded = False | |
| def _get_hf_token() -> str: | |
| for name in ("HF_TOKEN", "token", "HF_HUB_TOKEN"): | |
| val = os.getenv(name) | |
| if val: | |
| return val | |
| raise RuntimeError( | |
| "HF token not found in env. " | |
| "Set HF_TOKEN (или token / HF_HUB_TOKEN) в секретах Space." | |
| ) | |
| def _load(): | |
| global _tokenizer, _model, _loaded | |
| if _loaded and _tokenizer is not None and _model is not None: | |
| return _tokenizer, _model | |
| hf_token = _get_hf_token() | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| ALLOWED_MODEL_ID, | |
| token=hf_token, | |
| trust_remote_code=True, | |
| ) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| ALLOWED_MODEL_ID, | |
| token=hf_token, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ).to(_DEVICE).eval() | |
| if _tokenizer.pad_token_id is None: | |
| _tokenizer.pad_token_id = _tokenizer.eos_token_id | |
| _loaded = True | |
| return _tokenizer, _model | |
| def _gen(messages, tok, mdl, max_new_tokens=200, det=True): | |
| inputs = tok.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(_DEVICE) | |
| with torch.no_grad(): | |
| common = dict( | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=1.08 if det else 1.12, | |
| no_repeat_ngram_size=5 if det else 6, | |
| eos_token_id=tok.eos_token_id, | |
| pad_token_id=tok.pad_token_id, | |
| ) | |
| if det: | |
| out = mdl.generate( | |
| **inputs, | |
| do_sample=False, | |
| num_beams=4, | |
| **common, | |
| ) | |
| else: | |
| out = mdl.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=0.8, | |
| **common, | |
| ) | |
| gen_ids = out[0, inputs["input_ids"].shape[-1] :] | |
| return tok.decode(gen_ids, skip_special_tokens=True) | |
| _BULLET_KILL = re.compile( | |
| r"(?i)(учитывай данные|данные пользователя|месяц:|доход:|расход:|нетто:|топ стат|вопрос:|assistant)" | |
| ) | |
| _ONLY_PUNCT = re.compile(r"^[-•\s\.\,\;\:\!\?]+$") | |
| def _to_bullets(text: str) -> str: | |
| if not text: | |
| return "" | |
| m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text) | |
| if m: | |
| text = text[m.start() :] | |
| text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M) | |
| text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M) | |
| uniq, seen = [], set() | |
| for ln in text.split("\n"): | |
| s = ln.strip() | |
| if not s or not s.startswith("- "): | |
| continue | |
| if _BULLET_KILL.search(s) or _ONLY_PUNCT.match(s): | |
| continue | |
| s = re.sub(r"\s{2,}", " ", s) | |
| s = re.sub(r"\.\s*\.+$", ".", s) | |
| key = s.lower() | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| uniq.append(s) | |
| if len(uniq) >= 7: | |
| break | |
| return "\n".join(s.replace("- ", "• ", 1) for s in uniq) | |
| def main(): | |
| req = read_json_stdin() | |
| tx = req.get("transactions") or [] | |
| question = (req.get("question") or "").strip() | |
| df = pd.DataFrame(tx) if tx else None | |
| snap = current_month_snapshot(df) if df is not None and not df.empty else {} | |
| if snap: | |
| ctx_lines = [ | |
| f"Месяц: {snap['month']}", | |
| f"Доход: {snap['income_total']:.0f}", | |
| f"Расход: {abs(snap['expense_total']):.0f}", | |
| f"Нетто: {snap['net']:.0f}", | |
| ] | |
| if snap.get("top_expense_categories"): | |
| ctx_lines.append("Топ статей расходов:") | |
| for cat, val in snap["top_expense_categories"]: | |
| ctx_lines.append(f"- {cat}: {abs(val):.0f}") | |
| context = "\n".join(ctx_lines) | |
| else: | |
| context = "Данных за текущий месяц нет." | |
| system_msg = ( | |
| "Ты финансовый помощник. Отвечай по-русски. " | |
| "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами " | |
| "(лимиты, проценты, частота). " | |
| "Каждая строка должна начинаться с символов \"- \". Никаких вступлений." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Мои данные за текущий месяц:\n{context}\n\nВопрос: {question}\n" | |
| "Начни ответ сразу со строки, которая начинается с \"- \". Верни только список." | |
| ), | |
| }, | |
| ] | |
| tok, mdl = _load() | |
| raw = _gen(messages, tok, mdl, det=True) | |
| text = _to_bullets(clean_ru(raw)) | |
| if text.count("\n") + 1 < 3: | |
| raw2 = _gen(messages, tok, mdl, det=False) | |
| text2 = _to_bullets(clean_ru(raw2)) | |
| if text2: | |
| text = text2 | |
| write_json_stdout({"advice": text}) | |
| if __name__ == "__main__": | |
| main() | |