File size: 8,413 Bytes
2312f6d
 
 
081ea86
 
2312f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
081ea86
2312f6d
 
 
 
 
 
 
081ea86
 
 
 
 
 
 
 
 
 
 
 
 
 
729f055
 
 
081ea86
 
 
 
 
 
 
 
 
 
 
 
2312f6d
081ea86
 
 
 
 
 
 
 
 
 
2312f6d
081ea86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2312f6d
081ea86
 
 
2312f6d
 
 
 
081ea86
 
 
 
 
 
2312f6d
081ea86
 
 
 
 
 
2312f6d
081ea86
2312f6d
 
 
 
 
 
 
 
081ea86
 
 
 
2312f6d
 
 
 
 
 
 
 
 
 
 
 
 
081ea86
 
 
 
 
7ee35f0
 
 
081ea86
 
 
 
 
 
 
 
2312f6d
081ea86
2312f6d
 
 
 
 
081ea86
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import torch
import yaml
import os
from huggingface_hub import hf_hub_download, list_repo_files
from collections import OrderedDict

# --- Local project imports ---
from apps.gan_morpher import morph_images_with_gan
from models.models import Generator, Encoder, LandmarkEncoder
from models.landmark_predictor import OcularLMGenerator

# --- 1. Define Constants and Configuration ---
MODEL_REPO_ID = "BharathK333/DOOMGAN"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 450 # The epoch for the models we want to load

print("--- Initializing Gradio App: Downloading and Loading Models ---")

# --- 2. Global variables to store loaded models ---
config = None
models = None

# --- 3. Function to Load Models from Hugging Face Hub ---
def load_models_from_hub():
    """
    Downloads all necessary files from the Hugging Face Hub and loads the models.
    Based on the actual directory structure provided.
    """
    global config, models
    
    # Check if models are already loaded
    if config is not None and models is not None:
        return config, models
    
    try:
        # --- Download Model Files from the Model Repo (based on actual structure) ---
        print(f"Downloading generator model...")
        g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"generator_models/g_epoch_{EPOCH}.pth")
        
        print(f"Downloading encoder model...")
        e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"encoder_models/e_epoch_{EPOCH}.pth")
        
        print(f"Downloading landmark encoder model...")
        le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"landmark_encoder_models/le_epoch_{EPOCH}.pth")
        
        print(f"Downloading landmark predictor model...")
        lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="trained_models/Ocular_LM_Generator.pth")

        # Load config file (use local file from Space)
        print("Loading configuration from local file...")
        with open('config/config.yaml', 'r') as f:
            config = yaml.safe_load(f)
        
        model_cfg = config['model']
        data_cfg = config['data']

        # Helper function to remove 'module.' prefix from state dict keys
        def remove_module_prefix(state_dict):
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] if k.startswith('module.') else k
                new_state_dict[name] = v
            return new_state_dict

        # --- Initialize and Load the specific GAN Models (G, E, LE) ---
        print("Initializing models...")
        gan_models_init = {
            'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
            'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
            'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
        }
        
        model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
        loaded_models = {}

        for name, model in gan_models_init.items():
            print(f"Loading {name} model...")
            state_dict = torch.load(model_paths[name], map_location=DEVICE)
            state_dict = remove_module_prefix(state_dict)
            model.load_state_dict(state_dict)
            model.to(DEVICE).eval()
            loaded_models[name] = model

        # --- Initialize and Load Landmark Predictor ---
        print("Loading Landmark Predictor model...")
        landmark_predictor = OcularLMGenerator().to(DEVICE)
        state_dict_lp = torch.load(lp_path, map_location=DEVICE)
        state_dict_lp = remove_module_prefix(state_dict_lp)
        landmark_predictor.load_state_dict(state_dict_lp)
        landmark_predictor.eval()
        loaded_models['landmark_predictor'] = landmark_predictor
        
        models = loaded_models
        
        print(f"--- All models loaded successfully on {DEVICE} ---")
        return config, models
        
    except Exception as e:
        print(f"Error loading models: {e}")
        print("This could be due to:")
        print("1. Repository not found or private")
        print("2. Model files not available at expected locations")
        print("3. Network connectivity issues")
        print("4. Authentication required")
        print("5. Config file not found")
        
        # Print expected file paths for debugging
        print("\nExpected file paths:")
        print(f"  - generator_models/g_epoch_{EPOCH}.pth")
        print(f"  - encoder_models/e_epoch_{EPOCH}.pth")
        print(f"  - landmark_encoder_models/le_epoch_{EPOCH}.pth")
        print(f"  - trained_models/Ocular_LM_Generator.pth")
        print(f"  - config/config.yaml")
        
        # Return None to indicate failure
        return None, None

# --- 4. Function to check if models are loaded ---
def check_models_loaded():
    """
    Check if models are properly loaded.
    """
    return config is not None and models is not None

# Try to load models (but don't fail if they're not available)
try:
    config, models = load_models_from_hub()
    if not check_models_loaded():
        print("WARNING: Models could not be loaded. App will show error messages.")
except Exception as e:
    print(f"CRITICAL ERROR: Failed to load models: {e}")
    config, models = None, None

# --- 5. Define the core processing function for Gradio ---
def run_gan_morph(image1, image2, alpha):
    if not check_models_loaded():
        raise gr.Error("Models are not loaded. Please check the repository and model files.")
    
    if image1 is None or image2 is None:
        raise gr.Error("Please upload both source images to generate a morph.")
    
    print(f"Performing GAN morph with alpha={alpha}...")
    try:
        morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
        print("GAN morph complete.")
        return morphed_image_numpy
    except Exception as e:
        raise gr.Error(f"Error during morphing: {str(e)}")

# --- 6. Create a function to show model status ---
def get_model_status():
    if check_models_loaded():
        return "✅ Models loaded successfully"
    else:
        return "❌ Models failed to load. Check console for details."

# --- 7. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
    gr.Markdown(
        """
        # DOOMGAN: High-Fidelity Ocular Image Morphing
        An interactive demonstration of the IJCB-accepted **DOOMGAN** project. 
        Upload two ocular images, or use the examples below, and use the slider to morph between them.
        """
    )
    
    # Add model status indicator
    status_text = gr.Markdown(get_model_status())
    
    with gr.Row():
        img1 = gr.Image(type="pil", label="Source Image 1")
        img2 = gr.Image(type="pil", label="Source Image 2")
    
    alpha_slider = gr.Slider(
        minimum=0.0, maximum=1.0, value=0.5, step=0.05,
        label="Interpolation Factor (Image 1 <-> Image 2)",
        info="Slide towards 0 to resemble Image 1, or towards 1 to resemble Image 2."
    )
    
    output_img = gr.Image(type="pil", label="Morphed Result")
    run_button = gr.Button("Generate Morph", variant="primary")
    
    # Only show examples if models are loaded
    if check_models_loaded():
        gr.Examples(
            examples=[
                ["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
                ["assets/1162_r_9.png", "assets/1163_r_17.png", 0.4],
                ["assets/1172_l_1.png", "assets/1177_l_1.png", 0.6],
                ["assets/2517_r_9.png", "assets/3243_l_1.png", 0.7],
            ],
            inputs=[img1, img2, alpha_slider],
            outputs=output_img,
            fn=run_gan_morph,
            cache_examples=True # Caches the results for instant loading
        )
    else:
        gr.Markdown("⚠️ **Examples disabled**: Models could not be loaded.")

    # --- 8. Connect the UI components to the function ---
    run_button.click(
        fn=run_gan_morph,
        inputs=[img1, img2, alpha_slider],
        outputs=[output_img],
        api_name="morph" 
    )

if __name__ == "__main__":
    demo.launch()