finwise-ml / receipt_total_api.py
minstradamus's picture
Update receipt_total_api.py
7e87f71 verified
import sys
import json
import re
from typing import List, Dict, Any, Optional
from PIL import Image
import pytesseract
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
CLS_MODEL_ID = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
classifier = pipeline(
"zero-shot-classification",
model=CLS_MODEL_ID,
tokenizer=CLS_MODEL_ID,
device=0 if hasattr(sys, "gettrace") is False else -1,
)
def ocr_image_to_text(image: Image.Image) -> str:
image = image.convert("L")
config = r"--psm 6"
text = pytesseract.image_to_string(image, lang="rus", config=config)
return text
def pick_total_from_text(text: str) -> Optional[float]:
if not text:
return None
text = text.replace("\xa0", " ")
text = re.sub(r"[₽€$]", " ", text)
def _to_float(s: str) -> Optional[float]:
s = s.replace(" ", "").replace(",", ".")
try:
return float(s)
except Exception:
return None
pattern_num = r"(-?\d{1,3}(?:[ .,\u00A0]?\d{3})*(?:[.,]\d{2}))"
strong_candidates = []
medium_candidates = []
all_candidates = []
line_totals = []
for line in text.splitlines():
line_clean = line.strip()
if not line_clean:
continue
lower = line_clean.lower()
if "сдач" in lower:
continue
m_eq = re.search(r"=\s*([0-9][0-9 .,\u00A0]*[.,]\d{2})", line_clean)
if m_eq:
v_eq = _to_float(m_eq.group(1))
if v_eq and 0 < v_eq < 1e7:
line_totals.append(v_eq)
nums = re.findall(pattern_num, line_clean)
if not nums:
continue
for m in nums:
v = _to_float(m)
if not v or v <= 0 or v > 1e7:
continue
if any(k in lower for k in ["итог", "итого", "к оплате", "всего к оплате"]):
strong_candidates.append(v)
elif any(k in lower for k in ["наличн", "карта", "безнал", "оплачено"]):
medium_candidates.append(v)
all_candidates.append(v)
if strong_candidates:
return max(strong_candidates)
if medium_candidates:
return max(medium_candidates)
if len(line_totals) >= 3:
s = sum(line_totals)
if 0 < s < 1e7:
return s
if all_candidates:
return max(all_candidates)
return None
def classify_category_zeroshot(
receipt_text: str,
categories: List[Dict[str, Any]],
) -> Dict[str, Any]:
if not categories:
return {"id": None, "name": None}
receipt_short = receipt_text[:1500]
labels = [cat["name"] for cat in categories]
result = classifier(
receipt_short,
candidate_labels=labels,
multi_label=False,
hypothesis_template="Это покупка по категории {}.",
)
best_label = result["labels"][0]
best_cat = next((c for c in categories if c["name"] == best_label), None)
if best_cat is None:
best_label_low = best_label.lower()
for c in categories:
if c["name"].lower() == best_label_low:
best_cat = c
break
if best_cat is None:
best_cat = categories[-1]
return best_cat
def guess_shop_name(receipt_text: str) -> Optional[str]:
lines = [ln.strip() for ln in receipt_text.splitlines() if ln.strip()]
top = lines[:5]
candidates = []
for ln in top:
lower = ln.lower()
if any(x in lower for x in ["инн", "ккт", "касса", "рн кк", "россия", "г. ", "ул."]):
continue
if 2 <= len(ln) <= 40:
candidates.append(ln)
if candidates:
return candidates[0]
return None
def build_description(
receipt_text: str,
category_name: Optional[str],
total: Optional[float],
) -> str:
cat = category_name or "покупка"
shop = guess_shop_name(receipt_text)
if shop:
return f"Покупка по категории {cat} в {shop}"
else:
if total is not None:
return f"Покупка по категории {cat} на {total:.2f}"
else:
return f"Покупка по категории {cat}"
DEFAULT_CATEGORIES: List[Dict[str, Any]] = [
{"id": 1, "name": "Еда"},
{"id": 2, "name": "Спорт"},
{"id": 3, "name": "Обучение"},
{"id": 4, "name": "Транспорт"},
{"id": 5, "name": "Развлечения"},
{"id": 6, "name": "Медицина"},
{"id": 7, "name": "Бытовые товары"},
{"id": 8, "name": "Прочее"},
]
def extract_info(
image_path: str,
categories: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if categories is None:
categories = DEFAULT_CATEGORIES
image = Image.open(image_path).convert("RGB")
text = ocr_image_to_text(image)
total = pick_total_from_text(text)
best_cat = classify_category_zeroshot(text, categories)
description = build_description(text, best_cat["name"], total)
return {
"total": total,
"category_id": best_cat.get("id"),
"category_name": best_cat.get("name"),
"description": description,
"raw_text": text,
}
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: receipt_info_api.py path/to/receipt.jpg [categories.json]", file=sys.stderr)
sys.exit(1)
image_path = sys.argv[1]
cats = None
if len(sys.argv) >= 3:
with open(sys.argv[2], "r", encoding="utf-8") as f:
cats = json.load(f)
info = extract_info(image_path, categories=cats)
out = {
"total": info["total"],
"category_id": info["category_id"],
"category_name": info["category_name"],
"description": info["description"],
}
print(json.dumps(out, ensure_ascii=False))