Ellie5757575757's picture
Update app.py
89bed5d verified
#!/usr/bin/env python3
"""
Lightweight Aphasia Classification App
Optimized for Hugging Face Spaces with lazy loading and fallbacks
"""
import os
# Configure environment for CPU-only and memory optimization
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU-only
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['OMP_NUM_THREADS'] = '2' # Limit CPU threads
os.environ['MKL_NUM_THREADS'] = '2'
os.environ['NUMEXPR_NUM_THREADS'] = '2'
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Avoid tokenizer warnings
# Batchalign specific settings
os.environ['BATCHALIGN_CACHE'] = '/tmp/batchalign_cache'
os.environ['HF_HUB_CACHE'] = '/tmp/hf_cache' # Use tmp for model cache
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
# Whisper settings for CPU optimization
os.environ['WHISPER_CACHE'] = '/tmp/whisper_cache'
print("πŸ”§ Environment configured for CPU-only processing")
print("πŸ’Ύ Model caches set to /tmp/ to save space")
from flask import Flask, request, render_template_string, jsonify
import os
import tempfile
import logging
import json
import threading
import time
from pathlib import Path
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max (reduced)
print("πŸš€ Starting Lightweight Aphasia Classification System")
# Global state
MODULES = {}
MODELS_LOADED = False
LOADING_STATUS = "Starting up..."
def lazy_import_modules():
"""Import modules only when needed"""
global MODULES, MODELS_LOADED, LOADING_STATUS
if MODELS_LOADED:
return True
try:
LOADING_STATUS = "Loading audio processing..."
logger.info("Importing utils_audio...")
from utils_audio import convert_to_wav
MODULES['convert_to_wav'] = convert_to_wav
logger.info("βœ“ Audio processing loaded")
LOADING_STATUS = "Loading speech analysis..."
logger.info("Importing to_cha...")
from to_cha import to_cha_from_wav
MODULES['to_cha_from_wav'] = to_cha_from_wav
logger.info("βœ“ Speech analysis loaded")
LOADING_STATUS = "Loading data conversion..."
logger.info("Importing cha_json...")
from cha_json import cha_to_json_file
MODULES['cha_to_json_file'] = cha_to_json_file
logger.info("βœ“ Data conversion loaded")
LOADING_STATUS = "Loading AI model..."
logger.info("Importing output...")
from output import predict_from_chajson
MODULES['predict_from_chajson'] = predict_from_chajson
logger.info("βœ“ AI model loaded")
MODELS_LOADED = True
LOADING_STATUS = "Ready!"
logger.info("πŸŽ‰ All modules loaded successfully!")
return True
except Exception as e:
logger.error(f"Failed to load modules: {e}")
LOADING_STATUS = f"Error: {str(e)}"
return False
def background_loader():
"""Load modules in background thread"""
logger.info("Starting background module loading...")
lazy_import_modules()
# Start loading modules in background
loading_thread = threading.Thread(target=background_loader, daemon=True)
loading_thread.start()
# HTML Template (simplified)
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>🧠 Aphasia Classification</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
margin: 0;
}
.container {
max-width: 800px;
margin: 0 auto;
background: white;
border-radius: 20px;
box-shadow: 0 20px 60px rgba(0,0,0,0.1);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 40px 30px;
text-align: center;
}
.content {
padding: 40px 30px;
}
.status {
background: #f8f9fa;
border-radius: 10px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #28a745;
}
.status.loading {
border-left-color: #ffc107;
}
.status.error {
border-left-color: #dc3545;
}
.upload-section {
background: #f8f9fa;
border-radius: 15px;
padding: 30px;
text-align: center;
margin-bottom: 30px;
}
.file-input {
display: none;
}
.file-label {
display: inline-block;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 30px;
border-radius: 50px;
cursor: pointer;
font-weight: 600;
transition: transform 0.2s ease;
}
.file-label:hover {
transform: translateY(-2px);
}
.analyze-btn {
background: #28a745;
color: white;
border: none;
padding: 15px 40px;
border-radius: 50px;
font-weight: 600;
cursor: pointer;
margin-top: 20px;
transition: all 0.2s ease;
}
.analyze-btn:disabled {
background: #6c757d;
cursor: not-allowed;
}
.results {
background: #f8f9fa;
border-radius: 15px;
padding: 30px;
margin-top: 30px;
display: none;
white-space: pre-wrap;
font-family: monospace;
}
.loading {
text-align: center;
padding: 40px;
display: none;
}
.spinner {
border: 4px solid #f3f3f3;
border-top: 4px solid #667eea;
border-radius: 50%;
width: 50px;
height: 50px;
animation: spin 1s linear infinite;
margin: 0 auto 20px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.refresh-btn {
background: #17a2b8;
color: white;
border: none;
padding: 10px 20px;
border-radius: 25px;
cursor: pointer;
margin-left: 10px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🧠 Aphasia Classification</h1>
<p>AI-powered speech analysis for aphasia identification</p>
</div>
<div class="content">
<div class="status" id="statusBox">
<h3 id="statusTitle">πŸ”„ System Status</h3>
<div id="statusText">{{ status_message }}</div>
<button class="refresh-btn" onclick="checkStatus()">Refresh Status</button>
</div>
<div class="upload-section">
<h3>πŸ“ Upload Audio File</h3>
<p>Upload speech audio for aphasia classification</p>
<form id="uploadForm" enctype="multipart/form-data">
<input type="file" id="audioFile" name="audio" class="file-input" accept="audio/*" required>
<label for="audioFile" class="file-label">
🎡 Choose Audio File
</label>
<br>
<button type="submit" class="analyze-btn" id="analyzeBtn">
πŸ” Analyze Speech
</button>
</form>
<p style="color: #666; margin-top: 15px; font-size: 0.9em;">
Supported: MP3, WAV, M4A (max 50MB)
</p>
</div>
<div class="loading" id="loading">
<div class="spinner"></div>
<h3>πŸ”„ Processing Audio...</h3>
<p>This may take 2-5 minutes. Please be patient.</p>
</div>
<div class="results" id="results"></div>
</div>
</div>
<script>
// Check status periodically
function checkStatus() {
fetch('/status')
.then(response => response.json())
.then(data => {
const statusBox = document.getElementById('statusBox');
const statusTitle = document.getElementById('statusTitle');
const statusText = document.getElementById('statusText');
if (data.ready) {
statusBox.className = 'status';
statusTitle.textContent = '🟒 System Ready';
statusText.textContent = 'All components loaded. Ready to process audio files.';
} else {
statusBox.className = 'status loading';
statusTitle.textContent = '🟑 Loading...';
statusText.textContent = data.status;
}
})
.catch(error => {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status error';
document.getElementById('statusTitle').textContent = 'πŸ”΄ Error';
document.getElementById('statusText').textContent = 'Failed to check status';
});
}
// Check status every 5 seconds
setInterval(checkStatus, 5000);
// Form submission
document.getElementById('uploadForm').addEventListener('submit', async function(e) {
e.preventDefault();
const fileInput = document.getElementById('audioFile');
const loading = document.getElementById('loading');
const results = document.getElementById('results');
const analyzeBtn = document.getElementById('analyzeBtn');
if (!fileInput.files[0]) {
alert('Please select an audio file');
return;
}
// Check if system is ready
const statusCheck = await fetch('/status');
const status = await statusCheck.json();
if (!status.ready) {
alert('System is still loading. Please wait and try again.');
return;
}
// Show loading
loading.style.display = 'block';
results.style.display = 'none';
analyzeBtn.disabled = true;
analyzeBtn.textContent = 'Processing...';
try {
const formData = new FormData();
formData.append('audio', fileInput.files[0]);
const response = await fetch('/analyze', {
method: 'POST',
body: formData
});
const data = await response.json();
loading.style.display = 'none';
if (data.success) {
results.textContent = data.result;
results.style.borderLeft = '4px solid #28a745';
} else {
results.textContent = 'Error: ' + data.error;
results.style.borderLeft = '4px solid #dc3545';
}
results.style.display = 'block';
} catch (error) {
loading.style.display = 'none';
results.textContent = 'Network error: ' + error.message;
results.style.borderLeft = '4px solid #dc3545';
results.style.display = 'block';
}
analyzeBtn.disabled = false;
analyzeBtn.textContent = 'πŸ” Analyze Speech';
});
// File selection feedback
document.getElementById('audioFile').addEventListener('change', function(e) {
const label = document.querySelector('.file-label');
if (e.target.files[0]) {
label.textContent = 'βœ“ ' + e.target.files[0].name;
} else {
label.textContent = '🎡 Choose Audio File';
}
});
</script>
</body>
</html>
"""
@app.route('/')
def index():
"""Main page"""
return render_template_string(HTML_TEMPLATE, status_message=LOADING_STATUS)
@app.route('/status')
def status():
"""Status check endpoint"""
return jsonify({
'ready': MODELS_LOADED,
'status': LOADING_STATUS,
'modules_loaded': len(MODULES)
})
@app.route('/analyze', methods=['POST'])
def analyze_audio():
"""Process uploaded audio - only if models are loaded"""
try:
# Check if system is ready
if not MODELS_LOADED:
return jsonify({
'success': False,
'error': f'System still loading: {LOADING_STATUS}'
})
# Check file upload
if 'audio' not in request.files:
return jsonify({'success': False, 'error': 'No audio file uploaded'})
audio_file = request.files['audio']
if audio_file.filename == '':
return jsonify({'success': False, 'error': 'No file selected'})
# Save uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[1]) as tmp_file:
audio_file.save(tmp_file.name)
temp_path = tmp_file.name
try:
logger.info("🎡 Starting audio processing...")
# Step 1: Convert to WAV
logger.info("Converting to WAV...")
wav_path = MODULES['convert_to_wav'](temp_path, sr=16000, mono=True)
# Step 2: Generate CHA
logger.info("Generating CHA file...")
cha_path = MODULES['to_cha_from_wav'](wav_path, lang="eng")
# Step 3: Convert to JSON
logger.info("Converting to JSON...")
json_path, _ = MODULES['cha_to_json_file'](cha_path)
# Step 4: Classification
logger.info("Running classification...")
results = MODULES['predict_from_chajson'](".", json_path, output_file=None)
# Cleanup
for temp_file in [temp_path, wav_path, cha_path, json_path]:
try:
os.unlink(temp_file)
except:
pass
# Format results
if "predictions" in results and results["predictions"]:
pred = results["predictions"][0]
classification = pred["prediction"]["predicted_class"]
confidence = pred["prediction"]["confidence_percentage"]
description = pred["class_description"]["name"]
severity = pred["additional_predictions"]["predicted_severity_level"]
fluency = pred["additional_predictions"]["fluency_rating"]
result_text = f"""🧠 APHASIA CLASSIFICATION RESULTS
🎯 Classification: {classification}
πŸ“Š Confidence: {confidence}
πŸ“‹ Type: {description}
πŸ“ˆ Severity: {severity}/3
πŸ—£οΈ Fluency: {fluency}
πŸ“Š Top 3 Probabilities:"""
prob_dist = pred["probability_distribution"]
for i, (atype, info) in enumerate(list(prob_dist.items())[:3], 1):
result_text += f"\n{i}. {atype}: {info['percentage']}"
result_text += f"""
πŸ“ Description:
{pred["class_description"]["description"]}
βœ… Processing completed successfully!
"""
return jsonify({'success': True, 'result': result_text})
else:
return jsonify({'success': False, 'error': 'No predictions generated'})
except Exception as e:
# Cleanup on error
try:
os.unlink(temp_path)
except:
pass
raise e
except Exception as e:
logger.error(f"Processing error: {e}")
return jsonify({'success': False, 'error': str(e)})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
print(f"πŸš€ Starting on port {port}")
print("πŸ”„ Models loading in background...")
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)