ai_econsult_demo / src /explainability.py
Cardiosense-AG's picture
Update src/explainability.py
74e6d0b verified
raw
history blame
10.4 kB
# src/explainability.py
from __future__ import annotations
from typing import List, Dict, Any, Tuple
import math, re, os
import torch
from .rag_index import search_index
from .paths import faiss_index_dir
_WORD_RE = re.compile(r"[A-Za-z0-9\-\%]+")
def segment_claims(text: str) -> List[str]:
"""Split Assessment/Plan into claims: sentence-like spans with minimal heuristics."""
if not text:
return []
# Normalize bullets and newlines
t = text.replace("\n", " ").replace("β€’", ". ").replace(" - ", ". ")
# Split on period/question/exclamation while retaining meaningful spans
parts = re.split(r"(?<=[\.\?\!])\s+", t)
claims = [p.strip() for p in parts if len(p.strip()) > 15]
# Cap to a reasonable number for UI
return claims[:12]
def _tokens(s: str) -> List[str]:
return [w.lower() for w in _WORD_RE.findall(s or "") if w.strip()]
def _idf(corpus_tokens: List[List[str]]) -> Dict[str, float]:
# simple IDF over small corpus
N = max(1, len(corpus_tokens))
df: Dict[str,int] = {}
for toks in corpus_tokens:
for t in set(toks):
df[t] = df.get(t, 0) + 1
return {t: math.log((N+1) / (df.get(t,1))) + 1.0 for t in df.keys()}
def build_citation_index(citations: List[Dict[str,Any]]) -> Dict[Tuple[str,int], int]:
"""Map (doc,page) -> bracket number n for endnotes."""
idx: Dict[Tuple[str,int], int] = {}
for i, c in enumerate(citations or [], start=1):
key = (c.get("doc"), int(c.get("page", 0)))
idx[key] = i
return idx
def _guideline_chunks_for_claim(claim: str, specialty: str, top_k: int = 3) -> List[Dict[str,Any]]:
"""
Retrieve top guideline chunks for a claim using FAISS index.
Works with older search_index() versions and ensures a valid embed_model.
"""
try:
# Try calling with specialty_filter (newer versions)
results = search_index(
faiss_index_dir(),
claim,
top_k=top_k,
embed_model=None,
device="cpu",
specialty_filter=specialty or None,
)
except TypeError:
# Fallback for older signature (no specialty_filter)
# βœ… Ensure embed_model is passed from saved index metadata
try:
from .rag_index import load_index
_, _, info = load_index(faiss_index_dir())
embed_model = (info or {}).get("embed_model", "intfloat/e5-large-v2")
except Exception:
embed_model = "intfloat/e5-large-v2"
results = search_index(
faiss_index_dir(),
claim,
top_k=top_k,
embed_model=embed_model,
device="cpu",
)
chunks = []
for r in results:
chunks.append({
"text": r.get("text", ""),
"doc_name": r.get("doc_name"),
"page": int(r.get("page", 0)),
"score": float(r.get("score", 0.0)),
})
return chunks
def explain_claims_sim_only(
claims: List[str],
pcp_summary: str,
specialty: str,
top_terms: int = 5,
top_k_guideline_chunks: int = 3,
) -> List[Dict[str,Any]]:
"""Similarity-only explanation: token salience from overlap with PCP summary and retrieved guideline chunks."""
out: List[Dict[str,Any]] = []
pcp_toks = _tokens(pcp_summary)
for cl in claims:
cl_toks = _tokens(cl)
chunks = _guideline_chunks_for_claim(cl, specialty, top_k=top_k_guideline_chunks)
# Build corpus for IDF
corpus = [cl_toks, pcp_toks] + [_tokens(ch["text"]) for ch in chunks]
idf = _idf(corpus)
# Compute per-token weights
# guideline weight: sum over chunks (normalized by score) if token appears in chunk text
max_score = max([ch["score"] for ch in chunks], default=1.0) or 1.0
g_weights: Dict[str, float] = {}
for ch in chunks:
ctoks = set(_tokens(ch["text"]))
s = ch["score"] / max_score
for t in cl_toks:
if t in ctoks:
g_weights[t] = g_weights.get(t, 0.0) + s
# pcp weight: presence in PCP summary
p_weights: Dict[str, float] = {}
pset = set(pcp_toks)
for t in cl_toks:
if t in pset:
p_weights[t] = p_weights.get(t, 0.0) + 1.0
# apply idf to emphasize informative tokens
for d in (g_weights, p_weights):
for t in list(d.keys()):
d[t] = d[t] * idf.get(t, 1.0)
# Select top tokens
gtops = sorted(g_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms]
ptops = sorted(p_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms]
# top guideline refs (by score)
top_refs = [{
"doc": ch["doc_name"],
"page": ch["page"],
"score": round(ch["score"], 3)
} for ch in sorted(chunks, key=lambda c: c["score"], reverse=True)[:2]]
out.append({
"claim": cl,
"top_tokens_guideline": gtops, # list[(token, weight)]
"top_tokens_pcp": ptops,
"guideline_refs": top_refs,
})
return out
def maybe_attn_rerank(explanations: List[Dict[str,Any]], mode: str = "sim") -> List[Dict[str,Any]]:
"""
Optional final-layer attention re-weighting.
Always uses the 4B probe model, even if SOAP generation used 27B.
Falls back silently to similarity-only on failure or CPU-only systems.
"""
mode = (mode or "sim").lower()
if mode != "attn":
return explanations
# Early exits for disabled or CPU
if os.getenv("EXPLAINABILITY_ENABLED", "1").lower() not in {"1", "true", "yes"}:
return explanations
if not torch.cuda.is_available():
return explanations
attn_model_id = os.getenv("ATTN_MODEL_ID", "google/medgemma-4b-it")
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(attn_model_id, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
attn_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
output_attentions=True
)
new_exps: List[Dict[str,Any]] = []
for e in explanations:
cl = e["claim"]
enc = tok(cl, return_tensors="pt", truncation=True, max_length=200).to(mdl.device)
with torch.no_grad():
out = mdl(**enc, output_attentions=True, use_cache=False)
last = out.attentions[-1].mean(dim=1)[0, -1, :]
toks = tok.convert_ids_to_tokens(enc["input_ids"][0])
weights = {}
for t, w in zip(toks, last.tolist()):
tclean = t.replace("▁","").lower()
if tclean.isalpha():
weights[tclean] = max(weights.get(tclean, 0.0), float(w))
def _rerank(pairs):
if not pairs:
return pairs
mx = max(weights.values()) or 1.0
outp = []
for tok0, score0 in pairs:
boost = (weights.get(tok0.lower(), 0.0) / mx)
outp.append((tok0, float(score0 * (1.0 + 0.5*boost))))
return sorted(outp, key=lambda kv: kv[1], reverse=True)
e["top_tokens_guideline"] = _rerank(e.get("top_tokens_guideline", []))
e["top_tokens_pcp"] = _rerank(e.get("top_tokens_pcp", []))
new_exps.append(e)
# Cleanup to release GPU memory quickly
del mdl
torch.cuda.empty_cache()
return new_exps
except Exception as e:
print(f"[Explainability] Attention rerank fallback: {e}")
return explanations
def render_claims_html(explanations, citation_index):
"""
Render Assessment/Plan claims with hoverable tooltips showing token salience.
Works in Streamlit using pure CSS (no JS).
"""
html_lines = [
"""
<style>
.claim { margin:0.2rem 0; line-height:1.4; position:relative; }
.tooltiptext {
visibility:hidden;
width:280px;
background-color:#333;
color:#fff;
text-align:left;
border-radius:6px;
padding:8px;
position:absolute;
z-index:10;
top:100%;
left:0;
opacity:0;
transition:opacity 0.3s;
font-size:0.8em;
box-shadow:0 2px 8px rgba(0,0,0,0.25);
}
.claim:hover .tooltiptext {
visibility:visible;
opacity:1;
}
.hasref {
background:rgba(255,229,100,0.25);
border-left:3px solid rgba(255,196,0,0.9);
padding-left:6px;
border-radius:4px;
}
.badge {
display:inline-block;
font-size:0.8em;
margin-left:2px;
padding:0 4px;
background:#eee;
border-radius:3px;
color:#333;
border:1px solid #ccc;
}
</style>
"""
]
supported = 0
for e in explanations:
cl = e.get("claim", "")
# Map guideline refs for this claim to endnote numbers
nums = []
for r in e.get("guideline_refs", []):
key = (r.get("doc"), int(r.get("page", 0)))
if key in citation_index and citation_index[key] not in nums:
nums.append(citation_index[key])
nums = nums[:2]
# Build tooltip text
gtoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_guideline", [])[:5]]) or "β€”"
ptoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_pcp", [])[:5]]) or "β€”"
tooltip_html = (
f"<div class='tooltiptext'>"
f"<b>Guideline terms:</b> {gtoks}<br>"
f"<b>PCP terms:</b> {ptoks}"
f"</div>"
)
badges = "".join([f"<span class='badge'>[{n}]</span>" for n in nums])
css = "claim hasref" if nums else "claim"
html_lines.append(f"<div class='{css}'>{cl} {badges}{tooltip_html}</div>")
if nums:
supported += 1
cov = supported / max(1, len(explanations))
return ("\n".join(html_lines), cov)