ai_econsult_demo / src /context_simplifier.py
Cardiosense-AG's picture
Update src/context_simplifier.py
2420a7b verified
raw
history blame
3.43 kB
"""
context_simplifier.py
---------------------
Optional pre-processor to shorten retrieved RAG evidence context before it’s
passed to MedGemma for SOUP generation.
Default: uses facebook/bart-large-cnn for summarization (seq2seq).
Optionally, you can set USE_MISTRAL = True to use Mistral-7B-Instruct
for summarization via text generation instead.
"""
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# ---------------------------------------------------------------------
# βš™οΈ Configuration toggle
# ---------------------------------------------------------------------
USE_MISTRAL = os.getenv("USE_MISTRAL", "false").lower() in ("true", "1", "yes")
# ---------------------------------------------------------------------
# βœ… Load summarizer pipeline
# ---------------------------------------------------------------------
if USE_MISTRAL:
# Decoder-only model path
print("πŸ”§ Loading Mistral-7B-Instruct as a text-generation summarizer...")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
_simplifier = pipeline(
"text-generation",
model=model_id,
device_map="auto",
torch_dtype="auto",
)
def simplify_rag_context(context_text: str, max_words: int = 400) -> str:
"""Simplify RAG context using a decoder-only model (Mistral)."""
if not context_text.strip():
return "No evidence context available."
prompt = (
"Simplify and condense the following clinical evidence context. "
"Keep only essential numeric values, treatment names, and short "
"recommendations relevant to decision-making. "
f"Limit the summary to about {max_words} words.\n\n"
f"{context_text}\n\nSimplified summary:"
)
try:
out = _simplifier(
prompt,
max_new_tokens=int(max_words * 1.3),
do_sample=False,
temperature=0.0,
)
# Extract only the text after the summarization cue
summary = out[0]["generated_text"].split("Simplified summary:")[-1].strip()
except Exception as e:
summary = f"[Simplification failed: {e}]"
return summary
else:
# True summarization model (encoder-decoder)
print("πŸ”§ Loading BART Large CNN summarizer (default)...")
_simplifier = pipeline(
"summarization",
model="facebook/bart-large-cnn",
device_map="auto",
torch_dtype="auto",
)
def simplify_rag_context(context_text: str, max_words: int = 400) -> str:
"""Simplify RAG context using a true seq2seq summarizer (BART)."""
if not context_text.strip():
return "No evidence context available."
prompt = (
"Simplify and condense the following clinical evidence context. "
"Keep key numeric values, drug names, and brief recommendations. "
f"Limit to about {max_words} words.\n\n{context_text}"
)
try:
result = _simplifier(
prompt,
max_length=int(max_words * 1.3),
min_length=60,
do_sample=False,
temperature=0.0,
)
summary = result[0]["summary_text"].strip()
except Exception as e:
summary = f"[Simplification failed: {e}]"
return summary