Spaces:
Sleeping
Sleeping
gpu
Browse files- app.py +16 -37
- resources/mean_vector_list_ae_model_tf_2024-03-05_00-35-21.pth.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=09-train_loss=1.00.ckpt.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=29-train_loss=1.01.ckpt.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=49-train_loss=1.01.ckpt.npy +3 -0
- utils.py +1 -1
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
| 8 |
from PIL import Image
|
| 9 |
import base64
|
| 10 |
from io import BytesIO
|
|
|
|
| 11 |
|
| 12 |
import dataset
|
| 13 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
|
@@ -92,8 +93,7 @@ model_index = 0
|
|
| 92 |
|
| 93 |
# ヒートマップの生成関数
|
| 94 |
@spaces.GPU
|
| 95 |
-
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
| 96 |
-
global model_index, mean_vector_list
|
| 97 |
if type(uploaded_image) == str:
|
| 98 |
uploaded_image = Image.open(uploaded_image)
|
| 99 |
if type(source_num) == str:
|
|
@@ -102,6 +102,13 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
| 102 |
x_coords = int(x_coords)
|
| 103 |
if type(y_coords) == str:
|
| 104 |
y_coords = int(y_coords)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
dec5, _ = models[model_index](x)
|
| 107 |
feature_map = dec5
|
|
@@ -138,24 +145,6 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
| 138 |
plt.close(fig)
|
| 139 |
return fig
|
| 140 |
|
| 141 |
-
@spaces.GPU
|
| 142 |
-
def setup(model_info, input_image=None):
|
| 143 |
-
global model_index, mean_vector_list
|
| 144 |
-
# str -> dictに変換
|
| 145 |
-
if type(model_info) == str:
|
| 146 |
-
model_info = eval(model_info)
|
| 147 |
-
|
| 148 |
-
model_index = models_info.index(model_info)
|
| 149 |
-
|
| 150 |
-
feature_map, _ = models[model_index](test_imgs)
|
| 151 |
-
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
| 152 |
-
|
| 153 |
-
if input_image is not None:
|
| 154 |
-
fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
|
| 155 |
-
return fig
|
| 156 |
-
|
| 157 |
-
print("setup done.")
|
| 158 |
-
|
| 159 |
with gr.Blocks() as demo:
|
| 160 |
# title
|
| 161 |
gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
|
|
@@ -168,24 +157,17 @@ with gr.Blocks() as demo:
|
|
| 168 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
| 169 |
|
| 170 |
gr.Markdown("## Heatmap Visualization")
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
| 173 |
-
output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
|
| 174 |
-
with gr.Row():
|
| 175 |
-
with gr.Column():
|
| 176 |
-
with gr.Row():
|
| 177 |
-
model_name = gr.Dropdown(
|
| 178 |
-
choices=[str(model_info) for model_info in models_info],
|
| 179 |
-
container=False
|
| 180 |
-
)
|
| 181 |
-
load_button = gr.Button("Load Model")
|
| 182 |
-
load_button.click(setup, inputs=[model_name, input_image], outputs=[output_plot])
|
| 183 |
-
with gr.Row():
|
| 184 |
-
pass
|
| 185 |
-
|
| 186 |
inference = gr.Interface(
|
| 187 |
get_heatmaps,
|
| 188 |
inputs=[
|
|
|
|
| 189 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
| 190 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
|
| 191 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
|
@@ -205,8 +187,5 @@ with gr.Blocks() as demo:
|
|
| 205 |
inputs=[input_image],
|
| 206 |
)
|
| 207 |
|
| 208 |
-
setup(models_info[0])
|
| 209 |
-
print(mean_vector_list)
|
| 210 |
-
|
| 211 |
demo.launch()
|
| 212 |
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
import base64
|
| 10 |
from io import BytesIO
|
| 11 |
+
import os
|
| 12 |
|
| 13 |
import dataset
|
| 14 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
|
|
|
| 93 |
|
| 94 |
# ヒートマップの生成関数
|
| 95 |
@spaces.GPU
|
| 96 |
+
def get_heatmaps(model_info, source_num, x_coords, y_coords, uploaded_image):
|
|
|
|
| 97 |
if type(uploaded_image) == str:
|
| 98 |
uploaded_image = Image.open(uploaded_image)
|
| 99 |
if type(source_num) == str:
|
|
|
|
| 102 |
x_coords = int(x_coords)
|
| 103 |
if type(y_coords) == str:
|
| 104 |
y_coords = int(y_coords)
|
| 105 |
+
|
| 106 |
+
if type(model_info) == str:
|
| 107 |
+
model_info = eval(model_info)
|
| 108 |
+
model_index = models_info.index(model_info)
|
| 109 |
+
|
| 110 |
+
mean_vector_list = np.load(f"resources/mean_vector_list_{model_info['name']}.npy", allow_pickle=True)
|
| 111 |
+
mean_vector_list = torch.tensor(mean_vector_list).to(device)
|
| 112 |
|
| 113 |
dec5, _ = models[model_index](x)
|
| 114 |
feature_map = dec5
|
|
|
|
| 145 |
plt.close(fig)
|
| 146 |
return fig
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
with gr.Blocks() as demo:
|
| 149 |
# title
|
| 150 |
gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
|
|
|
|
| 157 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
| 158 |
|
| 159 |
gr.Markdown("## Heatmap Visualization")
|
| 160 |
+
|
| 161 |
+
model_info = gr.Dropdown(
|
| 162 |
+
choices=[str(model_info) for model_info in models_info],
|
| 163 |
+
container=False
|
| 164 |
+
)
|
| 165 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
| 166 |
+
output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
inference = gr.Interface(
|
| 168 |
get_heatmaps,
|
| 169 |
inputs=[
|
| 170 |
+
model_info,
|
| 171 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
| 172 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
|
| 173 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
|
|
|
| 187 |
inputs=[input_image],
|
| 188 |
)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
demo.launch()
|
| 191 |
|
resources/mean_vector_list_ae_model_tf_2024-03-05_00-35-21.pth.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a8f81b924edda413139a5743408c6f38ebd7930b5d39cc98a6b4dd49bd42dae
|
| 3 |
+
size 3328
|
resources/mean_vector_list_autoencoder-epoch=09-train_loss=1.00.ckpt.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21c0d7cd81ee6e7fac9f0333209daec9e74a7c1e72e358b7732e8ecb3efea5f2
|
| 3 |
+
size 6528
|
resources/mean_vector_list_autoencoder-epoch=29-train_loss=1.01.ckpt.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9edb6ad2b6a0b9121905ea04ac1a39618d709329045c9cf00673f6281fc412c
|
| 3 |
+
size 6528
|
resources/mean_vector_list_autoencoder-epoch=49-train_loss=1.01.ckpt.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4141deceac789a26c3d6cf02f20068e8caf37fb14b8af3f63ae365b32462d76f
|
| 3 |
+
size 6528
|
utils.py
CHANGED
|
@@ -132,7 +132,7 @@ def get_mean_vector(feature_map, points):
|
|
| 132 |
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
|
| 133 |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
| 134 |
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
|
| 135 |
-
mean_vector = vectors.mean(0)
|
| 136 |
mean_vector_list.append(mean_vector)
|
| 137 |
return mean_vector_list
|
| 138 |
|
|
|
|
| 132 |
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
|
| 133 |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
| 134 |
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
|
| 135 |
+
mean_vector = vectors.mean(0).detach().cpu().numpy()
|
| 136 |
mean_vector_list.append(mean_vector)
|
| 137 |
return mean_vector_list
|
| 138 |
|