DOOMGAN / app.py
BharathK333's picture
Update app.py
7ee35f0 verified
import gradio as gr
import torch
import yaml
import os
from huggingface_hub import hf_hub_download, list_repo_files
from collections import OrderedDict
# --- Local project imports ---
from apps.gan_morpher import morph_images_with_gan
from models.models import Generator, Encoder, LandmarkEncoder
from models.landmark_predictor import OcularLMGenerator
# --- 1. Define Constants and Configuration ---
MODEL_REPO_ID = "BharathK333/DOOMGAN"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 450 # The epoch for the models we want to load
print("--- Initializing Gradio App: Downloading and Loading Models ---")
# --- 2. Global variables to store loaded models ---
config = None
models = None
# --- 3. Function to Load Models from Hugging Face Hub ---
def load_models_from_hub():
"""
Downloads all necessary files from the Hugging Face Hub and loads the models.
Based on the actual directory structure provided.
"""
global config, models
# Check if models are already loaded
if config is not None and models is not None:
return config, models
try:
# --- Download Model Files from the Model Repo (based on actual structure) ---
print(f"Downloading generator model...")
g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"generator_models/g_epoch_{EPOCH}.pth")
print(f"Downloading encoder model...")
e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"encoder_models/e_epoch_{EPOCH}.pth")
print(f"Downloading landmark encoder model...")
le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"landmark_encoder_models/le_epoch_{EPOCH}.pth")
print(f"Downloading landmark predictor model...")
lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="trained_models/Ocular_LM_Generator.pth")
# Load config file (use local file from Space)
print("Loading configuration from local file...")
with open('config/config.yaml', 'r') as f:
config = yaml.safe_load(f)
model_cfg = config['model']
data_cfg = config['data']
# Helper function to remove 'module.' prefix from state dict keys
def remove_module_prefix(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
return new_state_dict
# --- Initialize and Load the specific GAN Models (G, E, LE) ---
print("Initializing models...")
gan_models_init = {
'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
}
model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
loaded_models = {}
for name, model in gan_models_init.items():
print(f"Loading {name} model...")
state_dict = torch.load(model_paths[name], map_location=DEVICE)
state_dict = remove_module_prefix(state_dict)
model.load_state_dict(state_dict)
model.to(DEVICE).eval()
loaded_models[name] = model
# --- Initialize and Load Landmark Predictor ---
print("Loading Landmark Predictor model...")
landmark_predictor = OcularLMGenerator().to(DEVICE)
state_dict_lp = torch.load(lp_path, map_location=DEVICE)
state_dict_lp = remove_module_prefix(state_dict_lp)
landmark_predictor.load_state_dict(state_dict_lp)
landmark_predictor.eval()
loaded_models['landmark_predictor'] = landmark_predictor
models = loaded_models
print(f"--- All models loaded successfully on {DEVICE} ---")
return config, models
except Exception as e:
print(f"Error loading models: {e}")
print("This could be due to:")
print("1. Repository not found or private")
print("2. Model files not available at expected locations")
print("3. Network connectivity issues")
print("4. Authentication required")
print("5. Config file not found")
# Print expected file paths for debugging
print("\nExpected file paths:")
print(f" - generator_models/g_epoch_{EPOCH}.pth")
print(f" - encoder_models/e_epoch_{EPOCH}.pth")
print(f" - landmark_encoder_models/le_epoch_{EPOCH}.pth")
print(f" - trained_models/Ocular_LM_Generator.pth")
print(f" - config/config.yaml")
# Return None to indicate failure
return None, None
# --- 4. Function to check if models are loaded ---
def check_models_loaded():
"""
Check if models are properly loaded.
"""
return config is not None and models is not None
# Try to load models (but don't fail if they're not available)
try:
config, models = load_models_from_hub()
if not check_models_loaded():
print("WARNING: Models could not be loaded. App will show error messages.")
except Exception as e:
print(f"CRITICAL ERROR: Failed to load models: {e}")
config, models = None, None
# --- 5. Define the core processing function for Gradio ---
def run_gan_morph(image1, image2, alpha):
if not check_models_loaded():
raise gr.Error("Models are not loaded. Please check the repository and model files.")
if image1 is None or image2 is None:
raise gr.Error("Please upload both source images to generate a morph.")
print(f"Performing GAN morph with alpha={alpha}...")
try:
morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
print("GAN morph complete.")
return morphed_image_numpy
except Exception as e:
raise gr.Error(f"Error during morphing: {str(e)}")
# --- 6. Create a function to show model status ---
def get_model_status():
if check_models_loaded():
return "✅ Models loaded successfully"
else:
return "❌ Models failed to load. Check console for details."
# --- 7. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
gr.Markdown(
"""
# DOOMGAN: High-Fidelity Ocular Image Morphing
An interactive demonstration of the IJCB-accepted **DOOMGAN** project.
Upload two ocular images, or use the examples below, and use the slider to morph between them.
"""
)
# Add model status indicator
status_text = gr.Markdown(get_model_status())
with gr.Row():
img1 = gr.Image(type="pil", label="Source Image 1")
img2 = gr.Image(type="pil", label="Source Image 2")
alpha_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Interpolation Factor (Image 1 <-> Image 2)",
info="Slide towards 0 to resemble Image 1, or towards 1 to resemble Image 2."
)
output_img = gr.Image(type="pil", label="Morphed Result")
run_button = gr.Button("Generate Morph", variant="primary")
# Only show examples if models are loaded
if check_models_loaded():
gr.Examples(
examples=[
["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
["assets/1162_r_9.png", "assets/1163_r_17.png", 0.4],
["assets/1172_l_1.png", "assets/1177_l_1.png", 0.6],
["assets/2517_r_9.png", "assets/3243_l_1.png", 0.7],
],
inputs=[img1, img2, alpha_slider],
outputs=output_img,
fn=run_gan_morph,
cache_examples=True # Caches the results for instant loading
)
else:
gr.Markdown("⚠️ **Examples disabled**: Models could not be loaded.")
# --- 8. Connect the UI components to the function ---
run_button.click(
fn=run_gan_morph,
inputs=[img1, img2, alpha_slider],
outputs=[output_img],
api_name="morph"
)
if __name__ == "__main__":
demo.launch()