ai_econsult_demo / pages /02_Workflow_UI.py
Cardiosense-AG's picture
Update pages/02_Workflow_UI.py
0719965 verified
raw
history blame
9.18 kB
# pages/02_Workflow_UI.py
# -----------------------------------------------------------------------------
# Why this change
# -----------------------------------------------------------------------------
# - Aligns the UI with the Phase 2 orchestration:
# * Removed obsolete arguments (e.g., rag_top_k).
# * Handles updated return keys: context_text, mapping, timings.
# * Adds a checkbox to toggle Explainability (default OFF) and calls the
# existing render_explainability_block() helper when available.
# - Keeps logging minimal and UI layout consistent (two-column intake, single
# action button, JSON + evidence panel rendering).
# - Notes a minimum recommended transformers version and handles version-safe
# imports without adding dependencies.
# -----------------------------------------------------------------------------
from __future__ import annotations
import re
import time
from typing import Dict, Any
import streamlit as st
# Core orchestration
from src import ai_core
# Optional panels: version-safe imports with soft failure
# Evidence panel (hybrid mapping visualization)
try:
from src.reasoning_panel import render_reasoning_panel # type: ignore
_HAS_REASONING_PANEL = True
except Exception:
_HAS_REASONING_PANEL = False
# Explainability provider (render_explainability_block)
try:
from src.explainability import render_explainability_block # type: ignore
_HAS_EXPLAINABILITY = True
except Exception:
_HAS_EXPLAINABILITY = False
# ------------------------------ utils ----------------------------------------
def _parse_semver(v: str) -> tuple:
"""Parse 'x.y.z' → (x, y, z) best-effort, ignoring suffixes."""
if not v:
return (0, 0, 0)
m = re.match(r"^\s*(\d+)\.(\d+)\.(\d+)", v)
if m:
return tuple(int(x) for x in m.groups())
# fallback: split by '.' and cast digits only
parts = re.split(r"[.\-+]", v)
nums = []
for p in parts[:3]:
m2 = re.search(r"(\d+)", p)
nums.append(int(m2.group(1)) if m2 else 0)
while len(nums) < 3:
nums.append(0)
return tuple(nums[:3])
def _transformers_version_note():
try:
import transformers # noqa
ver = getattr(transformers, "__version__", "unknown")
min_needed = (4, 40, 0) # recommended for med-gemma paths + generate kwargs
ok = _parse_semver(ver) >= min_needed
tip = (
f"transformers detected: {ver} — "
+ ("OK" if ok else "upgrade recommended (≥ 4.40.0)")
)
st.caption(tip)
except Exception:
st.caption("transformers: unavailable (UI will still run, model load in core)")
def _build_intake_from_form() -> Dict[str, Any]:
return {
"age": st.session_state.get("age", "").strip(),
"sex": st.session_state.get("sex", "").strip(),
"chief_complaint": st.session_state.get("chief_complaint", "").strip(),
"history": st.session_state.get("history", "").strip(),
"medications": st.session_state.get("medications", "").strip(),
"allergies": st.session_state.get("allergies", "").strip(),
"labs": st.session_state.get("labs", "").strip(),
"imaging": st.session_state.get("imaging", "").strip(),
"question": st.session_state.get("question", "").strip(),
# Any extra context fields from prior versions are preserved if present
"context": st.session_state.get("context", "").strip(),
"referrer": st.session_state.get("referrer", "").strip(),
"priority": st.session_state.get("priority", "").strip(),
}
def _render_mapping_panel(mapping: Dict[str, Any]):
if not mapping:
st.info("No mapping results to display.")
return
if not _HAS_REASONING_PANEL:
st.info("Evidence panel unavailable in this build.")
return
# Version-safe call: pass only 'mapping' if that's the sole parameter
try:
import inspect
sig = inspect.signature(render_reasoning_panel)
if len(sig.parameters) == 1:
render_reasoning_panel(mapping)
elif "mapping" in sig.parameters:
render_reasoning_panel(mapping=mapping)
else:
# Fallback: try positional
render_reasoning_panel(mapping) # type: ignore
except Exception as e:
st.warning(f"Could not render evidence panel: {e}")
def _render_explainability(soap: Dict[str, Any], mapping: Dict[str, Any], raw_text: str, summary: str):
if not _HAS_EXPLAINABILITY:
st.info("Explainability provider not available.")
return
try:
import inspect
sig = inspect.signature(render_explainability_block)
kwargs = {}
if "soap" in sig.parameters:
kwargs["soap"] = soap
if "mapping" in sig.parameters:
kwargs["mapping"] = mapping
if "raw_text" in sig.parameters:
kwargs["raw_text"] = raw_text
if "summary" in sig.parameters:
kwargs["summary"] = summary
if kwargs:
render_explainability_block(**kwargs)
else:
# Positional fallback: try soap first
render_explainability_block(soap) # type: ignore
except TypeError:
try:
render_explainability_block(soap) # type: ignore
except Exception as e:
st.warning(f"Explainability rendering failed: {e}")
except Exception as e:
st.warning(f"Explainability unavailable: {e}")
# ------------------------------ main UI --------------------------------------
def main():
st.title("Step 2 — Workflow: AI Draft + Mapping / RAG")
_transformers_version_note()
# ---- Controls (consistent layout) ---------------------------------------
col_left, col_right = st.columns(2)
with col_left:
st.text_input("Age", key="age", placeholder="e.g., 58")
st.text_input("Sex", key="sex", placeholder="F/M")
st.text_area("Chief complaint", key="chief_complaint", height=80)
st.text_area("History (HPI / background)", key="history", height=140)
st.text_area("Medications", key="medications", height=100)
with col_right:
st.text_area("Allergies", key="allergies", height=80)
st.text_area("Pertinent labs", key="labs", height=120)
st.text_area("Imaging", key="imaging", height=100)
st.text_area("Consult question", key="question", height=100)
st.divider()
# Orchestration options
mode = st.radio(
"Mode",
options=["AI with Guideline Mapping", "AI + RAG (Context Injection)"],
index=0,
horizontal=True,
)
mode_key = "mapping" if mode.startswith("AI with") else "rag"
explain = st.checkbox("Generate natural‑language explanation (Explainability)", value=False)
col_a, col_b, col_c = st.columns([1, 1, 2])
with col_a:
max_new_tokens = st.number_input("Max tokens", min_value=300, max_value=1200, value=700, step=50)
with col_b:
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.2, step=0.1, format="%.1f")
with col_c:
top_p = st.number_input("Top‑p", min_value=0.5, max_value=1.0, value=0.95, step=0.05, format="%.2f")
run = st.button("Generate Draft")
if not run:
return
# ---- Orchestration call --------------------------------------------------
intake = _build_intake_from_form()
t0 = time.time()
out = ai_core.generate_soap_draft(
intake=intake,
mode=mode_key,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
explain=bool(explain), # default OFF respected at core if False
)
dt = time.time() - t0
soap = out.get("soap") or {}
raw_text = out.get("raw_text") or ""
summary = out.get("summary") or ""
context_text = out.get("context_text") or None
mapping = out.get("mapping") or {}
timings = out.get("timings") or {}
# Console log (minimal)
print(
f"[page02] mode={mode_key} | gen={timings.get('generate_secs', 0)}s "
f"| map={timings.get('map_secs', 0)}s | total={dt:.3f}s"
)
# ---- Output blocks -------------------------------------------------------
with st.expander("Referral summary", expanded=False):
st.write(summary if summary else "(empty)")
if mode_key == "rag" and context_text:
with st.expander("Context excerpts (RAG)", expanded=False):
st.write(context_text)
st.subheader("SOAP (JSON)")
st.json(soap, expanded=False)
# Evidence panel for hybrid mapping
if mode_key == "mapping" and mapping:
st.subheader("Evidence Mapping")
_render_mapping_panel(mapping)
# Explainability (optional)
if explain:
st.subheader("Explainability")
_render_explainability(soap, mapping, raw_text, summary)
# Timings
st.caption(
f"Generate: {timings.get('generate_secs', 0)} s | "
f"Map: {timings.get('map_secs', 0)} s | "
f"Total: {timings.get('total_runtime', round(dt,3))} s"
)
if __name__ == "__main__":
main()