Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from collections import OrderedDict | |
| from PIL import Image | |
| # Local project imports | |
| from models import Generator, Encoder, LandmarkEncoder | |
| from models.landmark_predictor import OcularLMGenerator | |
| def remove_module_prefix(state_dict): | |
| """Removes the 'module.' prefix from state dict keys if it exists.""" | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] if k.startswith('module.') else k | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| def load_all_models(config, epoch, device): | |
| """Loads and initializes all models needed for the Streamlit app.""" | |
| model_cfg = config['model'] | |
| data_cfg = config['data'] | |
| paths_cfg = config['paths'] | |
| # --- 1. Load the main GAN models (G, E, LE) --- | |
| gan_models = { | |
| 'G': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']), | |
| 'E': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']), | |
| 'LE': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size']) | |
| } | |
| for name, model in gan_models.items(): | |
| model_path = os.path.join(paths_cfg['outputs'][name], f'{name.lower()}_epoch_{epoch}.pth') | |
| print(f"Loading {name} model from: {model_path}") | |
| state_dict = torch.load(model_path, map_location=device) | |
| state_dict = remove_module_prefix(state_dict) | |
| model.load_state_dict(state_dict) | |
| model.to(device).eval() | |
| # --- 2. Load the separate Landmark Predictor model --- | |
| lp_path = paths_cfg['inputs']['landmark_predictor_model'] | |
| print(f"Loading Landmark Predictor from: {lp_path}") | |
| landmark_predictor = OcularLMGenerator().to(device) | |
| state_dict_lp = torch.load(lp_path, map_location=device) | |
| state_dict_lp = remove_module_prefix(state_dict_lp) | |
| landmark_predictor.load_state_dict(state_dict_lp) | |
| landmark_predictor.eval() | |
| # Return all models in a dictionary for easy access | |
| return { | |
| 'netG': gan_models['G'], | |
| 'netE': gan_models['E'], | |
| 'landmark_encoder': gan_models['LE'], | |
| 'landmark_predictor': landmark_predictor | |
| } | |
| def load_image(file): | |
| """Safely loads an image file into a PIL Image object.""" | |
| try: | |
| image = Image.open(file).convert('RGB') | |
| return image | |
| except Exception as e: | |
| raise ValueError(f"Error loading image: {str(e)}") |