Spaces:
Sleeping
Sleeping
File size: 8,413 Bytes
2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 729f055 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 2312f6d 081ea86 7ee35f0 081ea86 2312f6d 081ea86 2312f6d 081ea86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gradio as gr
import torch
import yaml
import os
from huggingface_hub import hf_hub_download, list_repo_files
from collections import OrderedDict
# --- Local project imports ---
from apps.gan_morpher import morph_images_with_gan
from models.models import Generator, Encoder, LandmarkEncoder
from models.landmark_predictor import OcularLMGenerator
# --- 1. Define Constants and Configuration ---
MODEL_REPO_ID = "BharathK333/DOOMGAN"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 450 # The epoch for the models we want to load
print("--- Initializing Gradio App: Downloading and Loading Models ---")
# --- 2. Global variables to store loaded models ---
config = None
models = None
# --- 3. Function to Load Models from Hugging Face Hub ---
def load_models_from_hub():
"""
Downloads all necessary files from the Hugging Face Hub and loads the models.
Based on the actual directory structure provided.
"""
global config, models
# Check if models are already loaded
if config is not None and models is not None:
return config, models
try:
# --- Download Model Files from the Model Repo (based on actual structure) ---
print(f"Downloading generator model...")
g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"generator_models/g_epoch_{EPOCH}.pth")
print(f"Downloading encoder model...")
e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"encoder_models/e_epoch_{EPOCH}.pth")
print(f"Downloading landmark encoder model...")
le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"landmark_encoder_models/le_epoch_{EPOCH}.pth")
print(f"Downloading landmark predictor model...")
lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="trained_models/Ocular_LM_Generator.pth")
# Load config file (use local file from Space)
print("Loading configuration from local file...")
with open('config/config.yaml', 'r') as f:
config = yaml.safe_load(f)
model_cfg = config['model']
data_cfg = config['data']
# Helper function to remove 'module.' prefix from state dict keys
def remove_module_prefix(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
return new_state_dict
# --- Initialize and Load the specific GAN Models (G, E, LE) ---
print("Initializing models...")
gan_models_init = {
'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
}
model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
loaded_models = {}
for name, model in gan_models_init.items():
print(f"Loading {name} model...")
state_dict = torch.load(model_paths[name], map_location=DEVICE)
state_dict = remove_module_prefix(state_dict)
model.load_state_dict(state_dict)
model.to(DEVICE).eval()
loaded_models[name] = model
# --- Initialize and Load Landmark Predictor ---
print("Loading Landmark Predictor model...")
landmark_predictor = OcularLMGenerator().to(DEVICE)
state_dict_lp = torch.load(lp_path, map_location=DEVICE)
state_dict_lp = remove_module_prefix(state_dict_lp)
landmark_predictor.load_state_dict(state_dict_lp)
landmark_predictor.eval()
loaded_models['landmark_predictor'] = landmark_predictor
models = loaded_models
print(f"--- All models loaded successfully on {DEVICE} ---")
return config, models
except Exception as e:
print(f"Error loading models: {e}")
print("This could be due to:")
print("1. Repository not found or private")
print("2. Model files not available at expected locations")
print("3. Network connectivity issues")
print("4. Authentication required")
print("5. Config file not found")
# Print expected file paths for debugging
print("\nExpected file paths:")
print(f" - generator_models/g_epoch_{EPOCH}.pth")
print(f" - encoder_models/e_epoch_{EPOCH}.pth")
print(f" - landmark_encoder_models/le_epoch_{EPOCH}.pth")
print(f" - trained_models/Ocular_LM_Generator.pth")
print(f" - config/config.yaml")
# Return None to indicate failure
return None, None
# --- 4. Function to check if models are loaded ---
def check_models_loaded():
"""
Check if models are properly loaded.
"""
return config is not None and models is not None
# Try to load models (but don't fail if they're not available)
try:
config, models = load_models_from_hub()
if not check_models_loaded():
print("WARNING: Models could not be loaded. App will show error messages.")
except Exception as e:
print(f"CRITICAL ERROR: Failed to load models: {e}")
config, models = None, None
# --- 5. Define the core processing function for Gradio ---
def run_gan_morph(image1, image2, alpha):
if not check_models_loaded():
raise gr.Error("Models are not loaded. Please check the repository and model files.")
if image1 is None or image2 is None:
raise gr.Error("Please upload both source images to generate a morph.")
print(f"Performing GAN morph with alpha={alpha}...")
try:
morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
print("GAN morph complete.")
return morphed_image_numpy
except Exception as e:
raise gr.Error(f"Error during morphing: {str(e)}")
# --- 6. Create a function to show model status ---
def get_model_status():
if check_models_loaded():
return "✅ Models loaded successfully"
else:
return "❌ Models failed to load. Check console for details."
# --- 7. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
gr.Markdown(
"""
# DOOMGAN: High-Fidelity Ocular Image Morphing
An interactive demonstration of the IJCB-accepted **DOOMGAN** project.
Upload two ocular images, or use the examples below, and use the slider to morph between them.
"""
)
# Add model status indicator
status_text = gr.Markdown(get_model_status())
with gr.Row():
img1 = gr.Image(type="pil", label="Source Image 1")
img2 = gr.Image(type="pil", label="Source Image 2")
alpha_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Interpolation Factor (Image 1 <-> Image 2)",
info="Slide towards 0 to resemble Image 1, or towards 1 to resemble Image 2."
)
output_img = gr.Image(type="pil", label="Morphed Result")
run_button = gr.Button("Generate Morph", variant="primary")
# Only show examples if models are loaded
if check_models_loaded():
gr.Examples(
examples=[
["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
["assets/1162_r_9.png", "assets/1163_r_17.png", 0.4],
["assets/1172_l_1.png", "assets/1177_l_1.png", 0.6],
["assets/2517_r_9.png", "assets/3243_l_1.png", 0.7],
],
inputs=[img1, img2, alpha_slider],
outputs=output_img,
fn=run_gan_morph,
cache_examples=True # Caches the results for instant loading
)
else:
gr.Markdown("⚠️ **Examples disabled**: Models could not be loaded.")
# --- 8. Connect the UI components to the function ---
run_button.click(
fn=run_gan_morph,
inputs=[img1, img2, alpha_slider],
outputs=[output_img],
api_name="morph"
)
if __name__ == "__main__":
demo.launch() |