File size: 5,701 Bytes
e7bb669
 
 
 
 
 
f7ac86a
 
 
 
 
 
e7bb669
6e3d1dc
e7bb669
 
 
 
 
 
 
 
 
 
 
 
 
f9a7536
f7ac86a
 
 
 
 
 
 
 
 
f9a7536
e7bb669
 
 
 
 
 
f7ac86a
f9b7da1
e7bb669
 
f9b7da1
 
e7bb669
 
 
a25ba7b
e7bb669
 
f9b7da1
e7bb669
 
 
 
 
 
 
 
 
 
f7ac86a
6e3d1dc
 
f7ac86a
 
e7bb669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e3d1dc
f7ac86a
6e3d1dc
e7bb669
 
f7ac86a
 
 
 
 
 
 
 
 
f55e107
f7ac86a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7bb669
 
 
 
 
 
 
6e3d1dc
e7bb669
 
 
 
f7ac86a
e7bb669
 
 
 
 
 
f7ac86a
e7bb669
f7ac86a
 
e7bb669
 
 
f7ac86a
e7bb669
6e3d1dc
 
f7ac86a
e7bb669
f9a7536
e7bb669
f7ac86a
 
 
 
 
 
 
 
e7bb669
 
 
f9a7536
e7bb669
f7ac86a
e7bb669
 
 
f7ac86a
e7bb669
 
 
 
 
f7ac86a
e7bb669
f7ac86a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import re
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from common import (
    read_json_stdin,
    write_json_stdout,
    current_month_snapshot,
    clean_ru,
)

ALLOWED_MODEL_ID = "google/gemma-3-1b-it"

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
try:
    torch.set_num_threads(1)
except Exception:
    pass

_DEVICE = torch.device("cpu")
_tokenizer = None
_model = None
_loaded = False


def _get_hf_token() -> str:
    for name in ("HF_TOKEN", "token", "HF_HUB_TOKEN"):
        val = os.getenv(name)
        if val:
            return val
    raise RuntimeError(
        "HF token not found in env. "
        "Set HF_TOKEN (или token / HF_HUB_TOKEN) в секретах Space."
    )


def _load():
    global _tokenizer, _model, _loaded
    if _loaded and _tokenizer is not None and _model is not None:
        return _tokenizer, _model

    hf_token = _get_hf_token()

    _tokenizer = AutoTokenizer.from_pretrained(
        ALLOWED_MODEL_ID,
        token=hf_token,
        trust_remote_code=True,
    )
    _model = AutoModelForCausalLM.from_pretrained(
        ALLOWED_MODEL_ID,
        token=hf_token,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    ).to(_DEVICE).eval()

    if _tokenizer.pad_token_id is None:
        _tokenizer.pad_token_id = _tokenizer.eos_token_id

    _loaded = True
    return _tokenizer, _model


def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
    inputs = tok.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(_DEVICE)

    with torch.no_grad():
        common = dict(
            max_new_tokens=max_new_tokens,
            repetition_penalty=1.08 if det else 1.12,
            no_repeat_ngram_size=5 if det else 6,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.pad_token_id,
        )
        if det:
            out = mdl.generate(
                **inputs,
                do_sample=False,
                num_beams=4,
                **common,
            )
        else:
            out = mdl.generate(
                **inputs,
                do_sample=True,
                temperature=0.8,
                **common,
            )

    gen_ids = out[0, inputs["input_ids"].shape[-1] :]
    return tok.decode(gen_ids, skip_special_tokens=True)


_BULLET_KILL = re.compile(
    r"(?i)(учитывай данные|данные пользователя|месяц:|доход:|расход:|нетто:|топ стат|вопрос:|assistant)"
)
_ONLY_PUNCT = re.compile(r"^[-•\s\.\,\;\:\!\?]+$")


def _to_bullets(text: str) -> str:
    if not text:
        return ""

    m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text)
    if m:
        text = text[m.start() :]

    text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M)
    text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M)

    uniq, seen = [], set()
    for ln in text.split("\n"):
        s = ln.strip()
        if not s or not s.startswith("- "):
            continue
        if _BULLET_KILL.search(s) or _ONLY_PUNCT.match(s):
            continue
        s = re.sub(r"\s{2,}", " ", s)
        s = re.sub(r"\.\s*\.+$", ".", s)
        key = s.lower()
        if key in seen:
            continue
        seen.add(key)
        uniq.append(s)
        if len(uniq) >= 7:
            break

    return "\n".join(s.replace("- ", "• ", 1) for s in uniq)


def main():
    req = read_json_stdin()

    tx = req.get("transactions") or []
    question = (req.get("question") or "").strip()

    df = pd.DataFrame(tx) if tx else None
    snap = current_month_snapshot(df) if df is not None and not df.empty else {}

    if snap:
        ctx_lines = [
            f"Месяц: {snap['month']}",
            f"Доход: {snap['income_total']:.0f}",
            f"Расход: {abs(snap['expense_total']):.0f}",
            f"Нетто: {snap['net']:.0f}",
        ]
        if snap.get("top_expense_categories"):
            ctx_lines.append("Топ статей расходов:")
            for cat, val in snap["top_expense_categories"]:
                ctx_lines.append(f"- {cat}: {abs(val):.0f}")
        context = "\n".join(ctx_lines)
    else:
        context = "Данных за текущий месяц нет."

    system_msg = (
        "Ты финансовый помощник. Отвечай по-русски. "
        "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами "
        "(лимиты, проценты, частота). "
        "Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
    )

    messages = [
        {"role": "system", "content": system_msg},
        {
            "role": "user",
            "content": (
                f"Мои данные за текущий месяц:\n{context}\n\nВопрос: {question}\n"
                "Начни ответ сразу со строки, которая начинается с \"- \". Верни только список."
            ),
        },
    ]

    tok, mdl = _load()

    raw = _gen(messages, tok, mdl, det=True)
    text = _to_bullets(clean_ru(raw))

    if text.count("\n") + 1 < 3:
        raw2 = _gen(messages, tok, mdl, det=False)
        text2 = _to_bullets(clean_ru(raw2))
        if text2:
            text = text2

    write_json_stdout({"advice": text})


if __name__ == "__main__":
    main()