ai_econsult_demo / src /ai_core.py
Cardiosense-AG's picture
Update src/ai_core.py
ced1708 verified
raw
history blame
15.5 kB
# src/ai_core.py
from __future__ import annotations
import json
import os
import re
import time
from typing import Any, Dict, List, Tuple
from .prompt_builder import build_referral_summary, normalize_intake
from .model_loader import generate_chat
# rag_index is optional at import-time; map step handles unavailability gracefully.
try:
from . import rag_index
except Exception: # pragma: no cover
rag_index = None # type: ignore
# ------------------------------ configuration -------------------------------
_SPARSE_MIN_CHARS = 160 # if summary < this, we fall back to minimal builder
_DEFAULT_MAX_NEW_TOKENS = 700
_DEFAULT_TEMPERATURE = 0.2
_DEFAULT_TOP_P = 0.95
# ------------------------------ public API ----------------------------------
def generate_soap_draft(
intake: Dict[str, Any],
mode: str = "mapping", # or "rag"
*,
max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS,
temperature: float = _DEFAULT_TEMPERATURE,
top_p: float = _DEFAULT_TOP_P,
explain: bool = False,
) -> Dict[str, Any]:
"""
Orchestrate: intake β†’ summary β†’ LLM JSON β†’ (optional) mapping.
Returns:
{
'soap': {'subjective': str, 'objective': str, 'assessment': list, 'plan': list},
'raw_text': str,
'summary': str,
'context_text': Optional[str],
'mapping': {...} or None,
'timings': {'generate_secs': float, 'map_secs': float, 'total_runtime': float}
}
"""
t0 = time.time()
# 1) Build robust summary with guard
summary = _maybe_external_summary(intake)
context_text = None
# 2) Build messages (RAG mode can inject context)
if mode == "rag":
context_text = _maybe_build_context(summary)
messages = _messages_for_mode(mode, summary, context_text=context_text)
# 3) LLM call
gen_t0 = time.time()
raw_text = generate_chat(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
gen_secs = time.time() - gen_t0
# 4) JSON extract + repair
soap = _safe_json(raw_text)
# 5) Mapping (hybrid or citations pass-through)
map_t0 = time.time()
mapping = None
if mode == "mapping":
mapping = _hybrid_map(soap)
elif mode == "rag" and context_text:
# In RAG mode we surface the retrieved context as citations only.
mapping = {
"mode": "rag",
"claims_count": _count_assessment_plan_items(soap),
"registry_cap": None,
"unique_evidence_count": None,
"k": None,
"note": "Citations available via injected context; hybrid mapping skipped.",
}
map_secs = time.time() - map_t0
total = time.time() - t0
return {
"soap": soap,
"raw_text": raw_text,
"summary": summary,
"context_text": context_text,
"mapping": mapping,
"timings": {
"generate_secs": round(gen_secs, 3),
"map_secs": round(map_secs, 3),
"total_runtime": round(total, 3),
},
}
# ------------------------------ summarization --------------------------------
def _maybe_external_summary(intake: Dict[str, Any]) -> str:
"""Try unified summary; fall back if too sparse."""
try:
s = build_referral_summary(intake)
except Exception as e: # defensive
print(f"[ai_core] Summary builder error: {e}. Falling back to minimal.")
return _minimal_summary(intake)
# Guard: ensure key fields and minimum length
has_key = any(k in normalize_intake(intake).get("consult", {}) and normalize_intake(intake)["consult"][k]
for k in ("question", "chief_complaint", "history"))
if (not has_key) or (len(s) < _SPARSE_MIN_CHARS):
print(f"[ai_core] Summary too sparse (len={len(s)}). Using minimal summary.")
return _minimal_summary(intake)
return s
def _minimal_summary(intake: Dict[str, Any]) -> str:
"""Very compact fallback that assumes flat keys but works with nested too."""
norm = normalize_intake(intake)
p, c = norm["patient"], norm["consult"]
bits = []
demo = "; ".join([s for s in [p.get("age"), p.get("sex")] if s])
if demo:
bits.append(f"Patient: {demo}")
for k, label in [
("question", "Key question"),
("chief_complaint", "Chief complaint"),
("history", "Background"),
("medications", "Medications"),
("labs", "Pertinent labs"),
]:
v = (c.get(k) if k in c else p.get(k)) or ""
if v:
bits.append(f"{label}: {v}")
summary = "\n".join(bits)
print(f"[ai_core] Minimal summary length={len(summary)} chars")
return summary
# ------------------------------ prompting -----------------------------------
def _messages_for_mode(mode: str, summary: str, *, context_text: str | None = None) -> List[Dict[str, str]]:
sys = (
"You are a clinical decision support assistant. "
"Given a referral summary (and optionally context excerpts), "
"return a STRICT JSON object representing a SOAP note. "
"Keys: subjective (string), objective (string), assessment (array of strings), plan (array of strings). "
"Do not include any text outside the JSON object. Do not use markdown. Do not add commentary."
)
user_lines = ["Referral summary:", summary]
if mode == "rag" and context_text:
user_lines.append("\nGuideline/excerpt context:\n" + context_text)
user_lines.append(
"\nProduce only valid JSON with the exact keys: subjective, objective, assessment, plan. "
"Example: {\"subjective\":\"...\",\"objective\":\"...\",\"assessment\":[\"...\"],\"plan\":[\"...\"]}"
)
return [
{"role": "system", "content": sys},
{"role": "user", "content": "\n".join(user_lines)},
]
# ------------------------------- RAG helper ---------------------------------
def _maybe_build_context(summary: str, top_k: int = 6) -> str | None:
if rag_index is None:
return None
try:
# We try a simple search over the existing FAISS index using the summary.
# Signatures can vary; we defensively handle two common patterns.
if hasattr(rag_index, "search_index"):
hits = rag_index.search_index(summary, top_k=top_k) # type: ignore
else:
return None
if not hits:
return None
# Expect each hit to have either 'text' or 'page_content'
texts = []
for h in hits:
text = (h.get("text") or h.get("page_content") or "").strip()
if text:
texts.append(text)
snippet = "\n---\n".join(texts[:top_k])
return snippet or None
except Exception as e:
print(f"[ai_core] Context build skipped: {e}")
return None
# ------------------------------ JSON repair ---------------------------------
def _strip_code_fences(s: str) -> str:
# Remove ```json ... ``` or ``` ... ``` fences if present.
return re.sub(r"```(?:json)?\s*([\s\S]*?)\s*```", r"\1", s, flags=re.IGNORECASE)
def _largest_balanced_json(s: str) -> str | None:
"""
Return the largest balanced {...} substring, respecting strings.
If not found, return None.
"""
start = None
depth = 0
in_str = False
esc = False
last_good = None
for i, ch in enumerate(s):
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
continue
else:
if ch == '"':
in_str = True
elif ch == "{":
if depth == 0:
start = i
depth += 1
elif ch == "}":
if depth > 0:
depth -= 1
if depth == 0 and start is not None:
last_good = s[start : i + 1]
return last_good
def _trim_trailing_commas(s: str) -> str:
s = re.sub(r",\s*([}\]])", r"\1", s) # ,} or ,]
return s
def _coerce_soap_shape(obj: Dict[str, Any], raw_text: str) -> Dict[str, Any]:
# Ensure required keys and types.
out = {
"subjective": obj.get("subjective") or "",
"objective": obj.get("objective") or "",
"assessment": obj.get("assessment") or [],
"plan": obj.get("plan") or [],
}
# Coerce assessment/plan to list[str]
for k in ("assessment", "plan"):
v = out[k]
if isinstance(v, str) and v.strip():
out[k] = [v.strip()]
elif isinstance(v, dict):
# sometimes LLM returns {'text': '...'}
text = v.get("text") or ""
out[k] = [text] if text else []
elif isinstance(v, list):
coerced = []
for item in v:
if isinstance(item, str):
s = item.strip()
if s:
coerced.append(s)
elif isinstance(item, dict):
s = (item.get("text") or item.get("summary") or "")
if s:
coerced.append(s.strip())
out[k] = coerced
else:
out[k] = []
# If still empty and we have raw text, salvage subjective
if not out["assessment"] and not out["plan"] and raw_text.strip():
out["subjective"] = out["subjective"] or raw_text.strip()[:800]
return out
def _safe_json(raw_text: str) -> Dict[str, Any]:
s = _strip_code_fences(raw_text or "")
candidate = _largest_balanced_json(s)
if not candidate:
# Last resort: if it's an array or something, wrap minimally
try:
j = json.loads(_trim_trailing_commas(s.strip()))
if isinstance(j, dict):
return _coerce_soap_shape(j, raw_text)
except Exception:
pass
return _coerce_soap_shape({}, raw_text)
candidate = _trim_trailing_commas(candidate)
# Try parsing; if fail, attempt one more pass removing trailing commas.
try:
obj = json.loads(candidate)
return _coerce_soap_shape(obj, raw_text)
except Exception:
try:
obj = json.loads(_trim_trailing_commas(candidate))
return _coerce_soap_shape(obj, raw_text)
except Exception as e:
print(f"[ai_core] JSON repair failed: {e}")
return _coerce_soap_shape({}, raw_text)
# ------------------------------ hybrid mapping ------------------------------
def _map_mode_and_cap() -> Tuple[str, int, int]:
"""
Returns (mode, cap, per-claim top_k).
validation β†’ cap=20, top_k=5
production β†’ cap=5, top_k=5
"""
mode = os.environ.get("MAP_MODE", "production").strip().lower()
if mode == "validation":
return "validation", 20, 5
return "production", 5, 5
def _count_assessment_plan_items(soap: Dict[str, Any]) -> int:
a = soap.get("assessment") or []
p = soap.get("plan") or []
return (len(a) if isinstance(a, list) else 0) + (len(p) if isinstance(p, list) else 0)
def _extract_claims(soap: Dict[str, Any]) -> List[str]:
claims: List[str] = []
for k in ("assessment", "plan"):
v = soap.get(k) or []
if isinstance(v, list):
for item in v:
if isinstance(item, str) and item.strip():
claims.append(item.strip())
elif isinstance(item, dict):
t = (item.get("text") or item.get("summary") or "").strip()
if t:
claims.append(t)
# Deduplicate while preserving order
seen = set()
uniq = []
for c in claims:
if c not in seen:
seen.add(c)
uniq.append(c)
return uniq
def _hybrid_map(soap: Dict[str, Any]) -> Dict[str, Any]:
mode, cap, top_k = _map_mode_and_cap()
claims = _extract_claims(soap)
if not claims:
print("[ai_core] [MAP] No claims available for mapping.")
return {
"mode": "mapping",
"claims_count": 0,
"registry_cap": cap,
"unique_evidence_count": 0,
"k": top_k,
"registry": {},
"claim_to_indices": {},
}
if rag_index is None or not hasattr(rag_index, "search_index"):
print("[ai_core] [MAP] Index unavailable β€” skipping mapping.")
return {
"mode": "mapping",
"claims_count": len(claims),
"registry_cap": cap,
"unique_evidence_count": 0,
"k": top_k,
"registry": {},
"claim_to_indices": {},
}
t0 = time.time()
claim_to_hits: Dict[str, List[Dict[str, Any]]] = {}
registry_order: List[str] = [] # evidence ids in insertion order
registry: Dict[str, Dict[str, Any]] = {}
for c in claims:
try:
hits = rag_index.search_index(c, top_k=top_k) # type: ignore
except Exception as e:
print(f"[ai_core] [MAP] search_index error: {e}")
hits = []
claim_to_hits[c] = hits or []
# Add to registry
for h in hits or []:
evid_id = (
h.get("id")
or h.get("source_id")
or h.get("source")
or h.get("metadata", {}).get("id")
or h.get("metadata", {}).get("source")
or str(hash((h.get("text") or h.get("page_content") or "")[:120]))
)
if evid_id not in registry:
registry[evid_id] = {
"source": h.get("source")
or h.get("metadata", {}).get("source")
or "unknown",
"score": h.get("score") or h.get("similarity") or None,
"text": (h.get("text") or h.get("page_content") or ""),
}
registry_order.append(evid_id)
# Apply cap to registry (in insertion order)
if len(registry_order) > cap:
for evid_id in registry_order[cap:]:
registry.pop(evid_id, None)
registry_order = registry_order[:cap]
# Build claim_to_indices referencing the capped registry
registry_index = {eid: i for i, eid in enumerate(registry_order)}
claim_to_indices: Dict[str, List[int]] = {}
for c, hits in claim_to_hits.items():
indices = []
for h in hits:
evid_id = (
h.get("id")
or h.get("source_id")
or h.get("source")
or h.get("metadata", {}).get("id")
or h.get("metadata", {}).get("source")
or str(hash((h.get("text") or h.get("page_content") or "")[:120]))
)
if evid_id in registry_index:
indices.append(registry_index[evid_id])
claim_to_indices[c] = indices
dt = time.time() - t0
print(
f"[ai_core] [MAP] Claims={len(claims)} unique_evidence={len(registry_order)}"
f" top_k={top_k} cap={cap} time={dt:.3f}s mode={mode}"
)
return {
"mode": "mapping",
"claims_count": len(claims),
"registry_cap": cap,
"unique_evidence_count": len(registry_order),
"k": top_k,
"registry": {eid: registry[eid] for eid in registry_order},
"claim_to_indices": claim_to_indices,
}