|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
import time |
|
|
from typing import Dict, Any |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
from src import ai_core |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from src.reasoning_panel import render_reasoning_panel |
|
|
_HAS_REASONING_PANEL = True |
|
|
except Exception: |
|
|
_HAS_REASONING_PANEL = False |
|
|
|
|
|
|
|
|
try: |
|
|
from src.explainability import render_explainability_block |
|
|
_HAS_EXPLAINABILITY = True |
|
|
except Exception: |
|
|
_HAS_EXPLAINABILITY = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_semver(v: str) -> tuple: |
|
|
"""Parse 'x.y.z' → (x, y, z) best-effort, ignoring suffixes.""" |
|
|
if not v: |
|
|
return (0, 0, 0) |
|
|
m = re.match(r"^\s*(\d+)\.(\d+)\.(\d+)", v) |
|
|
if m: |
|
|
return tuple(int(x) for x in m.groups()) |
|
|
|
|
|
parts = re.split(r"[.\-+]", v) |
|
|
nums = [] |
|
|
for p in parts[:3]: |
|
|
m2 = re.search(r"(\d+)", p) |
|
|
nums.append(int(m2.group(1)) if m2 else 0) |
|
|
while len(nums) < 3: |
|
|
nums.append(0) |
|
|
return tuple(nums[:3]) |
|
|
|
|
|
|
|
|
def _transformers_version_note(): |
|
|
try: |
|
|
import transformers |
|
|
ver = getattr(transformers, "__version__", "unknown") |
|
|
min_needed = (4, 40, 0) |
|
|
ok = _parse_semver(ver) >= min_needed |
|
|
tip = ( |
|
|
f"transformers detected: {ver} — " |
|
|
+ ("OK" if ok else "upgrade recommended (≥ 4.40.0)") |
|
|
) |
|
|
st.caption(tip) |
|
|
except Exception: |
|
|
st.caption("transformers: unavailable (UI will still run, model load in core)") |
|
|
|
|
|
|
|
|
def _build_intake_from_form() -> Dict[str, Any]: |
|
|
return { |
|
|
"age": st.session_state.get("age", "").strip(), |
|
|
"sex": st.session_state.get("sex", "").strip(), |
|
|
"chief_complaint": st.session_state.get("chief_complaint", "").strip(), |
|
|
"history": st.session_state.get("history", "").strip(), |
|
|
"medications": st.session_state.get("medications", "").strip(), |
|
|
"allergies": st.session_state.get("allergies", "").strip(), |
|
|
"labs": st.session_state.get("labs", "").strip(), |
|
|
"imaging": st.session_state.get("imaging", "").strip(), |
|
|
"question": st.session_state.get("question", "").strip(), |
|
|
|
|
|
"context": st.session_state.get("context", "").strip(), |
|
|
"referrer": st.session_state.get("referrer", "").strip(), |
|
|
"priority": st.session_state.get("priority", "").strip(), |
|
|
} |
|
|
|
|
|
|
|
|
def _render_mapping_panel(mapping: Dict[str, Any]): |
|
|
if not mapping: |
|
|
st.info("No mapping results to display.") |
|
|
return |
|
|
if not _HAS_REASONING_PANEL: |
|
|
st.info("Evidence panel unavailable in this build.") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
import inspect |
|
|
sig = inspect.signature(render_reasoning_panel) |
|
|
if len(sig.parameters) == 1: |
|
|
render_reasoning_panel(mapping) |
|
|
elif "mapping" in sig.parameters: |
|
|
render_reasoning_panel(mapping=mapping) |
|
|
else: |
|
|
|
|
|
render_reasoning_panel(mapping) |
|
|
except Exception as e: |
|
|
st.warning(f"Could not render evidence panel: {e}") |
|
|
|
|
|
|
|
|
def _render_explainability(soap: Dict[str, Any], mapping: Dict[str, Any], raw_text: str, summary: str): |
|
|
if not _HAS_EXPLAINABILITY: |
|
|
st.info("Explainability provider not available.") |
|
|
return |
|
|
try: |
|
|
import inspect |
|
|
sig = inspect.signature(render_explainability_block) |
|
|
kwargs = {} |
|
|
if "soap" in sig.parameters: |
|
|
kwargs["soap"] = soap |
|
|
if "mapping" in sig.parameters: |
|
|
kwargs["mapping"] = mapping |
|
|
if "raw_text" in sig.parameters: |
|
|
kwargs["raw_text"] = raw_text |
|
|
if "summary" in sig.parameters: |
|
|
kwargs["summary"] = summary |
|
|
if kwargs: |
|
|
render_explainability_block(**kwargs) |
|
|
else: |
|
|
|
|
|
render_explainability_block(soap) |
|
|
except TypeError: |
|
|
try: |
|
|
render_explainability_block(soap) |
|
|
except Exception as e: |
|
|
st.warning(f"Explainability rendering failed: {e}") |
|
|
except Exception as e: |
|
|
st.warning(f"Explainability unavailable: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("Step 2 — Workflow: AI Draft + Mapping / RAG") |
|
|
|
|
|
_transformers_version_note() |
|
|
|
|
|
|
|
|
col_left, col_right = st.columns(2) |
|
|
|
|
|
with col_left: |
|
|
st.text_input("Age", key="age", placeholder="e.g., 58") |
|
|
st.text_input("Sex", key="sex", placeholder="F/M") |
|
|
st.text_area("Chief complaint", key="chief_complaint", height=80) |
|
|
st.text_area("History (HPI / background)", key="history", height=140) |
|
|
st.text_area("Medications", key="medications", height=100) |
|
|
|
|
|
with col_right: |
|
|
st.text_area("Allergies", key="allergies", height=80) |
|
|
st.text_area("Pertinent labs", key="labs", height=120) |
|
|
st.text_area("Imaging", key="imaging", height=100) |
|
|
st.text_area("Consult question", key="question", height=100) |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
mode = st.radio( |
|
|
"Mode", |
|
|
options=["AI with Guideline Mapping", "AI + RAG (Context Injection)"], |
|
|
index=0, |
|
|
horizontal=True, |
|
|
) |
|
|
mode_key = "mapping" if mode.startswith("AI with") else "rag" |
|
|
|
|
|
explain = st.checkbox("Generate natural‑language explanation (Explainability)", value=False) |
|
|
|
|
|
col_a, col_b, col_c = st.columns([1, 1, 2]) |
|
|
with col_a: |
|
|
max_new_tokens = st.number_input("Max tokens", min_value=300, max_value=1200, value=700, step=50) |
|
|
with col_b: |
|
|
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.2, step=0.1, format="%.1f") |
|
|
with col_c: |
|
|
top_p = st.number_input("Top‑p", min_value=0.5, max_value=1.0, value=0.95, step=0.05, format="%.2f") |
|
|
|
|
|
run = st.button("Generate Draft") |
|
|
|
|
|
if not run: |
|
|
return |
|
|
|
|
|
|
|
|
intake = _build_intake_from_form() |
|
|
|
|
|
t0 = time.time() |
|
|
out = ai_core.generate_soap_draft( |
|
|
intake=intake, |
|
|
mode=mode_key, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
explain=bool(explain), |
|
|
) |
|
|
dt = time.time() - t0 |
|
|
|
|
|
soap = out.get("soap") or {} |
|
|
raw_text = out.get("raw_text") or "" |
|
|
summary = out.get("summary") or "" |
|
|
context_text = out.get("context_text") or None |
|
|
mapping = out.get("mapping") or {} |
|
|
timings = out.get("timings") or {} |
|
|
|
|
|
|
|
|
print( |
|
|
f"[page02] mode={mode_key} | gen={timings.get('generate_secs', 0)}s " |
|
|
f"| map={timings.get('map_secs', 0)}s | total={dt:.3f}s" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with st.expander("Referral summary", expanded=False): |
|
|
st.write(summary if summary else "(empty)") |
|
|
|
|
|
if mode_key == "rag" and context_text: |
|
|
with st.expander("Context excerpts (RAG)", expanded=False): |
|
|
st.write(context_text) |
|
|
|
|
|
st.subheader("SOAP (JSON)") |
|
|
st.json(soap, expanded=False) |
|
|
|
|
|
|
|
|
if mode_key == "mapping" and mapping: |
|
|
st.subheader("Evidence Mapping") |
|
|
_render_mapping_panel(mapping) |
|
|
|
|
|
|
|
|
if explain: |
|
|
st.subheader("Explainability") |
|
|
_render_explainability(soap, mapping, raw_text, summary) |
|
|
|
|
|
|
|
|
st.caption( |
|
|
f"Generate: {timings.get('generate_secs', 0)} s | " |
|
|
f"Map: {timings.get('map_secs', 0)} s | " |
|
|
f"Total: {timings.get('total_runtime', round(dt,3))} s" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|