""" context_simplifier.py --------------------- Optional pre-processor to shorten retrieved RAG evidence context before it’s passed to MedGemma for SOUP generation. Default: uses facebook/bart-large-cnn for summarization (seq2seq). Optionally, you can set USE_MISTRAL = True to use Mistral-7B-Instruct for summarization via text generation instead. """ import os from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM # --------------------------------------------------------------------- # βš™οΈ Configuration toggle # --------------------------------------------------------------------- USE_MISTRAL = os.getenv("USE_MISTRAL", "false").lower() in ("true", "1", "yes") # --------------------------------------------------------------------- # βœ… Load summarizer pipeline # --------------------------------------------------------------------- if USE_MISTRAL: # Decoder-only model path print("πŸ”§ Loading Mistral-7B-Instruct as a text-generation summarizer...") model_id = "mistralai/Mistral-7B-Instruct-v0.2" _simplifier = pipeline( "text-generation", model=model_id, device_map="auto", torch_dtype="auto", ) def simplify_rag_context(context_text: str, max_words: int = 400) -> str: """Simplify RAG context using a decoder-only model (Mistral).""" if not context_text.strip(): return "No evidence context available." prompt = ( "Simplify and condense the following clinical evidence context. " "Keep only essential numeric values, treatment names, and short " "recommendations relevant to decision-making. " f"Limit the summary to about {max_words} words.\n\n" f"{context_text}\n\nSimplified summary:" ) try: out = _simplifier( prompt, max_new_tokens=int(max_words * 1.3), do_sample=False, temperature=0.0, ) # Extract only the text after the summarization cue summary = out[0]["generated_text"].split("Simplified summary:")[-1].strip() except Exception as e: summary = f"[Simplification failed: {e}]" return summary else: # True summarization model (encoder-decoder) print("πŸ”§ Loading BART Large CNN summarizer (default)...") _simplifier = pipeline( "summarization", model="facebook/bart-large-cnn", device_map="auto", torch_dtype="auto", ) def simplify_rag_context(context_text: str, max_words: int = 400) -> str: """Simplify RAG context using a true seq2seq summarizer (BART).""" if not context_text.strip(): return "No evidence context available." prompt = ( "Simplify and condense the following clinical evidence context. " "Keep key numeric values, drug names, and brief recommendations. " f"Limit to about {max_words} words.\n\n{context_text}" ) try: result = _simplifier( prompt, max_length=int(max_words * 1.3), min_length=60, do_sample=False, temperature=0.0, ) summary = result[0]["summary_text"].strip() except Exception as e: summary = f"[Simplification failed: {e}]" return summary