Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import zipfile | |
| from pathlib import Path | |
| import pandas as pd | |
| from typing import Dict, List, Tuple | |
| import random | |
| class MedQADatabase: | |
| """Handler for MedQA and Med-Gemini databases""" | |
| def __init__(self, zip_path="medqa_databases.zip"): | |
| self.data = { | |
| 'medgemini': [], | |
| 'medqa_train': [], | |
| 'medqa_dev': [], | |
| 'medqa_test': [] | |
| } | |
| self.load_databases(zip_path) | |
| def load_databases(self, zip_path): | |
| """Load all databases from the ZIP file""" | |
| print("π¦ Loading databases from ZIP...") | |
| try: | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| # Extract to temporary directory | |
| zip_ref.extractall('temp_data') | |
| # Load Med-Gemini | |
| medgemini_path = Path('temp_data/medqa_databases/med_gemini/medqa_relabelling.json') | |
| if medgemini_path.exists(): | |
| with open(medgemini_path, 'r', encoding='utf-8') as f: | |
| self.data['medgemini'] = json.load(f) | |
| print(f"β Loaded {len(self.data['medgemini'])} Med-Gemini questions") | |
| # Load MedQA splits | |
| medqa_base = Path('temp_data/medqa_databases/medqa_original') | |
| for split in ['train', 'dev', 'test']: | |
| split_path = medqa_base / f"{split}.json" | |
| if split_path.exists(): | |
| with open(split_path, 'r', encoding='utf-8') as f: | |
| self.data[f'medqa_{split}'] = json.load(f) | |
| print(f"β Loaded {len(self.data[f'medqa_{split}'])} MedQA {split} questions") | |
| except Exception as e: | |
| print(f"β Error loading databases: {e}") | |
| raise | |
| def get_stats(self) -> str: | |
| """Get database statistics""" | |
| stats = "## π Database Statistics\n\n" | |
| stats += f"**Med-Gemini**: {len(self.data['medgemini']):,} questions\n\n" | |
| stats += f"**MedQA Original**:\n" | |
| stats += f"- Training: {len(self.data['medqa_train']):,} questions\n" | |
| stats += f"- Development: {len(self.data['medqa_dev']):,} questions\n" | |
| stats += f"- Test: {len(self.data['medqa_test']):,} questions\n" | |
| stats += f"- **Total**: {sum(len(self.data[f'medqa_{s}']) for s in ['train', 'dev', 'test']):,} questions\n\n" | |
| stats += f"**Grand Total**: {sum(len(v) for v in self.data.values()):,} questions" | |
| return stats | |
| def get_question(self, dataset: str, index: int) -> Dict: | |
| """Get a specific question from a dataset""" | |
| try: | |
| return self.data[dataset][index] | |
| except (KeyError, IndexError): | |
| return None | |
| def search_questions(self, query: str, dataset: str = 'all', max_results: int = 50) -> List[Tuple[str, int, str]]: | |
| """Search questions by keyword""" | |
| results = [] | |
| query_lower = query.lower() | |
| datasets_to_search = list(self.data.keys()) if dataset == 'all' else [dataset] | |
| for ds in datasets_to_search: | |
| for idx, q in enumerate(self.data[ds]): | |
| # Search in question text | |
| question_text = q.get('question', q.get('Question', '')) | |
| if query_lower in question_text.lower(): | |
| preview = question_text[:100] + "..." if len(question_text) > 100 else question_text | |
| results.append((ds, idx, preview)) | |
| if len(results) >= max_results: | |
| return results | |
| return results | |
| # Initialize database | |
| print("π Initializing MedQA Explorer...") | |
| db = MedQADatabase() | |
| # ============================================================================ | |
| # GRADIO INTERFACE FUNCTIONS | |
| # ============================================================================ | |
| def format_question_display(question_data: Dict, dataset: str) -> str: | |
| """Format question data for display""" | |
| if not question_data: | |
| return "β Question not found" | |
| # Handle different data formats | |
| if dataset == 'medgemini': | |
| return format_medgemini_question(question_data) | |
| else: | |
| return format_medqa_question(question_data) | |
| def format_medgemini_question(q: Dict) -> str: | |
| """Format Med-Gemini question""" | |
| html = f""" | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;"> | |
| <h2 style="color: white; margin: 0;">π¬ Med-Gemini Question</h2> | |
| </div> | |
| <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;"> | |
| <h3>π Question</h3> | |
| <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p> | |
| </div> | |
| <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;"> | |
| <h3>π€ Answer Options</h3> | |
| """ | |
| # Display options | |
| options = q.get('options', {}) | |
| correct_answer = q.get('answer_idx', 'N/A') | |
| option_labels = ['A', 'B', 'C', 'D', 'E'] | |
| for label in option_labels: | |
| option_key = f'opa' if label == 'A' else f'op{label.lower()}' | |
| if option_key in options: | |
| is_correct = (label == correct_answer) | |
| color = '#d4edda' if is_correct else '#fff' | |
| icon = 'β ' if is_correct else 'β' | |
| html += f""" | |
| <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;"> | |
| {icon} <strong>{label}.</strong> {options[option_key]} | |
| </div> | |
| """ | |
| html += "</div>" | |
| # Show correct answer | |
| html += f""" | |
| <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;"> | |
| <h3 style="margin-top: 0;">β Correct Answer</h3> | |
| <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p> | |
| </div> | |
| """ | |
| # Show explanation if available | |
| explanation = q.get('explanation', q.get('Explanation', '')) | |
| if explanation: | |
| html += f""" | |
| <div style="background: #e7f3ff; padding: 20px; border-radius: 8px; border-left: 4px solid #2196F3;"> | |
| <h3 style="margin-top: 0;">π‘ Explanation</h3> | |
| <p style="line-height: 1.6;">{explanation}</p> | |
| </div> | |
| """ | |
| return html | |
| def format_medqa_question(q: Dict) -> str: | |
| """Format MedQA original question""" | |
| html = f""" | |
| <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;"> | |
| <h2 style="color: white; margin: 0;">π MedQA USMLE Question</h2> | |
| </div> | |
| <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;"> | |
| <h3>π Question</h3> | |
| <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p> | |
| </div> | |
| <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;"> | |
| <h3>π€ Answer Options</h3> | |
| """ | |
| # Display options | |
| options = q.get('options', {}) | |
| correct_answer = q.get('answer_idx', 'N/A') | |
| for key, value in options.items(): | |
| label = key.replace('op', '').upper() if key.startswith('op') else key | |
| is_correct = (label == correct_answer) | |
| color = '#d4edda' if is_correct else '#fff' | |
| icon = 'β ' if is_correct else 'β' | |
| html += f""" | |
| <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;"> | |
| {icon} <strong>{label}.</strong> {value} | |
| </div> | |
| """ | |
| html += "</div>" | |
| # Show correct answer | |
| html += f""" | |
| <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;"> | |
| <h3 style="margin-top: 0;">β Correct Answer</h3> | |
| <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p> | |
| </div> | |
| """ | |
| # Show metamap if available | |
| metamap = q.get('metamap_phrases') | |
| if metamap: | |
| html += f""" | |
| <div style="background: #fff3cd; padding: 15px; border-radius: 8px; border-left: 4px solid #ffc107;"> | |
| <h3 style="margin-top: 0;">π₯ Medical Concepts (MetaMap)</h3> | |
| <p style="line-height: 1.6;">{', '.join(metamap)}</p> | |
| </div> | |
| """ | |
| return html | |
| def browse_questions(dataset: str, index: int) -> Tuple[str, str]: | |
| """Browse questions by index""" | |
| total = len(db.data.get(dataset, [])) | |
| if total == 0: | |
| return "β No questions in this dataset", f"Dataset: {dataset} (empty)" | |
| # Clamp index to valid range | |
| index = max(0, min(index, total - 1)) | |
| question = db.get_question(dataset, index) | |
| html = format_question_display(question, dataset) | |
| info = f"π Question {index + 1} of {total} | Dataset: {dataset}" | |
| return html, info | |
| def random_question(dataset: str) -> Tuple[str, str, int]: | |
| """Get a random question""" | |
| total = len(db.data.get(dataset, [])) | |
| if total == 0: | |
| return "β No questions in this dataset", f"Dataset: {dataset} (empty)", 0 | |
| index = random.randint(0, total - 1) | |
| question = db.get_question(dataset, index) | |
| html = format_question_display(question, dataset) | |
| info = f"π² Random Question {index + 1} of {total} | Dataset: {dataset}" | |
| return html, info, index | |
| def search_interface(query: str, dataset: str) -> str: | |
| """Search interface""" | |
| if not query.strip(): | |
| return "π‘ Enter a search query to find questions" | |
| results = db.search_questions(query, dataset) | |
| if not results: | |
| return f"β No results found for '{query}' in {dataset}" | |
| html = f""" | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;"> | |
| <h2 style="color: white; margin: 0;">π Search Results: "{query}"</h2> | |
| <p style="color: white; margin: 5px 0 0 0;">Found {len(results)} results in {dataset}</p> | |
| </div> | |
| """ | |
| for ds, idx, preview in results[:20]: # Show top 20 | |
| dataset_name = ds.replace('_', ' ').title() | |
| html += f""" | |
| <div style="background: #fff; padding: 15px; margin: 10px 0; border-radius: 8px; border-left: 4px solid #667eea;"> | |
| <p style="margin: 0; color: #666; font-size: 12px;"><strong>{dataset_name}</strong> - Question #{idx + 1}</p> | |
| <p style="margin: 5px 0 0 0;">{preview}</p> | |
| </div> | |
| """ | |
| if len(results) > 20: | |
| html += f"<p>... and {len(results) - 20} more results</p>" | |
| return html | |
| # ============================================================================ | |
| # GRADIO APP | |
| # ============================================================================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Database Explorer") as app: | |
| gr.Markdown(""" | |
| # π₯ MedQA Database Explorer | |
| Explore medical question-answering databases including **Med-Gemini** and **MedQA USMLE**. | |
| """) | |
| # Statistics | |
| with gr.Accordion("π Database Statistics", open=False): | |
| gr.Markdown(db.get_stats()) | |
| # Main interface | |
| with gr.Tabs(): | |
| # Browse Tab | |
| with gr.Tab("π Browse Questions"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=['medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], | |
| value='medgemini', | |
| label="Select Database" | |
| ) | |
| question_slider = gr.Slider( | |
| minimum=0, | |
| maximum=len(db.data['medgemini']) - 1, | |
| value=0, | |
| step=1, | |
| label="Question Number" | |
| ) | |
| with gr.Row(): | |
| prev_btn = gr.Button("β¬ οΈ Previous", size="sm") | |
| random_btn = gr.Button("π² Random", size="sm", variant="primary") | |
| next_btn = gr.Button("Next β‘οΈ", size="sm") | |
| info_text = gr.Textbox(label="Info", interactive=False) | |
| with gr.Column(scale=2): | |
| question_display = gr.HTML() | |
| # Update slider max when dataset changes | |
| def update_slider(dataset): | |
| max_val = len(db.data.get(dataset, [])) - 1 | |
| return gr.Slider(maximum=max_val, value=0) | |
| dataset_dropdown.change( | |
| fn=update_slider, | |
| inputs=[dataset_dropdown], | |
| outputs=[question_slider] | |
| ) | |
| # Browse functions | |
| def show_question(dataset, index): | |
| return browse_questions(dataset, int(index)) | |
| question_slider.change( | |
| fn=show_question, | |
| inputs=[dataset_dropdown, question_slider], | |
| outputs=[question_display, info_text] | |
| ) | |
| dataset_dropdown.change( | |
| fn=show_question, | |
| inputs=[dataset_dropdown, question_slider], | |
| outputs=[question_display, info_text] | |
| ) | |
| # Navigation buttons | |
| def prev_question(dataset, index): | |
| new_index = max(0, int(index) - 1) | |
| html, info = browse_questions(dataset, new_index) | |
| return html, info, new_index | |
| def next_question(dataset, index): | |
| max_idx = len(db.data.get(dataset, [])) - 1 | |
| new_index = min(max_idx, int(index) + 1) | |
| html, info = browse_questions(dataset, new_index) | |
| return html, info, new_index | |
| prev_btn.click( | |
| fn=prev_question, | |
| inputs=[dataset_dropdown, question_slider], | |
| outputs=[question_display, info_text, question_slider] | |
| ) | |
| next_btn.click( | |
| fn=next_question, | |
| inputs=[dataset_dropdown, question_slider], | |
| outputs=[question_display, info_text, question_slider] | |
| ) | |
| random_btn.click( | |
| fn=random_question, | |
| inputs=[dataset_dropdown], | |
| outputs=[question_display, info_text, question_slider] | |
| ) | |
| # Load first question on start | |
| app.load( | |
| fn=show_question, | |
| inputs=[dataset_dropdown, question_slider], | |
| outputs=[question_display, info_text] | |
| ) | |
| # Search Tab | |
| with gr.Tab("π Search"): | |
| with gr.Row(): | |
| search_query = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Enter keywords (e.g., 'diabetes', 'heart failure', 'treatment')...", | |
| scale=3 | |
| ) | |
| search_dataset = gr.Dropdown( | |
| choices=['all', 'medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], | |
| value='all', | |
| label="Search In", | |
| scale=1 | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary") | |
| search_results = gr.HTML() | |
| search_btn.click( | |
| fn=search_interface, | |
| inputs=[search_query, search_dataset], | |
| outputs=[search_results] | |
| ) | |
| # Also search on Enter key | |
| search_query.submit( | |
| fn=search_interface, | |
| inputs=[search_query, search_dataset], | |
| outputs=[search_results] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π About the Databases | |
| **Med-Gemini**: Expert-relabeled medical questions with detailed explanations from Google's Med-Gemini project. | |
| **MedQA**: Original USMLE-style medical questions from the MedQA dataset. | |
| ### π Sources | |
| - [Med-Gemini Paper](https://arxiv.org/abs/2404.18416) | |
| - [MedQA Dataset](https://github.com/jind11/MedQA) | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |