|
|
""" |
|
|
context_simplifier.py |
|
|
--------------------- |
|
|
Optional pre-processor to shorten retrieved RAG evidence context before itβs |
|
|
passed to MedGemma for SOUP generation. |
|
|
|
|
|
Default: uses facebook/bart-large-cnn for summarization (seq2seq). |
|
|
Optionally, you can set USE_MISTRAL = True to use Mistral-7B-Instruct |
|
|
for summarization via text generation instead. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_MISTRAL = os.getenv("USE_MISTRAL", "false").lower() in ("true", "1", "yes") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if USE_MISTRAL: |
|
|
|
|
|
print("π§ Loading Mistral-7B-Instruct as a text-generation summarizer...") |
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
_simplifier = pipeline( |
|
|
"text-generation", |
|
|
model=model_id, |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
) |
|
|
|
|
|
def simplify_rag_context(context_text: str, max_words: int = 400) -> str: |
|
|
"""Simplify RAG context using a decoder-only model (Mistral).""" |
|
|
if not context_text.strip(): |
|
|
return "No evidence context available." |
|
|
|
|
|
prompt = ( |
|
|
"Simplify and condense the following clinical evidence context. " |
|
|
"Keep only essential numeric values, treatment names, and short " |
|
|
"recommendations relevant to decision-making. " |
|
|
f"Limit the summary to about {max_words} words.\n\n" |
|
|
f"{context_text}\n\nSimplified summary:" |
|
|
) |
|
|
try: |
|
|
out = _simplifier( |
|
|
prompt, |
|
|
max_new_tokens=int(max_words * 1.3), |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
) |
|
|
|
|
|
summary = out[0]["generated_text"].split("Simplified summary:")[-1].strip() |
|
|
except Exception as e: |
|
|
summary = f"[Simplification failed: {e}]" |
|
|
return summary |
|
|
|
|
|
else: |
|
|
|
|
|
print("π§ Loading BART Large CNN summarizer (default)...") |
|
|
_simplifier = pipeline( |
|
|
"summarization", |
|
|
model="facebook/bart-large-cnn", |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
) |
|
|
|
|
|
def simplify_rag_context(context_text: str, max_words: int = 400) -> str: |
|
|
"""Simplify RAG context using a true seq2seq summarizer (BART).""" |
|
|
if not context_text.strip(): |
|
|
return "No evidence context available." |
|
|
|
|
|
prompt = ( |
|
|
"Simplify and condense the following clinical evidence context. " |
|
|
"Keep key numeric values, drug names, and brief recommendations. " |
|
|
f"Limit to about {max_words} words.\n\n{context_text}" |
|
|
) |
|
|
try: |
|
|
result = _simplifier( |
|
|
prompt, |
|
|
max_length=int(max_words * 1.3), |
|
|
min_length=60, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
) |
|
|
summary = result[0]["summary_text"].strip() |
|
|
except Exception as e: |
|
|
summary = f"[Simplification failed: {e}]" |
|
|
return summary |
|
|
|