AlserFurma commited on
Commit
80ff644
·
verified ·
1 Parent(s): 534e83d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -76
app.py CHANGED
@@ -4,69 +4,109 @@ from PIL import Image
4
  import tempfile
5
  from gradio_client import Client, handle_file
6
  import torch
7
- from transformers import VitsModel, AutoTokenizer
8
  import scipy.io.wavfile as wavfile
9
  import traceback
10
 
11
- # Загрузка моделей при старте
 
 
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Using device: {device}")
14
 
15
  try:
16
- # TTS модель для казахского (исправлено с rus на kaz)
17
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
18
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
19
- print("TTS модель (kaz) загружена успешно!")
 
 
 
 
 
 
 
 
 
20
  except Exception as e:
21
- raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
 
 
 
 
 
22
 
23
- # Space для talking-head
24
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
25
 
 
 
 
 
 
26
  def inference(image: Image.Image, text: str):
 
27
  error_msg = ""
28
  video_path = None
29
  audio_path = None
30
  img_path = None
 
31
  try:
32
- # Валидация входных данных
33
  if image is None:
34
  raise ValueError("Загрузите изображение лектора!")
 
35
  if not text or not text.strip():
36
  raise ValueError("Введите текст лекции!")
 
37
  if len(text) > 500:
38
- raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
39
- print(f"Генерация TTS для текста: '{text[:50]}...'")
40
- # Шаг 1: Генерация аудио через TTS
41
- torch.manual_seed(42)
42
- inputs = tts_tokenizer(text, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  with torch.no_grad():
44
  output = tts_model(**inputs)
45
- waveform = output.waveform.squeeze().cpu().numpy()
46
- if waveform.size == 0:
47
- raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
48
- # Конвертация в int16 для WAV
49
  audio = (waveform * 32767).astype("int16")
50
  sampling_rate = tts_model.config.sampling_rate
51
- # Сохранение аудио
52
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
53
- wavfile.write(audio_file.name, sampling_rate, audio)
54
- audio_path = audio_file.name
55
- print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
56
- # Шаг 2: Сохранение изображения
57
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
58
- # Конвертация в RGB если нужно
59
- if image.mode != 'RGB':
60
- image = image.convert('RGB')
61
- image.save(img_file.name, format='PNG')
62
- img_path = img_file.name
63
- print(f"Изображение сохранено: {img_path}")
64
- # Шаг 3: Вызов talking-head API
65
- print(f"Подключение к {TALKING_HEAD_SPACE}...")
 
 
66
  client = Client(TALKING_HEAD_SPACE)
67
- # Проверяем доступные API endpoints
68
- print("Доступные API методы:", client.view_api())
69
- # Вызов API с правильными параметрами
70
  result = client.predict(
71
  image_path=handle_file(img_path),
72
  audio_path=handle_file(audio_path),
@@ -74,67 +114,63 @@ def inference(image: Image.Image, text: str):
74
  steps=10,
75
  api_name="/process_image_audio"
76
  )
77
- print(f"Результат API: {type(result)}")
78
- # Обработка результата
79
- if isinstance(result, tuple) and len(result) > 0:
80
- video_data = result[0]
81
- if isinstance(video_data, dict) and 'video' in video_data:
82
- video_path = video_data['video']
83
- elif isinstance(video_data, dict) and 'path' in video_data:
84
- video_path = video_data['path']
85
- elif isinstance(video_data, str):
86
- video_path = video_data
87
- else:
88
- video_path = video_data
89
  else:
90
- video_path = result
91
- print(f"Видео сгенерировано: {video_path}")
92
- error_msg = "✅ Видео успешно сгенерировано!"
 
93
  except Exception as e:
94
  error_msg = f"❌ Ошибка: {str(e)}"
95
- print(f"ОШИБКА: {error_msg}")
96
  traceback.print_exc()
 
97
  finally:
98
- # Очистка временных файлов
99
- if audio_path and os.path.exists(audio_path):
100
- try:
101
- os.remove(audio_path)
102
- print(f"Удален временный файл: {audio_path}")
103
- except:
104
- pass
105
- if img_path and os.path.exists(img_path):
106
- try:
107
- os.remove(img_path)
108
- print(f"Удален временный файл: {img_path}")
109
- except:
110
- pass
111
  return video_path, error_msg
112
 
113
- # Интерфейс Gradio
114
- title = "🎓 Видео-лектор с TTS (Русский)"
115
- description = """Загрузите фото лектора и введите текст лекции. Система сгенерирует видео, где лектор "произносит" ваш текст!
116
- **Требования:**
117
- - Фото: фронтальное изображение лица
118
- - Текст: до 500 символов на русском языке"""
 
 
 
 
 
 
 
 
 
119
 
120
  iface = gr.Interface(
121
  fn=inference,
122
  inputs=[
123
- gr.Image(type="pil", label="📸 Фото лектора"),
124
  gr.Textbox(
125
  lines=5,
126
- placeholder="Введите текст лекции на русском языке (до 500 символов)...",
127
- label="📝 Текст лекции"
128
  )
129
  ],
130
  outputs=[
131
- gr.Video(label="🎬 Готовое видео"),
132
- gr.Textbox(label="ℹ️ Статус", interactive=False)
133
  ],
134
  title=title,
135
  description=description,
136
- cache_examples=False
 
137
  )
138
 
139
  if __name__ == "__main__":
140
- iface.launch()
 
4
  import tempfile
5
  from gradio_client import Client, handle_file
6
  import torch
7
+ from transformers import VitsModel, AutoTokenizer, pipeline
8
  import scipy.io.wavfile as wavfile
9
  import traceback
10
 
11
+
12
+ # =========================
13
+ # Загрузка моделей
14
+ # =========================
15
+
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"Using device: {device}")
18
 
19
  try:
20
+ # TTS модель казахского языка
21
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
22
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
23
+
24
+ # Модель перевода ru -> kk
25
+ translator = pipeline(
26
+ "translation",
27
+ model="facebook/nllb-200-distilled-600M",
28
+ device=0 if device == "cuda" else -1
29
+ )
30
+
31
+ print("✅ Все модели успешно загружены!")
32
+
33
  except Exception as e:
34
+ raise RuntimeError(f"Ошибка загрузки моделей: {str(e)}")
35
+
36
+
37
+ # =========================
38
+ # Talking Head Space
39
+ # =========================
40
 
 
41
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
42
 
43
+
44
+ # =========================
45
+ # Основная функция
46
+ # =========================
47
+
48
  def inference(image: Image.Image, text: str):
49
+
50
  error_msg = ""
51
  video_path = None
52
  audio_path = None
53
  img_path = None
54
+
55
  try:
56
+ # Проверки
57
  if image is None:
58
  raise ValueError("Загрузите изображение лектора!")
59
+
60
  if not text or not text.strip():
61
  raise ValueError("Введите текст лекции!")
62
+
63
  if len(text) > 500:
64
+ raise ValueError("Текст превышает 500 символов!")
65
+
66
+ print("Ввод (RU):", text)
67
+
68
+ # =========================
69
+ # Шаг 1 — Перевод
70
+ # =========================
71
+ translation = translator(
72
+ text,
73
+ src_lang="rus_Cyrl",
74
+ tgt_lang="kaz_Cyrl"
75
+ )
76
+
77
+ translated_text = translation[0]["translation_text"]
78
+ print("Перевод (KK):", translated_text)
79
+
80
+ # =========================
81
+ # Шаг 2 — Озвучка
82
+ # =========================
83
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
84
+
85
  with torch.no_grad():
86
  output = tts_model(**inputs)
87
+
88
+ waveform = output.waveform.squeeze().cpu().numpy()
 
 
89
  audio = (waveform * 32767).astype("int16")
90
  sampling_rate = tts_model.config.sampling_rate
91
+
92
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
93
+ wavfile.write(f.name, sampling_rate, audio)
94
+ audio_path = f.name
95
+
96
+ # =========================
97
+ # Шаг 3 Сохранение изображения
98
+ # =========================
99
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
100
+ if image.mode != "RGB":
101
+ image = image.convert("RGB")
102
+ image.save(f.name)
103
+ img_path = f.name
104
+
105
+ # =========================
106
+ # Шаг 4 — Генерация видео
107
+ # =========================
108
  client = Client(TALKING_HEAD_SPACE)
109
+
 
 
110
  result = client.predict(
111
  image_path=handle_file(img_path),
112
  audio_path=handle_file(audio_path),
 
114
  steps=10,
115
  api_name="/process_image_audio"
116
  )
117
+
118
+ if isinstance(result, tuple):
119
+ video_path = result[0]
 
 
 
 
 
 
 
 
 
120
  else:
121
+ raise ValueError("Видео не получено!")
122
+
123
+ error_msg = "✅ Видео успешно создано!"
124
+
125
  except Exception as e:
126
  error_msg = f"❌ Ошибка: {str(e)}"
 
127
  traceback.print_exc()
128
+
129
  finally:
130
+ for p in [audio_path, img_path]:
131
+ if p and os.path.exists(p):
132
+ try:
133
+ os.remove(p)
134
+ except:
135
+ pass
136
+
 
 
 
 
 
 
137
  return video_path, error_msg
138
 
139
+
140
+ # =========================
141
+ # Gradio Интерфейс
142
+ # =========================
143
+
144
+ title = "Бейне Оқытушы"
145
+
146
+ description = """
147
+ Суретіңізді жүктеп, дәріс мәтінін орыс тілінде енгізіңіз.
148
+ Жүйе автоматты түрде қазақ тіліне аударады және бейне жасайды!
149
+
150
+ **Талаптар:**
151
+ - Фото: бет анық көрінетін
152
+ - Мәтін: орыс тілінде (500 таңбаға дейін)
153
+ """
154
 
155
  iface = gr.Interface(
156
  fn=inference,
157
  inputs=[
158
+ gr.Image(type="pil", label="📸 Фото дәріскер"),
159
  gr.Textbox(
160
  lines=5,
161
+ label="📝 Дәріс мәтіні (орыс тілінде)",
162
+ placeholder="500 таңбаға дейін..."
163
  )
164
  ],
165
  outputs=[
166
+ gr.Video(label="🎬 Дайын бейне"),
167
+ gr.Textbox(label="ℹ️ Мәртебе")
168
  ],
169
  title=title,
170
  description=description,
171
+ cache_examples=False,
172
+ flagging_mode="never"
173
  )
174
 
175
  if __name__ == "__main__":
176
+ iface.launch()