Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| from gradio_client import Client, handle_file | |
| import torch | |
| from transformers import VitsModel, AutoTokenizer | |
| import scipy.io.wavfile as wavfile | |
| import traceback | |
| # Загрузка моделей при старте | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| try: | |
| # TTS модель для казахского (исправлено с rus на kaz) | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device) | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz") | |
| print("TTS модель (kaz) загружена успешно!") | |
| except Exception as e: | |
| raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}") | |
| # Space для talking-head | |
| TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head" | |
| def inference(image: Image.Image, text: str): | |
| error_msg = "" | |
| video_path = None | |
| audio_path = None | |
| img_path = None | |
| try: | |
| # Валидация входных данных | |
| if image is None: | |
| raise ValueError("Загрузите изображение лектора!") | |
| if not text or not text.strip(): | |
| raise ValueError("Введите текст лекции!") | |
| if len(text) > 500: | |
| raise ValueError("Текст слишком длинный! Используйте до 500 символов.") | |
| print(f"Генерация TTS для текста: '{text[:50]}...'") | |
| # Шаг 1: Генерация аудио через TTS | |
| torch.manual_seed(42) | |
| inputs = tts_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = tts_model(**inputs) | |
| waveform = output.waveform.squeeze().cpu().numpy() | |
| if waveform.size == 0: | |
| raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.") | |
| # Конвертация в int16 для WAV | |
| audio = (waveform * 32767).astype("int16") | |
| sampling_rate = tts_model.config.sampling_rate | |
| # Сохранение аудио | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: | |
| wavfile.write(audio_file.name, sampling_rate, audio) | |
| audio_path = audio_file.name | |
| print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)") | |
| # Шаг 2: Сохранение изображения | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file: | |
| # Конвертация в RGB если нужно | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image.save(img_file.name, format='PNG') | |
| img_path = img_file.name | |
| print(f"Изображение сохранено: {img_path}") | |
| # Шаг 3: Вызов talking-head API | |
| print(f"Подключение к {TALKING_HEAD_SPACE}...") | |
| client = Client(TALKING_HEAD_SPACE) | |
| # Проверяем доступные API endpoints | |
| print("Доступные API методы:", client.view_api()) | |
| # Вызов API с правильными параметрами | |
| result = client.predict( | |
| image_path=handle_file(img_path), | |
| audio_path=handle_file(audio_path), | |
| guidance_scale=3.0, | |
| steps=10, | |
| api_name="/process_image_audio" | |
| ) | |
| print(f"Результат API: {type(result)}") | |
| # Обработка результата | |
| if isinstance(result, tuple) and len(result) > 0: | |
| video_data = result[0] | |
| if isinstance(video_data, dict) and 'video' in video_data: | |
| video_path = video_data['video'] | |
| elif isinstance(video_data, dict) and 'path' in video_data: | |
| video_path = video_data['path'] | |
| elif isinstance(video_data, str): | |
| video_path = video_data | |
| else: | |
| video_path = video_data | |
| else: | |
| video_path = result | |
| print(f"Видео сгенерировано: {video_path}") | |
| error_msg = "✅ Видео успешно сгенерировано!" | |
| except Exception as e: | |
| error_msg = f"❌ Ошибка: {str(e)}" | |
| print(f"ОШИБКА: {error_msg}") | |
| traceback.print_exc() | |
| finally: | |
| # Очистка временных файлов | |
| if audio_path and os.path.exists(audio_path): | |
| try: | |
| os.remove(audio_path) | |
| print(f"Удален временный файл: {audio_path}") | |
| except: | |
| pass | |
| if img_path and os.path.exists(img_path): | |
| try: | |
| os.remove(img_path) | |
| print(f"Удален временный файл: {img_path}") | |
| except: | |
| pass | |
| return video_path, error_msg | |
| # Интерфейс Gradio | |
| title = "🎓 Видео-лектор с TTS (Русский)" | |
| description = """Загрузите фото лектора и введите текст лекции. Система сгенерирует видео, где лектор "произносит" ваш текст! | |
| **Требования:** | |
| - Фото: фронтальное изображение лица | |
| - Текст: до 500 символов на русском языке""" | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Image(type="pil", label="📸 Фото лектора"), | |
| gr.Textbox( | |
| lines=5, | |
| placeholder="Введите текст лекции на русском языке (до 500 символов)...", | |
| label="📝 Текст лекции" | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Video(label="🎬 Готовое видео"), | |
| gr.Textbox(label="ℹ️ Статус", interactive=False) | |
| ], | |
| title=title, | |
| description=description, | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |