# src/explainability.py from __future__ import annotations from typing import List, Dict, Any, Tuple import math, re, os import torch from .rag_index import search_index from .paths import faiss_index_dir _WORD_RE = re.compile(r"[A-Za-z0-9\-\%]+") def segment_claims(text: str) -> List[str]: """Split Assessment/Plan into claims: sentence-like spans with minimal heuristics.""" if not text: return [] # Normalize bullets and newlines t = text.replace("\n", " ").replace("•", ". ").replace(" - ", ". ") # Split on period/question/exclamation while retaining meaningful spans parts = re.split(r"(?<=[\.\?\!])\s+", t) claims = [p.strip() for p in parts if len(p.strip()) > 15] # Cap to a reasonable number for UI return claims[:12] def _tokens(s: str) -> List[str]: return [w.lower() for w in _WORD_RE.findall(s or "") if w.strip()] def _idf(corpus_tokens: List[List[str]]) -> Dict[str, float]: # simple IDF over small corpus N = max(1, len(corpus_tokens)) df: Dict[str,int] = {} for toks in corpus_tokens: for t in set(toks): df[t] = df.get(t, 0) + 1 return {t: math.log((N+1) / (df.get(t,1))) + 1.0 for t in df.keys()} def build_citation_index(citations: List[Dict[str,Any]]) -> Dict[Tuple[str,int], int]: """Map (doc,page) -> bracket number n for endnotes.""" idx: Dict[Tuple[str,int], int] = {} for i, c in enumerate(citations or [], start=1): key = (c.get("doc"), int(c.get("page", 0))) idx[key] = i return idx def _guideline_chunks_for_claim(claim: str, specialty: str, top_k: int = 3) -> List[Dict[str,Any]]: """ Retrieve top guideline chunks for a claim using FAISS index. Works with older search_index() versions and ensures a valid embed_model. """ try: # Try calling with specialty_filter (newer versions) results = search_index( faiss_index_dir(), claim, top_k=top_k, embed_model=None, device="cpu", specialty_filter=specialty or None, ) except TypeError: # Fallback for older signature (no specialty_filter) # ✅ Ensure embed_model is passed from saved index metadata try: from .rag_index import load_index _, _, info = load_index(faiss_index_dir()) embed_model = (info or {}).get("embed_model", "intfloat/e5-large-v2") except Exception: embed_model = "intfloat/e5-large-v2" results = search_index( faiss_index_dir(), claim, top_k=top_k, embed_model=embed_model, device="cpu", ) chunks = [] for r in results: chunks.append({ "text": r.get("text", ""), "doc_name": r.get("doc_name"), "page": int(r.get("page", 0)), "score": float(r.get("score", 0.0)), }) return chunks def explain_claims_sim_only( claims: List[str], pcp_summary: str, specialty: str, top_terms: int = 5, top_k_guideline_chunks: int = 3, ) -> List[Dict[str,Any]]: """Similarity-only explanation: token salience from overlap with PCP summary and retrieved guideline chunks.""" out: List[Dict[str,Any]] = [] pcp_toks = _tokens(pcp_summary) for cl in claims: cl_toks = _tokens(cl) chunks = _guideline_chunks_for_claim(cl, specialty, top_k=top_k_guideline_chunks) # Build corpus for IDF corpus = [cl_toks, pcp_toks] + [_tokens(ch["text"]) for ch in chunks] idf = _idf(corpus) # Compute per-token weights # guideline weight: sum over chunks (normalized by score) if token appears in chunk text max_score = max([ch["score"] for ch in chunks], default=1.0) or 1.0 g_weights: Dict[str, float] = {} for ch in chunks: ctoks = set(_tokens(ch["text"])) s = ch["score"] / max_score for t in cl_toks: if t in ctoks: g_weights[t] = g_weights.get(t, 0.0) + s # pcp weight: presence in PCP summary p_weights: Dict[str, float] = {} pset = set(pcp_toks) for t in cl_toks: if t in pset: p_weights[t] = p_weights.get(t, 0.0) + 1.0 # apply idf to emphasize informative tokens for d in (g_weights, p_weights): for t in list(d.keys()): d[t] = d[t] * idf.get(t, 1.0) # Select top tokens gtops = sorted(g_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms] ptops = sorted(p_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms] # top guideline refs (by score) top_refs = [{ "doc": ch["doc_name"], "page": ch["page"], "score": round(ch["score"], 3) } for ch in sorted(chunks, key=lambda c: c["score"], reverse=True)[:2]] out.append({ "claim": cl, "top_tokens_guideline": gtops, # list[(token, weight)] "top_tokens_pcp": ptops, "guideline_refs": top_refs, }) return out def maybe_attn_rerank(explanations: List[Dict[str,Any]], mode: str = "sim") -> List[Dict[str,Any]]: """ Optional final-layer attention re-weighting. Always uses the 4B probe model, even if SOAP generation used 27B. Falls back silently to similarity-only on failure or CPU-only systems. """ mode = (mode or "sim").lower() if mode != "attn": return explanations # Early exits for disabled or CPU if os.getenv("EXPLAINABILITY_ENABLED", "1").lower() not in {"1", "true", "yes"}: return explanations if not torch.cuda.is_available(): return explanations attn_model_id = os.getenv("ATTN_MODEL_ID", "google/medgemma-4b-it") try: from transformers import AutoTokenizer, AutoModelForCausalLM tok = AutoTokenizer.from_pretrained(attn_model_id, trust_remote_code=True) mdl = AutoModelForCausalLM.from_pretrained( attn_model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, output_attentions=True ) new_exps: List[Dict[str,Any]] = [] for e in explanations: cl = e["claim"] enc = tok(cl, return_tensors="pt", truncation=True, max_length=200).to(mdl.device) with torch.no_grad(): out = mdl(**enc, output_attentions=True, use_cache=False) last = out.attentions[-1].mean(dim=1)[0, -1, :] toks = tok.convert_ids_to_tokens(enc["input_ids"][0]) weights = {} for t, w in zip(toks, last.tolist()): tclean = t.replace("▁","").lower() if tclean.isalpha(): weights[tclean] = max(weights.get(tclean, 0.0), float(w)) def _rerank(pairs): if not pairs: return pairs mx = max(weights.values()) or 1.0 outp = [] for tok0, score0 in pairs: boost = (weights.get(tok0.lower(), 0.0) / mx) outp.append((tok0, float(score0 * (1.0 + 0.5*boost)))) return sorted(outp, key=lambda kv: kv[1], reverse=True) e["top_tokens_guideline"] = _rerank(e.get("top_tokens_guideline", [])) e["top_tokens_pcp"] = _rerank(e.get("top_tokens_pcp", [])) new_exps.append(e) # Cleanup to release GPU memory quickly del mdl torch.cuda.empty_cache() return new_exps except Exception as e: print(f"[Explainability] Attention rerank fallback: {e}") return explanations def render_claims_html(explanations, citation_index): """ Render Assessment/Plan claims with hoverable tooltips showing token salience. Works in Streamlit using pure CSS (no JS). """ html_lines = [ """ """ ] supported = 0 for e in explanations: cl = e.get("claim", "") # Map guideline refs for this claim to endnote numbers nums = [] for r in e.get("guideline_refs", []): key = (r.get("doc"), int(r.get("page", 0))) if key in citation_index and citation_index[key] not in nums: nums.append(citation_index[key]) nums = nums[:2] # Build tooltip text gtoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_guideline", [])[:5]]) or "—" ptoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_pcp", [])[:5]]) or "—" tooltip_html = ( f"
" f"Guideline terms: {gtoks}
" f"PCP terms: {ptoks}" f"
" ) badges = "".join([f"[{n}]" for n in nums]) css = "claim hasref" if nums else "claim" html_lines.append(f"
{cl} {badges}{tooltip_html}
") if nums: supported += 1 cov = supported / max(1, len(explanations)) return ("\n".join(html_lines), cov)