import sys import json import re from typing import List, Dict, Any, Optional from PIL import Image import pytesseract from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline CLS_MODEL_ID = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" classifier = pipeline( "zero-shot-classification", model=CLS_MODEL_ID, tokenizer=CLS_MODEL_ID, device=0 if hasattr(sys, "gettrace") is False else -1, ) def ocr_image_to_text(image: Image.Image) -> str: image = image.convert("L") config = r"--psm 6" text = pytesseract.image_to_string(image, lang="rus", config=config) return text def pick_total_from_text(text: str) -> Optional[float]: if not text: return None text = text.replace("\xa0", " ") text = re.sub(r"[₽€$]", " ", text) def _to_float(s: str) -> Optional[float]: s = s.replace(" ", "").replace(",", ".") try: return float(s) except Exception: return None pattern_num = r"(-?\d{1,3}(?:[ .,\u00A0]?\d{3})*(?:[.,]\d{2}))" strong_candidates = [] medium_candidates = [] all_candidates = [] line_totals = [] for line in text.splitlines(): line_clean = line.strip() if not line_clean: continue lower = line_clean.lower() if "сдач" in lower: continue m_eq = re.search(r"=\s*([0-9][0-9 .,\u00A0]*[.,]\d{2})", line_clean) if m_eq: v_eq = _to_float(m_eq.group(1)) if v_eq and 0 < v_eq < 1e7: line_totals.append(v_eq) nums = re.findall(pattern_num, line_clean) if not nums: continue for m in nums: v = _to_float(m) if not v or v <= 0 or v > 1e7: continue if any(k in lower for k in ["итог", "итого", "к оплате", "всего к оплате"]): strong_candidates.append(v) elif any(k in lower for k in ["наличн", "карта", "безнал", "оплачено"]): medium_candidates.append(v) all_candidates.append(v) if strong_candidates: return max(strong_candidates) if medium_candidates: return max(medium_candidates) if len(line_totals) >= 3: s = sum(line_totals) if 0 < s < 1e7: return s if all_candidates: return max(all_candidates) return None def classify_category_zeroshot( receipt_text: str, categories: List[Dict[str, Any]], ) -> Dict[str, Any]: if not categories: return {"id": None, "name": None} receipt_short = receipt_text[:1500] labels = [cat["name"] for cat in categories] result = classifier( receipt_short, candidate_labels=labels, multi_label=False, hypothesis_template="Это покупка по категории {}.", ) best_label = result["labels"][0] best_cat = next((c for c in categories if c["name"] == best_label), None) if best_cat is None: best_label_low = best_label.lower() for c in categories: if c["name"].lower() == best_label_low: best_cat = c break if best_cat is None: best_cat = categories[-1] return best_cat def guess_shop_name(receipt_text: str) -> Optional[str]: lines = [ln.strip() for ln in receipt_text.splitlines() if ln.strip()] top = lines[:5] candidates = [] for ln in top: lower = ln.lower() if any(x in lower for x in ["инн", "ккт", "касса", "рн кк", "россия", "г. ", "ул."]): continue if 2 <= len(ln) <= 40: candidates.append(ln) if candidates: return candidates[0] return None def build_description( receipt_text: str, category_name: Optional[str], total: Optional[float], ) -> str: cat = category_name or "покупка" shop = guess_shop_name(receipt_text) if shop: return f"Покупка по категории {cat} в {shop}" else: if total is not None: return f"Покупка по категории {cat} на {total:.2f}" else: return f"Покупка по категории {cat}" DEFAULT_CATEGORIES: List[Dict[str, Any]] = [ {"id": 1, "name": "Еда"}, {"id": 2, "name": "Спорт"}, {"id": 3, "name": "Обучение"}, {"id": 4, "name": "Транспорт"}, {"id": 5, "name": "Развлечения"}, {"id": 6, "name": "Медицина"}, {"id": 7, "name": "Бытовые товары"}, {"id": 8, "name": "Прочее"}, ] def extract_info( image_path: str, categories: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: if categories is None: categories = DEFAULT_CATEGORIES image = Image.open(image_path).convert("RGB") text = ocr_image_to_text(image) total = pick_total_from_text(text) best_cat = classify_category_zeroshot(text, categories) description = build_description(text, best_cat["name"], total) return { "total": total, "category_id": best_cat.get("id"), "category_name": best_cat.get("name"), "description": description, "raw_text": text, } if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: receipt_info_api.py path/to/receipt.jpg [categories.json]", file=sys.stderr) sys.exit(1) image_path = sys.argv[1] cats = None if len(sys.argv) >= 3: with open(sys.argv[2], "r", encoding="utf-8") as f: cats = json.load(f) info = extract_info(image_path, categories=cats) out = { "total": info["total"], "category_id": info["category_id"], "category_name": info["category_name"], "description": info["description"], } print(json.dumps(out, ensure_ascii=False))