qwen-image-lora-dlc-v3 / optimization.py
Lisandro's picture
feat: Enhance optimize_pipeline_ function for improved dynamic shape handling and stability in ZeroGPU
d347d91
# optimization.py
from typing import Any, Callable, ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map
P = ParamSpec("P")
TEXT_SEQ_LENGTH = 12
IMAGE_SEQ_LENGTH = 4096
INDUCTOR_CONFIGS = {
"conv_1x1_as_mm": True,
"epilogue_fusion": False,
"coordinate_descent_tuning": True,
"coordinate_descent_check_all_directions": True,
"max_autotune": True,
"triton.cudagraphs": True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
"""
Versión estable y comprobada para tu Space con Qwen-Image.
Corrige completamente la estructura de img_shapes (lista de listas)
y evita todos los UserError de dynamic_shapes.
"""
if not torch.cuda.is_available():
print("⚠️ CUDA no disponible. Se omite AOT.")
return pipeline
try:
@spaces.GPU(duration=1500)
def compile_transformer():
print("🏗️ Capturando modelo para AOT...")
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
# Construimos dynamic_shapes desde los kwargs originales
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
# Definimos shapes estáticos fijos y seguros
static_shapes = {
"hidden_states": {1: IMAGE_SEQ_LENGTH},
"encoder_hidden_states": {1: TEXT_SEQ_LENGTH},
"encoder_hidden_states_mask": {1: TEXT_SEQ_LENGTH},
# 👇 clave final: lista de listas
"img_shapes": [[None, None]],
}
# Aplicamos solo las claves válidas
for k, v in static_shapes.items():
if k in call.kwargs:
dynamic_shapes[k] = v
print("🚀 Exportando modelo con torch.export...")
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
print("⚙️ Compilando con AOTInductor...")
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
print("🧠 Aplicando AOT al transformer...")
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
print("✅ AOT aplicado correctamente al transformer de Qwen-Image.")
except Exception as e:
print(f"⚠️ Error al aplicar AOT: {e}")
return pipeline