imurra's picture
update
4dcbdd1 verified
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import re
import time
# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
print("πŸ“¦ Extracting database...")
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
z.extractall(".")
print("βœ… Database extracted")
print("πŸ“Œ Loading ChromaDB...")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"βœ… Loaded {collection.count()} questions")
print("🧠 Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("βœ… Model ready")
# ============================================================================
# Deduplication function
# ============================================================================
def deduplicate_results(results, target_count):
"""
Remove duplicate questions based on:
1. High text similarity (>0.92) - catches near-exact duplicates
2. Same answer + moderate similarity (>0.85) - catches conceptual duplicates
"""
if not results['documents'][0]:
return results
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
selected_indices = []
for i in range(len(documents)):
is_duplicate = False
current_answer = metadatas[i].get('answer', '')
for j in selected_indices:
selected_answer = metadatas[j].get('answer', '')
dist_diff = abs(distances[i] - distances[j])
if dist_diff < 0.08:
is_duplicate = True
break
if current_answer == selected_answer and dist_diff < 0.15:
is_duplicate = True
break
if not is_duplicate:
selected_indices.append(i)
if len(selected_indices) >= target_count:
break
return {
'documents': [[documents[i] for i in selected_indices]],
'metadatas': [[metadatas[i] for i in selected_indices]],
'distances': [[distances[i] for i in selected_indices]],
'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None
}
# ============================================================================
# Search function with deduplication
# ============================================================================
def search(query, num_results=3, source_filter=None):
emb = model.encode(query).tolist()
where_clause = None
if source_filter and source_filter != "all":
where_clause = {"source": source_filter}
fetch_count = min(num_results * 4, 50)
results = collection.query(
query_embeddings=[emb],
n_results=fetch_count,
where=where_clause
)
return deduplicate_results(results, num_results)
# ============================================================================
# Parser to extract question structure
# ============================================================================
def parse_question_document(doc_text, metadata):
"""Extract question and choices from document text - NO TRUNCATION."""
lines = doc_text.split('\n')
question_lines = []
options_started = False
options = {}
for line in lines:
line = line.strip()
if not line:
continue
# Check if this is an option line (A., B., C., etc.)
option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line)
if option_match:
options_started = True
letter = option_match.group(1)
text = option_match.group(2).strip()
options[letter] = text
elif not options_started:
question_lines.append(line)
# Reconstruct FULL question text - no truncation
question_text = ' '.join(question_lines).strip()
answer_idx = metadata.get('answer_idx', 'N/A')
answer_text = metadata.get('answer', 'N/A')
# If answer_text is just the letter, map it to the actual option text
if answer_text in options:
answer_text = options[answer_text]
return {
'question': question_text,
'choices': options,
'correct_answer_letter': answer_idx,
'correct_answer_text': answer_text
}
# ============================================================================
# Enhanced Gradio UI
# ============================================================================
def ui_search(query, num_results=3, source_filter="all"):
if not query.strip():
return "πŸ’‘ Enter a medical query to search"
try:
r = search(query, num_results, source_filter if source_filter != "all" else None)
if not r['documents'][0]:
return "❌ No results found"
out = f"πŸ” Found {len(r['documents'][0])} unique results\n\n"
for i in range(len(r['documents'][0])):
source = r['metadatas'][0][i].get('source', 'unknown')
distance = r['distances'][0][i]
similarity = 1 - distance
# Source emoji
if source == 'medgemini':
source_icon = "πŸ”¬"
source_name = "Med-Gemini"
elif source.startswith('medqa_'):
source_icon = "πŸ“š"
split = source.replace('medqa_', '').upper()
source_name = f"MedQA {split}"
else:
source_icon = "πŸ“„"
source_name = source.upper()
out += f"\n{'='*70}\n"
out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n"
out += f"{'='*70}\n\n"
out += r['documents'][0][i]
answer = r['metadatas'][0][i].get('answer', 'N/A')
out += f"\n\nβœ… CORRECT ANSWER: {answer}\n"
explanation = r['metadatas'][0][i].get('explanation', '')
if explanation and explanation.strip():
out += f"\nπŸ’‘ EXPLANATION:\n{explanation}\n"
out += "\n"
return out
except Exception as e:
return f"❌ Error: {e}"
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
gr.Markdown("""
# πŸ₯Ό MedQA Semantic Search
Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
Uses medical-specific embeddings (MedCPT) for accurate retrieval.
✨ **Features**: Automatic deduplication, structured output for AI integration
""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Medical Query",
placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...",
lines=2
)
with gr.Column(scale=1):
num_results = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Results"
)
with gr.Row():
source_filter = gr.Radio(
choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"],
value="all",
label="Filter by Source"
)
search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
output = gr.Textbox(
label="Search Results",
lines=25,
max_lines=50
)
search_btn.click(
fn=ui_search,
inputs=[query_input, num_results, source_filter],
outputs=output
)
query_input.submit(
fn=ui_search,
inputs=[query_input, num_results, source_filter],
outputs=output
)
gr.Markdown("""
### πŸ“Š Database Info
**Med-Gemini**: Expert-relabeled questions with detailed explanations
**MedQA**: USMLE-style questions (Train/Dev/Test splits)
**Total Questions**: ~10,000+ USMLE-style questions
""")
gr.Examples(
examples=[
["hyponatremia", 3, "all"],
["myocardial infarction treatment", 2, "medgemini"],
["diabetes complications", 3, "all"],
["antibiotics for pneumonia", 2, "medqa_train"]
],
inputs=[query_input, num_results, source_filter]
)
# ============================================================================
# FastAPI with structured JSON output (for OpenAI integration)
# ============================================================================
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
source_filter: str = None
class BatchSearchRequest(BaseModel):
queries: List[str]
num_results_per_query: int = 10
source_filter: Optional[str] = None
@app.post("/search_medqa")
def api_search(req: SearchRequest):
"""
Search MedQA and return structured exemplars.
Returns COMPLETE question text with no truncation.
"""
r = search(req.query, req.num_results, req.source_filter)
if not r['documents'][0]:
return {"results": []}
results = []
for i in range(len(r['documents'][0])):
doc_text = r['documents'][0][i]
metadata = r['metadatas'][0][i]
# Parse the document into structured format
parsed = parse_question_document(doc_text, metadata)
# Build complete result object
result = {
"result_number": i + 1,
"question": parsed['question'], # FULL question text
"choices": parsed['choices'],
"correct_answer": parsed['correct_answer_letter'],
"correct_answer_text": parsed['correct_answer_text'],
"explanation": metadata.get('explanation', ''),
"has_explanation": bool(metadata.get('explanation', '').strip()),
"source": metadata.get('source', 'unknown'),
"exam_type": metadata.get('exam_type', 'unknown'),
"split": metadata.get('split', 'unknown'),
"similarity": round(1 - r['distances'][0][i], 3),
"metamap_phrases": metadata.get('metamap_phrases', '')
}
results.append(result)
return {"results": results}
@app.post("/batch_search_medqa")
def batch_api_search(req: BatchSearchRequest):
"""
NEW: Batch search for multiple learning objectives.
Processes all queries, tracks duplicates, and returns organized results.
Returns:
- results_by_objective: List of results organized by each objective
- unique_questions: Deduplicated list of all questions
- statistics: Coverage and quality metrics
"""
start_time = time.time()
# Track all questions and their objective mappings
all_questions = {} # key: question_text, value: question data + objectives list
results_by_objective = []
for obj_idx, query in enumerate(req.queries):
objective_id = obj_idx + 1
# Search for this objective
r = search(query, req.num_results_per_query, req.source_filter)
objective_results = []
similarities = []
if r['documents'][0]:
for i in range(len(r['documents'][0])):
doc_text = r['documents'][0][i]
metadata = r['metadatas'][0][i]
similarity = round(1 - r['distances'][0][i], 3)
similarities.append(similarity)
# Parse the document
parsed = parse_question_document(doc_text, metadata)
# Create unique key for deduplication
question_key = parsed['question'][:200] # Use first 200 chars as key
# Build result object
result = {
"question": parsed['question'],
"choices": parsed['choices'],
"correct_answer": parsed['correct_answer_letter'],
"correct_answer_text": parsed['correct_answer_text'],
"explanation": metadata.get('explanation', ''),
"has_explanation": bool(metadata.get('explanation', '').strip()),
"source": metadata.get('source', 'unknown'),
"similarity": similarity
}
# Track for global deduplication
if question_key in all_questions:
# This question already exists - add this objective to its list
all_questions[question_key]['matches_objectives'].append(objective_id)
# Update similarity if higher
if similarity > all_questions[question_key]['max_similarity']:
all_questions[question_key]['max_similarity'] = similarity
else:
# First time seeing this question
all_questions[question_key] = {
**result,
'matches_objectives': [objective_id],
'max_similarity': similarity,
'first_seen_at': objective_id
}
objective_results.append(result)
# Store results for this objective
results_by_objective.append({
"objective_id": objective_id,
"objective_text": query,
"num_results": len(objective_results),
"avg_similarity": round(sum(similarities) / len(similarities), 3) if similarities else 0,
"results": objective_results
})
# Prepare unique questions list
unique_questions = list(all_questions.values())
# Calculate statistics
execution_time = round(time.time() - start_time, 2)
total_retrieved = sum(obj['num_results'] for obj in results_by_objective)
# Coverage analysis
coverage = {
"excellent": [obj for obj in results_by_objective if obj['num_results'] >= 5],
"moderate": [obj for obj in results_by_objective if 2 <= obj['num_results'] < 5],
"limited": [obj for obj in results_by_objective if obj['num_results'] == 1],
"none": [obj for obj in results_by_objective if obj['num_results'] == 0]
}
# Multi-objective questions
multi_objective_questions = [q for q in unique_questions if len(q['matches_objectives']) > 1]
# Source distribution
sources = {}
for q in unique_questions:
source = q['source']
sources[source] = sources.get(source, 0) + 1
# Similarity distribution
all_similarities = [q['max_similarity'] for q in unique_questions]
high_sim = len([s for s in all_similarities if s > 0.8])
med_sim = len([s for s in all_similarities if 0.7 <= s <= 0.8])
low_sim = len([s for s in all_similarities if s < 0.7])
statistics = {
"total_objectives": len(req.queries),
"total_retrieved": total_retrieved,
"unique_questions": len(unique_questions),
"deduplication_rate": round((total_retrieved - len(unique_questions)) / total_retrieved * 100, 1) if total_retrieved > 0 else 0,
"execution_time_seconds": execution_time,
"coverage": {
"excellent_coverage_count": len(coverage["excellent"]),
"moderate_coverage_count": len(coverage["moderate"]),
"limited_coverage_count": len(coverage["limited"]),
"no_coverage_count": len(coverage["none"]),
"no_coverage_objectives": [obj['objective_id'] for obj in coverage["none"]]
},
"cross_objective": {
"multi_objective_questions": len(multi_objective_questions),
"multi_objective_percentage": round(len(multi_objective_questions) / len(unique_questions) * 100, 1) if unique_questions else 0
},
"sources": sources,
"similarity_distribution": {
"high_similarity_count": high_sim,
"medium_similarity_count": med_sim,
"low_similarity_count": low_sim,
"average_similarity": round(sum(all_similarities) / len(all_similarities), 3) if all_similarities else 0
}
}
return {
"results_by_objective": results_by_objective,
"unique_questions": unique_questions,
"statistics": statistics
}
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)