|
|
import ast |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
import litellm |
|
|
import plotly.express as px |
|
|
from collections import defaultdict |
|
|
from datetime import datetime |
|
|
import os |
|
|
from datasets import load_dataset |
|
|
import sqlite3 |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
import logging |
|
|
from plotly.graph_objects import Figure |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class EvaluationResult: |
|
|
accuracy: float |
|
|
subject_accuracy: Dict[str, float] |
|
|
detailed_results: List[Dict] |
|
|
|
|
|
class DatabaseManager: |
|
|
def __init__(self, db_path: str = 'afrimmlu_results.db'): |
|
|
self.db_path = db_path |
|
|
self._initialize_database() |
|
|
|
|
|
def _initialize_database(self) -> None: |
|
|
"""Initialize SQLite database with required tables.""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS summary_results ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
language TEXT NOT NULL, |
|
|
subject TEXT NOT NULL, |
|
|
accuracy REAL NOT NULL, |
|
|
timestamp TEXT NOT NULL |
|
|
) |
|
|
''') |
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS detailed_results ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
language TEXT NOT NULL, |
|
|
timestamp TEXT NOT NULL, |
|
|
subject TEXT NOT NULL, |
|
|
question TEXT NOT NULL, |
|
|
model_answer TEXT, |
|
|
correct_answer TEXT NOT NULL, |
|
|
is_correct INTEGER NOT NULL, |
|
|
total_tokens INTEGER |
|
|
) |
|
|
''') |
|
|
|
|
|
conn.commit() |
|
|
except sqlite3.Error as e: |
|
|
logger.error(f"Database initialization failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
def save_results(self, language: str, summary_results: Dict[str, float], |
|
|
detailed_results: List[Dict]) -> None: |
|
|
"""Save evaluation results to database.""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
timestamp = datetime.now().isoformat() |
|
|
|
|
|
|
|
|
cursor.executemany(''' |
|
|
INSERT INTO summary_results (language, subject, accuracy, timestamp) |
|
|
VALUES (?, ?, ?, ?) |
|
|
''', [(language, subject, accuracy, timestamp) |
|
|
for subject, accuracy in summary_results.items()]) |
|
|
|
|
|
|
|
|
cursor.executemany(''' |
|
|
INSERT INTO detailed_results ( |
|
|
language, timestamp, subject, question, model_answer, |
|
|
correct_answer, is_correct, total_tokens |
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?) |
|
|
''', [(language, result['timestamp'], result['subject'], result['question'], |
|
|
result['model_answer'], result['correct_answer'], |
|
|
int(result['is_correct']), result['total_tokens']) |
|
|
for result in detailed_results]) |
|
|
|
|
|
conn.commit() |
|
|
except sqlite3.Error as e: |
|
|
logger.error(f"Failed to save results to database: {str(e)}") |
|
|
raise |
|
|
|
|
|
def query(self, query: str) -> pd.DataFrame: |
|
|
"""Execute SQL query and return results as DataFrame.""" |
|
|
try: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
return pd.read_sql_query(query, conn) |
|
|
except sqlite3.Error as e: |
|
|
logger.error(f"Query execution failed: {str(e)}") |
|
|
return pd.DataFrame({'Error': [str(e)]}) |
|
|
|
|
|
class AfriMMLUEvaluator: |
|
|
def __init__(self, model_name: str = "deepseek/deepseek-chat"): |
|
|
self.model_name = model_name |
|
|
self.db_manager = DatabaseManager() |
|
|
|
|
|
def load_data(self, language_code: str = "swa") -> Optional[List[Dict]]: |
|
|
"""Load AfriMMLU dataset for specified language.""" |
|
|
try: |
|
|
dataset = load_dataset( |
|
|
'masakhane/afrimmlu', |
|
|
language_code, |
|
|
token=os.getenv('HF_TOKEN') |
|
|
) |
|
|
return dataset['test'].to_list() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset for {language_code}: {str(e)}") |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def preprocess_data(test_data: List[Dict]) -> List[Dict]: |
|
|
"""Preprocess dataset to convert choices field to list.""" |
|
|
preprocessed_data = [] |
|
|
for example in test_data: |
|
|
try: |
|
|
if isinstance(example['choices'], str): |
|
|
choices_str = example['choices'].strip("'\"").replace("\\'", "'") |
|
|
example['choices'] = ast.literal_eval(choices_str) |
|
|
preprocessed_data.append(example) |
|
|
except (ValueError, SyntaxError) as e: |
|
|
logger.warning(f"Skipping invalid choices: {example['choices']}") |
|
|
continue |
|
|
return preprocessed_data |
|
|
|
|
|
def evaluate(self, test_data: List[Dict], language: str) -> EvaluationResult: |
|
|
"""Evaluate model on AfriMMLU dataset.""" |
|
|
results = [] |
|
|
correct = 0 |
|
|
total = 0 |
|
|
subject_results = defaultdict(lambda: {"correct": 0, "total": 0}) |
|
|
|
|
|
for example in test_data: |
|
|
try: |
|
|
prompt = self._create_prompt(example) |
|
|
response = litellm.completion( |
|
|
model=self.model_name, |
|
|
messages=[{"role": "user", "content": prompt}] |
|
|
) |
|
|
|
|
|
model_answer = self._parse_model_answer(response.choices[0].message.content) |
|
|
is_correct = model_answer == example['answer'].upper() |
|
|
|
|
|
if is_correct: |
|
|
correct += 1 |
|
|
subject_results[example['subject']]["correct"] += 1 |
|
|
total += 1 |
|
|
subject_results[example['subject']]["total"] += 1 |
|
|
|
|
|
results.append({ |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
'subject': example['subject'], |
|
|
'question': example['question'], |
|
|
'model_answer': model_answer, |
|
|
'correct_answer': example['answer'].upper(), |
|
|
'is_correct': is_correct, |
|
|
'total_tokens': response.usage.total_tokens |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error processing question: {str(e)}") |
|
|
continue |
|
|
|
|
|
accuracy = (correct / total * 100) if total > 0 else 0 |
|
|
subject_accuracy = { |
|
|
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0 |
|
|
for subject, stats in subject_results.items() |
|
|
} |
|
|
|
|
|
self.db_manager.save_results(language, {**subject_accuracy, 'Overall': accuracy}, results) |
|
|
|
|
|
return EvaluationResult( |
|
|
accuracy=accuracy, |
|
|
subject_accuracy=subject_accuracy, |
|
|
detailed_results=results |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _create_prompt(example: Dict) -> str: |
|
|
"""Create formatted prompt for model evaluation.""" |
|
|
return ( |
|
|
f"Answer the following multiple-choice question. " |
|
|
f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n" |
|
|
f"Question: {example['question']}\n" |
|
|
f"Options:\n" |
|
|
f"A. {example['choices'][0]}\n" |
|
|
f"B. {example['choices'][1]}\n" |
|
|
f"C. {example['choices'][2]}\n" |
|
|
f"D. {example['choices'][3]}\n" |
|
|
f"Answer:" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _parse_model_answer(output: str) -> Optional[str]: |
|
|
"""Parse model output to extract answer letter.""" |
|
|
output = output.strip().upper() |
|
|
for char in output: |
|
|
if char in ['A', 'B', 'C', 'D']: |
|
|
return char |
|
|
return None |
|
|
|
|
|
class VisualizationManager: |
|
|
@staticmethod |
|
|
def create_visualization(results: EvaluationResult) -> Tuple[pd.DataFrame, Figure]: |
|
|
"""Create visualization from evaluation results.""" |
|
|
summary_data = [ |
|
|
{'Subject': subject, 'Accuracy (%)': accuracy} |
|
|
for subject, accuracy in results.subject_accuracy.items() |
|
|
] |
|
|
summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results.accuracy}) |
|
|
summary_df = pd.DataFrame(summary_data) |
|
|
|
|
|
fig = px.bar( |
|
|
summary_df, |
|
|
x='Subject', |
|
|
y='Accuracy (%)', |
|
|
title='AfriMMLU Evaluation Results', |
|
|
labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'}, |
|
|
template='plotly_white' |
|
|
) |
|
|
fig.update_layout( |
|
|
xaxis_tickangle=-45, |
|
|
showlegend=False, |
|
|
height=600, |
|
|
margin=dict(b=200) |
|
|
) |
|
|
|
|
|
return summary_df, fig |
|
|
|
|
|
def create_gradio_interface() -> gr.Blocks: |
|
|
"""Create Gradio interface for AfriMMLU evaluation.""" |
|
|
evaluator = AfriMMLUEvaluator() |
|
|
vis_manager = VisualizationManager() |
|
|
|
|
|
language_options = { |
|
|
"swa": "Swahili", |
|
|
"yor": "Yoruba", |
|
|
"wol": "Wolof", |
|
|
"lin": "Lingala", |
|
|
"ewe": "Ewe", |
|
|
"ibo": "Igbo" |
|
|
} |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# AfriMMLU Evaluation Dashboard") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Model Evaluation"): |
|
|
with gr.Row(): |
|
|
language_input = gr.Dropdown( |
|
|
choices=list(language_options.keys()), |
|
|
label="Select Language", |
|
|
value="swa" |
|
|
) |
|
|
model_input = gr.Dropdown( |
|
|
choices=["deepseek/deepseek-chat"], |
|
|
label="Select Model", |
|
|
value="deepseek/deepseek-chat" |
|
|
) |
|
|
evaluate_btn = gr.Button("Evaluate", variant="primary") |
|
|
|
|
|
summary_table = gr.Dataframe(label="Summary Results") |
|
|
summary_plot = gr.Plot(label="Performance by Subject") |
|
|
detailed_results = gr.Dataframe(label="Detailed Results", wrap=True) |
|
|
|
|
|
with gr.Tab("Database Analysis"): |
|
|
example_queries = gr.Dropdown( |
|
|
choices=[ |
|
|
"SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language", |
|
|
"SELECT subject, AVG(accuracy) as avg_accuracy FROM summary_results GROUP BY subject", |
|
|
"SELECT language, subject, accuracy, timestamp FROM summary_results ORDER BY timestamp DESC LIMIT 10", |
|
|
"SELECT language, COUNT(*) as total_questions, SUM(is_correct) as correct_answers FROM detailed_results GROUP BY language", |
|
|
"SELECT subject, COUNT(*) as total_evaluations FROM summary_results GROUP BY subject" |
|
|
], |
|
|
label="Example Queries" |
|
|
) |
|
|
query_input = gr.Textbox( |
|
|
label="SQL Query", |
|
|
placeholder="Enter your SQL query here", |
|
|
lines=3 |
|
|
) |
|
|
query_button = gr.Button("Run Query", variant="primary") |
|
|
query_output = gr.Dataframe(label="Query Results", wrap=True) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Available Tables: |
|
|
- summary_results (id, language, subject, accuracy, timestamp) |
|
|
- detailed_results (id, language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens) |
|
|
""") |
|
|
|
|
|
def evaluate_language(language_code: str, model_name: str): |
|
|
evaluator.model_name = model_name |
|
|
test_data = evaluator.load_data(language_code) |
|
|
if not test_data: |
|
|
return None, None, None |
|
|
|
|
|
preprocessed_data = evaluator.preprocess_data(test_data) |
|
|
results = evaluator.evaluate(preprocessed_data, language_code) |
|
|
summary_df, plot = vis_manager.create_visualization(results) |
|
|
detailed_df = pd.DataFrame(results.detailed_results) |
|
|
|
|
|
return summary_df, plot, detailed_df |
|
|
|
|
|
evaluate_btn.click( |
|
|
fn=evaluate_language, |
|
|
inputs=[language_input, model_input], |
|
|
outputs=[summary_table, summary_plot, detailed_results] |
|
|
) |
|
|
|
|
|
example_queries.change( |
|
|
fn=lambda x: x, |
|
|
inputs=[example_queries], |
|
|
outputs=[query_input] |
|
|
) |
|
|
|
|
|
query_button.click( |
|
|
fn=evaluator.db_manager.query, |
|
|
inputs=[query_input], |
|
|
outputs=[query_output] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
|
|
|
required_env_vars = ['DEEPSEEK_API_KEY', 'HF_TOKEN'] |
|
|
for var in required_env_vars: |
|
|
if not os.getenv(var): |
|
|
raise EnvironmentError(f"Missing required environment variable: {var}") |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
demo.launch(share=True) |
|
|
except Exception as e: |
|
|
logger.error(f"Application failed to start: {str(e)}") |
|
|
raise |