ai_econsult_demo / tests /test_hybrid_mapping.py
Cardiosense-AG's picture
Update tests/test_hybrid_mapping.py
af8b2e9 verified
# tests/test_hybrid_mapping.py
from __future__ import annotations
import os, json, time, csv
from pathlib import Path
from typing import Dict, List
from src.ai_core import generate_soap_draft
BASE_DIR = Path("/data/econsult/tests")
BASE_DIR.mkdir(parents=True, exist_ok=True)
CSV_PATH = BASE_DIR / "results.csv"
LOG_PATH = BASE_DIR / "run_logs.txt"
CASES: List[Dict[str, str]] = [
{
"id": "lipids",
"age": "58",
"sex": "Male",
"specialist": "Cardiology",
"chief_complaint": "Exertional chest tightness for ~2 months",
"history": "Type 2 diabetes, hyperlipidemia, no rest pain, no syncope.",
"medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.",
"vitals": "BP 132/78, HR 72, BMI 29",
"labs": "LDL 155 mg/dL, A1C 7.8%, eGFR 52",
"comorbidities": "DM2, CKD3a, hyperlipidemia",
"question": "Should we escalate to high-intensity statin and start low-dose aspirin?",
},
{
"id": "ckd_dose",
"age": "63",
"sex": "Male",
"specialist": "Cardiology",
"chief_complaint": "Medication dosing in CKD3a",
"history": "63 y/o M with DM2, CKD3a, HTN; needs metformin and statin dosing guidance.",
"medications": "Atorvastatin 20 mg nightly; Metformin 1000 mg BID.",
"vitals": "BP 128/80 mmHg, HR 70 bpm",
"labs": "A1C 7.5%, eGFR 50 mL/min/1.73 m2",
"comorbidities": "DM2, CKD3a, HTN",
"question": "What are recommended statin intensity and metformin dosing for eGFR ≈ 50?",
},
]
# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
def count_annotated(meta: Dict[str, object]) -> int:
ann = meta.get("annotated", {}) or {}
return len(ann.get("assessment_html", [])) + len(ann.get("plan_html", []))
def run_case(intake: Dict[str, str]) -> Dict[str, object]:
t0 = time.perf_counter()
result = generate_soap_draft(intake, mode="mapping", rag_top_k=5, max_new_tokens=700)
t1 = time.perf_counter()
meta = result.meta
timings = meta.get("timings", {})
rec = {
"case_id": intake["id"],
"generate_secs": timings.get("generate_secs", 0),
"map_secs": timings.get("map_secs", 0),
"total_runtime": round(t1 - t0, 2),
"assessment_items": len(result.soap.get("assessment", [])),
"plan_items": len(result.soap.get("plan", [])),
"annotated_items": count_annotated(meta),
"unique_evidence": len(result.citations),
"cache_stub": meta.get("stub", ""),
}
(BASE_DIR / f"{intake['id']}_result.json").write_text(
json.dumps(result.soap, ensure_ascii=False, indent=2)
)
return rec
def write_csv(rows: List[Dict[str, object]]) -> None:
if not rows:
return
keys = list(rows[0].keys())
with CSV_PATH.open("w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=keys)
w.writeheader()
w.writerows(rows)
def save_logs(log_text: str) -> None:
"""Append captured console logs to persistent file."""
LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
with LOG_PATH.open("a", encoding="utf-8") as f:
f.write("\n" + log_text.strip() + "\n")
# ---------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------
def run_all() -> str:
rows: List[Dict[str, object]] = []
print("=== Hybrid Mapping Validation Run ===")
for case in CASES:
print(f"\n--- Running case: {case['id']} ---")
rec = run_case(case)
rows.append(rec)
print(f"Result: {rec}")
write_csv(rows)
print(f"\nResults saved to: {CSV_PATH}")
return str(CSV_PATH)
if __name__ == "__main__":
run_all()