|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
from .prompt_builder import build_referral_summary, normalize_intake |
|
|
from .model_loader import generate_chat |
|
|
|
|
|
|
|
|
try: |
|
|
from . import rag_index |
|
|
except Exception: |
|
|
rag_index = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SPARSE_MIN_CHARS = 160 |
|
|
_DEFAULT_MAX_NEW_TOKENS = 700 |
|
|
_DEFAULT_TEMPERATURE = 0.2 |
|
|
_DEFAULT_TOP_P = 0.95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_soap_draft( |
|
|
intake: Dict[str, Any], |
|
|
mode: str = "mapping", |
|
|
*, |
|
|
max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS, |
|
|
temperature: float = _DEFAULT_TEMPERATURE, |
|
|
top_p: float = _DEFAULT_TOP_P, |
|
|
explain: bool = False, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Orchestrate: intake β summary β LLM JSON β (optional) mapping. |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
'soap': {'subjective': str, 'objective': str, 'assessment': list, 'plan': list}, |
|
|
'raw_text': str, |
|
|
'summary': str, |
|
|
'context_text': Optional[str], |
|
|
'mapping': {...} or None, |
|
|
'timings': {'generate_secs': float, 'map_secs': float, 'total_runtime': float} |
|
|
} |
|
|
""" |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
summary = _maybe_external_summary(intake) |
|
|
context_text = None |
|
|
|
|
|
|
|
|
if mode == "rag": |
|
|
context_text = _maybe_build_context(summary) |
|
|
messages = _messages_for_mode(mode, summary, context_text=context_text) |
|
|
|
|
|
|
|
|
gen_t0 = time.time() |
|
|
raw_text = generate_chat( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
gen_secs = time.time() - gen_t0 |
|
|
|
|
|
|
|
|
soap = _safe_json(raw_text) |
|
|
|
|
|
|
|
|
map_t0 = time.time() |
|
|
mapping = None |
|
|
if mode == "mapping": |
|
|
mapping = _hybrid_map(soap) |
|
|
elif mode == "rag" and context_text: |
|
|
|
|
|
mapping = { |
|
|
"mode": "rag", |
|
|
"claims_count": _count_assessment_plan_items(soap), |
|
|
"registry_cap": None, |
|
|
"unique_evidence_count": None, |
|
|
"k": None, |
|
|
"note": "Citations available via injected context; hybrid mapping skipped.", |
|
|
} |
|
|
map_secs = time.time() - map_t0 |
|
|
|
|
|
total = time.time() - t0 |
|
|
|
|
|
return { |
|
|
"soap": soap, |
|
|
"raw_text": raw_text, |
|
|
"summary": summary, |
|
|
"context_text": context_text, |
|
|
"mapping": mapping, |
|
|
"timings": { |
|
|
"generate_secs": round(gen_secs, 3), |
|
|
"map_secs": round(map_secs, 3), |
|
|
"total_runtime": round(total, 3), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_external_summary(intake: Dict[str, Any]) -> str: |
|
|
"""Try unified summary; fall back if too sparse.""" |
|
|
try: |
|
|
s = build_referral_summary(intake) |
|
|
except Exception as e: |
|
|
print(f"[ai_core] Summary builder error: {e}. Falling back to minimal.") |
|
|
return _minimal_summary(intake) |
|
|
|
|
|
|
|
|
has_key = any(k in normalize_intake(intake).get("consult", {}) and normalize_intake(intake)["consult"][k] |
|
|
for k in ("question", "chief_complaint", "history")) |
|
|
if (not has_key) or (len(s) < _SPARSE_MIN_CHARS): |
|
|
print(f"[ai_core] Summary too sparse (len={len(s)}). Using minimal summary.") |
|
|
return _minimal_summary(intake) |
|
|
|
|
|
return s |
|
|
|
|
|
|
|
|
def _minimal_summary(intake: Dict[str, Any]) -> str: |
|
|
"""Very compact fallback that assumes flat keys but works with nested too.""" |
|
|
norm = normalize_intake(intake) |
|
|
p, c = norm["patient"], norm["consult"] |
|
|
bits = [] |
|
|
demo = "; ".join([s for s in [p.get("age"), p.get("sex")] if s]) |
|
|
if demo: |
|
|
bits.append(f"Patient: {demo}") |
|
|
for k, label in [ |
|
|
("question", "Key question"), |
|
|
("chief_complaint", "Chief complaint"), |
|
|
("history", "Background"), |
|
|
("medications", "Medications"), |
|
|
("labs", "Pertinent labs"), |
|
|
]: |
|
|
v = (c.get(k) if k in c else p.get(k)) or "" |
|
|
if v: |
|
|
bits.append(f"{label}: {v}") |
|
|
summary = "\n".join(bits) |
|
|
print(f"[ai_core] Minimal summary length={len(summary)} chars") |
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _messages_for_mode(mode: str, summary: str, *, context_text: str | None = None) -> List[Dict[str, str]]: |
|
|
sys = ( |
|
|
"You are a clinical decision support assistant. " |
|
|
"Given a referral summary (and optionally context excerpts), " |
|
|
"return a STRICT JSON object representing a SOAP note. " |
|
|
"Keys: subjective (string), objective (string), assessment (array of strings), plan (array of strings). " |
|
|
"Do not include any text outside the JSON object. Do not use markdown. Do not add commentary." |
|
|
) |
|
|
user_lines = ["Referral summary:", summary] |
|
|
if mode == "rag" and context_text: |
|
|
user_lines.append("\nGuideline/excerpt context:\n" + context_text) |
|
|
user_lines.append( |
|
|
"\nProduce only valid JSON with the exact keys: subjective, objective, assessment, plan. " |
|
|
"Example: {\"subjective\":\"...\",\"objective\":\"...\",\"assessment\":[\"...\"],\"plan\":[\"...\"]}" |
|
|
) |
|
|
return [ |
|
|
{"role": "system", "content": sys}, |
|
|
{"role": "user", "content": "\n".join(user_lines)}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_build_context(summary: str, top_k: int = 6) -> str | None: |
|
|
if rag_index is None: |
|
|
return None |
|
|
try: |
|
|
|
|
|
|
|
|
if hasattr(rag_index, "search_index"): |
|
|
hits = rag_index.search_index(summary, top_k=top_k) |
|
|
else: |
|
|
return None |
|
|
if not hits: |
|
|
return None |
|
|
|
|
|
texts = [] |
|
|
for h in hits: |
|
|
text = (h.get("text") or h.get("page_content") or "").strip() |
|
|
if text: |
|
|
texts.append(text) |
|
|
snippet = "\n---\n".join(texts[:top_k]) |
|
|
return snippet or None |
|
|
except Exception as e: |
|
|
print(f"[ai_core] Context build skipped: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_code_fences(s: str) -> str: |
|
|
|
|
|
return re.sub(r"```(?:json)?\s*([\s\S]*?)\s*```", r"\1", s, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
def _largest_balanced_json(s: str) -> str | None: |
|
|
""" |
|
|
Return the largest balanced {...} substring, respecting strings. |
|
|
If not found, return None. |
|
|
""" |
|
|
start = None |
|
|
depth = 0 |
|
|
in_str = False |
|
|
esc = False |
|
|
last_good = None |
|
|
for i, ch in enumerate(s): |
|
|
if in_str: |
|
|
if esc: |
|
|
esc = False |
|
|
elif ch == "\\": |
|
|
esc = True |
|
|
elif ch == '"': |
|
|
in_str = False |
|
|
continue |
|
|
else: |
|
|
if ch == '"': |
|
|
in_str = True |
|
|
elif ch == "{": |
|
|
if depth == 0: |
|
|
start = i |
|
|
depth += 1 |
|
|
elif ch == "}": |
|
|
if depth > 0: |
|
|
depth -= 1 |
|
|
if depth == 0 and start is not None: |
|
|
last_good = s[start : i + 1] |
|
|
return last_good |
|
|
|
|
|
|
|
|
def _trim_trailing_commas(s: str) -> str: |
|
|
s = re.sub(r",\s*([}\]])", r"\1", s) |
|
|
return s |
|
|
|
|
|
|
|
|
def _coerce_soap_shape(obj: Dict[str, Any], raw_text: str) -> Dict[str, Any]: |
|
|
|
|
|
out = { |
|
|
"subjective": obj.get("subjective") or "", |
|
|
"objective": obj.get("objective") or "", |
|
|
"assessment": obj.get("assessment") or [], |
|
|
"plan": obj.get("plan") or [], |
|
|
} |
|
|
|
|
|
for k in ("assessment", "plan"): |
|
|
v = out[k] |
|
|
if isinstance(v, str) and v.strip(): |
|
|
out[k] = [v.strip()] |
|
|
elif isinstance(v, dict): |
|
|
|
|
|
text = v.get("text") or "" |
|
|
out[k] = [text] if text else [] |
|
|
elif isinstance(v, list): |
|
|
coerced = [] |
|
|
for item in v: |
|
|
if isinstance(item, str): |
|
|
s = item.strip() |
|
|
if s: |
|
|
coerced.append(s) |
|
|
elif isinstance(item, dict): |
|
|
s = (item.get("text") or item.get("summary") or "") |
|
|
if s: |
|
|
coerced.append(s.strip()) |
|
|
out[k] = coerced |
|
|
else: |
|
|
out[k] = [] |
|
|
|
|
|
if not out["assessment"] and not out["plan"] and raw_text.strip(): |
|
|
out["subjective"] = out["subjective"] or raw_text.strip()[:800] |
|
|
return out |
|
|
|
|
|
|
|
|
def _safe_json(raw_text: str) -> Dict[str, Any]: |
|
|
s = _strip_code_fences(raw_text or "") |
|
|
candidate = _largest_balanced_json(s) |
|
|
if not candidate: |
|
|
|
|
|
try: |
|
|
j = json.loads(_trim_trailing_commas(s.strip())) |
|
|
if isinstance(j, dict): |
|
|
return _coerce_soap_shape(j, raw_text) |
|
|
except Exception: |
|
|
pass |
|
|
return _coerce_soap_shape({}, raw_text) |
|
|
|
|
|
candidate = _trim_trailing_commas(candidate) |
|
|
|
|
|
|
|
|
try: |
|
|
obj = json.loads(candidate) |
|
|
return _coerce_soap_shape(obj, raw_text) |
|
|
except Exception: |
|
|
try: |
|
|
obj = json.loads(_trim_trailing_commas(candidate)) |
|
|
return _coerce_soap_shape(obj, raw_text) |
|
|
except Exception as e: |
|
|
print(f"[ai_core] JSON repair failed: {e}") |
|
|
return _coerce_soap_shape({}, raw_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _map_mode_and_cap() -> Tuple[str, int, int]: |
|
|
""" |
|
|
Returns (mode, cap, per-claim top_k). |
|
|
validation β cap=20, top_k=5 |
|
|
production β cap=5, top_k=5 |
|
|
""" |
|
|
mode = os.environ.get("MAP_MODE", "production").strip().lower() |
|
|
if mode == "validation": |
|
|
return "validation", 20, 5 |
|
|
return "production", 5, 5 |
|
|
|
|
|
|
|
|
def _count_assessment_plan_items(soap: Dict[str, Any]) -> int: |
|
|
a = soap.get("assessment") or [] |
|
|
p = soap.get("plan") or [] |
|
|
return (len(a) if isinstance(a, list) else 0) + (len(p) if isinstance(p, list) else 0) |
|
|
|
|
|
|
|
|
def _extract_claims(soap: Dict[str, Any]) -> List[str]: |
|
|
claims: List[str] = [] |
|
|
for k in ("assessment", "plan"): |
|
|
v = soap.get(k) or [] |
|
|
if isinstance(v, list): |
|
|
for item in v: |
|
|
if isinstance(item, str) and item.strip(): |
|
|
claims.append(item.strip()) |
|
|
elif isinstance(item, dict): |
|
|
t = (item.get("text") or item.get("summary") or "").strip() |
|
|
if t: |
|
|
claims.append(t) |
|
|
|
|
|
seen = set() |
|
|
uniq = [] |
|
|
for c in claims: |
|
|
if c not in seen: |
|
|
seen.add(c) |
|
|
uniq.append(c) |
|
|
return uniq |
|
|
|
|
|
|
|
|
def _hybrid_map(soap: Dict[str, Any]) -> Dict[str, Any]: |
|
|
mode, cap, top_k = _map_mode_and_cap() |
|
|
|
|
|
claims = _extract_claims(soap) |
|
|
if not claims: |
|
|
print("[ai_core] [MAP] No claims available for mapping.") |
|
|
return { |
|
|
"mode": "mapping", |
|
|
"claims_count": 0, |
|
|
"registry_cap": cap, |
|
|
"unique_evidence_count": 0, |
|
|
"k": top_k, |
|
|
"registry": {}, |
|
|
"claim_to_indices": {}, |
|
|
} |
|
|
|
|
|
if rag_index is None or not hasattr(rag_index, "search_index"): |
|
|
print("[ai_core] [MAP] Index unavailable β skipping mapping.") |
|
|
return { |
|
|
"mode": "mapping", |
|
|
"claims_count": len(claims), |
|
|
"registry_cap": cap, |
|
|
"unique_evidence_count": 0, |
|
|
"k": top_k, |
|
|
"registry": {}, |
|
|
"claim_to_indices": {}, |
|
|
} |
|
|
|
|
|
t0 = time.time() |
|
|
claim_to_hits: Dict[str, List[Dict[str, Any]]] = {} |
|
|
registry_order: List[str] = [] |
|
|
registry: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
for c in claims: |
|
|
try: |
|
|
hits = rag_index.search_index(c, top_k=top_k) |
|
|
except Exception as e: |
|
|
print(f"[ai_core] [MAP] search_index error: {e}") |
|
|
hits = [] |
|
|
claim_to_hits[c] = hits or [] |
|
|
|
|
|
|
|
|
for h in hits or []: |
|
|
evid_id = ( |
|
|
h.get("id") |
|
|
or h.get("source_id") |
|
|
or h.get("source") |
|
|
or h.get("metadata", {}).get("id") |
|
|
or h.get("metadata", {}).get("source") |
|
|
or str(hash((h.get("text") or h.get("page_content") or "")[:120])) |
|
|
) |
|
|
if evid_id not in registry: |
|
|
registry[evid_id] = { |
|
|
"source": h.get("source") |
|
|
or h.get("metadata", {}).get("source") |
|
|
or "unknown", |
|
|
"score": h.get("score") or h.get("similarity") or None, |
|
|
"text": (h.get("text") or h.get("page_content") or ""), |
|
|
} |
|
|
registry_order.append(evid_id) |
|
|
|
|
|
|
|
|
if len(registry_order) > cap: |
|
|
for evid_id in registry_order[cap:]: |
|
|
registry.pop(evid_id, None) |
|
|
registry_order = registry_order[:cap] |
|
|
|
|
|
|
|
|
registry_index = {eid: i for i, eid in enumerate(registry_order)} |
|
|
claim_to_indices: Dict[str, List[int]] = {} |
|
|
for c, hits in claim_to_hits.items(): |
|
|
indices = [] |
|
|
for h in hits: |
|
|
evid_id = ( |
|
|
h.get("id") |
|
|
or h.get("source_id") |
|
|
or h.get("source") |
|
|
or h.get("metadata", {}).get("id") |
|
|
or h.get("metadata", {}).get("source") |
|
|
or str(hash((h.get("text") or h.get("page_content") or "")[:120])) |
|
|
) |
|
|
if evid_id in registry_index: |
|
|
indices.append(registry_index[evid_id]) |
|
|
claim_to_indices[c] = indices |
|
|
|
|
|
dt = time.time() - t0 |
|
|
print( |
|
|
f"[ai_core] [MAP] Claims={len(claims)} unique_evidence={len(registry_order)}" |
|
|
f" top_k={top_k} cap={cap} time={dt:.3f}s mode={mode}" |
|
|
) |
|
|
|
|
|
return { |
|
|
"mode": "mapping", |
|
|
"claims_count": len(claims), |
|
|
"registry_cap": cap, |
|
|
"unique_evidence_count": len(registry_order), |
|
|
"k": top_k, |
|
|
"registry": {eid: registry[eid] for eid in registry_order}, |
|
|
"claim_to_indices": claim_to_indices, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|