Svngoku's picture
Update app.py
9eb891e verified
import ast
import pandas as pd
import gradio as gr
import litellm
import plotly.express as px
from collections import defaultdict
from datetime import datetime
import os
from datasets import load_dataset
import sqlite3
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import logging
from plotly.graph_objects import Figure
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class EvaluationResult:
accuracy: float
subject_accuracy: Dict[str, float]
detailed_results: List[Dict]
class DatabaseManager:
def __init__(self, db_path: str = 'afrimmlu_results.db'):
self.db_path = db_path
self._initialize_database()
def _initialize_database(self) -> None:
"""Initialize SQLite database with required tables."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS summary_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
language TEXT NOT NULL,
subject TEXT NOT NULL,
accuracy REAL NOT NULL,
timestamp TEXT NOT NULL
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS detailed_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
language TEXT NOT NULL,
timestamp TEXT NOT NULL,
subject TEXT NOT NULL,
question TEXT NOT NULL,
model_answer TEXT,
correct_answer TEXT NOT NULL,
is_correct INTEGER NOT NULL,
total_tokens INTEGER
)
''')
conn.commit()
except sqlite3.Error as e:
logger.error(f"Database initialization failed: {str(e)}")
raise
def save_results(self, language: str, summary_results: Dict[str, float],
detailed_results: List[Dict]) -> None:
"""Save evaluation results to database."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
timestamp = datetime.now().isoformat()
# Save summary results
cursor.executemany('''
INSERT INTO summary_results (language, subject, accuracy, timestamp)
VALUES (?, ?, ?, ?)
''', [(language, subject, accuracy, timestamp)
for subject, accuracy in summary_results.items()])
# Save detailed results
cursor.executemany('''
INSERT INTO detailed_results (
language, timestamp, subject, question, model_answer,
correct_answer, is_correct, total_tokens
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', [(language, result['timestamp'], result['subject'], result['question'],
result['model_answer'], result['correct_answer'],
int(result['is_correct']), result['total_tokens'])
for result in detailed_results])
conn.commit()
except sqlite3.Error as e:
logger.error(f"Failed to save results to database: {str(e)}")
raise
def query(self, query: str) -> pd.DataFrame:
"""Execute SQL query and return results as DataFrame."""
try:
with sqlite3.connect(self.db_path) as conn:
return pd.read_sql_query(query, conn)
except sqlite3.Error as e:
logger.error(f"Query execution failed: {str(e)}")
return pd.DataFrame({'Error': [str(e)]})
class AfriMMLUEvaluator:
def __init__(self, model_name: str = "deepseek/deepseek-chat"):
self.model_name = model_name
self.db_manager = DatabaseManager()
def load_data(self, language_code: str = "swa") -> Optional[List[Dict]]:
"""Load AfriMMLU dataset for specified language."""
try:
dataset = load_dataset(
'masakhane/afrimmlu',
language_code,
token=os.getenv('HF_TOKEN')
)
return dataset['test'].to_list()
except Exception as e:
logger.error(f"Failed to load dataset for {language_code}: {str(e)}")
return None
@staticmethod
def preprocess_data(test_data: List[Dict]) -> List[Dict]:
"""Preprocess dataset to convert choices field to list."""
preprocessed_data = []
for example in test_data:
try:
if isinstance(example['choices'], str):
choices_str = example['choices'].strip("'\"").replace("\\'", "'")
example['choices'] = ast.literal_eval(choices_str)
preprocessed_data.append(example)
except (ValueError, SyntaxError) as e:
logger.warning(f"Skipping invalid choices: {example['choices']}")
continue
return preprocessed_data
def evaluate(self, test_data: List[Dict], language: str) -> EvaluationResult:
"""Evaluate model on AfriMMLU dataset."""
results = []
correct = 0
total = 0
subject_results = defaultdict(lambda: {"correct": 0, "total": 0})
for example in test_data:
try:
prompt = self._create_prompt(example)
response = litellm.completion(
model=self.model_name,
messages=[{"role": "user", "content": prompt}]
)
model_answer = self._parse_model_answer(response.choices[0].message.content)
is_correct = model_answer == example['answer'].upper()
if is_correct:
correct += 1
subject_results[example['subject']]["correct"] += 1
total += 1
subject_results[example['subject']]["total"] += 1
results.append({
'timestamp': datetime.now().isoformat(),
'subject': example['subject'],
'question': example['question'],
'model_answer': model_answer,
'correct_answer': example['answer'].upper(),
'is_correct': is_correct,
'total_tokens': response.usage.total_tokens
})
except Exception as e:
logger.warning(f"Error processing question: {str(e)}")
continue
accuracy = (correct / total * 100) if total > 0 else 0
subject_accuracy = {
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
for subject, stats in subject_results.items()
}
self.db_manager.save_results(language, {**subject_accuracy, 'Overall': accuracy}, results)
return EvaluationResult(
accuracy=accuracy,
subject_accuracy=subject_accuracy,
detailed_results=results
)
@staticmethod
def _create_prompt(example: Dict) -> str:
"""Create formatted prompt for model evaluation."""
return (
f"Answer the following multiple-choice question. "
f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n"
f"Question: {example['question']}\n"
f"Options:\n"
f"A. {example['choices'][0]}\n"
f"B. {example['choices'][1]}\n"
f"C. {example['choices'][2]}\n"
f"D. {example['choices'][3]}\n"
f"Answer:"
)
@staticmethod
def _parse_model_answer(output: str) -> Optional[str]:
"""Parse model output to extract answer letter."""
output = output.strip().upper()
for char in output:
if char in ['A', 'B', 'C', 'D']:
return char
return None
class VisualizationManager:
@staticmethod
def create_visualization(results: EvaluationResult) -> Tuple[pd.DataFrame, Figure]:
"""Create visualization from evaluation results."""
summary_data = [
{'Subject': subject, 'Accuracy (%)': accuracy}
for subject, accuracy in results.subject_accuracy.items()
]
summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results.accuracy})
summary_df = pd.DataFrame(summary_data)
fig = px.bar(
summary_df,
x='Subject',
y='Accuracy (%)',
title='AfriMMLU Evaluation Results',
labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'},
template='plotly_white'
)
fig.update_layout(
xaxis_tickangle=-45,
showlegend=False,
height=600,
margin=dict(b=200)
)
return summary_df, fig
def create_gradio_interface() -> gr.Blocks:
"""Create Gradio interface for AfriMMLU evaluation."""
evaluator = AfriMMLUEvaluator()
vis_manager = VisualizationManager()
language_options = {
"swa": "Swahili",
"yor": "Yoruba",
"wol": "Wolof",
"lin": "Lingala",
"ewe": "Ewe",
"ibo": "Igbo"
}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AfriMMLU Evaluation Dashboard")
with gr.Tabs():
with gr.Tab("Model Evaluation"):
with gr.Row():
language_input = gr.Dropdown(
choices=list(language_options.keys()),
label="Select Language",
value="swa"
)
model_input = gr.Dropdown(
choices=["deepseek/deepseek-chat"],
label="Select Model",
value="deepseek/deepseek-chat"
)
evaluate_btn = gr.Button("Evaluate", variant="primary")
summary_table = gr.Dataframe(label="Summary Results")
summary_plot = gr.Plot(label="Performance by Subject")
detailed_results = gr.Dataframe(label="Detailed Results", wrap=True)
with gr.Tab("Database Analysis"):
example_queries = gr.Dropdown(
choices=[
"SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language",
"SELECT subject, AVG(accuracy) as avg_accuracy FROM summary_results GROUP BY subject",
"SELECT language, subject, accuracy, timestamp FROM summary_results ORDER BY timestamp DESC LIMIT 10",
"SELECT language, COUNT(*) as total_questions, SUM(is_correct) as correct_answers FROM detailed_results GROUP BY language",
"SELECT subject, COUNT(*) as total_evaluations FROM summary_results GROUP BY subject"
],
label="Example Queries"
)
query_input = gr.Textbox(
label="SQL Query",
placeholder="Enter your SQL query here",
lines=3
)
query_button = gr.Button("Run Query", variant="primary")
query_output = gr.Dataframe(label="Query Results", wrap=True)
gr.Markdown("""
### Available Tables:
- summary_results (id, language, subject, accuracy, timestamp)
- detailed_results (id, language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens)
""")
def evaluate_language(language_code: str, model_name: str):
evaluator.model_name = model_name
test_data = evaluator.load_data(language_code)
if not test_data:
return None, None, None
preprocessed_data = evaluator.preprocess_data(test_data)
results = evaluator.evaluate(preprocessed_data, language_code)
summary_df, plot = vis_manager.create_visualization(results)
detailed_df = pd.DataFrame(results.detailed_results)
return summary_df, plot, detailed_df
evaluate_btn.click(
fn=evaluate_language,
inputs=[language_input, model_input],
outputs=[summary_table, summary_plot, detailed_results]
)
example_queries.change(
fn=lambda x: x,
inputs=[example_queries],
outputs=[query_input]
)
query_button.click(
fn=evaluator.db_manager.query,
inputs=[query_input],
outputs=[query_output]
)
return demo
if __name__ == "__main__":
try:
# Validate environment variables
required_env_vars = ['DEEPSEEK_API_KEY', 'HF_TOKEN']
for var in required_env_vars:
if not os.getenv(var):
raise EnvironmentError(f"Missing required environment variable: {var}")
demo = create_gradio_interface()
demo.launch(share=True)
except Exception as e:
logger.error(f"Application failed to start: {str(e)}")
raise