Spaces:
Sleeping
Sleeping
| import sys | |
| import json | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| from PIL import Image | |
| import pytesseract | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| CLS_MODEL_ID = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model=CLS_MODEL_ID, | |
| tokenizer=CLS_MODEL_ID, | |
| device=0 if hasattr(sys, "gettrace") is False else -1, | |
| ) | |
| def ocr_image_to_text(image: Image.Image) -> str: | |
| image = image.convert("L") | |
| config = r"--psm 6" | |
| text = pytesseract.image_to_string(image, lang="rus", config=config) | |
| return text | |
| def pick_total_from_text(text: str) -> Optional[float]: | |
| if not text: | |
| return None | |
| text = text.replace("\xa0", " ") | |
| text = re.sub(r"[₽€$]", " ", text) | |
| def _to_float(s: str) -> Optional[float]: | |
| s = s.replace(" ", "").replace(",", ".") | |
| try: | |
| return float(s) | |
| except Exception: | |
| return None | |
| pattern_num = r"(-?\d{1,3}(?:[ .,\u00A0]?\d{3})*(?:[.,]\d{2}))" | |
| strong_candidates = [] | |
| medium_candidates = [] | |
| all_candidates = [] | |
| line_totals = [] | |
| for line in text.splitlines(): | |
| line_clean = line.strip() | |
| if not line_clean: | |
| continue | |
| lower = line_clean.lower() | |
| if "сдач" in lower: | |
| continue | |
| m_eq = re.search(r"=\s*([0-9][0-9 .,\u00A0]*[.,]\d{2})", line_clean) | |
| if m_eq: | |
| v_eq = _to_float(m_eq.group(1)) | |
| if v_eq and 0 < v_eq < 1e7: | |
| line_totals.append(v_eq) | |
| nums = re.findall(pattern_num, line_clean) | |
| if not nums: | |
| continue | |
| for m in nums: | |
| v = _to_float(m) | |
| if not v or v <= 0 or v > 1e7: | |
| continue | |
| if any(k in lower for k in ["итог", "итого", "к оплате", "всего к оплате"]): | |
| strong_candidates.append(v) | |
| elif any(k in lower for k in ["наличн", "карта", "безнал", "оплачено"]): | |
| medium_candidates.append(v) | |
| all_candidates.append(v) | |
| if strong_candidates: | |
| return max(strong_candidates) | |
| if medium_candidates: | |
| return max(medium_candidates) | |
| if len(line_totals) >= 3: | |
| s = sum(line_totals) | |
| if 0 < s < 1e7: | |
| return s | |
| if all_candidates: | |
| return max(all_candidates) | |
| return None | |
| def classify_category_zeroshot( | |
| receipt_text: str, | |
| categories: List[Dict[str, Any]], | |
| ) -> Dict[str, Any]: | |
| if not categories: | |
| return {"id": None, "name": None} | |
| receipt_short = receipt_text[:1500] | |
| labels = [cat["name"] for cat in categories] | |
| result = classifier( | |
| receipt_short, | |
| candidate_labels=labels, | |
| multi_label=False, | |
| hypothesis_template="Это покупка по категории {}.", | |
| ) | |
| best_label = result["labels"][0] | |
| best_cat = next((c for c in categories if c["name"] == best_label), None) | |
| if best_cat is None: | |
| best_label_low = best_label.lower() | |
| for c in categories: | |
| if c["name"].lower() == best_label_low: | |
| best_cat = c | |
| break | |
| if best_cat is None: | |
| best_cat = categories[-1] | |
| return best_cat | |
| def guess_shop_name(receipt_text: str) -> Optional[str]: | |
| lines = [ln.strip() for ln in receipt_text.splitlines() if ln.strip()] | |
| top = lines[:5] | |
| candidates = [] | |
| for ln in top: | |
| lower = ln.lower() | |
| if any(x in lower for x in ["инн", "ккт", "касса", "рн кк", "россия", "г. ", "ул."]): | |
| continue | |
| if 2 <= len(ln) <= 40: | |
| candidates.append(ln) | |
| if candidates: | |
| return candidates[0] | |
| return None | |
| def build_description( | |
| receipt_text: str, | |
| category_name: Optional[str], | |
| total: Optional[float], | |
| ) -> str: | |
| cat = category_name or "покупка" | |
| shop = guess_shop_name(receipt_text) | |
| if shop: | |
| return f"Покупка по категории {cat} в {shop}" | |
| else: | |
| if total is not None: | |
| return f"Покупка по категории {cat} на {total:.2f}" | |
| else: | |
| return f"Покупка по категории {cat}" | |
| DEFAULT_CATEGORIES: List[Dict[str, Any]] = [ | |
| {"id": 1, "name": "Еда"}, | |
| {"id": 2, "name": "Спорт"}, | |
| {"id": 3, "name": "Обучение"}, | |
| {"id": 4, "name": "Транспорт"}, | |
| {"id": 5, "name": "Развлечения"}, | |
| {"id": 6, "name": "Медицина"}, | |
| {"id": 7, "name": "Бытовые товары"}, | |
| {"id": 8, "name": "Прочее"}, | |
| ] | |
| def extract_info( | |
| image_path: str, | |
| categories: Optional[List[Dict[str, Any]]] = None, | |
| ) -> Dict[str, Any]: | |
| if categories is None: | |
| categories = DEFAULT_CATEGORIES | |
| image = Image.open(image_path).convert("RGB") | |
| text = ocr_image_to_text(image) | |
| total = pick_total_from_text(text) | |
| best_cat = classify_category_zeroshot(text, categories) | |
| description = build_description(text, best_cat["name"], total) | |
| return { | |
| "total": total, | |
| "category_id": best_cat.get("id"), | |
| "category_name": best_cat.get("name"), | |
| "description": description, | |
| "raw_text": text, | |
| } | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| print("Usage: receipt_info_api.py path/to/receipt.jpg [categories.json]", file=sys.stderr) | |
| sys.exit(1) | |
| image_path = sys.argv[1] | |
| cats = None | |
| if len(sys.argv) >= 3: | |
| with open(sys.argv[2], "r", encoding="utf-8") as f: | |
| cats = json.load(f) | |
| info = extract_info(image_path, categories=cats) | |
| out = { | |
| "total": info["total"], | |
| "category_id": info["category_id"], | |
| "category_name": info["category_name"], | |
| "description": info["description"], | |
| } | |
| print(json.dumps(out, ensure_ascii=False)) | |