Update demos/musicgen_colab.py
Browse files- demos/musicgen_colab.py +8 -22
demos/musicgen_colab.py
CHANGED
|
@@ -200,11 +200,11 @@ class Predictor:
|
|
| 200 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
| 201 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
| 202 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
| 203 |
-
return output, outputs_diffusion #Return the task id.
|
| 204 |
else:
|
| 205 |
-
return output, None
|
| 206 |
except Exception as e:
|
| 207 |
-
return task_id, e
|
| 208 |
else:
|
| 209 |
# Use the multiprocessing queue (multi-process mode)
|
| 210 |
self.current_task_id += 1
|
|
@@ -238,9 +238,7 @@ _default_model_name = "facebook/musicgen-melody"
|
|
| 238 |
def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
| 239 |
# Initialize Predictor *INSIDE* the function
|
| 240 |
predictor = Predictor(model, depth)
|
| 241 |
-
|
| 242 |
-
# Call predict() - this will return either (wav, diffusion_wav) or a task_id
|
| 243 |
-
prediction_result = predictor.predict(
|
| 244 |
text=text,
|
| 245 |
melody=melody,
|
| 246 |
duration=duration,
|
|
@@ -250,16 +248,7 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
| 250 |
temperature=temperature,
|
| 251 |
cfg_coef=cfg_coef,
|
| 252 |
)
|
| 253 |
-
|
| 254 |
-
# Handle daemon and non-daemon cases
|
| 255 |
-
if predictor.is_daemon:
|
| 256 |
-
wav, diffusion_wav = prediction_result # Direct unpacking (daemon mode)
|
| 257 |
-
else:
|
| 258 |
-
# Get the result using the task_id (multi-process mode)
|
| 259 |
-
task_id = prediction_result
|
| 260 |
-
wav, diffusion_wav = predictor.get_result(task_id)
|
| 261 |
-
|
| 262 |
-
# Save and return audio files (rest of the function remains the same)
|
| 263 |
wav_paths = []
|
| 264 |
video_paths = []
|
| 265 |
# Save standard output
|
|
@@ -274,7 +263,6 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
| 274 |
video_paths.append(video_path)
|
| 275 |
file_cleaner.add(file.name)
|
| 276 |
file_cleaner.add(video_path)
|
| 277 |
-
|
| 278 |
# Save MBD output if used
|
| 279 |
if diffusion_wav is not None:
|
| 280 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
|
@@ -288,15 +276,13 @@ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk
|
|
| 288 |
video_paths.append(video_path)
|
| 289 |
file_cleaner.add(file.name)
|
| 290 |
file_cleaner.add(video_path)
|
| 291 |
-
|
| 292 |
# Shutdown predictor to prevent hanging processes!
|
| 293 |
-
if not predictor.is_daemon:
|
| 294 |
predictor.shutdown()
|
| 295 |
-
|
| 296 |
if use_mbd:
|
| 297 |
-
|
| 298 |
return video_paths[0], wav_paths[0], None, None
|
| 299 |
-
|
| 300 |
def toggle_audio_src(choice):
|
| 301 |
if choice == "mic":
|
| 302 |
return gr.update(sources="microphone", value=None, label="Microphone")
|
|
|
|
| 200 |
assert outputs_diffusion.shape[1] == 1 # output is mono
|
| 201 |
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
| 202 |
outputs_diffusion = outputs_diffusion.detach().cpu()
|
| 203 |
+
return task_id, (output, outputs_diffusion) #Return the task id.
|
| 204 |
else:
|
| 205 |
+
return task_id, (output, None)
|
| 206 |
except Exception as e:
|
| 207 |
+
return task_id, e
|
| 208 |
else:
|
| 209 |
# Use the multiprocessing queue (multi-process mode)
|
| 210 |
self.current_task_id += 1
|
|
|
|
| 238 |
def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
| 239 |
# Initialize Predictor *INSIDE* the function
|
| 240 |
predictor = Predictor(model, depth)
|
| 241 |
+
task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
|
|
|
|
|
|
|
| 242 |
text=text,
|
| 243 |
melody=melody,
|
| 244 |
duration=duration,
|
|
|
|
| 248 |
temperature=temperature,
|
| 249 |
cfg_coef=cfg_coef,
|
| 250 |
)
|
| 251 |
+
# Save and return audio files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
wav_paths = []
|
| 253 |
video_paths = []
|
| 254 |
# Save standard output
|
|
|
|
| 263 |
video_paths.append(video_path)
|
| 264 |
file_cleaner.add(file.name)
|
| 265 |
file_cleaner.add(video_path)
|
|
|
|
| 266 |
# Save MBD output if used
|
| 267 |
if diffusion_wav is not None:
|
| 268 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
|
|
|
| 276 |
video_paths.append(video_path)
|
| 277 |
file_cleaner.add(file.name)
|
| 278 |
file_cleaner.add(video_path)
|
|
|
|
| 279 |
# Shutdown predictor to prevent hanging processes!
|
| 280 |
+
if not predictor.is_daemon: # Important!
|
| 281 |
predictor.shutdown()
|
|
|
|
| 282 |
if use_mbd:
|
| 283 |
+
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
| 284 |
return video_paths[0], wav_paths[0], None, None
|
| 285 |
+
|
| 286 |
def toggle_audio_src(choice):
|
| 287 |
if choice == "mic":
|
| 288 |
return gr.update(sources="microphone", value=None, label="Microphone")
|