imurra commited on
Commit
9f98759
Β·
verified Β·
1 Parent(s): 90c275f

You need to use YOUR ORIGINAL app.py that works with ChromaDB!

Browse files
Files changed (1) hide show
  1. app.py +189 -436
app.py CHANGED
@@ -1,436 +1,189 @@
1
- import gradio as gr
2
- import json
3
- import zipfile
4
- from pathlib import Path
5
- import pandas as pd
6
- from typing import Dict, List, Tuple
7
- import random
8
-
9
- class MedQADatabase:
10
- """Handler for MedQA and Med-Gemini databases"""
11
-
12
- def __init__(self, zip_path="medqa_databases.zip"):
13
- self.data = {
14
- 'medgemini': [],
15
- 'medqa_train': [],
16
- 'medqa_dev': [],
17
- 'medqa_test': []
18
- }
19
- self.load_databases(zip_path)
20
-
21
- def load_databases(self, zip_path):
22
- """Load all databases from the ZIP file"""
23
- print("πŸ“¦ Loading databases from ZIP...")
24
-
25
- try:
26
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
27
- # Extract to temporary directory
28
- zip_ref.extractall('temp_data')
29
-
30
- # Load Med-Gemini
31
- medgemini_path = Path('temp_data/medqa_databases/med_gemini/medqa_relabelling.json')
32
- if medgemini_path.exists():
33
- with open(medgemini_path, 'r', encoding='utf-8') as f:
34
- self.data['medgemini'] = json.load(f)
35
- print(f"βœ… Loaded {len(self.data['medgemini'])} Med-Gemini questions")
36
-
37
- # Load MedQA splits
38
- medqa_base = Path('temp_data/medqa_databases/medqa_original')
39
- for split in ['train', 'dev', 'test']:
40
- split_path = medqa_base / f"{split}.json"
41
- if split_path.exists():
42
- with open(split_path, 'r', encoding='utf-8') as f:
43
- self.data[f'medqa_{split}'] = json.load(f)
44
- print(f"βœ… Loaded {len(self.data[f'medqa_{split}'])} MedQA {split} questions")
45
-
46
- except Exception as e:
47
- print(f"❌ Error loading databases: {e}")
48
- raise
49
-
50
- def get_stats(self) -> str:
51
- """Get database statistics"""
52
- stats = "## πŸ“Š Database Statistics\n\n"
53
- stats += f"**Med-Gemini**: {len(self.data['medgemini']):,} questions\n\n"
54
- stats += f"**MedQA Original**:\n"
55
- stats += f"- Training: {len(self.data['medqa_train']):,} questions\n"
56
- stats += f"- Development: {len(self.data['medqa_dev']):,} questions\n"
57
- stats += f"- Test: {len(self.data['medqa_test']):,} questions\n"
58
- stats += f"- **Total**: {sum(len(self.data[f'medqa_{s}']) for s in ['train', 'dev', 'test']):,} questions\n\n"
59
- stats += f"**Grand Total**: {sum(len(v) for v in self.data.values()):,} questions"
60
- return stats
61
-
62
- def get_question(self, dataset: str, index: int) -> Dict:
63
- """Get a specific question from a dataset"""
64
- try:
65
- return self.data[dataset][index]
66
- except (KeyError, IndexError):
67
- return None
68
-
69
- def search_questions(self, query: str, dataset: str = 'all', max_results: int = 50) -> List[Tuple[str, int, str]]:
70
- """Search questions by keyword"""
71
- results = []
72
- query_lower = query.lower()
73
-
74
- datasets_to_search = list(self.data.keys()) if dataset == 'all' else [dataset]
75
-
76
- for ds in datasets_to_search:
77
- for idx, q in enumerate(self.data[ds]):
78
- # Search in question text
79
- question_text = q.get('question', q.get('Question', ''))
80
- if query_lower in question_text.lower():
81
- preview = question_text[:100] + "..." if len(question_text) > 100 else question_text
82
- results.append((ds, idx, preview))
83
-
84
- if len(results) >= max_results:
85
- return results
86
-
87
- return results
88
-
89
- # Initialize database
90
- print("πŸš€ Initializing MedQA Explorer...")
91
- db = MedQADatabase()
92
-
93
- # ============================================================================
94
- # GRADIO INTERFACE FUNCTIONS
95
- # ============================================================================
96
-
97
- def format_question_display(question_data: Dict, dataset: str) -> str:
98
- """Format question data for display"""
99
-
100
- if not question_data:
101
- return "❌ Question not found"
102
-
103
- # Handle different data formats
104
- if dataset == 'medgemini':
105
- return format_medgemini_question(question_data)
106
- else:
107
- return format_medqa_question(question_data)
108
-
109
- def format_medgemini_question(q: Dict) -> str:
110
- """Format Med-Gemini question"""
111
- html = f"""
112
- <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
113
- <h2 style="color: white; margin: 0;">πŸ”¬ Med-Gemini Question</h2>
114
- </div>
115
-
116
- <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
117
- <h3>πŸ“‹ Question</h3>
118
- <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p>
119
- </div>
120
-
121
- <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;">
122
- <h3>πŸ”€ Answer Options</h3>
123
- """
124
-
125
- # Display options
126
- options = q.get('options', {})
127
- correct_answer = q.get('answer_idx', 'N/A')
128
-
129
- option_labels = ['A', 'B', 'C', 'D', 'E']
130
- for label in option_labels:
131
- option_key = f'opa' if label == 'A' else f'op{label.lower()}'
132
- if option_key in options:
133
- is_correct = (label == correct_answer)
134
- color = '#d4edda' if is_correct else '#fff'
135
- icon = 'βœ…' if is_correct else 'β­•'
136
-
137
- html += f"""
138
- <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;">
139
- {icon} <strong>{label}.</strong> {options[option_key]}
140
- </div>
141
- """
142
-
143
- html += "</div>"
144
-
145
- # Show correct answer
146
- html += f"""
147
- <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;">
148
- <h3 style="margin-top: 0;">βœ… Correct Answer</h3>
149
- <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p>
150
- </div>
151
- """
152
-
153
- # Show explanation if available
154
- explanation = q.get('explanation', q.get('Explanation', ''))
155
- if explanation:
156
- html += f"""
157
- <div style="background: #e7f3ff; padding: 20px; border-radius: 8px; border-left: 4px solid #2196F3;">
158
- <h3 style="margin-top: 0;">πŸ’‘ Explanation</h3>
159
- <p style="line-height: 1.6;">{explanation}</p>
160
- </div>
161
- """
162
-
163
- return html
164
-
165
- def format_medqa_question(q: Dict) -> str:
166
- """Format MedQA original question"""
167
- html = f"""
168
- <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
169
- <h2 style="color: white; margin: 0;">πŸ“š MedQA USMLE Question</h2>
170
- </div>
171
-
172
- <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
173
- <h3>πŸ“‹ Question</h3>
174
- <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p>
175
- </div>
176
-
177
- <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;">
178
- <h3>πŸ”€ Answer Options</h3>
179
- """
180
-
181
- # Display options
182
- options = q.get('options', {})
183
- correct_answer = q.get('answer_idx', 'N/A')
184
-
185
- for key, value in options.items():
186
- label = key.replace('op', '').upper() if key.startswith('op') else key
187
- is_correct = (label == correct_answer)
188
- color = '#d4edda' if is_correct else '#fff'
189
- icon = 'βœ…' if is_correct else 'β­•'
190
-
191
- html += f"""
192
- <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;">
193
- {icon} <strong>{label}.</strong> {value}
194
- </div>
195
- """
196
-
197
- html += "</div>"
198
-
199
- # Show correct answer
200
- html += f"""
201
- <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;">
202
- <h3 style="margin-top: 0;">βœ… Correct Answer</h3>
203
- <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p>
204
- </div>
205
- """
206
-
207
- # Show metamap if available
208
- metamap = q.get('metamap_phrases')
209
- if metamap:
210
- html += f"""
211
- <div style="background: #fff3cd; padding: 15px; border-radius: 8px; border-left: 4px solid #ffc107;">
212
- <h3 style="margin-top: 0;">πŸ₯ Medical Concepts (MetaMap)</h3>
213
- <p style="line-height: 1.6;">{', '.join(metamap)}</p>
214
- </div>
215
- """
216
-
217
- return html
218
-
219
- def browse_questions(dataset: str, index: int) -> Tuple[str, str]:
220
- """Browse questions by index"""
221
- total = len(db.data.get(dataset, []))
222
-
223
- if total == 0:
224
- return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)"
225
-
226
- # Clamp index to valid range
227
- index = max(0, min(index, total - 1))
228
-
229
- question = db.get_question(dataset, index)
230
- html = format_question_display(question, dataset)
231
- info = f"πŸ“Š Question {index + 1} of {total} | Dataset: {dataset}"
232
-
233
- return html, info
234
-
235
- def random_question(dataset: str) -> Tuple[str, str, int]:
236
- """Get a random question"""
237
- total = len(db.data.get(dataset, []))
238
-
239
- if total == 0:
240
- return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)", 0
241
-
242
- index = random.randint(0, total - 1)
243
- question = db.get_question(dataset, index)
244
- html = format_question_display(question, dataset)
245
- info = f"🎲 Random Question {index + 1} of {total} | Dataset: {dataset}"
246
-
247
- return html, info, index
248
-
249
- def search_interface(query: str, dataset: str) -> str:
250
- """Search interface"""
251
- if not query.strip():
252
- return "πŸ’‘ Enter a search query to find questions"
253
-
254
- results = db.search_questions(query, dataset)
255
-
256
- if not results:
257
- return f"❌ No results found for '{query}' in {dataset}"
258
-
259
- html = f"""
260
- <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
261
- <h2 style="color: white; margin: 0;">πŸ” Search Results: "{query}"</h2>
262
- <p style="color: white; margin: 5px 0 0 0;">Found {len(results)} results in {dataset}</p>
263
- </div>
264
- """
265
-
266
- for ds, idx, preview in results[:20]: # Show top 20
267
- dataset_name = ds.replace('_', ' ').title()
268
- html += f"""
269
- <div style="background: #fff; padding: 15px; margin: 10px 0; border-radius: 8px; border-left: 4px solid #667eea;">
270
- <p style="margin: 0; color: #666; font-size: 12px;"><strong>{dataset_name}</strong> - Question #{idx + 1}</p>
271
- <p style="margin: 5px 0 0 0;">{preview}</p>
272
- </div>
273
- """
274
-
275
- if len(results) > 20:
276
- html += f"<p>... and {len(results) - 20} more results</p>"
277
-
278
- return html
279
-
280
- # ============================================================================
281
- # GRADIO APP
282
- # ============================================================================
283
-
284
- with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Database Explorer") as app:
285
-
286
- gr.Markdown("""
287
- # πŸ₯ MedQA Database Explorer
288
-
289
- Explore medical question-answering databases including **Med-Gemini** and **MedQA USMLE**.
290
- """)
291
-
292
- # Statistics
293
- with gr.Accordion("πŸ“Š Database Statistics", open=False):
294
- gr.Markdown(db.get_stats())
295
-
296
- # Main interface
297
- with gr.Tabs():
298
-
299
- # Browse Tab
300
- with gr.Tab("πŸ“– Browse Questions"):
301
- with gr.Row():
302
- with gr.Column(scale=1):
303
- dataset_dropdown = gr.Dropdown(
304
- choices=['medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'],
305
- value='medgemini',
306
- label="Select Database"
307
- )
308
-
309
- question_slider = gr.Slider(
310
- minimum=0,
311
- maximum=len(db.data['medgemini']) - 1,
312
- value=0,
313
- step=1,
314
- label="Question Number"
315
- )
316
-
317
- with gr.Row():
318
- prev_btn = gr.Button("⬅️ Previous", size="sm")
319
- random_btn = gr.Button("🎲 Random", size="sm", variant="primary")
320
- next_btn = gr.Button("Next ➑️", size="sm")
321
-
322
- info_text = gr.Textbox(label="Info", interactive=False)
323
-
324
- with gr.Column(scale=2):
325
- question_display = gr.HTML()
326
-
327
- # Update slider max when dataset changes
328
- def update_slider(dataset):
329
- max_val = len(db.data.get(dataset, [])) - 1
330
- return gr.Slider(maximum=max_val, value=0)
331
-
332
- dataset_dropdown.change(
333
- fn=update_slider,
334
- inputs=[dataset_dropdown],
335
- outputs=[question_slider]
336
- )
337
-
338
- # Browse functions
339
- def show_question(dataset, index):
340
- return browse_questions(dataset, int(index))
341
-
342
- question_slider.change(
343
- fn=show_question,
344
- inputs=[dataset_dropdown, question_slider],
345
- outputs=[question_display, info_text]
346
- )
347
-
348
- dataset_dropdown.change(
349
- fn=show_question,
350
- inputs=[dataset_dropdown, question_slider],
351
- outputs=[question_display, info_text]
352
- )
353
-
354
- # Navigation buttons
355
- def prev_question(dataset, index):
356
- new_index = max(0, int(index) - 1)
357
- html, info = browse_questions(dataset, new_index)
358
- return html, info, new_index
359
-
360
- def next_question(dataset, index):
361
- max_idx = len(db.data.get(dataset, [])) - 1
362
- new_index = min(max_idx, int(index) + 1)
363
- html, info = browse_questions(dataset, new_index)
364
- return html, info, new_index
365
-
366
- prev_btn.click(
367
- fn=prev_question,
368
- inputs=[dataset_dropdown, question_slider],
369
- outputs=[question_display, info_text, question_slider]
370
- )
371
-
372
- next_btn.click(
373
- fn=next_question,
374
- inputs=[dataset_dropdown, question_slider],
375
- outputs=[question_display, info_text, question_slider]
376
- )
377
-
378
- random_btn.click(
379
- fn=random_question,
380
- inputs=[dataset_dropdown],
381
- outputs=[question_display, info_text, question_slider]
382
- )
383
-
384
- # Load first question on start
385
- app.load(
386
- fn=show_question,
387
- inputs=[dataset_dropdown, question_slider],
388
- outputs=[question_display, info_text]
389
- )
390
-
391
- # Search Tab
392
- with gr.Tab("πŸ” Search"):
393
- with gr.Row():
394
- search_query = gr.Textbox(
395
- label="Search Query",
396
- placeholder="Enter keywords (e.g., 'diabetes', 'heart failure', 'treatment')...",
397
- scale=3
398
- )
399
- search_dataset = gr.Dropdown(
400
- choices=['all', 'medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'],
401
- value='all',
402
- label="Search In",
403
- scale=1
404
- )
405
-
406
- search_btn = gr.Button("πŸ” Search", variant="primary")
407
- search_results = gr.HTML()
408
-
409
- search_btn.click(
410
- fn=search_interface,
411
- inputs=[search_query, search_dataset],
412
- outputs=[search_results]
413
- )
414
-
415
- # Also search on Enter key
416
- search_query.submit(
417
- fn=search_interface,
418
- inputs=[search_query, search_dataset],
419
- outputs=[search_results]
420
- )
421
-
422
- gr.Markdown("""
423
- ---
424
- ### πŸ“š About the Databases
425
-
426
- **Med-Gemini**: Expert-relabeled medical questions with detailed explanations from Google's Med-Gemini project.
427
-
428
- **MedQA**: Original USMLE-style medical questions from the MedQA dataset.
429
-
430
- ### πŸ”— Sources
431
- - [Med-Gemini Paper](https://arxiv.org/abs/2404.18416)
432
- - [MedQA Dataset](https://github.com/jind11/MedQA)
433
- """)
434
-
435
- if __name__ == "__main__":
436
- app.launch()
 
1
+ import os
2
+ os.environ['ANONYMIZED_TELEMETRY'] = 'False'
3
+
4
+ import zipfile
5
+ import chromadb
6
+ from sentence_transformers import SentenceTransformer
7
+ import gradio as gr
8
+ from fastapi import FastAPI
9
+ from pydantic import BaseModel
10
+
11
+ # Extract and load database
12
+ DB_PATH = "./medqa_db"
13
+ if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
14
+ print("πŸ“¦ Extracting database...")
15
+ with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
16
+ z.extractall(".")
17
+ print("βœ… Database extracted")
18
+
19
+ print("πŸ”Œ Loading ChromaDB...")
20
+ client = chromadb.PersistentClient(path=DB_PATH)
21
+ collection = client.get_collection("medqa")
22
+ print(f"βœ… Loaded {collection.count()} questions")
23
+
24
+ print("🧠 Loading MedCPT model...")
25
+ model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
26
+ print("βœ… Model ready")
27
+
28
+ # Search function
29
+ def search(query, num_results=3, source_filter=None):
30
+ emb = model.encode(query).tolist()
31
+
32
+ # Apply source filter if specified
33
+ where_clause = None
34
+ if source_filter and source_filter != "all":
35
+ where_clause = {"source": source_filter}
36
+
37
+ return collection.query(
38
+ query_embeddings=[emb],
39
+ n_results=int(num_results),
40
+ where=where_clause
41
+ )
42
+
43
+ # Enhanced Gradio UI
44
+ def ui_search(query, num_results=3, source_filter="all"):
45
+ if not query.strip():
46
+ return "πŸ’‘ Enter a medical query to search"
47
+
48
+ try:
49
+ r = search(query, num_results, source_filter if source_filter != "all" else None)
50
+
51
+ if not r['documents'][0]:
52
+ return "❌ No results found"
53
+
54
+ out = f"πŸ” Found {len(r['documents'][0])} results\n\n"
55
+
56
+ for i in range(len(r['documents'][0])):
57
+ source = r['metadatas'][0][i].get('source', 'unknown')
58
+ distance = r['distances'][0][i]
59
+ similarity = 1 - distance
60
+
61
+ # Source emoji
62
+ if source == 'medgemini':
63
+ source_icon = "πŸ”¬"
64
+ source_name = "Med-Gemini"
65
+ elif source.startswith('medqa_'):
66
+ source_icon = "πŸ“š"
67
+ split = source.replace('medqa_', '').upper()
68
+ source_name = f"MedQA {split}"
69
+ else:
70
+ source_icon = "πŸ“„"
71
+ source_name = source.upper()
72
+
73
+ out += f"\n{'='*70}\n"
74
+ out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n"
75
+ out += f"{'='*70}\n\n"
76
+ out += r['documents'][0][i]
77
+
78
+ # Show answer
79
+ answer = r['metadatas'][0][i].get('answer', 'N/A')
80
+ out += f"\n\nβœ… CORRECT ANSWER: {answer}\n"
81
+
82
+ # Show explanation if available (Med-Gemini)
83
+ explanation = r['metadatas'][0][i].get('explanation', '')
84
+ if explanation and explanation.strip():
85
+ out += f"\nπŸ’‘ EXPLANATION:\n{explanation}\n"
86
+
87
+ out += "\n"
88
+
89
+ return out
90
+
91
+ except Exception as e:
92
+ return f"❌ Error: {e}"
93
+
94
+ # Create Gradio interface
95
+ with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
96
+ gr.Markdown("""
97
+ # πŸ₯ MedQA Semantic Search
98
+
99
+ Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
100
+ Uses medical-specific embeddings (MedCPT) for accurate retrieval.
101
+ """)
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=3):
105
+ query_input = gr.Textbox(
106
+ label="Medical Query",
107
+ placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...",
108
+ lines=2
109
+ )
110
+ with gr.Column(scale=1):
111
+ num_results = gr.Slider(
112
+ minimum=1,
113
+ maximum=10,
114
+ value=3,
115
+ step=1,
116
+ label="Number of Results"
117
+ )
118
+
119
+ with gr.Row():
120
+ source_filter = gr.Radio(
121
+ choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"],
122
+ value="all",
123
+ label="Filter by Source"
124
+ )
125
+
126
+ search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
127
+
128
+ output = gr.Textbox(
129
+ label="Search Results",
130
+ lines=25,
131
+ max_lines=50
132
+ )
133
+
134
+ search_btn.click(
135
+ fn=ui_search,
136
+ inputs=[query_input, num_results, source_filter],
137
+ outputs=output
138
+ )
139
+
140
+ query_input.submit(
141
+ fn=ui_search,
142
+ inputs=[query_input, num_results, source_filter],
143
+ outputs=output
144
+ )
145
+
146
+ gr.Markdown("""
147
+ ### πŸ“Š Database Info
148
+
149
+ **Med-Gemini**: Expert-relabeled questions with detailed explanations
150
+ **MedQA**: USMLE-style questions (Train/Dev/Test splits)
151
+
152
+ **Total Questions**: Use the database you built with `build_combined_db.py`
153
+ """)
154
+
155
+ gr.Examples(
156
+ examples=[
157
+ ["hyponatremia", 3, "all"],
158
+ ["myocardial infarction treatment", 2, "medgemini"],
159
+ ["diabetes complications", 3, "all"],
160
+ ["antibiotics for pneumonia", 2, "medqa_train"]
161
+ ],
162
+ inputs=[query_input, num_results, source_filter]
163
+ )
164
+
165
+ # FastAPI
166
+ app = FastAPI()
167
+
168
+ class SearchRequest(BaseModel):
169
+ query: str
170
+ num_results: int = 3
171
+ source_filter: str = None
172
+
173
+ @app.post("/search_medqa")
174
+ def api_search(req: SearchRequest):
175
+ r = search(req.query, req.num_results, req.source_filter)
176
+ return {"results": [{
177
+ "result_number": i+1,
178
+ "question": r['documents'][0][i],
179
+ "answer": r['metadatas'][0][i].get('answer', 'N/A'),
180
+ "source": r['metadatas'][0][i].get('source', 'unknown'),
181
+ "similarity": 1 - r['distances'][0][i]
182
+ } for i in range(len(r['documents'][0]))]}
183
+
184
+ app = gr.mount_gradio_app(app, demo, path="/")
185
+
186
+ # Launch
187
+ if __name__ == "__main__":
188
+ import uvicorn
189
+ uvicorn.run(app, host="0.0.0.0", port=7860)