BharathK333 commited on
Commit
081ea86
·
verified ·
1 Parent(s): 2312f6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -73
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  import yaml
4
- from huggingface_hub import hf_hub_download
 
5
  from collections import OrderedDict
6
 
7
  # --- Local project imports ---
@@ -24,6 +25,7 @@ models = None
24
  def load_models_from_hub():
25
  """
26
  Downloads all necessary files from the Hugging Face Hub and loads the models.
 
27
  """
28
  global config, models
29
 
@@ -31,73 +33,130 @@ def load_models_from_hub():
31
  if config is not None and models is not None:
32
  return config, models
33
 
34
- # --- Download Model Files from the Model Repo ---
35
- g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"G_{EPOCH}.pth")
36
- e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"E_{EPOCH}.pth")
37
- le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"LE_{EPOCH}.pth")
38
- lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="landmark_predictor.pth")
39
-
40
- with open('config/config.yaml', 'r') as f:
41
- config = yaml.safe_load(f)
42
-
43
- model_cfg = config['model']
44
- data_cfg = config['data']
45
-
46
- # Helper function to remove 'module.' prefix from state dict keys
47
- def remove_module_prefix(state_dict):
48
- new_state_dict = OrderedDict()
49
- for k, v in state_dict.items():
50
- name = k[7:] if k.startswith('module.') else k
51
- new_state_dict[name] = v
52
- return new_state_dict
53
-
54
- # --- Initialize and Load the specific GAN Models (G, E, LE) ---
55
- gan_models_init = {
56
- 'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
57
- 'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
58
- 'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
59
- }
60
-
61
- model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
62
- loaded_models = {}
63
-
64
- for name, model in gan_models_init.items():
65
- print(f"Loading {name} model...")
66
- state_dict = torch.load(model_paths[name], map_location=DEVICE)
67
- state_dict = remove_module_prefix(state_dict)
68
- model.load_state_dict(state_dict)
69
- model.to(DEVICE).eval()
70
- loaded_models[name] = model
71
-
72
- # --- Initialize and Load Landmark Predictor ---
73
- print("Loading Landmark Predictor model...")
74
- landmark_predictor = OcularLMGenerator().to(DEVICE)
75
- state_dict_lp = torch.load(lp_path, map_location=DEVICE)
76
- state_dict_lp = remove_module_prefix(state_dict_lp)
77
- landmark_predictor.load_state_dict(state_dict_lp)
78
- landmark_predictor.eval()
79
- loaded_models['landmark_predictor'] = landmark_predictor
80
-
81
- models = loaded_models
82
-
83
- print(f"--- All models loaded successfully on {DEVICE} ---")
84
- return config, models
85
 
86
- # Load everything when the app starts
87
- config, models = load_models_from_hub()
 
 
 
 
 
 
 
 
88
 
89
- # --- 4. Define the core processing function for Gradio ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def run_gan_morph(image1, image2, alpha):
 
 
 
91
  if image1 is None or image2 is None:
92
  raise gr.Error("Please upload both source images to generate a morph.")
93
 
94
  print(f"Performing GAN morph with alpha={alpha}...")
95
- morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
96
- print("GAN morph complete.")
97
- return morphed_image_numpy
 
 
 
98
 
 
 
 
 
 
 
99
 
100
- # --- 5. Build the Gradio Interface ---
101
  with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
102
  gr.Markdown(
103
  """
@@ -106,6 +165,10 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
106
  Upload two ocular images, or use the examples below, and use the slider to morph between them.
107
  """
108
  )
 
 
 
 
109
  with gr.Row():
110
  img1 = gr.Image(type="pil", label="Source Image 1")
111
  img2 = gr.Image(type="pil", label="Source Image 2")
@@ -119,23 +182,30 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
119
  output_img = gr.Image(type="pil", label="Morphed Result")
120
  run_button = gr.Button("Generate Morph", variant="primary")
121
 
122
- gr.Examples(
123
- examples=[
124
- ["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
125
- ["assets/1162_r_9.png", "assets/1163_r_17.png", 0.3],
126
- ["assets/1172_l_1.png", "assets/1177_l_1.png", 0.7],
127
- ["assets/2517_r_9.png", "assets/3243_l_1.png", 0.5],
128
- ],
129
- inputs=[img1, img2, alpha_slider],
130
- outputs=output_img,
131
- fn=run_gan_morph,
132
- cache_examples=True # Caches the results for instant loading
133
- )
 
 
 
 
134
 
135
- # --- 6. Connect the UI components to the function ---
136
  run_button.click(
137
  fn=run_gan_morph,
138
  inputs=[img1, img2, alpha_slider],
139
  outputs=[output_img],
140
  api_name="morph"
141
- )
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import yaml
4
+ import os
5
+ from huggingface_hub import hf_hub_download, list_repo_files
6
  from collections import OrderedDict
7
 
8
  # --- Local project imports ---
 
25
  def load_models_from_hub():
26
  """
27
  Downloads all necessary files from the Hugging Face Hub and loads the models.
28
+ Based on the actual directory structure provided.
29
  """
30
  global config, models
31
 
 
33
  if config is not None and models is not None:
34
  return config, models
35
 
36
+ try:
37
+ # --- Download Model Files from the Model Repo (based on actual structure) ---
38
+ print(f"Downloading generator model...")
39
+ g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"generator_models/g_epoch_{EPOCH}.pth")
40
+
41
+ print(f"Downloading encoder model...")
42
+ e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"encoder_models/e_epoch_{EPOCH}.pth")
43
+
44
+ print(f"Downloading landmark encoder model...")
45
+ le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"landmark_encoder_models/le_epoch_{EPOCH}.pth")
46
+
47
+ print(f"Downloading landmark predictor model...")
48
+ lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="trained_models/Ocular_LM_Generator.pth")
49
+
50
+ # Load config file
51
+ print("Loading configuration...")
52
+ config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="config/config.yaml")
53
+ with open(config_path, 'r') as f:
54
+ config = yaml.safe_load(f)
55
+
56
+ model_cfg = config['model']
57
+ data_cfg = config['data']
58
+
59
+ # Helper function to remove 'module.' prefix from state dict keys
60
+ def remove_module_prefix(state_dict):
61
+ new_state_dict = OrderedDict()
62
+ for k, v in state_dict.items():
63
+ name = k[7:] if k.startswith('module.') else k
64
+ new_state_dict[name] = v
65
+ return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # --- Initialize and Load the specific GAN Models (G, E, LE) ---
68
+ print("Initializing models...")
69
+ gan_models_init = {
70
+ 'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
71
+ 'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
72
+ 'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
73
+ }
74
+
75
+ model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
76
+ loaded_models = {}
77
 
78
+ for name, model in gan_models_init.items():
79
+ print(f"Loading {name} model...")
80
+ state_dict = torch.load(model_paths[name], map_location=DEVICE)
81
+ state_dict = remove_module_prefix(state_dict)
82
+ model.load_state_dict(state_dict)
83
+ model.to(DEVICE).eval()
84
+ loaded_models[name] = model
85
+
86
+ # --- Initialize and Load Landmark Predictor ---
87
+ print("Loading Landmark Predictor model...")
88
+ landmark_predictor = OcularLMGenerator().to(DEVICE)
89
+ state_dict_lp = torch.load(lp_path, map_location=DEVICE)
90
+ state_dict_lp = remove_module_prefix(state_dict_lp)
91
+ landmark_predictor.load_state_dict(state_dict_lp)
92
+ landmark_predictor.eval()
93
+ loaded_models['landmark_predictor'] = landmark_predictor
94
+
95
+ models = loaded_models
96
+
97
+ print(f"--- All models loaded successfully on {DEVICE} ---")
98
+ return config, models
99
+
100
+ except Exception as e:
101
+ print(f"Error loading models: {e}")
102
+ print("This could be due to:")
103
+ print("1. Repository not found or private")
104
+ print("2. Model files not available at expected locations")
105
+ print("3. Network connectivity issues")
106
+ print("4. Authentication required")
107
+ print("5. Config file not found")
108
+
109
+ # Print expected file paths for debugging
110
+ print("\nExpected file paths:")
111
+ print(f" - generator_models/g_epoch_{EPOCH}.pth")
112
+ print(f" - encoder_models/e_epoch_{EPOCH}.pth")
113
+ print(f" - landmark_encoder_models/le_epoch_{EPOCH}.pth")
114
+ print(f" - trained_models/Ocular_LM_Generator.pth")
115
+ print(f" - config/config.yaml")
116
+
117
+ # Return None to indicate failure
118
+ return None, None
119
+
120
+ # --- 4. Function to check if models are loaded ---
121
+ def check_models_loaded():
122
+ """
123
+ Check if models are properly loaded.
124
+ """
125
+ return config is not None and models is not None
126
+
127
+ # Try to load models (but don't fail if they're not available)
128
+ try:
129
+ config, models = load_models_from_hub()
130
+ if not check_models_loaded():
131
+ print("WARNING: Models could not be loaded. App will show error messages.")
132
+ except Exception as e:
133
+ print(f"CRITICAL ERROR: Failed to load models: {e}")
134
+ config, models = None, None
135
+
136
+ # --- 5. Define the core processing function for Gradio ---
137
  def run_gan_morph(image1, image2, alpha):
138
+ if not check_models_loaded():
139
+ raise gr.Error("Models are not loaded. Please check the repository and model files.")
140
+
141
  if image1 is None or image2 is None:
142
  raise gr.Error("Please upload both source images to generate a morph.")
143
 
144
  print(f"Performing GAN morph with alpha={alpha}...")
145
+ try:
146
+ morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
147
+ print("GAN morph complete.")
148
+ return morphed_image_numpy
149
+ except Exception as e:
150
+ raise gr.Error(f"Error during morphing: {str(e)}")
151
 
152
+ # --- 6. Create a function to show model status ---
153
+ def get_model_status():
154
+ if check_models_loaded():
155
+ return "✅ Models loaded successfully"
156
+ else:
157
+ return "❌ Models failed to load. Check console for details."
158
 
159
+ # --- 7. Build the Gradio Interface ---
160
  with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
161
  gr.Markdown(
162
  """
 
165
  Upload two ocular images, or use the examples below, and use the slider to morph between them.
166
  """
167
  )
168
+
169
+ # Add model status indicator
170
+ status_text = gr.Markdown(get_model_status())
171
+
172
  with gr.Row():
173
  img1 = gr.Image(type="pil", label="Source Image 1")
174
  img2 = gr.Image(type="pil", label="Source Image 2")
 
182
  output_img = gr.Image(type="pil", label="Morphed Result")
183
  run_button = gr.Button("Generate Morph", variant="primary")
184
 
185
+ # Only show examples if models are loaded
186
+ if check_models_loaded():
187
+ gr.Examples(
188
+ examples=[
189
+ ["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
190
+ ["assets/1162_r_9.png", "assets/1163_r_17.png", 0.3],
191
+ ["assets/1172_l_1.png", "assets/1177_l_1.png", 0.7],
192
+ ["assets/2517_r_9.png", "assets/3243_l_1.png", 0.5],
193
+ ],
194
+ inputs=[img1, img2, alpha_slider],
195
+ outputs=output_img,
196
+ fn=run_gan_morph,
197
+ cache_examples=True # Caches the results for instant loading
198
+ )
199
+ else:
200
+ gr.Markdown("⚠️ **Examples disabled**: Models could not be loaded.")
201
 
202
+ # --- 8. Connect the UI components to the function ---
203
  run_button.click(
204
  fn=run_gan_morph,
205
  inputs=[img1, img2, alpha_slider],
206
  outputs=[output_img],
207
  api_name="morph"
208
+ )
209
+
210
+ if __name__ == "__main__":
211
+ demo.launch()