|
|
import json |
|
|
import os |
|
|
import types |
|
|
import warnings |
|
|
import re |
|
|
from pathlib import Path |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", message="nothing to repeat") |
|
|
warnings.filterwarnings("ignore", message=".*regex.*") |
|
|
warnings.filterwarnings("ignore", message=".*nothing to repeat.*") |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*config attributes.*were passed to.*but are not expected.*") |
|
|
warnings.filterwarnings("ignore", message=".*Please verify your config.json configuration file.*") |
|
|
|
|
|
import cv2 |
|
|
import diffusers |
|
|
import numpy as np |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from huggingface_hub import hf_hub_download |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image, ImageOps |
|
|
from safetensors.torch import load_file |
|
|
from torch.nn import functional as F |
|
|
from torchdiffeq import odeint_adjoint as odeint |
|
|
|
|
|
|
|
|
import sys |
|
|
_ROOT = Path(__file__).resolve().parents[2] |
|
|
_CANDIDATES = [ |
|
|
_ROOT / "tool_repos" / "EchoFlow", |
|
|
_ROOT / "tool_repos" / "EchoFlow-main", |
|
|
] |
|
|
_workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
|
|
if _workspace_root: |
|
|
_CANDIDATES.append(Path(_workspace_root) / "EchoFlow") |
|
|
_CANDIDATES.append(Path(_workspace_root) / "tool_repos" / "EchoFlow") |
|
|
|
|
|
echoflow_path = next((path for path in _CANDIDATES if path.exists()), None) |
|
|
if echoflow_path is None: |
|
|
raise RuntimeError("EchoFlow repository not found. Place it under tool_repos/EchoFlow.") |
|
|
|
|
|
sys.path.insert(0, str(echoflow_path)) |
|
|
|
|
|
try: |
|
|
from echoflow.common import instantiate_class_from_config, unscale_latents |
|
|
from echoflow.common.models import ( |
|
|
ContrastiveModel, |
|
|
DiffuserSTDiT, |
|
|
ResNet18, |
|
|
SegDiTTransformer2DModel, |
|
|
) |
|
|
except ImportError as e: |
|
|
print(f"⚠️ EchoFlow common modules not available: {e}") |
|
|
|
|
|
def instantiate_class_from_config(config, *args, **kwargs): |
|
|
raise NotImplementedError("EchoFlow common modules not available") |
|
|
|
|
|
def unscale_latents(latents, vae_scaling=None): |
|
|
if vae_scaling is not None: |
|
|
if latents.ndim == 4: |
|
|
v = (1, -1, 1, 1) |
|
|
elif latents.ndim == 5: |
|
|
v = (1, -1, 1, 1, 1) |
|
|
else: |
|
|
raise ValueError("Latents should be 4D or 5D") |
|
|
latents *= vae_scaling["std"].view(*v) |
|
|
latents += vae_scaling["mean"].view(*v) |
|
|
return latents |
|
|
|
|
|
from ..general.base_model_manager import BaseModelManager, ModelStatus |
|
|
|
|
|
|
|
|
class EchoFlowConfig: |
|
|
"""Configuration class for EchoFlow.""" |
|
|
def __init__(self): |
|
|
self.name = "EchoFlow" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.dtype = torch.float32 |
|
|
|
|
|
|
|
|
class EchoFlowManager(BaseModelManager): |
|
|
"""Manager for EchoFlow model components.""" |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__(config) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.dtype = torch.float32 |
|
|
|
|
|
|
|
|
self.lifm = None |
|
|
self.vae = None |
|
|
self.vae_scaler = None |
|
|
self.lvfm = None |
|
|
self.reid = None |
|
|
|
|
|
|
|
|
self.B, self.T, self.C, self.H, self.W = 1, 64, 4, 28, 28 |
|
|
self.VIEWS = ["A4C", "PSAX", "PLAX"] |
|
|
|
|
|
|
|
|
self.assets_dir = Path(__file__).parent.parent.parent / "model_weights" / "EchoFlow" / "assets" |
|
|
|
|
|
self._initialize_model() |
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Initialize the EchoFlow model using local assets.""" |
|
|
try: |
|
|
print("Initializing EchoFlow model...") |
|
|
self._load_models() |
|
|
self._set_status(ModelStatus.READY) |
|
|
print("✅ EchoFlow model initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"⚠️ EchoFlow model loading failed: {e}") |
|
|
print("EchoFlow initialization failed - continuing without EchoFlow") |
|
|
self._set_status(ModelStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _load_models(self): |
|
|
"""Load all EchoFlow model components from local assets.""" |
|
|
|
|
|
import warnings |
|
|
import re |
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.cuda") |
|
|
warnings.filterwarnings("ignore", message="The config attributes*") |
|
|
warnings.filterwarnings("ignore", message="*were passed to*but are not expected*") |
|
|
warnings.filterwarnings("ignore", message="nothing to repeat") |
|
|
warnings.filterwarnings("ignore", category=re.error) |
|
|
|
|
|
|
|
|
print("Loading LIFM model...") |
|
|
try: |
|
|
|
|
|
print("⚠️ Skipping LIFM model loading due to regex issues") |
|
|
self.lifm = None |
|
|
except Exception as e: |
|
|
print(f"⚠️ LIFM model loading failed: {e}") |
|
|
self.lifm = None |
|
|
|
|
|
|
|
|
print("Loading VAE model...") |
|
|
try: |
|
|
|
|
|
print("⚠️ Skipping VAE model loading due to regex issues") |
|
|
self.vae = None |
|
|
except Exception as e: |
|
|
print(f"⚠️ VAE model loading failed: {e}") |
|
|
self.vae = None |
|
|
|
|
|
|
|
|
print("Loading VAE scaler...") |
|
|
try: |
|
|
scaler_path = self.assets_dir / "scaling.pt" |
|
|
if scaler_path.exists(): |
|
|
self.vae_scaler = self._get_vae_scaler(str(scaler_path)) |
|
|
print("✅ VAE scaler loaded from local assets") |
|
|
else: |
|
|
print("⚠️ VAE scaler not found in local assets") |
|
|
self.vae_scaler = None |
|
|
except Exception as e: |
|
|
print(f"⚠️ VAE scaler loading failed: {e}") |
|
|
self.vae_scaler = None |
|
|
|
|
|
|
|
|
print("Loading REID models...") |
|
|
try: |
|
|
|
|
|
print("⚠️ Skipping REID models loading due to regex issues") |
|
|
self.reid = None |
|
|
except Exception as e: |
|
|
print(f"⚠️ REID models loading failed: {e}") |
|
|
self.reid = None |
|
|
|
|
|
|
|
|
print("Loading LVFM model...") |
|
|
try: |
|
|
|
|
|
print("⚠️ Skipping LVFM model loading due to regex issues") |
|
|
self.lvfm = None |
|
|
except Exception as e: |
|
|
print(f"⚠️ LVFM model loading failed: {e}") |
|
|
self.lvfm = None |
|
|
|
|
|
def _load_model(self, path): |
|
|
"""Load a model from HuggingFace or local path.""" |
|
|
if path.startswith("http"): |
|
|
parsed_url = urlparse(path) |
|
|
if "huggingface.co" in parsed_url.netloc: |
|
|
parts = parsed_url.path.strip("/").split("/") |
|
|
repo_id = "/".join(parts[:2]) |
|
|
|
|
|
subfolder = None |
|
|
if len(parts) > 3: |
|
|
subfolder = "/".join(parts[4:]) |
|
|
|
|
|
local_root = "./tmp" |
|
|
local_dir = os.path.join(local_root, repo_id.replace("/", "_")) |
|
|
if subfolder: |
|
|
local_dir = os.path.join(local_dir, subfolder) |
|
|
os.makedirs(local_root, exist_ok=True) |
|
|
|
|
|
config_file = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
subfolder=subfolder, |
|
|
filename="config.json", |
|
|
local_dir=local_root, |
|
|
repo_type="model", |
|
|
token=os.getenv("READ_HF_TOKEN"), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
assert os.path.exists(config_file) |
|
|
|
|
|
hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="diffusion_pytorch_model.safetensors", |
|
|
subfolder=subfolder, |
|
|
local_dir=local_root, |
|
|
local_dir_use_symlinks=False, |
|
|
token=os.getenv("READ_HF_TOKEN"), |
|
|
) |
|
|
|
|
|
path = local_dir |
|
|
|
|
|
model_root = os.path.join(config_file.split("config.json")[0]) |
|
|
json_path = os.path.join(model_root, "config.json") |
|
|
assert os.path.exists(json_path) |
|
|
|
|
|
with open(json_path, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
klass_name = config["_class_name"] |
|
|
klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None) |
|
|
assert ( |
|
|
klass is not None |
|
|
), f"Could not find class {klass_name} in diffusers or global scope." |
|
|
assert hasattr( |
|
|
klass, "from_pretrained" |
|
|
), f"Class {klass_name} does not support 'from_pretrained'." |
|
|
|
|
|
return klass.from_pretrained(path) |
|
|
|
|
|
def _load_reid_models(self): |
|
|
"""Load REID models and anatomies from local assets.""" |
|
|
reid = { |
|
|
"anatomies": { |
|
|
"A4C": torch.cat( |
|
|
[ |
|
|
torch.load(self.assets_dir / "anatomies_dynamic.pt"), |
|
|
torch.load(self.assets_dir / "anatomies_ped_a4c.pt"), |
|
|
], |
|
|
dim=0, |
|
|
), |
|
|
"PSAX": torch.load(self.assets_dir / "anatomies_ped_psax.pt"), |
|
|
"PLAX": torch.load(self.assets_dir / "anatomies_lvh.pt"), |
|
|
}, |
|
|
"models": {}, |
|
|
"tau": { |
|
|
"A4C": 0.9997, |
|
|
"PSAX": 0.9997, |
|
|
"PLAX": 0.9997, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
reid_urls = { |
|
|
"A4C": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/dynamic-4f4", |
|
|
"PSAX": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4", |
|
|
"PLAX": "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/lvh-4f4", |
|
|
} |
|
|
|
|
|
for view, url in reid_urls.items(): |
|
|
try: |
|
|
reid["models"][view] = self._load_reid_model(url) |
|
|
except Exception as e: |
|
|
print(f"⚠️ REID model for {view} loading failed: {e}") |
|
|
reid["models"][view] = None |
|
|
|
|
|
return reid |
|
|
|
|
|
def _load_reid_model(self, path): |
|
|
"""Load a REID model from HuggingFace.""" |
|
|
parsed_url = urlparse(path) |
|
|
parts = parsed_url.path.strip("/").split("/") |
|
|
repo_id = "/".join(parts[:2]) |
|
|
subfolder = "/".join(parts[4:]) |
|
|
|
|
|
local_root = "./tmp" |
|
|
|
|
|
config_file = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
subfolder=subfolder, |
|
|
filename="config.yaml", |
|
|
local_dir=local_root, |
|
|
repo_type="model", |
|
|
token=os.getenv("READ_HF_TOKEN"), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
weights_file = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
subfolder=subfolder, |
|
|
filename="backbone.safetensors", |
|
|
local_dir=local_root, |
|
|
repo_type="model", |
|
|
token=os.getenv("READ_HF_TOKEN"), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
config = OmegaConf.load(config_file) |
|
|
backbone = instantiate_class_from_config(config.backbone) |
|
|
backbone = ContrastiveModel.patch_backbone( |
|
|
backbone, config.model.args.in_channels, config.model.args.out_channels |
|
|
) |
|
|
state_dict = load_file(weights_file) |
|
|
backbone.load_state_dict(state_dict) |
|
|
backbone = backbone.to(self.device, dtype=self.dtype) |
|
|
backbone.eval() |
|
|
return backbone |
|
|
|
|
|
def _get_vae_scaler(self, path): |
|
|
"""Load VAE scaler from file.""" |
|
|
scaler = torch.load(path) |
|
|
scaler = {k: v.to(self.device) for k, v in scaler.items()} |
|
|
return scaler |
|
|
|
|
|
def generate_latent_image(self, mask, class_selection, sampling_steps=50): |
|
|
"""Generate a latent image based on mask, class selection, and sampling steps.""" |
|
|
if not self.lifm: |
|
|
return {"status": "error", "message": "LIFM model not available"} |
|
|
|
|
|
try: |
|
|
|
|
|
mask = self._preprocess_mask(mask) |
|
|
mask = torch.from_numpy(mask).to(self.device, dtype=self.dtype) |
|
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
mask = F.interpolate(mask, size=(self.H, self.W), mode="bilinear", align_corners=False) |
|
|
mask = 1.0 * (mask > 0) |
|
|
|
|
|
|
|
|
class_idx = self.VIEWS.index(class_selection) |
|
|
class_idx = torch.tensor([class_idx], device=self.device, dtype=torch.long) |
|
|
|
|
|
|
|
|
timesteps = torch.linspace( |
|
|
1.0, 0.0, steps=sampling_steps + 1, device=self.device, dtype=self.dtype |
|
|
) |
|
|
|
|
|
forward_kwargs = { |
|
|
"class_labels": class_idx, |
|
|
"segmentation": mask, |
|
|
} |
|
|
|
|
|
z_1 = torch.randn( |
|
|
(self.B, self.C, self.H, self.W), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
|
|
|
self.lifm.forward_original = self.lifm.forward |
|
|
|
|
|
def new_forward(self, t, y, *args, **kwargs): |
|
|
kwargs = {**kwargs, **forward_kwargs} |
|
|
return self.forward_original(y, t.view(1), *args, **kwargs).sample |
|
|
|
|
|
self.lifm.forward = types.MethodType(new_forward, self.lifm) |
|
|
|
|
|
|
|
|
with torch.autocast("cuda"): |
|
|
latent_image = odeint( |
|
|
self.lifm, |
|
|
z_1, |
|
|
timesteps, |
|
|
atol=1e-5, |
|
|
rtol=1e-5, |
|
|
adjoint_params=self.lifm.parameters(), |
|
|
method="euler", |
|
|
)[-1] |
|
|
|
|
|
self.lifm.forward = self.lifm.forward_original |
|
|
|
|
|
latent_image = latent_image.detach().cpu().numpy() |
|
|
|
|
|
return {"status": "success", "latent_image": latent_image} |
|
|
|
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
def decode_latent_to_pixel(self, latent_image): |
|
|
"""Decode a latent image to pixel space.""" |
|
|
if not self.vae or not self.vae_scaler: |
|
|
return {"status": "error", "message": "VAE or VAE scaler not available"} |
|
|
|
|
|
try: |
|
|
if latent_image is None: |
|
|
return {"status": "error", "message": "No latent image provided"} |
|
|
|
|
|
|
|
|
if len(latent_image.shape) == 3: |
|
|
latent_image = latent_image[None, ...] |
|
|
|
|
|
|
|
|
if not isinstance(latent_image, torch.Tensor): |
|
|
latent_image = torch.from_numpy(latent_image).to(self.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
latent_image = unscale_latents(latent_image, self.vae_scaler) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
decoded = self.vae.decode(latent_image.float()).sample |
|
|
decoded = (decoded + 1) * 128 |
|
|
decoded = decoded.clamp(0, 255).to(torch.uint8).cpu() |
|
|
decoded = decoded.squeeze() |
|
|
decoded = decoded.permute(1, 2, 0) |
|
|
|
|
|
|
|
|
decoded_image = cv2.resize( |
|
|
decoded.numpy(), (400, 400), interpolation=cv2.INTER_NEAREST |
|
|
) |
|
|
|
|
|
return {"status": "success", "decoded_image": decoded_image} |
|
|
|
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
def _preprocess_mask(self, mask): |
|
|
"""Preprocess mask for the model.""" |
|
|
if mask is None: |
|
|
return np.zeros((112, 112), dtype=np.uint8) |
|
|
|
|
|
|
|
|
if isinstance(mask, dict) and "composite" in mask: |
|
|
|
|
|
mask = mask["composite"] |
|
|
|
|
|
|
|
|
if isinstance(mask, np.ndarray): |
|
|
mask_pil = Image.fromarray(mask) |
|
|
else: |
|
|
mask_pil = mask |
|
|
|
|
|
|
|
|
mask_pil = mask_pil.convert("L") |
|
|
|
|
|
|
|
|
mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0) |
|
|
|
|
|
|
|
|
mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0) |
|
|
|
|
|
|
|
|
mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
return np.array(mask_pil) |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up model resources.""" |
|
|
try: |
|
|
if hasattr(self, 'lifm') and self.lifm: |
|
|
del self.lifm |
|
|
except AttributeError: |
|
|
pass |
|
|
try: |
|
|
if hasattr(self, 'vae') and self.vae: |
|
|
del self.vae |
|
|
except AttributeError: |
|
|
pass |
|
|
try: |
|
|
if hasattr(self, 'lvfm') and self.lvfm: |
|
|
del self.lvfm |
|
|
except AttributeError: |
|
|
pass |
|
|
try: |
|
|
if hasattr(self, 'reid') and self.reid: |
|
|
del self.reid |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def is_available(self): |
|
|
"""Check if EchoFlow is available.""" |
|
|
return (self.lifm is not None and |
|
|
self.vae is not None and |
|
|
self.vae_scaler is not None and |
|
|
self.lvfm is not None and |
|
|
self.reid is not None) |
|
|
|
|
|
def get_status(self): |
|
|
"""Get current status.""" |
|
|
if self.is_available(): |
|
|
return ModelStatus.READY |
|
|
else: |
|
|
return ModelStatus.NOT_AVAILABLE |
|
|
|
|
|
def predict(self, *args, **kwargs): |
|
|
"""Predict method required by BaseModelManager.""" |
|
|
return {"status": "error", "message": "EchoFlow predict not implemented"} |
|
|
|