Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import yaml | |
| import os | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from collections import OrderedDict | |
| # --- Local project imports --- | |
| from apps.gan_morpher import morph_images_with_gan | |
| from models.models import Generator, Encoder, LandmarkEncoder | |
| from models.landmark_predictor import OcularLMGenerator | |
| # --- 1. Define Constants and Configuration --- | |
| MODEL_REPO_ID = "BharathK333/DOOMGAN" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| EPOCH = 450 # The epoch for the models we want to load | |
| print("--- Initializing Gradio App: Downloading and Loading Models ---") | |
| # --- 2. Global variables to store loaded models --- | |
| config = None | |
| models = None | |
| # --- 3. Function to Load Models from Hugging Face Hub --- | |
| def load_models_from_hub(): | |
| """ | |
| Downloads all necessary files from the Hugging Face Hub and loads the models. | |
| Based on the actual directory structure provided. | |
| """ | |
| global config, models | |
| # Check if models are already loaded | |
| if config is not None and models is not None: | |
| return config, models | |
| try: | |
| # --- Download Model Files from the Model Repo (based on actual structure) --- | |
| print(f"Downloading generator model...") | |
| g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"generator_models/g_epoch_{EPOCH}.pth") | |
| print(f"Downloading encoder model...") | |
| e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"encoder_models/e_epoch_{EPOCH}.pth") | |
| print(f"Downloading landmark encoder model...") | |
| le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"landmark_encoder_models/le_epoch_{EPOCH}.pth") | |
| print(f"Downloading landmark predictor model...") | |
| lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="trained_models/Ocular_LM_Generator.pth") | |
| # Load config file (use local file from Space) | |
| print("Loading configuration from local file...") | |
| with open('config/config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| model_cfg = config['model'] | |
| data_cfg = config['data'] | |
| # Helper function to remove 'module.' prefix from state dict keys | |
| def remove_module_prefix(state_dict): | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] if k.startswith('module.') else k | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| # --- Initialize and Load the specific GAN Models (G, E, LE) --- | |
| print("Initializing models...") | |
| gan_models_init = { | |
| 'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']), | |
| 'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']), | |
| 'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size']) | |
| } | |
| model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path} | |
| loaded_models = {} | |
| for name, model in gan_models_init.items(): | |
| print(f"Loading {name} model...") | |
| state_dict = torch.load(model_paths[name], map_location=DEVICE) | |
| state_dict = remove_module_prefix(state_dict) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE).eval() | |
| loaded_models[name] = model | |
| # --- Initialize and Load Landmark Predictor --- | |
| print("Loading Landmark Predictor model...") | |
| landmark_predictor = OcularLMGenerator().to(DEVICE) | |
| state_dict_lp = torch.load(lp_path, map_location=DEVICE) | |
| state_dict_lp = remove_module_prefix(state_dict_lp) | |
| landmark_predictor.load_state_dict(state_dict_lp) | |
| landmark_predictor.eval() | |
| loaded_models['landmark_predictor'] = landmark_predictor | |
| models = loaded_models | |
| print(f"--- All models loaded successfully on {DEVICE} ---") | |
| return config, models | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| print("This could be due to:") | |
| print("1. Repository not found or private") | |
| print("2. Model files not available at expected locations") | |
| print("3. Network connectivity issues") | |
| print("4. Authentication required") | |
| print("5. Config file not found") | |
| # Print expected file paths for debugging | |
| print("\nExpected file paths:") | |
| print(f" - generator_models/g_epoch_{EPOCH}.pth") | |
| print(f" - encoder_models/e_epoch_{EPOCH}.pth") | |
| print(f" - landmark_encoder_models/le_epoch_{EPOCH}.pth") | |
| print(f" - trained_models/Ocular_LM_Generator.pth") | |
| print(f" - config/config.yaml") | |
| # Return None to indicate failure | |
| return None, None | |
| # --- 4. Function to check if models are loaded --- | |
| def check_models_loaded(): | |
| """ | |
| Check if models are properly loaded. | |
| """ | |
| return config is not None and models is not None | |
| # Try to load models (but don't fail if they're not available) | |
| try: | |
| config, models = load_models_from_hub() | |
| if not check_models_loaded(): | |
| print("WARNING: Models could not be loaded. App will show error messages.") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR: Failed to load models: {e}") | |
| config, models = None, None | |
| # --- 5. Define the core processing function for Gradio --- | |
| def run_gan_morph(image1, image2, alpha): | |
| if not check_models_loaded(): | |
| raise gr.Error("Models are not loaded. Please check the repository and model files.") | |
| if image1 is None or image2 is None: | |
| raise gr.Error("Please upload both source images to generate a morph.") | |
| print(f"Performing GAN morph with alpha={alpha}...") | |
| try: | |
| morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha) | |
| print("GAN morph complete.") | |
| return morphed_image_numpy | |
| except Exception as e: | |
| raise gr.Error(f"Error during morphing: {str(e)}") | |
| # --- 6. Create a function to show model status --- | |
| def get_model_status(): | |
| if check_models_loaded(): | |
| return "✅ Models loaded successfully" | |
| else: | |
| return "❌ Models failed to load. Check console for details." | |
| # --- 7. Build the Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo: | |
| gr.Markdown( | |
| """ | |
| # DOOMGAN: High-Fidelity Ocular Image Morphing | |
| An interactive demonstration of the IJCB-accepted **DOOMGAN** project. | |
| Upload two ocular images, or use the examples below, and use the slider to morph between them. | |
| """ | |
| ) | |
| # Add model status indicator | |
| status_text = gr.Markdown(get_model_status()) | |
| with gr.Row(): | |
| img1 = gr.Image(type="pil", label="Source Image 1") | |
| img2 = gr.Image(type="pil", label="Source Image 2") | |
| alpha_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Interpolation Factor (Image 1 <-> Image 2)", | |
| info="Slide towards 0 to resemble Image 1, or towards 1 to resemble Image 2." | |
| ) | |
| output_img = gr.Image(type="pil", label="Morphed Result") | |
| run_button = gr.Button("Generate Morph", variant="primary") | |
| # Only show examples if models are loaded | |
| if check_models_loaded(): | |
| gr.Examples( | |
| examples=[ | |
| ["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5], | |
| ["assets/1162_r_9.png", "assets/1163_r_17.png", 0.4], | |
| ["assets/1172_l_1.png", "assets/1177_l_1.png", 0.6], | |
| ["assets/2517_r_9.png", "assets/3243_l_1.png", 0.7], | |
| ], | |
| inputs=[img1, img2, alpha_slider], | |
| outputs=output_img, | |
| fn=run_gan_morph, | |
| cache_examples=True # Caches the results for instant loading | |
| ) | |
| else: | |
| gr.Markdown("⚠️ **Examples disabled**: Models could not be loaded.") | |
| # --- 8. Connect the UI components to the function --- | |
| run_button.click( | |
| fn=run_gan_morph, | |
| inputs=[img1, img2, alpha_slider], | |
| outputs=[output_img], | |
| api_name="morph" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |