LipSyncAI / app.py
AlserFurma's picture
Update app.py
af30315 verified
raw
history blame
6.15 kB
import gradio as gr
import os
from PIL import Image
import tempfile
from gradio_client import Client, handle_file
import torch
from transformers import VitsModel, AutoTokenizer
import scipy.io.wavfile as wavfile
import traceback
# Загрузка моделей при старте
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
# TTS модель для казахского (исправлено с rus на kaz)
tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
print("TTS модель (kaz) загружена успешно!")
except Exception as e:
raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
# Space для talking-head
TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
def inference(image: Image.Image, text: str):
error_msg = ""
video_path = None
audio_path = None
img_path = None
try:
# Валидация входных данных
if image is None:
raise ValueError("Загрузите изображение лектора!")
if not text or not text.strip():
raise ValueError("Введите текст лекции!")
if len(text) > 500:
raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
print(f"Генерация TTS для текста: '{text[:50]}...'")
# Шаг 1: Генерация аудио через TTS
torch.manual_seed(42)
inputs = tts_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
output = tts_model(**inputs)
waveform = output.waveform.squeeze().cpu().numpy()
if waveform.size == 0:
raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
# Конвертация в int16 для WAV
audio = (waveform * 32767).astype("int16")
sampling_rate = tts_model.config.sampling_rate
# Сохранение аудио
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
wavfile.write(audio_file.name, sampling_rate, audio)
audio_path = audio_file.name
print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
# Шаг 2: Сохранение изображения
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
# Конвертация в RGB если нужно
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(img_file.name, format='PNG')
img_path = img_file.name
print(f"Изображение сохранено: {img_path}")
# Шаг 3: Вызов talking-head API
print(f"Подключение к {TALKING_HEAD_SPACE}...")
client = Client(TALKING_HEAD_SPACE)
# Проверяем доступные API endpoints
print("Доступные API методы:", client.view_api())
# Вызов API с правильными параметрами
result = client.predict(
image_path=handle_file(img_path),
audio_path=handle_file(audio_path),
guidance_scale=3.0,
steps=10,
api_name="/process_image_audio"
)
print(f"Результат API: {type(result)}")
# Обработка результата
if isinstance(result, tuple) and len(result) > 0:
video_data = result[0]
if isinstance(video_data, dict) and 'video' in video_data:
video_path = video_data['video']
elif isinstance(video_data, dict) and 'path' in video_data:
video_path = video_data['path']
elif isinstance(video_data, str):
video_path = video_data
else:
video_path = video_data
else:
video_path = result
print(f"Видео сгенерировано: {video_path}")
error_msg = "✅ Видео успешно сгенерировано!"
except Exception as e:
error_msg = f"❌ Ошибка: {str(e)}"
print(f"ОШИБКА: {error_msg}")
traceback.print_exc()
finally:
# Очистка временных файлов
if audio_path and os.path.exists(audio_path):
try:
os.remove(audio_path)
print(f"Удален временный файл: {audio_path}")
except:
pass
if img_path and os.path.exists(img_path):
try:
os.remove(img_path)
print(f"Удален временный файл: {img_path}")
except:
pass
return video_path, error_msg
# Интерфейс Gradio
title = "🎓 Видео-лектор с TTS (Русский)"
description = """Загрузите фото лектора и введите текст лекции. Система сгенерирует видео, где лектор "произносит" ваш текст!
**Требования:**
- Фото: фронтальное изображение лица
- Текст: до 500 символов на русском языке"""
iface = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", label="📸 Фото лектора"),
gr.Textbox(
lines=5,
placeholder="Введите текст лекции на русском языке (до 500 символов)...",
label="📝 Текст лекции"
)
],
outputs=[
gr.Video(label="🎬 Готовое видео"),
gr.Textbox(label="ℹ️ Статус", interactive=False)
],
title=title,
description=description,
cache_examples=False
)
if __name__ == "__main__":
iface.launch()