BharathK333 commited on
Commit
24c8da0
·
verified ·
1 Parent(s): cf3edca

Upload 29 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/1144_r_1.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/1147_r_1.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/1162_r_9.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/1163_r_17.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/1172_l_1.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/1177_l_1.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/2517_r_9.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download
5
+ from collections import OrderedDict
6
+
7
+ # --- Local project imports ---
8
+ from apps.gan_morpher import morph_images_with_gan
9
+ from models.models import Generator, Encoder, LandmarkEncoder
10
+ from models.landmark_predictor import OcularLMGenerator
11
+
12
+ # --- 1. Define Constants and Configuration ---
13
+ MODEL_REPO_ID = "BharathK333/DOOMGAN"
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ EPOCH = 450 # The epoch for the models we want to load
16
+
17
+ print("--- Initializing Gradio App: Downloading and Loading Models ---")
18
+
19
+ # --- 2. Function to Load Models from Hugging Face Hub ---
20
+ @gr.cache_resource() # This decorator caches the models, so they are loaded only once.
21
+ def load_models_from_hub():
22
+ """
23
+ Downloads all necessary files from the Hugging Face Hub and loads the models.
24
+ """
25
+ # --- Download Model Files from the Model Repo ---
26
+ g_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"G_{EPOCH}.pth")
27
+ e_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"E_{EPOCH}.pth")
28
+ le_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=f"LE_{EPOCH}.pth")
29
+ lp_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="landmark_predictor.pth")
30
+
31
+ with open('config/config.yaml', 'r') as f:
32
+ config = yaml.safe_load(f)
33
+
34
+ model_cfg = config['model']
35
+ data_cfg = config['data']
36
+
37
+ # Helper function to remove 'module.' prefix from state dict keys
38
+ def remove_module_prefix(state_dict):
39
+ new_state_dict = OrderedDict()
40
+ for k, v in state_dict.items():
41
+ name = k[7:] if k.startswith('module.') else k
42
+ new_state_dict[name] = v
43
+ return new_state_dict
44
+
45
+ # --- Initialize and Load the specific GAN Models (G, E, LE) ---
46
+ gan_models_init = {
47
+ 'netG': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
48
+ 'netE': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
49
+ 'landmark_encoder': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
50
+ }
51
+
52
+ model_paths = {'netG': g_path, 'netE': e_path, 'landmark_encoder': le_path}
53
+ loaded_models = {}
54
+
55
+ for name, model in gan_models_init.items():
56
+ print(f"Loading {name} model...")
57
+ state_dict = torch.load(model_paths[name], map_location=DEVICE)
58
+ state_dict = remove_module_prefix(state_dict)
59
+ model.load_state_dict(state_dict)
60
+ model.to(DEVICE).eval()
61
+ loaded_models[name] = model
62
+
63
+ # --- Initialize and Load Landmark Predictor ---
64
+ print("Loading Landmark Predictor model...")
65
+ landmark_predictor = OcularLMGenerator().to(DEVICE)
66
+ state_dict_lp = torch.load(lp_path, map_location=DEVICE)
67
+ state_dict_lp = remove_module_prefix(state_dict_lp)
68
+ landmark_predictor.load_state_dict(state_dict_lp)
69
+ landmark_predictor.eval()
70
+ loaded_models['landmark_predictor'] = landmark_predictor
71
+
72
+ print(f"--- All models loaded successfully on {DEVICE} ---")
73
+ return config, loaded_models
74
+
75
+ # Load everything when the app starts
76
+ config, models = load_models_from_hub()
77
+
78
+ # --- 3. Define the core processing function for Gradio ---
79
+ def run_gan_morph(image1, image2, alpha):
80
+ if image1 is None or image2 is None:
81
+ raise gr.Error("Please upload both source images to generate a morph.")
82
+
83
+ print(f"Performing GAN morph with alpha={alpha}...")
84
+ morphed_image_numpy = morph_images_with_gan(image1, image2, config, DEVICE, models, alpha)
85
+ print("GAN morph complete.")
86
+ return morphed_image_numpy
87
+
88
+
89
+ # --- 4. Build the Gradio Interface ---
90
+ with gr.Blocks(theme=gr.themes.Soft(), title="DOOMGAN Morphing") as demo:
91
+ gr.Markdown(
92
+ """
93
+ # DOOMGAN: High-Fidelity Ocular Image Morphing
94
+ An interactive demonstration of the IJCB-accepted **DOOMGAN** project.
95
+ Upload two ocular images, or use the examples below, and use the slider to morph between them.
96
+ """
97
+ )
98
+ with gr.Row():
99
+ img1 = gr.Image(type="pil", label="Source Image 1")
100
+ img2 = gr.Image(type="pil", label="Source Image 2")
101
+
102
+ alpha_slider = gr.Slider(
103
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
104
+ label="Interpolation Factor (Image 1 <-> Image 2)",
105
+ info="Slide towards 0 to resemble Image 1, or towards 1 to resemble Image 2."
106
+ )
107
+
108
+ output_img = gr.Image(type="pil", label="Morphed Result")
109
+ run_button = gr.Button("Generate Morph", variant="primary")
110
+
111
+ gr.Examples(
112
+ examples=[
113
+ ["assets/1144_r_1.png", "assets/1147_r_1.png", 0.5],
114
+ ["assets/1162_r_9.png", "assets/1163_r_17.png", 0.3],
115
+ ["assets/1172_l_1.png", "assets/1177_l_1.png", 0.7],
116
+ ["assets/2517_r_9.png", "assets/3243_l_1.png", 0.5],
117
+ ],
118
+ inputs=[img1, img2, alpha_slider],
119
+ outputs=output_img,
120
+ fn=run_gan_morph,
121
+ cache_examples=True # Caches the results for instant loading
122
+ )
123
+
124
+ # --- 5. Connect the UI components to the function ---
125
+ run_button.click(
126
+ fn=run_gan_morph,
127
+ inputs=[img1, img2, alpha_slider],
128
+ outputs=[output_img],
129
+ api_name="morph"
130
+ )
apps/__init__.py ADDED
File without changes
apps/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
apps/__pycache__/classical_morpher.cpython-310.pyc ADDED
Binary file (4.07 kB). View file
 
apps/__pycache__/gan_morpher.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
apps/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
apps/classical_morpher.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms import ToTensor, Resize, Compose, Normalize
5
+
6
+ def predict_landmarks_for_classical(image, landmark_predictor_model, device):
7
+ """Predicts landmarks and returns them as an unnormalized tensor for OpenCV."""
8
+ transform = Compose([
9
+ Resize((256, 256)),
10
+ ToTensor(),
11
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
12
+ ])
13
+ image_transformed = transform(image).unsqueeze(0).to(device)
14
+
15
+ with torch.no_grad():
16
+ landmarks = landmark_predictor_model(image_transformed).squeeze(0).cpu()
17
+ # Use 38 landmarks (19 pairs) to match the original LM-1 behavior
18
+ landmarks = landmarks[:38]
19
+ landmarks = landmarks.view(-1, 2)
20
+ return landmarks
21
+
22
+ def _extract_index_nparray(nparray):
23
+ """Helper function to extract index from numpy where clause."""
24
+ return nparray[0][0] if len(nparray[0]) > 0 else None
25
+
26
+ def _tensor_to_int_array(tensor):
27
+ """Converts a landmark tensor to a list of integer tuples."""
28
+ return [(int(x[0]), int(x[1])) for x in tensor.numpy()]
29
+
30
+ def ocular_morph_classical(img1_pil, img2_pil, landmarks1_tensor, landmarks2_tensor):
31
+ """Performs landmark-based morphing using Delaunay triangulation and seamless cloning."""
32
+ img1 = cv2.cvtColor(np.array(img1_pil), cv2.COLOR_RGB2BGR)
33
+ img2 = cv2.cvtColor(np.array(img2_pil), cv2.COLOR_RGB2BGR)
34
+
35
+ points1 = _tensor_to_int_array(landmarks1_tensor)
36
+ points2 = _tensor_to_int_array(landmarks2_tensor)
37
+
38
+ img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
39
+ # --- FIX: Define img2_gray, which was previously missing. ---
40
+ img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
41
+
42
+ mask = np.zeros_like(img1_gray)
43
+
44
+ points_np = np.array(points1, np.int32)
45
+ convexhull = cv2.convexHull(points_np)
46
+ cv2.fillConvexPoly(mask, convexhull, 255)
47
+
48
+ rect = cv2.boundingRect(convexhull)
49
+ subdiv = cv2.Subdiv2D(rect)
50
+ subdiv.insert(points1)
51
+ triangles = subdiv.getTriangleList()
52
+ triangles = np.array(triangles, dtype=np.int32)
53
+
54
+ indexes_triangles = []
55
+ for t in triangles:
56
+ pt1, pt2, pt3 = (t[0], t[1]), (t[2], t[3]), (t[4], t[5])
57
+ index_pt1 = _extract_index_nparray(np.where((points_np == pt1).all(axis=1)))
58
+ index_pt2 = _extract_index_nparray(np.where((points_np == pt2).all(axis=1)))
59
+ index_pt3 = _extract_index_nparray(np.where((points_np == pt3).all(axis=1)))
60
+
61
+ if all(idx is not None for idx in [index_pt1, index_pt2, index_pt3]):
62
+ indexes_triangles.append([index_pt1, index_pt2, index_pt3])
63
+
64
+ img2_new_face = np.zeros_like(img2)
65
+
66
+ for triangle_index in indexes_triangles:
67
+ tr1_pt1, tr1_pt2, tr1_pt3 = points1[triangle_index[0]], points1[triangle_index[1]], points1[triangle_index[2]]
68
+ tr2_pt1, tr2_pt2, tr2_pt3 = points2[triangle_index[0]], points2[triangle_index[1]], points2[triangle_index[2]]
69
+
70
+ triangle1 = np.array([tr1_pt1, tr1_pt2, tr1_pt3], np.int32)
71
+ triangle2 = np.array([tr2_pt1, tr2_pt2, tr2_pt3], np.int32)
72
+
73
+ rect1 = cv2.boundingRect(triangle1)
74
+ (x1, y1, w1, h1) = rect1
75
+ cropped_triangle = img1[y1: y1 + h1, x1: x1 + w1]
76
+ points_rel1 = np.array([[tr1_pt1[0] - x1, tr1_pt1[1] - y1], [tr1_pt2[0] - x1, tr1_pt2[1] - y1], [tr1_pt3[0] - x1, tr1_pt3[1] - y1]], np.float32)
77
+
78
+ rect2 = cv2.boundingRect(triangle2)
79
+ (x2, y2, w2, h2) = rect2
80
+ points_rel2 = np.array([[tr2_pt1[0] - x2, tr2_pt1[1] - y2], [tr2_pt2[0] - x2, tr2_pt2[1] - y2], [tr2_pt3[0] - x2, tr2_pt3[1] - y2]], np.float32)
81
+
82
+ M = cv2.getAffineTransform(points_rel1, points_rel2)
83
+ warped_triangle = cv2.warpAffine(cropped_triangle, M, (w2, h2))
84
+
85
+ cropped_tr2_mask = np.zeros((h2, w2), np.uint8)
86
+ cv2.fillConvexPoly(cropped_tr2_mask, np.int32(points_rel2), 255)
87
+
88
+ warped_triangle = cv2.bitwise_and(warped_triangle, warped_triangle, mask=cropped_tr2_mask)
89
+
90
+ img2_new_face_rect_area = img2_new_face[y2: y2 + h2, x2: x2 + w2]
91
+ img2_new_face_rect_area_gray = cv2.cvtColor(img2_new_face_rect_area, cv2.COLOR_BGR2GRAY)
92
+ _, mask_triangles_designed = cv2.threshold(img2_new_face_rect_area_gray, 1, 255, cv2.THRESH_BINARY_INV)
93
+ warped_triangle = cv2.bitwise_and(warped_triangle, warped_triangle, mask=mask_triangles_designed)
94
+
95
+ img2_new_face_rect_area = cv2.add(img2_new_face_rect_area, warped_triangle)
96
+ img2_new_face[y2: y2 + h2, x2: x2 + w2] = img2_new_face_rect_area
97
+
98
+ img2_face_mask = np.zeros_like(img2_gray)
99
+ convexhull2 = cv2.convexHull(np.array(points2, np.int32))
100
+ img2_head_mask = cv2.fillConvexPoly(img2_face_mask, convexhull2, 255)
101
+
102
+ (x, y, w, h) = cv2.boundingRect(convexhull2)
103
+ center_face2 = (int(x + w / 2), int(y + h / 2))
104
+
105
+ seamlessclone = cv2.seamlessClone(img2_new_face, img2, img2_head_mask, center_face2, cv2.NORMAL_CLONE)
106
+
107
+ return cv2.cvtColor(seamlessclone, cv2.COLOR_BGR2RGB)
apps/gan_morpher.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import ToTensor, Resize, Compose, Normalize
3
+ from utils import create_landmark_heatmaps
4
+
5
+ def predict_landmarks_for_gan(image, landmark_predictor_model, device):
6
+ """Predicts landmarks and formats them specifically for the GAN pipeline."""
7
+ transform = Compose([
8
+ Resize((256, 256)),
9
+ ToTensor(),
10
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
+ ])
12
+ image_transformed = transform(image).unsqueeze(0).to(device)
13
+
14
+ with torch.no_grad():
15
+ landmarks = landmark_predictor_model(image_transformed).squeeze(0).cpu()
16
+ landmarks = landmarks[:38] # Corresponds to 19 landmarks (x,y)
17
+ landmarks = landmarks.view(-1, 2)
18
+ # Normalize to [0, 1] range for heatmap generation
19
+ landmarks[:, 0] /= 256.0
20
+ landmarks[:, 1] /= 256.0
21
+ landmarks = landmarks.flatten()
22
+
23
+ return landmarks.unsqueeze(0) # Return with a batch dimension
24
+
25
+ def process_image_for_gan(image, config, device, models):
26
+ """Processes a single image to get its latent vector (z) and landmark features (lf)."""
27
+ image_tensor = Compose([
28
+ Resize((config['data']['image_size'], config['data']['image_size'])),
29
+ ToTensor(),
30
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
31
+ ])(image).unsqueeze(0).to(device)
32
+
33
+ landmarks = predict_landmarks_for_gan(image, models['landmark_predictor'], device).to(device)
34
+ heatmap = create_landmark_heatmaps(landmarks, image_size=config['data']['image_size']).to(device)
35
+
36
+ with torch.no_grad():
37
+ landmark_features = models['landmark_encoder'](landmarks)
38
+ z = models['netE'](image_tensor, heatmap)
39
+
40
+ return z, landmark_features
41
+
42
+ def morph_images_with_gan(image1, image2, config, device, models, alpha=0.5):
43
+ """Generates a morphed image using the GAN with a given interpolation factor."""
44
+ z1, lf1 = process_image_for_gan(image1, config, device, models)
45
+ z2, lf2 = process_image_for_gan(image2, config, device, models)
46
+
47
+ # Interpolate in both latent and landmark feature spaces
48
+ z_morph = (1 - alpha) * z1 + alpha * z2
49
+ lf_morph = (1 - alpha) * lf1 + alpha * lf2
50
+
51
+ with torch.no_grad():
52
+ morphed_image_tensor = models['netG'](z_morph, lf_morph)
53
+
54
+ # Denormalize from [-1, 1] to [0, 1] for display
55
+ morphed_image = (morphed_image_tensor * 0.5 + 0.5).clamp(0, 1)
56
+ morphed_image_numpy = morphed_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
57
+
58
+ return morphed_image_numpy
apps/utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from PIL import Image
5
+
6
+ # Local project imports
7
+ from models import Generator, Encoder, LandmarkEncoder
8
+ from models.landmark_predictor import OcularLMGenerator
9
+
10
+ def remove_module_prefix(state_dict):
11
+ """Removes the 'module.' prefix from state dict keys if it exists."""
12
+ new_state_dict = OrderedDict()
13
+ for k, v in state_dict.items():
14
+ name = k[7:] if k.startswith('module.') else k
15
+ new_state_dict[name] = v
16
+ return new_state_dict
17
+
18
+ def load_all_models(config, epoch, device):
19
+ """Loads and initializes all models needed for the Streamlit app."""
20
+ model_cfg = config['model']
21
+ data_cfg = config['data']
22
+ paths_cfg = config['paths']
23
+
24
+ # --- 1. Load the main GAN models (G, E, LE) ---
25
+ gan_models = {
26
+ 'G': Generator(nz=model_cfg['nz'], ngf=model_cfg['ngf'], nc=data_cfg['nc'], landmark_feature_size=model_cfg['landmark_feature_size']),
27
+ 'E': Encoder(nc=data_cfg['nc'], ndf=model_cfg['ndf'], nz=model_cfg['nz'], num_landmarks=model_cfg['num_landmarks']),
28
+ 'LE': LandmarkEncoder(input_dim=model_cfg['num_landmarks'] * 2, output_dim=model_cfg['landmark_feature_size'])
29
+ }
30
+ for name, model in gan_models.items():
31
+ model_path = os.path.join(paths_cfg['outputs'][name], f'{name.lower()}_epoch_{epoch}.pth')
32
+ print(f"Loading {name} model from: {model_path}")
33
+ state_dict = torch.load(model_path, map_location=device)
34
+ state_dict = remove_module_prefix(state_dict)
35
+ model.load_state_dict(state_dict)
36
+ model.to(device).eval()
37
+
38
+ # --- 2. Load the separate Landmark Predictor model ---
39
+ lp_path = paths_cfg['inputs']['landmark_predictor_model']
40
+ print(f"Loading Landmark Predictor from: {lp_path}")
41
+ landmark_predictor = OcularLMGenerator().to(device)
42
+ state_dict_lp = torch.load(lp_path, map_location=device)
43
+ state_dict_lp = remove_module_prefix(state_dict_lp)
44
+ landmark_predictor.load_state_dict(state_dict_lp)
45
+ landmark_predictor.eval()
46
+
47
+ # Return all models in a dictionary for easy access
48
+ return {
49
+ 'netG': gan_models['G'],
50
+ 'netE': gan_models['E'],
51
+ 'landmark_encoder': gan_models['LE'],
52
+ 'landmark_predictor': landmark_predictor
53
+ }
54
+
55
+ def load_image(file):
56
+ """Safely loads an image file into a PIL Image object."""
57
+ try:
58
+ image = Image.open(file).convert('RGB')
59
+ return image
60
+ except Exception as e:
61
+ raise ValueError(f"Error loading image: {str(e)}")
assets/1144_r_1.png ADDED

Git LFS Details

  • SHA256: 793102c6c55225ca3af28bb77cf6526ac29f2ccc23a4b2848ab3db5a17f46e72
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
assets/1147_r_1.png ADDED

Git LFS Details

  • SHA256: 2e4b7372287ade634ac68f44576d94983a7eead7dd82a59b61bd2b2159c72671
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
assets/1162_r_9.png ADDED

Git LFS Details

  • SHA256: b150989ef4391a4c03b14a234c16fce1cb27ced6154f774653ff840726982314
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
assets/1163_r_17.png ADDED

Git LFS Details

  • SHA256: 2b5e08fd6ff801759897de0887943ffe3766b8549c4059b140fef5ca06ed31c8
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
assets/1172_l_1.png ADDED

Git LFS Details

  • SHA256: d51f9b73b6e048f15089f375251a8b5599ec1146cd785bae05cfc3832f4cca4f
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
assets/1177_l_1.png ADDED

Git LFS Details

  • SHA256: 0da794a833e1020711fd2264948934203924dcdcfae7d3e507801956bdada2d9
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
assets/2517_r_9.png ADDED

Git LFS Details

  • SHA256: 358494196fca72f3c4e6b6ac62afa8232a989d1078c7239c61604d126839dc44
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
assets/3243_l_1.png ADDED
config/__init__.py ADDED
File without changes
config/config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config/config.yaml
2
+
3
+ # --- General Settings ---
4
+ project_name: "OcularMorph-DOOMGAN"
5
+ manual_seed: 42
6
+ ngpu: 1
7
+ use_deterministic_algorithms: true
8
+ device: "cuda:1"
9
+
10
+ # --- Data Settings ---
11
+ data:
12
+ image_root: "data/filtered_output"
13
+ landmark_json_path: "data/landmarks_GAN.json"
14
+ image_size: 256
15
+ nc: 3
16
+ workers: 4
17
+
18
+ # --- Model Hyperparameters ---
19
+ model:
20
+ nz: 200
21
+ ngf: 64
22
+ ndf: 64
23
+ num_landmarks: 19
24
+ landmark_feature_size: 128
25
+
26
+ # --- Training Hyperparameters ---
27
+ training:
28
+ num_epochs: 501
29
+ batch_size: 64
30
+ optimizer:
31
+ lr_g: 0.0002
32
+ lr_d: 0.00001
33
+ lr_e: 0.0002
34
+ lr_le: 0.0001
35
+ beta1: 0.5
36
+ beta2: 0.999
37
+ weight_decay: 0.00001
38
+ scheduler:
39
+ gamma_d: 0.9998
40
+ gamma_g: 0.9998
41
+ gamma_e: 0.9998
42
+ gamma_le: 0.9998
43
+ loss_weights:
44
+ gp: 10.0
45
+ initial_dynamic:
46
+ base: 50.0
47
+ ms_ssim: 30.0
48
+ perceptual: 50.0
49
+ reconstruction: 10.0
50
+ identity: 50.0
51
+ identity_diff: 40.0
52
+
53
+ # --- Paths ---
54
+ paths:
55
+ inputs:
56
+ arcface_model: "trained_models/resnet50_arcface.pth"
57
+ landmark_predictor_model: "trained_models/Ocular_LM_Generator.pth"
58
+ outputs:
59
+ G: "generator_models"
60
+ D: "discriminator_models"
61
+ E: "encoder_models"
62
+ LE: "landmark_encoder_models"
models/ResNet_Model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ class ResNet50_ArcFace(nn.Module):
6
+ """
7
+ ResNet-50 model modified for ArcFace loss.
8
+ Outputs embeddings that can be used for feature extraction.
9
+ """
10
+ def __init__(self, embedding_size=512, pretrained=True):
11
+ super(ResNet50_ArcFace, self).__init__()
12
+ self.embedding_size = embedding_size
13
+
14
+ # Load a pre-trained ResNet-50 model
15
+ self.backbone = models.resnet50(pretrained=pretrained)
16
+
17
+ # Modify the final fully connected layer
18
+ # Replace the last fully connected layer with a linear layer to get embeddings
19
+ in_features = self.backbone.fc.in_features
20
+ self.backbone.fc = nn.Linear(in_features, self.embedding_size)
21
+
22
+ # Normalize the embedding vectors
23
+ self.l2_norm = nn.functional.normalize
24
+
25
+ def forward(self, x):
26
+ x = self.backbone(x)
27
+ # Normalize embeddings to have unit length
28
+ x = self.l2_norm(x, p=2, dim=1)
29
+ return x
30
+
31
+ # Example usage
32
+ if __name__ == "__main__":
33
+ # Load config parameters if needed
34
+ import yaml
35
+ with open('config.yml', 'r') as f:
36
+ config = yaml.safe_load(f)
37
+
38
+ device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
39
+
40
+ model = ResNet50_ArcFace(
41
+ embedding_size=config['embedding_size'],
42
+ pretrained=True
43
+ ).to(device)
44
+
45
+ # Print model architecture
46
+ print(model)
47
+
48
+ # Test with a random input
49
+ dummy_input = torch.randn(1, 3, config['image_height'], config['image_width']).to(device)
50
+ embeddings = model(dummy_input)
51
+ print("Embeddings shape:", embeddings.shape)
models/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All Required Model imports
2
+ from .models import (
3
+ weights_init,
4
+ SelfAttention,
5
+ ResidualBlock,
6
+ Encoder,
7
+ Generator,
8
+ Discriminator,
9
+ LandmarkEncoder
10
+ )
11
+
12
+ # Import for the ArcFace model
13
+ try:
14
+ from .ResNet_Model import ResNet50_ArcFace
15
+ except ImportError:
16
+ ResNet50_ArcFace = None
17
+
18
+ # Import the LM Predictor model for App
19
+ try:
20
+ from .landmark_predictor import OcularLMGenerator
21
+ except ImportError:
22
+ OcularLMGenerator = None
models/__pycache__/ResNet_Model.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (515 Bytes). View file
 
models/__pycache__/landmark_predictor.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
models/__pycache__/models.cpython-310.pyc ADDED
Binary file (6.28 kB). View file
 
models/__pycache__/test_models.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
models/landmark_predictor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ class OcularLMGenerator(nn.Module):
6
+ def __init__(self):
7
+ super(OcularLMGenerator, self).__init__()
8
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
9
+ self.pool = nn.MaxPool2d(2, 2)
10
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
11
+ self.fc1 = nn.Linear(64 * 64 * 64, 500)
12
+ self.fc2 = nn.Linear(500, 66) # Output the maximum number of landmarks
13
+
14
+ def forward(self, x):
15
+ x = self.pool(F.relu(self.conv1(x)))
16
+ x = self.pool(F.relu(self.conv2(x)))
17
+ x = x.view(-1, 64 * 64 * 64)
18
+ x = F.relu(self.fc1(x))
19
+ x = self.fc2(x)
20
+ return x
21
+
22
+ if __name__ == "__main__":
23
+ model = OcularLMGenerator()
24
+ print(model)
models/models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def weights_init(m):
5
+ """
6
+ Applies custom weights initialization to a model's modules.
7
+ - Conv layers: He (Kaiming) normal initialization.
8
+ - InstanceNorm layers: Normal distribution for weights, constant for biases.
9
+ """
10
+ classname = m.__class__.__name__
11
+ if classname.find('Conv') != -1:
12
+ # Use a fan-in He initialization for Conv layers
13
+ nn.init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in')
14
+ elif classname.find('InstanceNorm') != -1:
15
+ if m.weight is not None:
16
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
17
+ if m.bias is not None:
18
+ nn.init.constant_(m.bias.data, 0)
19
+
20
+ class SelfAttention(nn.Module):
21
+ def __init__(self, in_channels):
22
+ super(SelfAttention, self).__init__()
23
+ self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
24
+ self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
25
+ self.value = nn.Conv2d(in_channels, in_channels, 1)
26
+ self.gamma = nn.Parameter(torch.zeros(1))
27
+ self.softmax = nn.Softmax(dim=-1)
28
+ def forward(self, x):
29
+ B, C, W, H = x.size()
30
+ proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
31
+ proj_key = self.key(x).view(B, -1, W * H)
32
+ attention = self.softmax(torch.bmm(proj_query, proj_key))
33
+ proj_value = self.value(x).view(B, -1, W * H)
34
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
35
+ out = out.view(B, C, W, H)
36
+ return self.gamma * out + x
37
+
38
+ class ResidualBlock(nn.Module):
39
+ def __init__(self, in_channels):
40
+ super(ResidualBlock, self).__init__()
41
+ self.block = nn.Sequential(
42
+ nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
43
+ nn.InstanceNorm2d(in_channels),
44
+ nn.ReLU(),
45
+ nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
46
+ nn.InstanceNorm2d(in_channels)
47
+ )
48
+ def forward(self, x):
49
+ return x + self.block(x)
50
+
51
+ class Encoder(nn.Module):
52
+ def __init__(self, nc=3, ndf=64, nz=200, num_landmarks=19):
53
+ super(Encoder, self).__init__()
54
+ in_channels = nc + num_landmarks
55
+ self.model = nn.Sequential(
56
+ nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2),
57
+ nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False), nn.InstanceNorm2d(ndf*2), nn.LeakyReLU(0.2),
58
+ nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False), nn.InstanceNorm2d(ndf*4), nn.LeakyReLU(0.2),
59
+ nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False), nn.InstanceNorm2d(ndf*8), nn.LeakyReLU(0.2),
60
+ ResidualBlock(ndf*8), SelfAttention(ndf*8),
61
+ nn.Conv2d(ndf*8, ndf*16, 4, 2, 1, bias=False), nn.InstanceNorm2d(ndf*16), nn.LeakyReLU(0.2),
62
+ nn.Conv2d(ndf*16, ndf*16, 4, 2, 1), nn.LeakyReLU(0.2),
63
+ nn.AdaptiveAvgPool2d(1),
64
+ nn.Conv2d(ndf * 16, nz, 1)
65
+ )
66
+ def forward(self, img, heatmaps):
67
+ return self.model(torch.cat([img, heatmaps], 1)).view(img.size(0), -1)
68
+
69
+ class Generator(nn.Module):
70
+ def __init__(self, nz=200, ngf=64, nc=3, landmark_feature_size=128):
71
+ super(Generator, self).__init__()
72
+ self.ngf = ngf
73
+ self.fc = nn.Sequential(
74
+ nn.Linear(nz + landmark_feature_size, ngf * 32 * 4 * 4),
75
+ nn.ReLU() # Removed inplace=True
76
+ )
77
+ def block(in_c, out_c):
78
+ return [nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False), nn.InstanceNorm2d(out_c), nn.ReLU()]
79
+ self.main = nn.Sequential(
80
+ ResidualBlock(ngf * 32), # 4x4
81
+ *block(ngf*32, ngf*16), # 8x8
82
+ *block(ngf*16, ngf*8), # 16x16
83
+ SelfAttention(ngf * 8),
84
+ *block(ngf*8, ngf*4), # 32x32
85
+ *block(ngf*4, ngf*2), # 64x64
86
+ *block(ngf*2, ngf), # 128x128
87
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1), nn.Tanh() # 256x256
88
+ )
89
+ def forward(self, z, landmark_features):
90
+ x = self.fc(torch.cat([z, landmark_features], 1)).view(-1, self.ngf*32, 4, 4)
91
+ return self.main(x)
92
+
93
+ class Discriminator(nn.Module):
94
+ def __init__(self, nc=3, ndf=64, num_landmarks=19):
95
+ super(Discriminator, self).__init__()
96
+ in_channels = nc + num_landmarks
97
+ def block(in_c, out_c, norm=True, dropout=True):
98
+ layers = [nn.utils.spectral_norm(nn.Conv2d(in_c, out_c, 4, 2, 1)) if norm else nn.Conv2d(in_c, out_c, 4, 2, 1)]
99
+ layers.append(nn.LeakyReLU(0.2))
100
+ if dropout: layers.append(nn.Dropout(0.5))
101
+ layers.append(ResidualBlock(out_c))
102
+ return layers
103
+ self.model = nn.ModuleList([
104
+ nn.Sequential(*block(in_channels, ndf, norm=False, dropout=False)), # 128
105
+ nn.Sequential(*block(ndf, ndf * 2)), # 64
106
+ nn.Sequential(*block(ndf*2, ndf * 4)), # 32
107
+ nn.Sequential(SelfAttention(ndf*4), *block(ndf*4, ndf * 8)), # 16
108
+ nn.Sequential(*block(ndf*8, ndf * 16)), # 8
109
+ ])
110
+ self.out_layers = nn.ModuleList([
111
+ nn.utils.spectral_norm(nn.Conv2d(ndf*2, 1, 3, 1, 1)),
112
+ nn.utils.spectral_norm(nn.Conv2d(ndf*4, 1, 3, 1, 1)),
113
+ nn.utils.spectral_norm(nn.Conv2d(ndf*8, 1, 3, 1, 1)),
114
+ nn.utils.spectral_norm(nn.Conv2d(ndf*16, 1, 3, 1, 1)),
115
+ ])
116
+ def forward(self, img, heatmaps):
117
+ x = torch.cat([img, heatmaps], 1)
118
+ outputs = []
119
+ for i, layer in enumerate(self.model):
120
+ x = layer(x)
121
+ if i > 0: outputs.append(self.out_layers[i-1](x))
122
+ return outputs
123
+
124
+ class LandmarkEncoder(nn.Module):
125
+ def __init__(self, input_dim, output_dim):
126
+ super(LandmarkEncoder, self).__init__()
127
+ self.encoder = nn.Sequential(
128
+ nn.Linear(input_dim, 128), nn.LeakyReLU(0.2),
129
+ nn.Linear(128, output_dim), nn.LeakyReLU(0.2)
130
+ )
131
+ def forward(self, landmarks):
132
+ return self.encoder(landmarks)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.2.1+cu118
3
+ torchvision==0.17.1+cu118
4
+ gradio
5
+ huggingface_hub
6
+ pyyaml
7
+ numpy
8
+ Pillow
9
+ scipy
10
+ opencv-python-headless