Update src/explainability.py
Browse files- src/explainability.py +22 -24
src/explainability.py
CHANGED
|
@@ -1,18 +1,34 @@
|
|
| 1 |
# src/explainability.py
|
| 2 |
from __future__ import annotations
|
| 3 |
-
"""Explainability helpers
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
import math
|
| 10 |
import re
|
| 11 |
from typing import Dict, List
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def _tokenize(s: str) -> List[str]:
|
| 14 |
s = s.lower()
|
| 15 |
-
# Keep simple alphanumerics
|
| 16 |
toks = re.findall(r"[a-z0-9]+", s)
|
| 17 |
return [t for t in toks if len(t) >= 3]
|
| 18 |
|
|
@@ -20,7 +36,6 @@ def segment_claims(text: str) -> List[str]:
|
|
| 20 |
"""Split text into claim-like sentences/lines."""
|
| 21 |
if not text:
|
| 22 |
return []
|
| 23 |
-
# Split by newline or period, keep moderately long segments
|
| 24 |
raw = re.split(r"[.\n]+", text)
|
| 25 |
claims = [c.strip() for c in raw if len(c.strip()) >= 12]
|
| 26 |
return claims[:10]
|
|
@@ -49,40 +64,23 @@ def chips_from_text(text: str, top_n: int = 10, min_weight: float = 0.02) -> Lis
|
|
| 49 |
return []
|
| 50 |
docs = [_tokenize(c) for c in claims]
|
| 51 |
idf = _idf(docs)
|
| 52 |
-
# Weight tokens by TF * average claim-length proxy
|
| 53 |
agg: Dict[str, float] = {}
|
| 54 |
for toks in docs:
|
| 55 |
tf = _tf(toks)
|
| 56 |
for t, tv in tf.items():
|
| 57 |
agg[t] = agg.get(t, 0.0) + tv * idf.get(t, 1.0)
|
| 58 |
-
# Normalize L1
|
| 59 |
s = sum(agg.values()) or 1.0
|
| 60 |
for k in list(agg.keys()):
|
| 61 |
agg[k] /= s
|
| 62 |
ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
|
| 63 |
return [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
|
| 64 |
|
| 65 |
-
# --- V2 helpers (post-hoc only, deterministic) ---
|
| 66 |
def chip_cache_key(case_id: str, section: str, text: str) -> str:
|
| 67 |
"""Deterministic cache key for explainability chips."""
|
| 68 |
-
import
|
| 69 |
-
blob = json.dumps({"case_id": case_id, "section": section, "text": text}, sort_keys=True).encode("utf-8")
|
| 70 |
return hashlib.sha256(blob).hexdigest()
|
| 71 |
|
| 72 |
-
def ensure_chip_schema(chips):
|
| 73 |
-
"""Force a consistent chip schema: [{token, weight}] sorted by weight desc."""
|
| 74 |
-
if not isinstance(chips, (list, tuple)):
|
| 75 |
-
return []
|
| 76 |
-
norm = []
|
| 77 |
-
for c in chips:
|
| 78 |
-
if not isinstance(c, dict):
|
| 79 |
-
continue
|
| 80 |
-
tok = str(c.get("token", "")).strip()
|
| 81 |
-
w = float(c.get("weight", 0.0))
|
| 82 |
-
if tok:
|
| 83 |
-
norm.append({"token": tok, "weight": round(w, 4)})
|
| 84 |
-
norm.sort(key=lambda x: x["weight"], reverse=True)
|
| 85 |
-
return norm
|
| 86 |
|
| 87 |
|
| 88 |
|
|
|
|
| 1 |
# src/explainability.py
|
| 2 |
from __future__ import annotations
|
| 3 |
+
"""Explainability helpers.
|
| 4 |
|
| 5 |
+
V3 adds simple hash-based staleness, while keeping deterministic token "chips"
|
| 6 |
+
as a fallback utility (used only if the model omits a rationale).
|
| 7 |
"""
|
| 8 |
|
| 9 |
+
import hashlib
|
| 10 |
import math
|
| 11 |
import re
|
| 12 |
from typing import Dict, List
|
| 13 |
|
| 14 |
+
# -------------------- NEW: staleness helpers --------------------
|
| 15 |
+
|
| 16 |
+
def normalize_text(s: str) -> str:
|
| 17 |
+
return re.sub(r"\s+", " ", (s or "").strip())
|
| 18 |
+
|
| 19 |
+
def text_hash(s: str) -> str:
|
| 20 |
+
s_norm = normalize_text(s)
|
| 21 |
+
return hashlib.sha256(s_norm.encode("utf-8")).hexdigest()[:16]
|
| 22 |
+
|
| 23 |
+
def is_stale(current_text: str, baseline_hash: str | None) -> bool:
|
| 24 |
+
if not baseline_hash:
|
| 25 |
+
return False
|
| 26 |
+
return text_hash(current_text) != baseline_hash
|
| 27 |
+
|
| 28 |
+
# -------------------- legacy token chips (fallback) ---------------
|
| 29 |
+
|
| 30 |
def _tokenize(s: str) -> List[str]:
|
| 31 |
s = s.lower()
|
|
|
|
| 32 |
toks = re.findall(r"[a-z0-9]+", s)
|
| 33 |
return [t for t in toks if len(t) >= 3]
|
| 34 |
|
|
|
|
| 36 |
"""Split text into claim-like sentences/lines."""
|
| 37 |
if not text:
|
| 38 |
return []
|
|
|
|
| 39 |
raw = re.split(r"[.\n]+", text)
|
| 40 |
claims = [c.strip() for c in raw if len(c.strip()) >= 12]
|
| 41 |
return claims[:10]
|
|
|
|
| 64 |
return []
|
| 65 |
docs = [_tokenize(c) for c in claims]
|
| 66 |
idf = _idf(docs)
|
|
|
|
| 67 |
agg: Dict[str, float] = {}
|
| 68 |
for toks in docs:
|
| 69 |
tf = _tf(toks)
|
| 70 |
for t, tv in tf.items():
|
| 71 |
agg[t] = agg.get(t, 0.0) + tv * idf.get(t, 1.0)
|
|
|
|
| 72 |
s = sum(agg.values()) or 1.0
|
| 73 |
for k in list(agg.keys()):
|
| 74 |
agg[k] /= s
|
| 75 |
ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
|
| 76 |
return [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
|
| 77 |
|
|
|
|
| 78 |
def chip_cache_key(case_id: str, section: str, text: str) -> str:
|
| 79 |
"""Deterministic cache key for explainability chips."""
|
| 80 |
+
import json
|
| 81 |
+
blob = json.dumps({"case_id": case_id, "section": section, "text": normalize_text(text)}, sort_keys=True).encode("utf-8")
|
| 82 |
return hashlib.sha256(blob).hexdigest()
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
|