Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['ANONYMIZED_TELEMETRY'] = 'False' | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| print("π¦ Extracting database...") | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| print("β Database extracted") | |
| print("π Loading ChromaDB...") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"β Loaded {collection.count()} questions") | |
| print("π§ Loading MedCPT model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("β Model ready") | |
| # ============================================================================ | |
| # NEW: Deduplication function | |
| # ============================================================================ | |
| def deduplicate_results(results, target_count): | |
| """ | |
| Remove duplicate questions based on: | |
| 1. High text similarity (>0.92) - catches near-exact duplicates | |
| 2. Same answer + moderate similarity (>0.85) - catches conceptual duplicates | |
| """ | |
| if not results['documents'][0]: | |
| return results | |
| documents = results['documents'][0] | |
| metadatas = results['metadatas'][0] | |
| distances = results['distances'][0] | |
| selected_indices = [] | |
| for i in range(len(documents)): | |
| is_duplicate = False | |
| current_answer = metadatas[i].get('answer', '') | |
| # Compare to already-selected results | |
| for j in selected_indices: | |
| selected_answer = metadatas[j].get('answer', '') | |
| # Calculate similarity between questions | |
| # Lower distance = higher similarity | |
| dist_diff = abs(distances[i] - distances[j]) | |
| # Rule 1: Very similar questions (likely exact/near-exact duplicates) | |
| if dist_diff < 0.08: # Roughly equivalent to >0.92 similarity | |
| is_duplicate = True | |
| break | |
| # Rule 2: Same answer + similar question (conceptual duplicates) | |
| if current_answer == selected_answer and dist_diff < 0.15: # ~0.85 similarity | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| selected_indices.append(i) | |
| # Stop when we have enough unique results | |
| if len(selected_indices) >= target_count: | |
| break | |
| # Return filtered results in same format | |
| return { | |
| 'documents': [[documents[i] for i in selected_indices]], | |
| 'metadatas': [[metadatas[i] for i in selected_indices]], | |
| 'distances': [[distances[i] for i in selected_indices]], | |
| 'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None | |
| } | |
| # ============================================================================ | |
| # MODIFIED: Search function with deduplication | |
| # ============================================================================ | |
| def search(query, num_results=3, source_filter=None): | |
| emb = model.encode(query).tolist() | |
| # Apply source filter if specified | |
| where_clause = None | |
| if source_filter and source_filter != "all": | |
| where_clause = {"source": source_filter} | |
| # Over-fetch to ensure we get enough unique results | |
| fetch_count = min(num_results * 4, 50) # Fetch 4x but cap at 50 | |
| results = collection.query( | |
| query_embeddings=[emb], | |
| n_results=fetch_count, | |
| where=where_clause | |
| ) | |
| # Deduplicate and return only requested number | |
| return deduplicate_results(results, num_results) | |
| # Enhanced Gradio UI | |
| def ui_search(query, num_results=3, source_filter="all"): | |
| if not query.strip(): | |
| return "π‘ Enter a medical query to search" | |
| try: | |
| r = search(query, num_results, source_filter if source_filter != "all" else None) | |
| if not r['documents'][0]: | |
| return "β No results found" | |
| out = f"π Found {len(r['documents'][0])} unique results\n\n" | |
| for i in range(len(r['documents'][0])): | |
| source = r['metadatas'][0][i].get('source', 'unknown') | |
| distance = r['distances'][0][i] | |
| similarity = 1 - distance | |
| # Source emoji | |
| if source == 'medgemini': | |
| source_icon = "π¬" | |
| source_name = "Med-Gemini" | |
| elif source.startswith('medqa_'): | |
| source_icon = "π" | |
| split = source.replace('medqa_', '').upper() | |
| source_name = f"MedQA {split}" | |
| else: | |
| source_icon = "π" | |
| source_name = source.upper() | |
| out += f"\n{'='*70}\n" | |
| out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n" | |
| out += f"{'='*70}\n\n" | |
| out += r['documents'][0][i] | |
| # Show answer | |
| answer = r['metadatas'][0][i].get('answer', 'N/A') | |
| out += f"\n\nβ CORRECT ANSWER: {answer}\n" | |
| # Show explanation if available (Med-Gemini) | |
| explanation = r['metadatas'][0][i].get('explanation', '') | |
| if explanation and explanation.strip(): | |
| out += f"\nπ‘ EXPLANATION:\n{explanation}\n" | |
| out += "\n" | |
| return out | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: | |
| gr.Markdown(""" | |
| # π₯ MedQA Semantic Search | |
| Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases. | |
| Uses medical-specific embeddings (MedCPT) for accurate retrieval. | |
| β¨ **New**: Automatic deduplication removes similar/duplicate questions | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Medical Query", | |
| placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of Results" | |
| ) | |
| with gr.Row(): | |
| source_filter = gr.Radio( | |
| choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"], | |
| value="all", | |
| label="Filter by Source" | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| output = gr.Textbox( | |
| label="Search Results", | |
| lines=25, | |
| max_lines=50 | |
| ) | |
| search_btn.click( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| query_input.submit( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| ### π Database Info | |
| **Med-Gemini**: Expert-relabeled questions with detailed explanations | |
| **MedQA**: USMLE-style questions (Train/Dev/Test splits) | |
| **Total Questions**: Use the database you built with `build_combined_db.py` | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| ["hyponatremia", 3, "all"], | |
| ["myocardial infarction treatment", 2, "medgemini"], | |
| ["diabetes complications", 3, "all"], | |
| ["antibiotics for pneumonia", 2, "medqa_train"] | |
| ], | |
| inputs=[query_input, num_results, source_filter] | |
| ) | |
| # FastAPI | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| source_filter: str = None | |
| def api_search(req: SearchRequest): | |
| r = search(req.query, req.num_results, req.source_filter) | |
| return {"results": [{ | |
| "result_number": i+1, | |
| "question": r['documents'][0][i], | |
| "answer": r['metadatas'][0][i].get('answer', 'N/A'), | |
| "source": r['metadatas'][0][i].get('source', 'unknown'), | |
| "similarity": 1 - r['distances'][0][i] | |
| } for i in range(len(r['documents'][0]))]} | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |