Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['ANONYMIZED_TELEMETRY'] = 'False' | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import re | |
| import time | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| print("π¦ Extracting database...") | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| print("β Database extracted") | |
| print("π Loading ChromaDB...") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"β Loaded {collection.count()} questions") | |
| print("π§ Loading MedCPT model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("β Model ready") | |
| # ============================================================================ | |
| # Deduplication function | |
| # ============================================================================ | |
| def deduplicate_results(results, target_count): | |
| """ | |
| Remove duplicate questions based on: | |
| 1. High text similarity (>0.92) - catches near-exact duplicates | |
| 2. Same answer + moderate similarity (>0.85) - catches conceptual duplicates | |
| """ | |
| if not results['documents'][0]: | |
| return results | |
| documents = results['documents'][0] | |
| metadatas = results['metadatas'][0] | |
| distances = results['distances'][0] | |
| selected_indices = [] | |
| for i in range(len(documents)): | |
| is_duplicate = False | |
| current_answer = metadatas[i].get('answer', '') | |
| for j in selected_indices: | |
| selected_answer = metadatas[j].get('answer', '') | |
| dist_diff = abs(distances[i] - distances[j]) | |
| if dist_diff < 0.08: | |
| is_duplicate = True | |
| break | |
| if current_answer == selected_answer and dist_diff < 0.15: | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| selected_indices.append(i) | |
| if len(selected_indices) >= target_count: | |
| break | |
| return { | |
| 'documents': [[documents[i] for i in selected_indices]], | |
| 'metadatas': [[metadatas[i] for i in selected_indices]], | |
| 'distances': [[distances[i] for i in selected_indices]], | |
| 'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None | |
| } | |
| # ============================================================================ | |
| # Search function with deduplication | |
| # ============================================================================ | |
| def search(query, num_results=3, source_filter=None): | |
| emb = model.encode(query).tolist() | |
| where_clause = None | |
| if source_filter and source_filter != "all": | |
| where_clause = {"source": source_filter} | |
| fetch_count = min(num_results * 4, 50) | |
| results = collection.query( | |
| query_embeddings=[emb], | |
| n_results=fetch_count, | |
| where=where_clause | |
| ) | |
| return deduplicate_results(results, num_results) | |
| # ============================================================================ | |
| # Parser to extract question structure | |
| # ============================================================================ | |
| def parse_question_document(doc_text, metadata): | |
| """Extract question and choices from document text - NO TRUNCATION.""" | |
| lines = doc_text.split('\n') | |
| question_lines = [] | |
| options_started = False | |
| options = {} | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Check if this is an option line (A., B., C., etc.) | |
| option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line) | |
| if option_match: | |
| options_started = True | |
| letter = option_match.group(1) | |
| text = option_match.group(2).strip() | |
| options[letter] = text | |
| elif not options_started: | |
| question_lines.append(line) | |
| # Reconstruct FULL question text - no truncation | |
| question_text = ' '.join(question_lines).strip() | |
| answer_idx = metadata.get('answer_idx', 'N/A') | |
| answer_text = metadata.get('answer', 'N/A') | |
| # If answer_text is just the letter, map it to the actual option text | |
| if answer_text in options: | |
| answer_text = options[answer_text] | |
| return { | |
| 'question': question_text, | |
| 'choices': options, | |
| 'correct_answer_letter': answer_idx, | |
| 'correct_answer_text': answer_text | |
| } | |
| # ============================================================================ | |
| # Enhanced Gradio UI | |
| # ============================================================================ | |
| def ui_search(query, num_results=3, source_filter="all"): | |
| if not query.strip(): | |
| return "π‘ Enter a medical query to search" | |
| try: | |
| r = search(query, num_results, source_filter if source_filter != "all" else None) | |
| if not r['documents'][0]: | |
| return "β No results found" | |
| out = f"π Found {len(r['documents'][0])} unique results\n\n" | |
| for i in range(len(r['documents'][0])): | |
| source = r['metadatas'][0][i].get('source', 'unknown') | |
| distance = r['distances'][0][i] | |
| similarity = 1 - distance | |
| # Source emoji | |
| if source == 'medgemini': | |
| source_icon = "π¬" | |
| source_name = "Med-Gemini" | |
| elif source.startswith('medqa_'): | |
| source_icon = "π" | |
| split = source.replace('medqa_', '').upper() | |
| source_name = f"MedQA {split}" | |
| else: | |
| source_icon = "π" | |
| source_name = source.upper() | |
| out += f"\n{'='*70}\n" | |
| out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n" | |
| out += f"{'='*70}\n\n" | |
| out += r['documents'][0][i] | |
| answer = r['metadatas'][0][i].get('answer', 'N/A') | |
| out += f"\n\nβ CORRECT ANSWER: {answer}\n" | |
| explanation = r['metadatas'][0][i].get('explanation', '') | |
| if explanation and explanation.strip(): | |
| out += f"\nπ‘ EXPLANATION:\n{explanation}\n" | |
| out += "\n" | |
| return out | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: | |
| gr.Markdown(""" | |
| # π₯Ό MedQA Semantic Search | |
| Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases. | |
| Uses medical-specific embeddings (MedCPT) for accurate retrieval. | |
| β¨ **Features**: Automatic deduplication, structured output for AI integration | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Medical Query", | |
| placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of Results" | |
| ) | |
| with gr.Row(): | |
| source_filter = gr.Radio( | |
| choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"], | |
| value="all", | |
| label="Filter by Source" | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| output = gr.Textbox( | |
| label="Search Results", | |
| lines=25, | |
| max_lines=50 | |
| ) | |
| search_btn.click( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| query_input.submit( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| ### π Database Info | |
| **Med-Gemini**: Expert-relabeled questions with detailed explanations | |
| **MedQA**: USMLE-style questions (Train/Dev/Test splits) | |
| **Total Questions**: ~10,000+ USMLE-style questions | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| ["hyponatremia", 3, "all"], | |
| ["myocardial infarction treatment", 2, "medgemini"], | |
| ["diabetes complications", 3, "all"], | |
| ["antibiotics for pneumonia", 2, "medqa_train"] | |
| ], | |
| inputs=[query_input, num_results, source_filter] | |
| ) | |
| # ============================================================================ | |
| # FastAPI with structured JSON output (for OpenAI integration) | |
| # ============================================================================ | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| source_filter: str = None | |
| class BatchSearchRequest(BaseModel): | |
| queries: List[str] | |
| num_results_per_query: int = 10 | |
| source_filter: Optional[str] = None | |
| def api_search(req: SearchRequest): | |
| """ | |
| Search MedQA and return structured exemplars. | |
| Returns COMPLETE question text with no truncation. | |
| """ | |
| r = search(req.query, req.num_results, req.source_filter) | |
| if not r['documents'][0]: | |
| return {"results": []} | |
| results = [] | |
| for i in range(len(r['documents'][0])): | |
| doc_text = r['documents'][0][i] | |
| metadata = r['metadatas'][0][i] | |
| # Parse the document into structured format | |
| parsed = parse_question_document(doc_text, metadata) | |
| # Build complete result object | |
| result = { | |
| "result_number": i + 1, | |
| "question": parsed['question'], # FULL question text | |
| "choices": parsed['choices'], | |
| "correct_answer": parsed['correct_answer_letter'], | |
| "correct_answer_text": parsed['correct_answer_text'], | |
| "explanation": metadata.get('explanation', ''), | |
| "has_explanation": bool(metadata.get('explanation', '').strip()), | |
| "source": metadata.get('source', 'unknown'), | |
| "exam_type": metadata.get('exam_type', 'unknown'), | |
| "split": metadata.get('split', 'unknown'), | |
| "similarity": round(1 - r['distances'][0][i], 3), | |
| "metamap_phrases": metadata.get('metamap_phrases', '') | |
| } | |
| results.append(result) | |
| return {"results": results} | |
| def batch_api_search(req: BatchSearchRequest): | |
| """ | |
| NEW: Batch search for multiple learning objectives. | |
| Processes all queries, tracks duplicates, and returns organized results. | |
| Returns: | |
| - results_by_objective: List of results organized by each objective | |
| - unique_questions: Deduplicated list of all questions | |
| - statistics: Coverage and quality metrics | |
| """ | |
| start_time = time.time() | |
| # Track all questions and their objective mappings | |
| all_questions = {} # key: question_text, value: question data + objectives list | |
| results_by_objective = [] | |
| for obj_idx, query in enumerate(req.queries): | |
| objective_id = obj_idx + 1 | |
| # Search for this objective | |
| r = search(query, req.num_results_per_query, req.source_filter) | |
| objective_results = [] | |
| similarities = [] | |
| if r['documents'][0]: | |
| for i in range(len(r['documents'][0])): | |
| doc_text = r['documents'][0][i] | |
| metadata = r['metadatas'][0][i] | |
| similarity = round(1 - r['distances'][0][i], 3) | |
| similarities.append(similarity) | |
| # Parse the document | |
| parsed = parse_question_document(doc_text, metadata) | |
| # Create unique key for deduplication | |
| question_key = parsed['question'][:200] # Use first 200 chars as key | |
| # Build result object | |
| result = { | |
| "question": parsed['question'], | |
| "choices": parsed['choices'], | |
| "correct_answer": parsed['correct_answer_letter'], | |
| "correct_answer_text": parsed['correct_answer_text'], | |
| "explanation": metadata.get('explanation', ''), | |
| "has_explanation": bool(metadata.get('explanation', '').strip()), | |
| "source": metadata.get('source', 'unknown'), | |
| "similarity": similarity | |
| } | |
| # Track for global deduplication | |
| if question_key in all_questions: | |
| # This question already exists - add this objective to its list | |
| all_questions[question_key]['matches_objectives'].append(objective_id) | |
| # Update similarity if higher | |
| if similarity > all_questions[question_key]['max_similarity']: | |
| all_questions[question_key]['max_similarity'] = similarity | |
| else: | |
| # First time seeing this question | |
| all_questions[question_key] = { | |
| **result, | |
| 'matches_objectives': [objective_id], | |
| 'max_similarity': similarity, | |
| 'first_seen_at': objective_id | |
| } | |
| objective_results.append(result) | |
| # Store results for this objective | |
| results_by_objective.append({ | |
| "objective_id": objective_id, | |
| "objective_text": query, | |
| "num_results": len(objective_results), | |
| "avg_similarity": round(sum(similarities) / len(similarities), 3) if similarities else 0, | |
| "results": objective_results | |
| }) | |
| # Prepare unique questions list | |
| unique_questions = list(all_questions.values()) | |
| # Calculate statistics | |
| execution_time = round(time.time() - start_time, 2) | |
| total_retrieved = sum(obj['num_results'] for obj in results_by_objective) | |
| # Coverage analysis | |
| coverage = { | |
| "excellent": [obj for obj in results_by_objective if obj['num_results'] >= 5], | |
| "moderate": [obj for obj in results_by_objective if 2 <= obj['num_results'] < 5], | |
| "limited": [obj for obj in results_by_objective if obj['num_results'] == 1], | |
| "none": [obj for obj in results_by_objective if obj['num_results'] == 0] | |
| } | |
| # Multi-objective questions | |
| multi_objective_questions = [q for q in unique_questions if len(q['matches_objectives']) > 1] | |
| # Source distribution | |
| sources = {} | |
| for q in unique_questions: | |
| source = q['source'] | |
| sources[source] = sources.get(source, 0) + 1 | |
| # Similarity distribution | |
| all_similarities = [q['max_similarity'] for q in unique_questions] | |
| high_sim = len([s for s in all_similarities if s > 0.8]) | |
| med_sim = len([s for s in all_similarities if 0.7 <= s <= 0.8]) | |
| low_sim = len([s for s in all_similarities if s < 0.7]) | |
| statistics = { | |
| "total_objectives": len(req.queries), | |
| "total_retrieved": total_retrieved, | |
| "unique_questions": len(unique_questions), | |
| "deduplication_rate": round((total_retrieved - len(unique_questions)) / total_retrieved * 100, 1) if total_retrieved > 0 else 0, | |
| "execution_time_seconds": execution_time, | |
| "coverage": { | |
| "excellent_coverage_count": len(coverage["excellent"]), | |
| "moderate_coverage_count": len(coverage["moderate"]), | |
| "limited_coverage_count": len(coverage["limited"]), | |
| "no_coverage_count": len(coverage["none"]), | |
| "no_coverage_objectives": [obj['objective_id'] for obj in coverage["none"]] | |
| }, | |
| "cross_objective": { | |
| "multi_objective_questions": len(multi_objective_questions), | |
| "multi_objective_percentage": round(len(multi_objective_questions) / len(unique_questions) * 100, 1) if unique_questions else 0 | |
| }, | |
| "sources": sources, | |
| "similarity_distribution": { | |
| "high_similarity_count": high_sim, | |
| "medium_similarity_count": med_sim, | |
| "low_similarity_count": low_sim, | |
| "average_similarity": round(sum(all_similarities) / len(all_similarities), 3) if all_similarities else 0 | |
| } | |
| } | |
| return { | |
| "results_by_objective": results_by_objective, | |
| "unique_questions": unique_questions, | |
| "statistics": statistics | |
| } | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |