xinjie.wang commited on
Commit
53cdb99
·
1 Parent(s): 7b5cba5
Files changed (3) hide show
  1. app.py +9 -9
  2. common.py +1 -1
  3. embodied_gen/models/sam3d.py +4 -2
app.py CHANGED
@@ -41,11 +41,11 @@ from common import (
41
 
42
  app_name = os.getenv("GRADIO_APP")
43
  if app_name == "imageto3d_sam3d":
44
- enable_pre_resize = gr.State(False)
45
  sample_step = 25
46
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
47
  elif app_name == "imageto3d":
48
- enable_pre_resize = gr.State(True)
49
  sample_step = 12
50
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
51
 
@@ -77,7 +77,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
77
  ),
78
  elem_classes=["header"],
79
  )
80
-
81
  with gr.Row():
82
  with gr.Column(scale=3):
83
  with gr.Tabs() as input_tabs:
@@ -271,9 +271,9 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
271
  glob("assets/example_image/*")
272
  )
273
  ],
274
- inputs=[image_prompt, rmbg_tag, enable_pre_resize],
275
- fn=preprocess_image_fn,
276
- outputs=[image_prompt, raw_image_cache, enable_pre_resize],
277
  run_on_click=True,
278
  examples_per_page=10,
279
  cache_examples=True,
@@ -339,9 +339,9 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
339
  )
340
 
341
  image_prompt.upload(
342
- preprocess_image_fn,
343
- inputs=[image_prompt, rmbg_tag, enable_pre_resize],
344
- outputs=[image_prompt, raw_image_cache, enable_pre_resize],
345
  )
346
  image_prompt.change(
347
  lambda: tuple(
 
41
 
42
  app_name = os.getenv("GRADIO_APP")
43
  if app_name == "imageto3d_sam3d":
44
+ _enable_pre_resize_default = False
45
  sample_step = 25
46
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
47
  elif app_name == "imageto3d":
48
+ _enable_pre_resize_default = True
49
  sample_step = 12
50
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
51
 
 
77
  ),
78
  elem_classes=["header"],
79
  )
80
+ enable_pre_resize = gr.State(_enable_pre_resize_default)
81
  with gr.Row():
82
  with gr.Column(scale=3):
83
  with gr.Tabs() as input_tabs:
 
271
  glob("assets/example_image/*")
272
  )
273
  ],
274
+ inputs=[image_prompt, rmbg_tag],
275
+ fn=lambda img, rmbg: preprocess_image_fn(img, rmbg, _enable_pre_resize_default),
276
+ outputs=[image_prompt, raw_image_cache],
277
  run_on_click=True,
278
  examples_per_page=10,
279
  cache_examples=True,
 
339
  )
340
 
341
  image_prompt.upload(
342
+ lambda img, rmbg: preprocess_image_fn(img, rmbg, _enable_pre_resize_default),
343
+ inputs=[image_prompt, rmbg_tag],
344
+ outputs=[image_prompt, raw_image_cache],
345
  )
346
  image_prompt.change(
347
  lambda: tuple(
common.py CHANGED
@@ -177,7 +177,7 @@ def preprocess_image_fn(
177
  if preprocess:
178
  image = trellis_preprocess(image)
179
 
180
- return image, image_cache, preprocess
181
 
182
 
183
  def preprocess_sam_image_fn(
 
177
  if preprocess:
178
  image = trellis_preprocess(image)
179
 
180
+ return image, image_cache
181
 
182
 
183
  def preprocess_sam_image_fn(
embodied_gen/models/sam3d.py CHANGED
@@ -22,7 +22,8 @@ import sys
22
 
23
  import numpy as np
24
  from hydra.utils import instantiate
25
- from modelscope import snapshot_download
 
26
  from omegaconf import OmegaConf
27
  from PIL import Image
28
 
@@ -65,7 +66,8 @@ class Sam3dInference:
65
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
66
  ) -> None:
67
  if not os.path.exists(local_dir):
68
- snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
 
69
  config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
70
  config = OmegaConf.load(config_file)
71
  config.rendering_engine = "nvdiffrast"
 
22
 
23
  import numpy as np
24
  from hydra.utils import instantiate
25
+ # from modelscope import snapshot_download
26
+ from huggingface_hub import snapshot_download
27
  from omegaconf import OmegaConf
28
  from PIL import Image
29
 
 
66
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
67
  ) -> None:
68
  if not os.path.exists(local_dir):
69
+ # snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
70
+ snapshot_download(repo_id="tuandao-zenai/sam-3d-objects", local_dir=local_dir)
71
  config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
72
  config = OmegaConf.load(config_file)
73
  config.rendering_engine = "nvdiffrast"