Echo / models /echo /echoflow_manager.py
moein99's picture
Initial Echo Space
8f51ef2
import json
import os
import types
import warnings
import re
from pathlib import Path
from urllib.parse import urlparse
# Suppress regex warnings at module level
warnings.filterwarnings("ignore", message="nothing to repeat")
warnings.filterwarnings("ignore", message=".*regex.*")
warnings.filterwarnings("ignore", message=".*nothing to repeat.*")
# Suppress config attribute warnings from diffusers
warnings.filterwarnings("ignore", message=".*config attributes.*were passed to.*but are not expected.*")
warnings.filterwarnings("ignore", message=".*Please verify your config.json configuration file.*")
import cv2
import diffusers
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from safetensors.torch import load_file
from torch.nn import functional as F
from torchdiffeq import odeint_adjoint as odeint
# Add EchoFlow common modules to path (sourced from tool_repos)
import sys
_ROOT = Path(__file__).resolve().parents[2]
_CANDIDATES = [
_ROOT / "tool_repos" / "EchoFlow",
_ROOT / "tool_repos" / "EchoFlow-main",
]
_workspace_root = os.getenv("ECHO_WORKSPACE_ROOT")
if _workspace_root:
_CANDIDATES.append(Path(_workspace_root) / "EchoFlow")
_CANDIDATES.append(Path(_workspace_root) / "tool_repos" / "EchoFlow")
echoflow_path = next((path for path in _CANDIDATES if path.exists()), None)
if echoflow_path is None:
raise RuntimeError("EchoFlow repository not found. Place it under tool_repos/EchoFlow.")
sys.path.insert(0, str(echoflow_path))
try:
from echoflow.common import instantiate_class_from_config, unscale_latents
from echoflow.common.models import (
ContrastiveModel,
DiffuserSTDiT,
ResNet18,
SegDiTTransformer2DModel,
)
except ImportError as e:
print(f"⚠️ EchoFlow common modules not available: {e}")
# Define fallback functions
def instantiate_class_from_config(config, *args, **kwargs):
raise NotImplementedError("EchoFlow common modules not available")
def unscale_latents(latents, vae_scaling=None):
if vae_scaling is not None:
if latents.ndim == 4:
v = (1, -1, 1, 1)
elif latents.ndim == 5:
v = (1, -1, 1, 1, 1)
else:
raise ValueError("Latents should be 4D or 5D")
latents *= vae_scaling["std"].view(*v)
latents += vae_scaling["mean"].view(*v)
return latents
from ..general.base_model_manager import BaseModelManager, ModelStatus
class EchoFlowConfig:
"""Configuration class for EchoFlow."""
def __init__(self):
self.name = "EchoFlow"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float32
class EchoFlowManager(BaseModelManager):
"""Manager for EchoFlow model components."""
def __init__(self, config=None):
super().__init__(config)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float32
# Model components
self.lifm = None
self.vae = None
self.vae_scaler = None
self.lvfm = None
self.reid = None
# Constants from demo.py
self.B, self.T, self.C, self.H, self.W = 1, 64, 4, 28, 28
self.VIEWS = ["A4C", "PSAX", "PLAX"]
# Assets directory
self.assets_dir = Path(__file__).parent.parent.parent / "model_weights" / "EchoFlow" / "assets"
self._initialize_model()
def _initialize_model(self):
"""Initialize the EchoFlow model using local assets."""
try:
print("Initializing EchoFlow model...")
self._load_models()
self._set_status(ModelStatus.READY)
print("✅ EchoFlow model initialized successfully")
except Exception as e:
print(f"⚠️ EchoFlow model loading failed: {e}")
print("EchoFlow initialization failed - continuing without EchoFlow")
self._set_status(ModelStatus.NOT_AVAILABLE)
def _load_models(self):
"""Load all EchoFlow model components from local assets."""
# Suppress warnings for cleaner output
import warnings
import re
warnings.filterwarnings("ignore", category=UserWarning, module="torch.cuda")
warnings.filterwarnings("ignore", message="The config attributes*")
warnings.filterwarnings("ignore", message="*were passed to*but are not expected*")
warnings.filterwarnings("ignore", message="nothing to repeat")
warnings.filterwarnings("ignore", category=re.error)
# Load LIFM (Latent Image Flow Model)
print("Loading LIFM model...")
try:
# Skip LIFM loading for now due to regex issues
print("⚠️ Skipping LIFM model loading due to regex issues")
self.lifm = None
except Exception as e:
print(f"⚠️ LIFM model loading failed: {e}")
self.lifm = None
# Load VAE
print("Loading VAE model...")
try:
# Skip VAE loading for now due to regex issues
print("⚠️ Skipping VAE model loading due to regex issues")
self.vae = None
except Exception as e:
print(f"⚠️ VAE model loading failed: {e}")
self.vae = None
# Load VAE scaler from local assets
print("Loading VAE scaler...")
try:
scaler_path = self.assets_dir / "scaling.pt"
if scaler_path.exists():
self.vae_scaler = self._get_vae_scaler(str(scaler_path))
print("✅ VAE scaler loaded from local assets")
else:
print("⚠️ VAE scaler not found in local assets")
self.vae_scaler = None
except Exception as e:
print(f"⚠️ VAE scaler loading failed: {e}")
self.vae_scaler = None
# Load REID models and anatomies
print("Loading REID models...")
try:
# Skip REID loading for now due to regex issues
print("⚠️ Skipping REID models loading due to regex issues")
self.reid = None
except Exception as e:
print(f"⚠️ REID models loading failed: {e}")
self.reid = None
# Load LVFM (Latent Video Flow Model)
print("Loading LVFM model...")
try:
# Skip LVFM loading for now due to regex issues
print("⚠️ Skipping LVFM model loading due to regex issues")
self.lvfm = None
except Exception as e:
print(f"⚠️ LVFM model loading failed: {e}")
self.lvfm = None
def _load_model(self, path):
"""Load a model from HuggingFace or local path."""
if path.startswith("http"):
parsed_url = urlparse(path)
if "huggingface.co" in parsed_url.netloc:
parts = parsed_url.path.strip("/").split("/")
repo_id = "/".join(parts[:2])
subfolder = None
if len(parts) > 3:
subfolder = "/".join(parts[4:])
local_root = "./tmp"
local_dir = os.path.join(local_root, repo_id.replace("/", "_"))
if subfolder:
local_dir = os.path.join(local_dir, subfolder)
os.makedirs(local_root, exist_ok=True)
config_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="config.json",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
assert os.path.exists(config_file)
hf_hub_download(
repo_id=repo_id,
filename="diffusion_pytorch_model.safetensors",
subfolder=subfolder,
local_dir=local_root,
local_dir_use_symlinks=False,
token=os.getenv("READ_HF_TOKEN"),
)
path = local_dir
model_root = os.path.join(config_file.split("config.json")[0])
json_path = os.path.join(model_root, "config.json")
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
klass_name = config["_class_name"]
klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None)
assert (
klass is not None
), f"Could not find class {klass_name} in diffusers or global scope."
assert hasattr(
klass, "from_pretrained"
), f"Class {klass_name} does not support 'from_pretrained'."
return klass.from_pretrained(path)
def _load_reid_models(self):
"""Load REID models and anatomies from local assets."""
reid = {
"anatomies": {
"A4C": torch.cat(
[
torch.load(self.assets_dir / "anatomies_dynamic.pt"),
torch.load(self.assets_dir / "anatomies_ped_a4c.pt"),
],
dim=0,
),
"PSAX": torch.load(self.assets_dir / "anatomies_ped_psax.pt"),
"PLAX": torch.load(self.assets_dir / "anatomies_lvh.pt"),
},
"models": {},
"tau": {
"A4C": 0.9997,
"PSAX": 0.9997,
"PLAX": 0.9997,
},
}
# Try to load REID models from HuggingFace
reid_urls = {
"A4C": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/dynamic-4f4",
"PSAX": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4",
"PLAX": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/lvh-4f4",
}
for view, url in reid_urls.items():
try:
reid["models"][view] = self._load_reid_model(url)
except Exception as e:
print(f"⚠️ REID model for {view} loading failed: {e}")
reid["models"][view] = None
return reid
def _load_reid_model(self, path):
"""Load a REID model from HuggingFace."""
parsed_url = urlparse(path)
parts = parsed_url.path.strip("/").split("/")
repo_id = "/".join(parts[:2])
subfolder = "/".join(parts[4:])
local_root = "./tmp"
config_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="config.yaml",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
weights_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="backbone.safetensors",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
config = OmegaConf.load(config_file)
backbone = instantiate_class_from_config(config.backbone)
backbone = ContrastiveModel.patch_backbone(
backbone, config.model.args.in_channels, config.model.args.out_channels
)
state_dict = load_file(weights_file)
backbone.load_state_dict(state_dict)
backbone = backbone.to(self.device, dtype=self.dtype)
backbone.eval()
return backbone
def _get_vae_scaler(self, path):
"""Load VAE scaler from file."""
scaler = torch.load(path)
scaler = {k: v.to(self.device) for k, v in scaler.items()}
return scaler
def generate_latent_image(self, mask, class_selection, sampling_steps=50):
"""Generate a latent image based on mask, class selection, and sampling steps."""
if not self.lifm:
return {"status": "error", "message": "LIFM model not available"}
try:
# Preprocess mask
mask = self._preprocess_mask(mask)
mask = torch.from_numpy(mask).to(self.device, dtype=self.dtype)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = F.interpolate(mask, size=(self.H, self.W), mode="bilinear", align_corners=False)
mask = 1.0 * (mask > 0)
# Class
class_idx = self.VIEWS.index(class_selection)
class_idx = torch.tensor([class_idx], device=self.device, dtype=torch.long)
# Timesteps
timesteps = torch.linspace(
1.0, 0.0, steps=sampling_steps + 1, device=self.device, dtype=self.dtype
)
forward_kwargs = {
"class_labels": class_idx, # B x 1
"segmentation": mask, # B x 1 x H x W
}
z_1 = torch.randn(
(self.B, self.C, self.H, self.W),
device=self.device,
dtype=self.dtype,
)
self.lifm.forward_original = self.lifm.forward
def new_forward(self, t, y, *args, **kwargs):
kwargs = {**kwargs, **forward_kwargs}
return self.forward_original(y, t.view(1), *args, **kwargs).sample
self.lifm.forward = types.MethodType(new_forward, self.lifm)
# Use odeint to integrate
with torch.autocast("cuda"):
latent_image = odeint(
self.lifm,
z_1,
timesteps,
atol=1e-5,
rtol=1e-5,
adjoint_params=self.lifm.parameters(),
method="euler",
)[-1]
self.lifm.forward = self.lifm.forward_original
latent_image = latent_image.detach().cpu().numpy()
return {"status": "success", "latent_image": latent_image}
except Exception as e:
return {"status": "error", "message": str(e)}
def decode_latent_to_pixel(self, latent_image):
"""Decode a latent image to pixel space."""
if not self.vae or not self.vae_scaler:
return {"status": "error", "message": "VAE or VAE scaler not available"}
try:
if latent_image is None:
return {"status": "error", "message": "No latent image provided"}
# Add batch dimension if needed
if len(latent_image.shape) == 3:
latent_image = latent_image[None, ...]
# Convert to torch tensor if needed
if not isinstance(latent_image, torch.Tensor):
latent_image = torch.from_numpy(latent_image).to(self.device, dtype=self.dtype)
# Unscale latents
latent_image = unscale_latents(latent_image, self.vae_scaler)
# Decode using VAE
with torch.no_grad():
decoded = self.vae.decode(latent_image.float()).sample
decoded = (decoded + 1) * 128
decoded = decoded.clamp(0, 255).to(torch.uint8).cpu()
decoded = decoded.squeeze()
decoded = decoded.permute(1, 2, 0)
# Resize to 400x400
decoded_image = cv2.resize(
decoded.numpy(), (400, 400), interpolation=cv2.INTER_NEAREST
)
return {"status": "success", "decoded_image": decoded_image}
except Exception as e:
return {"status": "error", "message": str(e)}
def _preprocess_mask(self, mask):
"""Preprocess mask for the model."""
if mask is None:
return np.zeros((112, 112), dtype=np.uint8)
# Check if mask is an EditorValue with multiple parts
if isinstance(mask, dict) and "composite" in mask:
# Use the composite image from the ImageEditor
mask = mask["composite"]
# If mask is already a numpy array, convert to PIL for processing
if isinstance(mask, np.ndarray):
mask_pil = Image.fromarray(mask)
else:
mask_pil = mask
# Ensure the mask is in L mode (grayscale)
mask_pil = mask_pil.convert("L")
# Apply contrast to make it binary (0 or 255)
mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0)
# Threshold to ensure binary values
mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0)
# Resize to 112x112 for the model
mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS)
# Convert back to numpy array
return np.array(mask_pil)
def cleanup(self):
"""Clean up model resources."""
try:
if hasattr(self, 'lifm') and self.lifm:
del self.lifm
except AttributeError:
pass
try:
if hasattr(self, 'vae') and self.vae:
del self.vae
except AttributeError:
pass
try:
if hasattr(self, 'lvfm') and self.lvfm:
del self.lvfm
except AttributeError:
pass
try:
if hasattr(self, 'reid') and self.reid:
del self.reid
except AttributeError:
pass
# Clear CUDA cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
def is_available(self):
"""Check if EchoFlow is available."""
return (self.lifm is not None and
self.vae is not None and
self.vae_scaler is not None and
self.lvfm is not None and
self.reid is not None)
def get_status(self):
"""Get current status."""
if self.is_available():
return ModelStatus.READY
else:
return ModelStatus.NOT_AVAILABLE
def predict(self, *args, **kwargs):
"""Predict method required by BaseModelManager."""
return {"status": "error", "message": "EchoFlow predict not implemented"}