Cardiosense-AG commited on
Commit
eae09d2
·
verified ·
1 Parent(s): 9536490

Update src/explainability.py

Browse files
Files changed (1) hide show
  1. src/explainability.py +22 -24
src/explainability.py CHANGED
@@ -1,18 +1,34 @@
1
  # src/explainability.py
2
  from __future__ import annotations
3
- """Explainability helpers (post-hoc only).
4
 
5
- Provides deterministic "chips" extracted from assessment/plan text.
6
- Caching by (case_id, section, text_hash) can be layered on top by the UI.
7
  """
8
 
 
9
  import math
10
  import re
11
  from typing import Dict, List
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def _tokenize(s: str) -> List[str]:
14
  s = s.lower()
15
- # Keep simple alphanumerics
16
  toks = re.findall(r"[a-z0-9]+", s)
17
  return [t for t in toks if len(t) >= 3]
18
 
@@ -20,7 +36,6 @@ def segment_claims(text: str) -> List[str]:
20
  """Split text into claim-like sentences/lines."""
21
  if not text:
22
  return []
23
- # Split by newline or period, keep moderately long segments
24
  raw = re.split(r"[.\n]+", text)
25
  claims = [c.strip() for c in raw if len(c.strip()) >= 12]
26
  return claims[:10]
@@ -49,40 +64,23 @@ def chips_from_text(text: str, top_n: int = 10, min_weight: float = 0.02) -> Lis
49
  return []
50
  docs = [_tokenize(c) for c in claims]
51
  idf = _idf(docs)
52
- # Weight tokens by TF * average claim-length proxy
53
  agg: Dict[str, float] = {}
54
  for toks in docs:
55
  tf = _tf(toks)
56
  for t, tv in tf.items():
57
  agg[t] = agg.get(t, 0.0) + tv * idf.get(t, 1.0)
58
- # Normalize L1
59
  s = sum(agg.values()) or 1.0
60
  for k in list(agg.keys()):
61
  agg[k] /= s
62
  ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
63
  return [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
64
 
65
- # --- V2 helpers (post-hoc only, deterministic) ---
66
  def chip_cache_key(case_id: str, section: str, text: str) -> str:
67
  """Deterministic cache key for explainability chips."""
68
- import hashlib, json
69
- blob = json.dumps({"case_id": case_id, "section": section, "text": text}, sort_keys=True).encode("utf-8")
70
  return hashlib.sha256(blob).hexdigest()
71
 
72
- def ensure_chip_schema(chips):
73
- """Force a consistent chip schema: [{token, weight}] sorted by weight desc."""
74
- if not isinstance(chips, (list, tuple)):
75
- return []
76
- norm = []
77
- for c in chips:
78
- if not isinstance(c, dict):
79
- continue
80
- tok = str(c.get("token", "")).strip()
81
- w = float(c.get("weight", 0.0))
82
- if tok:
83
- norm.append({"token": tok, "weight": round(w, 4)})
84
- norm.sort(key=lambda x: x["weight"], reverse=True)
85
- return norm
86
 
87
 
88
 
 
1
  # src/explainability.py
2
  from __future__ import annotations
3
+ """Explainability helpers.
4
 
5
+ V3 adds simple hash-based staleness, while keeping deterministic token "chips"
6
+ as a fallback utility (used only if the model omits a rationale).
7
  """
8
 
9
+ import hashlib
10
  import math
11
  import re
12
  from typing import Dict, List
13
 
14
+ # -------------------- NEW: staleness helpers --------------------
15
+
16
+ def normalize_text(s: str) -> str:
17
+ return re.sub(r"\s+", " ", (s or "").strip())
18
+
19
+ def text_hash(s: str) -> str:
20
+ s_norm = normalize_text(s)
21
+ return hashlib.sha256(s_norm.encode("utf-8")).hexdigest()[:16]
22
+
23
+ def is_stale(current_text: str, baseline_hash: str | None) -> bool:
24
+ if not baseline_hash:
25
+ return False
26
+ return text_hash(current_text) != baseline_hash
27
+
28
+ # -------------------- legacy token chips (fallback) ---------------
29
+
30
  def _tokenize(s: str) -> List[str]:
31
  s = s.lower()
 
32
  toks = re.findall(r"[a-z0-9]+", s)
33
  return [t for t in toks if len(t) >= 3]
34
 
 
36
  """Split text into claim-like sentences/lines."""
37
  if not text:
38
  return []
 
39
  raw = re.split(r"[.\n]+", text)
40
  claims = [c.strip() for c in raw if len(c.strip()) >= 12]
41
  return claims[:10]
 
64
  return []
65
  docs = [_tokenize(c) for c in claims]
66
  idf = _idf(docs)
 
67
  agg: Dict[str, float] = {}
68
  for toks in docs:
69
  tf = _tf(toks)
70
  for t, tv in tf.items():
71
  agg[t] = agg.get(t, 0.0) + tv * idf.get(t, 1.0)
 
72
  s = sum(agg.values()) or 1.0
73
  for k in list(agg.keys()):
74
  agg[k] /= s
75
  ranked = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
76
  return [{"token": tok, "weight": round(w, 4)} for tok, w in ranked if w >= min_weight][:top_n]
77
 
 
78
  def chip_cache_key(case_id: str, section: str, text: str) -> str:
79
  """Deterministic cache key for explainability chips."""
80
+ import json
81
+ blob = json.dumps({"case_id": case_id, "section": section, "text": normalize_text(text)}, sort_keys=True).encode("utf-8")
82
  return hashlib.sha256(blob).hexdigest()
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86