|
|
|
|
|
from __future__ import annotations |
|
|
from typing import List, Dict, Any, Tuple |
|
|
import math, re, os |
|
|
|
|
|
import torch |
|
|
|
|
|
from .rag_index import search_index |
|
|
from .paths import faiss_index_dir |
|
|
|
|
|
_WORD_RE = re.compile(r"[A-Za-z0-9\-\%]+") |
|
|
|
|
|
def segment_claims(text: str) -> List[str]: |
|
|
"""Split Assessment/Plan into claims: sentence-like spans with minimal heuristics.""" |
|
|
if not text: |
|
|
return [] |
|
|
|
|
|
t = text.replace("\n", " ").replace("β’", ". ").replace(" - ", ". ") |
|
|
|
|
|
parts = re.split(r"(?<=[\.\?\!])\s+", t) |
|
|
claims = [p.strip() for p in parts if len(p.strip()) > 15] |
|
|
|
|
|
return claims[:12] |
|
|
|
|
|
def _tokens(s: str) -> List[str]: |
|
|
return [w.lower() for w in _WORD_RE.findall(s or "") if w.strip()] |
|
|
|
|
|
def _idf(corpus_tokens: List[List[str]]) -> Dict[str, float]: |
|
|
|
|
|
N = max(1, len(corpus_tokens)) |
|
|
df: Dict[str,int] = {} |
|
|
for toks in corpus_tokens: |
|
|
for t in set(toks): |
|
|
df[t] = df.get(t, 0) + 1 |
|
|
return {t: math.log((N+1) / (df.get(t,1))) + 1.0 for t in df.keys()} |
|
|
|
|
|
def build_citation_index(citations: List[Dict[str,Any]]) -> Dict[Tuple[str,int], int]: |
|
|
"""Map (doc,page) -> bracket number n for endnotes.""" |
|
|
idx: Dict[Tuple[str,int], int] = {} |
|
|
for i, c in enumerate(citations or [], start=1): |
|
|
key = (c.get("doc"), int(c.get("page", 0))) |
|
|
idx[key] = i |
|
|
return idx |
|
|
|
|
|
def _guideline_chunks_for_claim(claim: str, specialty: str, top_k: int = 3) -> List[Dict[str,Any]]: |
|
|
""" |
|
|
Retrieve top guideline chunks for a claim using FAISS index. |
|
|
Works with older search_index() versions and ensures a valid embed_model. |
|
|
""" |
|
|
try: |
|
|
|
|
|
results = search_index( |
|
|
faiss_index_dir(), |
|
|
claim, |
|
|
top_k=top_k, |
|
|
embed_model=None, |
|
|
device="cpu", |
|
|
specialty_filter=specialty or None, |
|
|
) |
|
|
except TypeError: |
|
|
|
|
|
|
|
|
try: |
|
|
from .rag_index import load_index |
|
|
_, _, info = load_index(faiss_index_dir()) |
|
|
embed_model = (info or {}).get("embed_model", "intfloat/e5-large-v2") |
|
|
except Exception: |
|
|
embed_model = "intfloat/e5-large-v2" |
|
|
results = search_index( |
|
|
faiss_index_dir(), |
|
|
claim, |
|
|
top_k=top_k, |
|
|
embed_model=embed_model, |
|
|
device="cpu", |
|
|
) |
|
|
|
|
|
chunks = [] |
|
|
for r in results: |
|
|
chunks.append({ |
|
|
"text": r.get("text", ""), |
|
|
"doc_name": r.get("doc_name"), |
|
|
"page": int(r.get("page", 0)), |
|
|
"score": float(r.get("score", 0.0)), |
|
|
}) |
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
def explain_claims_sim_only( |
|
|
claims: List[str], |
|
|
pcp_summary: str, |
|
|
specialty: str, |
|
|
top_terms: int = 5, |
|
|
top_k_guideline_chunks: int = 3, |
|
|
) -> List[Dict[str,Any]]: |
|
|
"""Similarity-only explanation: token salience from overlap with PCP summary and retrieved guideline chunks.""" |
|
|
out: List[Dict[str,Any]] = [] |
|
|
pcp_toks = _tokens(pcp_summary) |
|
|
for cl in claims: |
|
|
cl_toks = _tokens(cl) |
|
|
chunks = _guideline_chunks_for_claim(cl, specialty, top_k=top_k_guideline_chunks) |
|
|
|
|
|
corpus = [cl_toks, pcp_toks] + [_tokens(ch["text"]) for ch in chunks] |
|
|
idf = _idf(corpus) |
|
|
|
|
|
|
|
|
|
|
|
max_score = max([ch["score"] for ch in chunks], default=1.0) or 1.0 |
|
|
g_weights: Dict[str, float] = {} |
|
|
for ch in chunks: |
|
|
ctoks = set(_tokens(ch["text"])) |
|
|
s = ch["score"] / max_score |
|
|
for t in cl_toks: |
|
|
if t in ctoks: |
|
|
g_weights[t] = g_weights.get(t, 0.0) + s |
|
|
|
|
|
|
|
|
p_weights: Dict[str, float] = {} |
|
|
pset = set(pcp_toks) |
|
|
for t in cl_toks: |
|
|
if t in pset: |
|
|
p_weights[t] = p_weights.get(t, 0.0) + 1.0 |
|
|
|
|
|
|
|
|
for d in (g_weights, p_weights): |
|
|
for t in list(d.keys()): |
|
|
d[t] = d[t] * idf.get(t, 1.0) |
|
|
|
|
|
|
|
|
gtops = sorted(g_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms] |
|
|
ptops = sorted(p_weights.items(), key=lambda kv: kv[1], reverse=True)[:top_terms] |
|
|
|
|
|
|
|
|
top_refs = [{ |
|
|
"doc": ch["doc_name"], |
|
|
"page": ch["page"], |
|
|
"score": round(ch["score"], 3) |
|
|
} for ch in sorted(chunks, key=lambda c: c["score"], reverse=True)[:2]] |
|
|
|
|
|
out.append({ |
|
|
"claim": cl, |
|
|
"top_tokens_guideline": gtops, |
|
|
"top_tokens_pcp": ptops, |
|
|
"guideline_refs": top_refs, |
|
|
}) |
|
|
return out |
|
|
|
|
|
def maybe_attn_rerank(explanations: List[Dict[str,Any]], mode: str = "sim") -> List[Dict[str,Any]]: |
|
|
""" |
|
|
Optional final-layer attention re-weighting. |
|
|
Always uses the 4B probe model, even if SOAP generation used 27B. |
|
|
Falls back silently to similarity-only on failure or CPU-only systems. |
|
|
""" |
|
|
mode = (mode or "sim").lower() |
|
|
if mode != "attn": |
|
|
return explanations |
|
|
|
|
|
|
|
|
if os.getenv("EXPLAINABILITY_ENABLED", "1").lower() not in {"1", "true", "yes"}: |
|
|
return explanations |
|
|
if not torch.cuda.is_available(): |
|
|
return explanations |
|
|
|
|
|
attn_model_id = os.getenv("ATTN_MODEL_ID", "google/medgemma-4b-it") |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
tok = AutoTokenizer.from_pretrained(attn_model_id, trust_remote_code=True) |
|
|
mdl = AutoModelForCausalLM.from_pretrained( |
|
|
attn_model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
output_attentions=True |
|
|
) |
|
|
|
|
|
new_exps: List[Dict[str,Any]] = [] |
|
|
for e in explanations: |
|
|
cl = e["claim"] |
|
|
enc = tok(cl, return_tensors="pt", truncation=True, max_length=200).to(mdl.device) |
|
|
with torch.no_grad(): |
|
|
out = mdl(**enc, output_attentions=True, use_cache=False) |
|
|
last = out.attentions[-1].mean(dim=1)[0, -1, :] |
|
|
toks = tok.convert_ids_to_tokens(enc["input_ids"][0]) |
|
|
weights = {} |
|
|
for t, w in zip(toks, last.tolist()): |
|
|
tclean = t.replace("β","").lower() |
|
|
if tclean.isalpha(): |
|
|
weights[tclean] = max(weights.get(tclean, 0.0), float(w)) |
|
|
|
|
|
def _rerank(pairs): |
|
|
if not pairs: |
|
|
return pairs |
|
|
mx = max(weights.values()) or 1.0 |
|
|
outp = [] |
|
|
for tok0, score0 in pairs: |
|
|
boost = (weights.get(tok0.lower(), 0.0) / mx) |
|
|
outp.append((tok0, float(score0 * (1.0 + 0.5*boost)))) |
|
|
return sorted(outp, key=lambda kv: kv[1], reverse=True) |
|
|
|
|
|
e["top_tokens_guideline"] = _rerank(e.get("top_tokens_guideline", [])) |
|
|
e["top_tokens_pcp"] = _rerank(e.get("top_tokens_pcp", [])) |
|
|
new_exps.append(e) |
|
|
|
|
|
|
|
|
del mdl |
|
|
torch.cuda.empty_cache() |
|
|
return new_exps |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[Explainability] Attention rerank fallback: {e}") |
|
|
return explanations |
|
|
|
|
|
|
|
|
def render_claims_html(explanations, citation_index): |
|
|
""" |
|
|
Render Assessment/Plan claims with hoverable tooltips showing token salience. |
|
|
Works in Streamlit using pure CSS (no JS). |
|
|
""" |
|
|
html_lines = [ |
|
|
""" |
|
|
<style> |
|
|
.claim { margin:0.2rem 0; line-height:1.4; position:relative; } |
|
|
.tooltiptext { |
|
|
visibility:hidden; |
|
|
width:280px; |
|
|
background-color:#333; |
|
|
color:#fff; |
|
|
text-align:left; |
|
|
border-radius:6px; |
|
|
padding:8px; |
|
|
position:absolute; |
|
|
z-index:10; |
|
|
top:100%; |
|
|
left:0; |
|
|
opacity:0; |
|
|
transition:opacity 0.3s; |
|
|
font-size:0.8em; |
|
|
box-shadow:0 2px 8px rgba(0,0,0,0.25); |
|
|
} |
|
|
.claim:hover .tooltiptext { |
|
|
visibility:visible; |
|
|
opacity:1; |
|
|
} |
|
|
.hasref { |
|
|
background:rgba(255,229,100,0.25); |
|
|
border-left:3px solid rgba(255,196,0,0.9); |
|
|
padding-left:6px; |
|
|
border-radius:4px; |
|
|
} |
|
|
.badge { |
|
|
display:inline-block; |
|
|
font-size:0.8em; |
|
|
margin-left:2px; |
|
|
padding:0 4px; |
|
|
background:#eee; |
|
|
border-radius:3px; |
|
|
color:#333; |
|
|
border:1px solid #ccc; |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
] |
|
|
|
|
|
supported = 0 |
|
|
for e in explanations: |
|
|
cl = e.get("claim", "") |
|
|
|
|
|
nums = [] |
|
|
for r in e.get("guideline_refs", []): |
|
|
key = (r.get("doc"), int(r.get("page", 0))) |
|
|
if key in citation_index and citation_index[key] not in nums: |
|
|
nums.append(citation_index[key]) |
|
|
nums = nums[:2] |
|
|
|
|
|
|
|
|
gtoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_guideline", [])[:5]]) or "β" |
|
|
ptoks = ", ".join([f"{t} ({w:.2f})" for t, w in e.get("top_tokens_pcp", [])[:5]]) or "β" |
|
|
tooltip_html = ( |
|
|
f"<div class='tooltiptext'>" |
|
|
f"<b>Guideline terms:</b> {gtoks}<br>" |
|
|
f"<b>PCP terms:</b> {ptoks}" |
|
|
f"</div>" |
|
|
) |
|
|
|
|
|
badges = "".join([f"<span class='badge'>[{n}]</span>" for n in nums]) |
|
|
css = "claim hasref" if nums else "claim" |
|
|
html_lines.append(f"<div class='{css}'>{cl} {badges}{tooltip_html}</div>") |
|
|
|
|
|
if nums: |
|
|
supported += 1 |
|
|
|
|
|
cov = supported / max(1, len(explanations)) |
|
|
return ("\n".join(html_lines), cov) |
|
|
|
|
|
|