DOOMGAN / apps /utils.py
BharathK333's picture
Upload 29 files
24c8da0 verified
import os
import torch
from collections import OrderedDict
from PIL import Image
# Local project imports
from models import Generator, Encoder, LandmarkEncoder
from models.landmark_predictor import OcularLMGenerator
def remove_module_prefix(state_dict):
"""Removes the 'module.' prefix from state dict keys if it exists."""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
return new_state_dict
def load_all_models(config, epoch, device):
"""Loads and initializes all models needed for the Streamlit app."""
model_cfg = config['model']
data_cfg = config['data']
paths_cfg = config['paths']
# --- 1. Load the main GAN models (G, E, LE) ---
gan_models = {
'G': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
'E': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
'LE': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
}
for name, model in gan_models.items():
model_path = os.path.join(paths_cfg['outputs'][name], f'{name.lower()}_epoch_{epoch}.pth')
print(f"Loading {name} model from: {model_path}")
state_dict = torch.load(model_path, map_location=device)
state_dict = remove_module_prefix(state_dict)
model.load_state_dict(state_dict)
model.to(device).eval()
# --- 2. Load the separate Landmark Predictor model ---
lp_path = paths_cfg['inputs']['landmark_predictor_model']
print(f"Loading Landmark Predictor from: {lp_path}")
landmark_predictor = OcularLMGenerator().to(device)
state_dict_lp = torch.load(lp_path, map_location=device)
state_dict_lp = remove_module_prefix(state_dict_lp)
landmark_predictor.load_state_dict(state_dict_lp)
landmark_predictor.eval()
# Return all models in a dictionary for easy access
return {
'netG': gan_models['G'],
'netE': gan_models['E'],
'landmark_encoder': gan_models['LE'],
'landmark_predictor': landmark_predictor
}
def load_image(file):
"""Safely loads an image file into a PIL Image object."""
try:
image = Image.open(file).convert('RGB')
return image
except Exception as e:
raise ValueError(f"Error loading image: {str(e)}")