GeoRemover / app.py
zixinz
chore: ignore pyc and __pycache__
a01e858
import gradio as gr
import spaces
import sys, pathlib
BASE_DIR = pathlib.Path(__file__).resolve().parent
LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src"
# Ensure local diffusers is importable
if (LOCAL_DIFFUSERS_SRC / "diffusers").exists():
sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC))
else:
raise RuntimeError(f"Local diffusers not found at: {LOCAL_DIFFUSERS_SRC}")
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)
# ==== STAGE-2 ONLY ADDED: import Stage-2 Pipeline (do not touch Stage-1) ====
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import (
FluxFillPipeline_token12_depth as FluxFillPipelineStage2,
)
# ===========================================================================
import os
import subprocess
import random
from typing import Optional, Tuple, Dict, Any
import torch
from PIL import Image, ImageOps
import numpy as np
import cv2
# ---------------- Paths & assets ----------------
BASE_DIR = pathlib.Path(__file__).resolve().parent
CODE_DEPTH = BASE_DIR / "code_depth"
CODE_EDIT = BASE_DIR / "code_edit"
GET_ASSETS = BASE_DIR / "get_assets.sh"
EXPECTED_ASSETS = [
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth",
BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth",
BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors",
BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors",
]
# Import depth helper
if str(CODE_DEPTH) not in sys.path:
sys.path.insert(0, str(CODE_DEPTH))
from depth_infer import DepthModel # noqa: E402
# Import your custom diffusers (local fork)
if str(CODE_EDIT / "diffusers") not in sys.path:
sys.path.insert(0, str(CODE_EDIT / "diffusers"))
from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402
FluxFillPipeline_token12_depth_only as FluxFillPipeline,
)
# ---------------- Asset preparation (on-demand) ----------------
def _have_all_assets() -> bool:
return all(p.is_file() for p in EXPECTED_ASSETS)
def _ensure_executable(p: pathlib.Path):
if not p.exists():
raise FileNotFoundError(f"Not found: {p}")
os.chmod(p, os.stat(p).st_mode | 0o111)
def ensure_assets_if_missing():
"""
If SKIP_ASSET_DOWNLOAD=1 -> skip checks.
Otherwise ensure checkpoints/LoRAs exist; if missing, run get_assets.sh.
"""
if os.getenv("SKIP_ASSET_DOWNLOAD") == "1":
print("↪️ SKIP_ASSET_DOWNLOAD=1 -> skip asset download check")
return
if _have_all_assets():
print("✅ Assets already present")
return
print("⬇️ Missing assets, running get_assets.sh ...")
_ensure_executable(GET_ASSETS)
subprocess.run(
["bash", str(GET_ASSETS)],
check=True,
cwd=str(BASE_DIR),
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
)
if not _have_all_assets():
missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()]
raise RuntimeError(f"Assets missing after get_assets.sh: {missing}")
print("✅ Assets ready.")
try:
ensure_assets_if_missing()
except Exception as e:
print(f"⚠️ Asset prepare failed: {e}")
# ---------------- Global singletons ----------------
_MODELS: Dict[str, DepthModel] = {}
_PIPE: Optional[FluxFillPipeline] = None
# ==== STAGE-2 ONLY ADDED: singleton ====
_PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None
# ======================================
def get_model(encoder: str) -> DepthModel:
if encoder not in _MODELS:
_MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
return _MODELS[encoder]
def get_pipe() -> FluxFillPipeline:
"""
Load Stage-1 pipeline (FluxFillPipeline_token12_depth_only) and mount Stage-1 LoRA if present.
"""
global _PIPE
if _PIPE is not None:
return _PIPE
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
local_flux = BASE_DIR / "code_edit" / "flux_cache"
use_local = local_flux.exists()
hf_token = os.environ.get("HF_TOKEN")
try:
from huggingface_hub import hf_hub_enable_hf_transfer
hf_hub_enable_hf_transfer()
except Exception:
pass
print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
try:
if use_local:
pipe = FluxFillPipeline.from_pretrained(local_flux, torch_dtype=dtype).to(device)
else:
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
torch_dtype=dtype,
token=hf_token,
).to(device)
except Exception as e:
raise RuntimeError(
"Failed to load FLUX.1-Fill-dev. "
"Ensure gated access and HF_TOKEN; or pre-download to local cache."
) from e
# -------- LoRA (Stage-1) --------
lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800"
lora_file = "pytorch_lora_weights.safetensors"
adapter_name = "stage1"
if lora_dir.exists():
try:
import peft # assert backend presence
print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}")
pipe.load_lora_weights(
str(lora_dir),
weight_name=lora_file,
adapter_name=adapter_name,
)
try:
pipe.set_adapters(adapter_name, scale=1.0)
print(f"[pipe] set_adapters('{adapter_name}', 1.0)")
except Exception as e_set:
print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()")
try:
pipe.fuse_lora(lora_scale=1.0)
print("[pipe] fuse_lora(lora_scale=1.0) done")
except Exception as e_fuse:
print(f"[pipe] fuse_lora failed: {e_fuse}")
print("[pipe] LoRA ready ✅")
except ImportError:
print("[pipe] peft not installed; LoRA skipped (add `peft>=0.11`).")
except Exception as e:
print(f"[pipe] load_lora_weights failed (continue without): {e}")
else:
print(f"[pipe] LoRA path not found: {lora_dir} (continue without)")
_PIPE = pipe
return pipe
# ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ====
def get_pipe_stage2() -> FluxFillPipelineStage2:
"""
Load Stage-2 FluxFillPipeline_token12_depth and mount Stage-2 LoRA.
"""
global _PIPE_STAGE2
if _PIPE_STAGE2 is not None:
return _PIPE_STAGE2
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
local_flux = BASE_DIR / "code_edit" / "flux_cache"
use_local = local_flux.exists()
hf_token = os.environ.get("HF_TOKEN")
try:
from huggingface_hub import hf_hub_enable_hf_transfer
hf_hub_enable_hf_transfer()
except Exception:
pass
print(f"[stage2] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})")
try:
if use_local:
pipe2 = FluxFillPipelineStage2.from_pretrained(local_flux, torch_dtype=dtype).to(device)
else:
pipe2 = FluxFillPipelineStage2.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
torch_dtype=dtype,
token=hf_token,
).to(device)
except Exception as e:
raise RuntimeError("Stage-2: Failed to load FLUX.1-Fill-dev.") from e
# Load Stage-2 LoRA
lora_dir2 = CODE_EDIT / "stage2" / "checkpoint-20000"
candidate_names = [
"pytorch_lora_weights.safetensors",
"adapter_model.safetensors",
"lora.safetensors",
]
weight_name = None
for name in candidate_names:
if (lora_dir2 / name).is_file():
weight_name = name
break
if not lora_dir2.exists():
raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}")
if weight_name is None:
raise RuntimeError(
f"Stage-2 LoRA weight not found under {lora_dir2}. Tried: {candidate_names}"
)
try:
import peft # noqa: F401
except Exception as e:
raise RuntimeError("peft is not installed (requires peft>=0.11).") from e
try:
print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}")
pipe2.load_lora_weights(
str(lora_dir2),
weight_name=weight_name,
adapter_name="stage2",
)
try:
pipe2.set_adapters("stage2", scale=1.0)
print("[stage2] set_adapters('stage2', 1.0)")
except Exception as e_set:
print(f"[stage2] set_adapters not available ({e_set}); trying fuse_lora()")
try:
pipe2.fuse_lora(lora_scale=1.0)
print("[stage2] fuse_lora(lora_scale=1.0) done")
except Exception as e_fuse:
raise RuntimeError(f"Stage-2 fuse_lora failed: {e_fuse}") from e_fuse
except Exception as e:
raise RuntimeError(f"Stage-2 LoRA load failed: {e}") from e
_PIPE_STAGE2 = pipe2
return pipe2
# ==========================================================================
# ---------------- Mask helpers ----------------
def to_grayscale_mask(im: Image.Image) -> Image.Image:
"""
Convert any RGBA/RGB/L image to L mode.
Output: white = region to remove/fill, black = keep.
"""
if im.mode == "RGBA":
mask = im.split()[-1] # alpha as mask
else:
mask = im.convert("L")
# Simple binarization & denoise
mask = mask.point(lambda p: 255 if p > 16 else 0)
return mask # Do not invert; white = mask region
def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image:
"""Dilate the white region by ~px pixels."""
if px <= 0:
return mask_l
arr = np.array(mask_l, dtype=np.uint8)
kernel = np.ones((3, 3), np.uint8)
iters = max(1, int(px // 2)) # heuristic
dilated = cv2.dilate(arr, kernel, iterations=iters)
return Image.fromarray(dilated, mode="L")
def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image:
"""
Extract "pure red strokes" as a binary mask (white=brush, black=others) from RGBA/RGB.
Thresholds are lenient to tolerate compression/resampling.
"""
arr = np.array(img.convert("RGBA"))
r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3]
red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0)
mask = (red_hit.astype(np.uint8) * 255)
m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST)
return m
def pick_mask(
upload_mask: Optional[Image.Image],
sketch_data: Optional[dict],
base_image: Image.Image,
dilate_px: int = 0,
) -> Optional[Image.Image]:
"""
Selection rules:
1) If a mask is uploaded: use it directly (white=mask)
2) Else from ImageEditor output, only red strokes are recognized as mask:
- Try sketch_data['mask'] first (some versions provide it)
- Else merge red strokes from sketch_data['layers'][*]['image']
- If still none, try sketch_data['composite'] for red strokes
"""
# 1) Uploaded mask has highest priority
if isinstance(upload_mask, Image.Image):
m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2) Hand-drawn (ImageEditor)
if isinstance(sketch_data, dict):
# 2a) explicit mask (still supported)
m = sketch_data.get("mask")
if isinstance(m, Image.Image):
m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST)
return dilate_mask(m, dilate_px) if dilate_px > 0 else m
# 2b) merge red strokes from layers
layers = sketch_data.get("layers")
acc = None
if isinstance(layers, list) and layers:
acc = Image.new("L", base_image.size, 0)
for lyr in layers:
if not isinstance(lyr, dict):
continue
li = lyr.get("image") or lyr.get("mask")
if isinstance(li, Image.Image):
m_layer = _mask_from_red(li, base_image.size)
acc = ImageOps.lighter(acc, m_layer) # union
if acc.getbbox() is not None:
return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc
# 2c) finally, search composite for red strokes
comp = sketch_data.get("composite")
if isinstance(comp, Image.Image):
m_comp = _mask_from_red(comp, base_image.size)
if m_comp.getbbox() is not None:
return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp
# 3) No valid mask
return None
def _round_mult64(x: float, mode: str = "nearest") -> int:
"""
Align x to a multiple of 64:
- mode="ceil" round up
- mode="floor" round down
- mode="nearest" nearest multiple
"""
if mode == "ceil":
return int((x + 63) // 64) * 64
elif mode == "floor":
return int(x // 64) * 64
else: # nearest
return int((x + 32) // 64) * 64
def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]:
"""
Steps:
1) Round w,h up to multiples of 64 (avoid too-small sizes)
2) Fix the long side to target_max (default 1024)
3) Scale the short side proportionally and align to a multiple of 64 (>= 64)
"""
w, h = img.size
w1 = max(64, _round_mult64(w, mode="ceil"))
h1 = max(64, _round_mult64(h, mode="ceil"))
if w1 >= h1:
out_w = target_max
scaled_h = h1 * (target_max / w1)
out_h = max(64, _round_mult64(scaled_h, mode="nearest"))
else:
out_h = target_max
scaled_w = w1 * (target_max / h1)
out_w = max(64, _round_mult64(scaled_w, mode="nearest"))
return int(out_w), int(out_h)
@spaces.GPU
# ---------------- Preview depth for canvas (colored) ----------------
def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool):
if image is None:
return None
dm = get_model(encoder)
d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False)
return d_rgb
def prepare_canvas(image, depth_img, source):
base = depth_img if source == "depth" else image
if base is None:
raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".')
return gr.update(value=base)
# ---------------- Stage-1: depth(color) -> fill ----------------
@spaces.GPU
def run_depth_and_fill(
image: Image.Image,
mask_upload: Optional[Image.Image],
sketch: Optional[dict],
prompt: str,
encoder: str,
max_res: int,
input_size: int,
fp32: bool,
max_side: int,
mask_dilate_px: int,
guidance_scale: float,
steps: int,
seed: Optional[int],
) -> Tuple[Image.Image, Image.Image]:
if image is None:
raise gr.Error("Please upload an image first.")
# 1) produce a colored depth map (RGB)
depth_model = get_model(encoder)
depth_rgb: Image.Image = depth_model.infer(
image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False
).convert("RGB")
print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}")
# 2) extract mask (uploaded > drawn)
mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px)
if (mask_l is None) or (mask_l.getbbox() is None):
raise gr.Error("No valid mask detected: please draw with the red brush or upload a binary mask.")
print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}")
# 3) decide output size
width, height = prepare_size_for_flux(depth_rgb, target_max=max_side)
orig_w, orig_h = image.size
print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}")
# 4) run FLUX pipeline (key: use depth_rgb as both image and depth input)
pipe = get_pipe()
generator = (
torch.Generator("cpu").manual_seed(int(seed))
if (seed is not None and seed >= 0)
else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
)
result = pipe(
prompt=prompt,
image=depth_rgb, # use the colored depth map instead of original image
mask_image=mask_l,
width=width,
height=height,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
max_sequence_length=512,
generator=generator,
depth=depth_rgb, # feed depth (colored)
).images[0]
final_result = result.resize((orig_w, orig_h), Image.BICUBIC)
# return result and mask preview
mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB")
return final_result, mask_preview
def _to_pil_rgb(img_like) -> Image.Image:
"""Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array."""
if isinstance(img_like, Image.Image):
return img_like.convert("RGB")
try:
arr = np.array(img_like)
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=-1)
return Image.fromarray(arr.astype(np.uint8), mode="RGB")
except Exception:
raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image object.")
# ---------------- Stage-2: REQUIRED refine/render ----------------
@spaces.GPU
def run_stage2_refine(
image: Image.Image, # original image (RGB)
stage1_out: Image.Image, # output from Stage-1
depth_img_from_stage1_input: Image.Image, # Stage-1 depth preview (from UI)
mask_upload: Optional[Image.Image],
sketch: Optional[dict],
prompt: str,
encoder: str,
max_res: int,
input_size: int,
fp32: bool,
max_side: int,
guidance_scale: float,
steps: int,
seed: Optional[int],
) -> Image.Image:
if image is None or stage1_out is None:
raise gr.Error("Please complete Stage-1 first (needs original image and Stage-1 output).")
# Allow refine without mask (use all-black)
mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0)
if (mask_l is None) or (mask_l.getbbox() is None):
mask_l = Image.new("L", image.size, 0)
# Unify sizes
width, height = prepare_size_for_flux(image, target_max=max_side)
orig_w, orig_h = image.size
pipe2 = get_pipe_stage2()
g2 = (
torch.Generator("cpu").manual_seed(int(seed))
if (seed is not None and seed >= 0)
else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1))
)
depth_pil = _to_pil_rgb(stage1_out) # for `depth`
depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input) # for `depth_image`
image_rgb = _to_pil_rgb(image)
# Resize to (width, height)
depth_pil = depth_pil.resize((width, height), Image.BICUBIC)
depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC)
# Mapping:
# image = original RGB
# depth = Stage-1 output (updated geometry)
# depth_image = Stage-1 input depth (UI depth preview)
out2 = pipe2(
prompt=prompt,
image=image, # original image
mask_image=mask_l,
width=width,
height=height,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
max_sequence_length=512,
generator=g2,
depth=depth_pil,
depth_image=depth_image_pil,
).images[0]
out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC) # keep your 3× showcase layout
return out2
# ---------------- UI ----------------
with gr.Blocks() as demo:
gr.Markdown(
"""
# GeoRemover · Depth-Guided Object Removal (Two-Stage, Stage-2 REQUIRED)
**Pipeline overview**
1) Compute a **colored depth map** from your input image.
2) You create a **removal mask** (red brush or upload).
3) **Stage-1** runs FLUX Fill with depth guidance to get a first pass.
4) **Stage-2 (REQUIRED)** renders the final result from depth → image using Stage-1 output and the original depth.
> ⚠️ **Stage-2 is required.** Always click **Run Stage-2 (Render)** *after* Stage-1 finishes. Stage-1 alone is not the final output.
---
### Quick start
1. **Upload image** (left). Wait for **Depth preview (colored)** (right).
2. In **Draw mask**, pick **Draw on: _image_** or **_depth_**, then click **Prepare canvas**.
3. Paint the region to remove using the **red brush** (**red = remove**).
4. Optionally adjust **Mask dilation** for thin edges.
5. Enter a concise **Prompt** describing the fill content.
6. Click **Run** → produces **Stage-1** (first pass).
7. Click **Run Stage-2 (Render)** → produces the **final** result.
---
### Mask rules & tips
- Only **red strokes** are treated as mask (**white = remove, black = keep** internally).
- Paint **slightly larger** than the object boundary to avoid seams/halos.
- If you have a binary mask already, use **Upload mask**.
- **Mask dilation (px)** expands the mask to cover thin borders.
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input image
img = gr.Image(
label="Upload image",
type="pil",
)
# Mask: upload or draw
with gr.Tab("Upload mask"):
mask_upload = gr.Image(
label="Mask (optional)",
type="pil",
)
with gr.Tab("Draw mask"):
draw_source = gr.Radio(
["image", "depth"],
value="image",
label="Draw on",
)
prepare_btn = gr.Button("Prepare canvas", variant="secondary")
gr.Markdown(
"""
**Canvas usage**
- Click **Prepare canvas** after selecting *image* or *depth*.
- Use the **red brush** only—red strokes are extracted as the removal mask.
- Switch tabs anytime if you prefer uploading a ready-made mask.
"""
)
sketch = gr.ImageEditor(
label="Sketch mask (red = remove)",
type="pil",
brush=gr.Brush(colors=["#FF0000"], default_size=24),
)
# Prompt
prompt = gr.Textbox(
label="Prompt",
value="A beautiful scene",
placeholder="don't change it",
)
# Tunables
with gr.Accordion("Advanced (Depth & FLUX)", open=False):
encoder = gr.Dropdown(
["vits", "vitl"],
value="vitl",
label="Depth encoder",
)
max_res = gr.Slider(
512, 2048, value=1280, step=64,
label="Depth: max_res",
)
input_size = gr.Slider(
256, 1024, value=518, step=2,
label="Depth: input_size",
)
fp32 = gr.Checkbox(
False,
label="Depth: use FP32 (default FP16)",
)
max_side = gr.Slider(
512, 1536, value=1024, step=64,
label="FLUX: max side (px)",
)
mask_dilate_px = gr.Slider(
0, 128, value=0, step=1,
label="Mask dilation (px)",
)
guidance_scale = gr.Slider(
0, 50, value=30, step=0.5,
label="FLUX: guidance_scale",
)
steps = gr.Slider(
10, 75, value=50, step=1,
label="FLUX: steps",
)
seed = gr.Number(
value=0, precision=0,
label="Seed (>=0 = fixed; empty = random)",
)
run_btn = gr.Button("Run", variant="primary")
# Stage-2 is REQUIRED: keep disabled until Stage-1 finishes
run_btn_stage2 = gr.Button("Run Stage-2 (Render)", variant="secondary", interactive=False)
with gr.Column(scale=1):
depth_preview = gr.Image(
label="Depth preview (colored)",
interactive=False,
)
mask_preview = gr.Image(
label="Mask preview (areas to remove)",
interactive=False,
)
out = gr.Image(
label="Output (Stage-1 first pass)",
)
out_stage2 = gr.Image(
label="Final Output (Stage-2)",
)
gr.Markdown(
"""
### Why Stage-2 is required
Stage-1 provides a depth-guided fill that is *not final*. **Stage-2 renders** the definitive image by leveraging:
- **Stage-1 output** as updated geometry hints, and
- **Original colored depth** as `depth_image` guidance.
Skipping Stage-2 will leave the process incomplete.
### Troubleshooting
- **“No valid mask detected”**: Either upload a binary mask (white=remove) **or** draw with **red brush** after clicking **Prepare canvas**.
- **Seams/halos**: Increase **Mask dilation (px)** (e.g., 8–16) and re-run both stages.
- **Prompt not followed**: Lower **guidance_scale** (e.g., 18–24) and make the prompt more concrete.
- **Depth looks noisy**: Use **vitl**, increase **Depth: max_res**, or enable **FP32**.
"""
)
# ===== Helpers to toggle Stage-2 button =====
def _enable_button():
return gr.update(interactive=True)
# Auto depth preview on image change
img.change(
fn=preview_depth,
inputs=[img, encoder, max_res, input_size, fp32],
outputs=[depth_preview],
)
# Prepare canvas for drawing on image or depth
prepare_btn.click(
fn=prepare_canvas,
inputs=[img, depth_preview, draw_source],
outputs=[sketch],
)
# Stage-1
run_btn.click(
fn=run_depth_and_fill,
inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
max_side, mask_dilate_px, guidance_scale, steps, seed],
outputs=[out, mask_preview],
api_name="run",
).then( # Enable Stage-2 only after Stage-1 completes
fn=_enable_button,
inputs=[],
outputs=[run_btn_stage2],
)
# Stage-2 (REQUIRED; unlocked after Stage-1)
run_btn_stage2.click(
fn=run_stage2_refine,
inputs=[img, out, depth_preview,
mask_upload, sketch, prompt, encoder, max_res, input_size, fp32,
max_side, guidance_scale, steps, seed],
outputs=[out_stage2],
api_name="run_stage2",
)
if __name__ == "__main__":
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
demo.launch(server_name="0.0.0.0", server_port=7860)