Spaces:
Running
Running
| import torch | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| import numpy as np | |
| import re | |
| from datetime import datetime | |
| import os | |
| import logging | |
| from typing import Tuple, Dict, Any | |
| import json | |
| # import pyttsx3 # Comentado para no usar TTS en el servidor | |
| class MentalHealthChatbot: | |
| def __init__(self, model_path: str = 'models/bert_emotion_model'): | |
| """ | |
| Inicializa el chatbot con el modelo BERT fine-tuned y configuraciones necesarias. | |
| Args: | |
| model_path: Ruta al modelo fine-tuned | |
| """ | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Configuraci贸n del logging | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.setLevel(logging.INFO) | |
| # Archivo de log en /tmp | |
| handler = logging.FileHandler('/tmp/chatbot.log') | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| handler.setFormatter(formatter) | |
| if not self.logger.handlers: | |
| self.logger.addHandler(handler) | |
| # Rutas de carpetas en /tmp para evitar permisos de solo lectura en /app | |
| self.conversations_dir = "/tmp/conversations" | |
| self.audio_dir = "/tmp/audio" | |
| try: | |
| self.logger.info("Cargando el tokenizador y el modelo BERT fine-tuned...") | |
| # Crear carpeta de conversaciones en /tmp | |
| os.makedirs(self.conversations_dir, exist_ok=True) | |
| self.tokenizer = BertTokenizer.from_pretrained(model_path) | |
| self.model = BertForSequenceClassification.from_pretrained(model_path).to(self.device) | |
| # Cargar respuestas predefinidas | |
| self.load_responses() | |
| # Inicializar el historial de conversaci贸n | |
| self.conversation_history = [] | |
| self.logger.info("Chatbot inicializado correctamente.") | |
| except Exception as e: | |
| self.logger.error(f"Error al cargar el modelo: {str(e)}") | |
| raise e | |
| def load_responses(self): | |
| """Carga las respuestas predefinidas desde un archivo JSON.""" | |
| try: | |
| with open('models/responses.json', 'r', encoding='utf-8') as f: | |
| self.responses = json.load(f) | |
| self.logger.info("Respuestas cargadas desde 'responses.json'.") | |
| except FileNotFoundError: | |
| self.logger.error("Archivo 'responses.json' no encontrado. Aseg煤rate de que el archivo existe en la ruta especificada.") | |
| raise | |
| except json.JSONDecodeError as e: | |
| self.logger.error(f"Error al decodificar 'responses.json': {str(e)}") | |
| raise | |
| def preprocess_text(self, text: str) -> str: | |
| """Preprocesa el texto de entrada.""" | |
| try: | |
| text = text.lower() | |
| text = re.sub(r'[^\w\s]', '', text) | |
| return text.strip() | |
| except Exception as e: | |
| self.logger.error(f"Error al preprocesar el texto: {str(e)}") | |
| return text | |
| def detect_emergency(self, text: str) -> bool: | |
| """Detecta si el mensaje indica una emergencia de salud mental.""" | |
| try: | |
| emergency_keywords = [ | |
| 'suicidar', 'morir', 'muerte', 'matar', 'dolor', | |
| 'ayuda', 'emergencia', 'crisis', 'grave' | |
| ] | |
| return any(keyword in text.lower() for keyword in emergency_keywords) | |
| except Exception as e: | |
| self.logger.error(f"Error al detectar emergencia: {str(e)}") | |
| return False | |
| def get_emotion_prediction(self, text: str) -> Tuple[str, float]: | |
| """Predice la emoci贸n del texto usando el modelo fine-tuned.""" | |
| emotion_labels = [ | |
| 'FELICIDAD', 'NEUTRAL', 'DEPRESI脫N', 'ANSIEDAD', 'ESTR脡S', | |
| 'EMERGENCIA', 'CONFUSI脫N', 'IRA', 'MIEDO', 'SORPRESA', 'DISGUSTO' | |
| ] | |
| try: | |
| inputs = self.tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=128, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| predicted_class = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][predicted_class].item() | |
| emotion = emotion_labels[predicted_class] | |
| self.logger.info(f"Emoci贸n predicha: {emotion} con confianza {confidence:.2f}") | |
| return emotion, confidence | |
| except Exception as e: | |
| self.logger.error(f"Error en la predicci贸n de emoci贸n: {str(e)}") | |
| return 'CONFUSI脫N', 0.0 | |
| def generate_response(self, user_input: str) -> Dict[str, Any]: | |
| """Genera una respuesta basada en el input del usuario, sin generar audio en el servidor.""" | |
| try: | |
| processed_text = self.preprocess_text(user_input) | |
| self.logger.info(f"Texto procesado: {processed_text}") | |
| if self.detect_emergency(processed_text): | |
| emotion = 'EMERGENCIA' | |
| confidence = 1.0 | |
| self.logger.info("Emergencia detectada en el mensaje del usuario.") | |
| else: | |
| emotion, confidence = self.get_emotion_prediction(processed_text) | |
| responses = self.responses.get(emotion, self.responses.get('CONFUSI脫N', ["Lo siento, no he entendido tu mensaje."])) | |
| response = np.random.choice(responses) | |
| self.logger.info(f"Respuesta seleccionada: {response}") | |
| # Comentamos la generaci贸n de audio en el servidor | |
| # audio_path = self.generate_audio(response) | |
| self.update_conversation_history(user_input, response, emotion) | |
| self.save_conversation_history() | |
| return { | |
| 'text': response, | |
| # 'audio_path': audio_path, # Comentado | |
| 'emotion': emotion, | |
| 'confidence': confidence, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error al generar respuesta: {str(e)}") | |
| return { | |
| 'text': "Lo siento, ha ocurrido un error. 驴Podr铆as intentarlo de nuevo?", | |
| # 'audio_path': None, | |
| 'emotion': 'ERROR', | |
| 'confidence': 0.0, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| # Comentamos toda la funci贸n generate_audio si no se usa | |
| """ | |
| def generate_audio(self, text: str) -> str: | |
| # Genera el audio en el servidor (COMENTADO). | |
| try: | |
| filename = f"response_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3" | |
| file_path = os.path.join(self.audio_dir, filename) | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| engine = pyttsx3.init() | |
| voices = engine.getProperty('voices') | |
| for voice in voices: | |
| if 'Spanish' in voice.name or 'Espa帽ol' in voice.name: | |
| engine.setProperty('voice', voice.id) | |
| break | |
| else: | |
| self.logger.warning("No se encontr贸 una voz en espa帽ol. Usando la voz predeterminada.") | |
| rate = engine.getProperty('rate') | |
| engine.setProperty('rate', rate - 50) | |
| engine.save_to_file(text, file_path) | |
| engine.runAndWait() | |
| self.logger.info(f"Audio generado y guardado en {file_path}") | |
| return file_path | |
| except Exception as e: | |
| self.logger.error(f"Error al generar audio: {str(e)}") | |
| return None | |
| """ | |
| def update_conversation_history(self, user_input: str, response: str, emotion: str): | |
| """Actualiza el historial de conversaci贸n en memoria.""" | |
| try: | |
| self.conversation_history.append({ | |
| 'user_input': user_input, | |
| 'response': response, | |
| 'emotion': emotion, | |
| 'timestamp': datetime.now().isoformat() | |
| }) | |
| if len(self.conversation_history) > 10: | |
| self.conversation_history.pop(0) | |
| self.logger.info("Historial de conversaci贸n actualizado.") | |
| except Exception as e: | |
| self.logger.error(f"Error al actualizar el historial de conversaci贸n: {str(e)}") | |
| def save_conversation_history(self): | |
| """Guarda el historial de conversaci贸n en un archivo dentro de /tmp.""" | |
| try: | |
| filename = f"{self.conversations_dir}/chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| json.dump(self.conversation_history, f, ensure_ascii=False, indent=2) | |
| self.logger.info(f"Historial de conversaci贸n guardado en {filename}") | |
| except Exception as e: | |
| self.logger.error(f"Error al guardar el historial: {str(e)}") |