import os import json import torch import torchaudio import requests from fastapi import FastAPI, UploadFile, File from fastapi.responses import FileResponse from transformers import ( Wav2Vec2Processor, Wav2Vec2ForCTC, AutoFeatureExtractor, AutoModelForAudioClassification ) from starlette.middleware.cors import CORSMiddleware DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("Device:", DEVICE) # Load config with open("config.json") as f: config = json.load(f) ELEVEN_API_KEY = config["eleven_api_key"] VOICE_ID = config["eleven_voice_id"] LLM_URL = config["llm_url"] def load_audio(audio_path, target_sr=16000): wav, sr = torchaudio.load(audio_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) return wav.squeeze().numpy(), target_sr # STT MODEL print("Loading STT model...") stt_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all") stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all").to(DEVICE) stt_model.eval() print("STT loaded") def transcribe(audio_path): wav, sr = load_audio(audio_path) inputs = stt_processor(wav, sampling_rate=sr, return_tensors="pt", padding=True) with torch.no_grad(): logits = stt_model(inputs.input_values.to(DEVICE)).logits ids = torch.argmax(logits, dim=-1) return stt_processor.batch_decode(ids)[0].strip() # EMOTION MODEL # print("Loading Emotion model...") emotion_extractor = AutoFeatureExtractor.from_pretrained("superb/hubert-base-superb-er") emotion_model = AutoModelForAudioClassification.from_pretrained( "superb/hubert-base-superb-er" ).to(DEVICE) emotion_model.eval() print("Emotion model loaded") def get_emotion(audio_path): wav, sr = load_audio(audio_path) feats = emotion_extractor(wav, sampling_rate=sr, return_tensors="pt") with torch.no_grad(): out = emotion_model(feats["input_values"].to(DEVICE)) pred = torch.argmax(out.logits, dim=-1).item() return emotion_model.config.id2label[pred] # LLM CALL def ask_llm(text): payload = {"query": text} r = requests.post(LLM_URL, json=payload, timeout=200) try: return r.json()["answer"] except: return str(r.json()) # TTS def tts_eleven(text, out_file="response.mp3"): url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}" headers = { "xi-api-key": ELEVEN_API_KEY, "Content-Type": "application/json", } payload = {"text": text, "model_id": "eleven_multilingual_v2"} resp = requests.post(url, json=payload, headers=headers) if resp.status_code != 200: raise Exception(f"ElevenLabs API Error: {resp.text}") with open(out_file, "wb") as f: f.write(resp.content) return out_file # FASTAPI APP app = FastAPI(title="Voice AI API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/process-audio/") async def process_audio(file: UploadFile = File(...)): audio_path = f"temp_{file.filename}" with open(audio_path, "wb") as f: f.write(await file.read()) transcript = transcribe(audio_path) emotion = get_emotion(audio_path) llm_response = ask_llm(transcript) tts_file = tts_eleven(llm_response) return FileResponse(tts_file, media_type="audio/mpeg", filename="response.mp3") @app.get("/") async def root(): return { "message": "Voice AI API is running. Use /process-audio/ to upload audio." }