Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import sys, pathlib | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| LOCAL_DIFFUSERS_SRC = BASE_DIR / "code_edit" / "diffusers" / "src" | |
| # Ensure local diffusers is importable | |
| if (LOCAL_DIFFUSERS_SRC / "diffusers").exists(): | |
| sys.path.insert(0, str(LOCAL_DIFFUSERS_SRC)) | |
| else: | |
| raise RuntimeError(f"Local diffusers not found at: {LOCAL_DIFFUSERS_SRC}") | |
| from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( | |
| FluxFillPipeline_token12_depth_only as FluxFillPipeline, | |
| ) | |
| # ==== STAGE-2 ONLY ADDED: import Stage-2 Pipeline (do not touch Stage-1) ==== | |
| from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( | |
| FluxFillPipeline_token12_depth as FluxFillPipelineStage2, | |
| ) | |
| # =========================================================================== | |
| import os | |
| import subprocess | |
| import random | |
| from typing import Optional, Tuple, Dict, Any | |
| import torch | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import cv2 | |
| # ---------------- Paths & assets ---------------- | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| CODE_DEPTH = BASE_DIR / "code_depth" | |
| CODE_EDIT = BASE_DIR / "code_edit" | |
| GET_ASSETS = BASE_DIR / "get_assets.sh" | |
| EXPECTED_ASSETS = [ | |
| BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vits.pth", | |
| BASE_DIR / "code_depth" / "checkpoints" / "video_depth_anything_vitl.pth", | |
| BASE_DIR / "code_edit" / "stage1" / "checkpoint-4800" / "pytorch_lora_weights.safetensors", | |
| BASE_DIR / "code_edit" / "stage2" / "checkpoint-20000" / "pytorch_lora_weights.safetensors", | |
| ] | |
| # Import depth helper | |
| if str(CODE_DEPTH) not in sys.path: | |
| sys.path.insert(0, str(CODE_DEPTH)) | |
| from depth_infer import DepthModel # noqa: E402 | |
| # Import your custom diffusers (local fork) | |
| if str(CODE_EDIT / "diffusers") not in sys.path: | |
| sys.path.insert(0, str(CODE_EDIT / "diffusers")) | |
| from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import ( # type: ignore # noqa: E402 | |
| FluxFillPipeline_token12_depth_only as FluxFillPipeline, | |
| ) | |
| # ---------------- Asset preparation (on-demand) ---------------- | |
| def _have_all_assets() -> bool: | |
| return all(p.is_file() for p in EXPECTED_ASSETS) | |
| def _ensure_executable(p: pathlib.Path): | |
| if not p.exists(): | |
| raise FileNotFoundError(f"Not found: {p}") | |
| os.chmod(p, os.stat(p).st_mode | 0o111) | |
| def ensure_assets_if_missing(): | |
| """ | |
| If SKIP_ASSET_DOWNLOAD=1 -> skip checks. | |
| Otherwise ensure checkpoints/LoRAs exist; if missing, run get_assets.sh. | |
| """ | |
| if os.getenv("SKIP_ASSET_DOWNLOAD") == "1": | |
| print("↪️ SKIP_ASSET_DOWNLOAD=1 -> skip asset download check") | |
| return | |
| if _have_all_assets(): | |
| print("✅ Assets already present") | |
| return | |
| print("⬇️ Missing assets, running get_assets.sh ...") | |
| _ensure_executable(GET_ASSETS) | |
| subprocess.run( | |
| ["bash", str(GET_ASSETS)], | |
| check=True, | |
| cwd=str(BASE_DIR), | |
| env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"}, | |
| ) | |
| if not _have_all_assets(): | |
| missing = [str(p.relative_to(BASE_DIR)) for p in EXPECTED_ASSETS if not p.exists()] | |
| raise RuntimeError(f"Assets missing after get_assets.sh: {missing}") | |
| print("✅ Assets ready.") | |
| try: | |
| ensure_assets_if_missing() | |
| except Exception as e: | |
| print(f"⚠️ Asset prepare failed: {e}") | |
| # ---------------- Global singletons ---------------- | |
| _MODELS: Dict[str, DepthModel] = {} | |
| _PIPE: Optional[FluxFillPipeline] = None | |
| # ==== STAGE-2 ONLY ADDED: singleton ==== | |
| _PIPE_STAGE2: Optional[FluxFillPipelineStage2] = None | |
| # ====================================== | |
| def get_model(encoder: str) -> DepthModel: | |
| if encoder not in _MODELS: | |
| _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder) | |
| return _MODELS[encoder] | |
| def get_pipe() -> FluxFillPipeline: | |
| """ | |
| Load Stage-1 pipeline (FluxFillPipeline_token12_depth_only) and mount Stage-1 LoRA if present. | |
| """ | |
| global _PIPE | |
| if _PIPE is not None: | |
| return _PIPE | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| local_flux = BASE_DIR / "code_edit" / "flux_cache" | |
| use_local = local_flux.exists() | |
| hf_token = os.environ.get("HF_TOKEN") | |
| try: | |
| from huggingface_hub import hf_hub_enable_hf_transfer | |
| hf_hub_enable_hf_transfer() | |
| except Exception: | |
| pass | |
| print(f"[pipe] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})") | |
| try: | |
| if use_local: | |
| pipe = FluxFillPipeline.from_pretrained(local_flux, torch_dtype=dtype).to(device) | |
| else: | |
| pipe = FluxFillPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Fill-dev", | |
| torch_dtype=dtype, | |
| token=hf_token, | |
| ).to(device) | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Failed to load FLUX.1-Fill-dev. " | |
| "Ensure gated access and HF_TOKEN; or pre-download to local cache." | |
| ) from e | |
| # -------- LoRA (Stage-1) -------- | |
| lora_dir = CODE_EDIT / "stage1" / "checkpoint-4800" | |
| lora_file = "pytorch_lora_weights.safetensors" | |
| adapter_name = "stage1" | |
| if lora_dir.exists(): | |
| try: | |
| import peft # assert backend presence | |
| print(f"[pipe] loading LoRA from: {lora_dir}/{lora_file}") | |
| pipe.load_lora_weights( | |
| str(lora_dir), | |
| weight_name=lora_file, | |
| adapter_name=adapter_name, | |
| ) | |
| try: | |
| pipe.set_adapters(adapter_name, scale=1.0) | |
| print(f"[pipe] set_adapters('{adapter_name}', 1.0)") | |
| except Exception as e_set: | |
| print(f"[pipe] set_adapters not available ({e_set}); trying fuse_lora()") | |
| try: | |
| pipe.fuse_lora(lora_scale=1.0) | |
| print("[pipe] fuse_lora(lora_scale=1.0) done") | |
| except Exception as e_fuse: | |
| print(f"[pipe] fuse_lora failed: {e_fuse}") | |
| print("[pipe] LoRA ready ✅") | |
| except ImportError: | |
| print("[pipe] peft not installed; LoRA skipped (add `peft>=0.11`).") | |
| except Exception as e: | |
| print(f"[pipe] load_lora_weights failed (continue without): {e}") | |
| else: | |
| print(f"[pipe] LoRA path not found: {lora_dir} (continue without)") | |
| _PIPE = pipe | |
| return pipe | |
| # ==== STAGE-2 ONLY ADDED: Stage-2 loader (no change to Stage-1 logic) ==== | |
| def get_pipe_stage2() -> FluxFillPipelineStage2: | |
| """ | |
| Load Stage-2 FluxFillPipeline_token12_depth and mount Stage-2 LoRA. | |
| """ | |
| global _PIPE_STAGE2 | |
| if _PIPE_STAGE2 is not None: | |
| return _PIPE_STAGE2 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| local_flux = BASE_DIR / "code_edit" / "flux_cache" | |
| use_local = local_flux.exists() | |
| hf_token = os.environ.get("HF_TOKEN") | |
| try: | |
| from huggingface_hub import hf_hub_enable_hf_transfer | |
| hf_hub_enable_hf_transfer() | |
| except Exception: | |
| pass | |
| print(f"[stage2] loading FLUX.1-Fill-dev (dtype={dtype}, device={device}, local={use_local})") | |
| try: | |
| if use_local: | |
| pipe2 = FluxFillPipelineStage2.from_pretrained(local_flux, torch_dtype=dtype).to(device) | |
| else: | |
| pipe2 = FluxFillPipelineStage2.from_pretrained( | |
| "black-forest-labs/FLUX.1-Fill-dev", | |
| torch_dtype=dtype, | |
| token=hf_token, | |
| ).to(device) | |
| except Exception as e: | |
| raise RuntimeError("Stage-2: Failed to load FLUX.1-Fill-dev.") from e | |
| # Load Stage-2 LoRA | |
| lora_dir2 = CODE_EDIT / "stage2" / "checkpoint-20000" | |
| candidate_names = [ | |
| "pytorch_lora_weights.safetensors", | |
| "adapter_model.safetensors", | |
| "lora.safetensors", | |
| ] | |
| weight_name = None | |
| for name in candidate_names: | |
| if (lora_dir2 / name).is_file(): | |
| weight_name = name | |
| break | |
| if not lora_dir2.exists(): | |
| raise RuntimeError(f"Stage-2 LoRA dir not found: {lora_dir2}") | |
| if weight_name is None: | |
| raise RuntimeError( | |
| f"Stage-2 LoRA weight not found under {lora_dir2}. Tried: {candidate_names}" | |
| ) | |
| try: | |
| import peft # noqa: F401 | |
| except Exception as e: | |
| raise RuntimeError("peft is not installed (requires peft>=0.11).") from e | |
| try: | |
| print(f"[stage2] loading LoRA: {lora_dir2}/{weight_name}") | |
| pipe2.load_lora_weights( | |
| str(lora_dir2), | |
| weight_name=weight_name, | |
| adapter_name="stage2", | |
| ) | |
| try: | |
| pipe2.set_adapters("stage2", scale=1.0) | |
| print("[stage2] set_adapters('stage2', 1.0)") | |
| except Exception as e_set: | |
| print(f"[stage2] set_adapters not available ({e_set}); trying fuse_lora()") | |
| try: | |
| pipe2.fuse_lora(lora_scale=1.0) | |
| print("[stage2] fuse_lora(lora_scale=1.0) done") | |
| except Exception as e_fuse: | |
| raise RuntimeError(f"Stage-2 fuse_lora failed: {e_fuse}") from e_fuse | |
| except Exception as e: | |
| raise RuntimeError(f"Stage-2 LoRA load failed: {e}") from e | |
| _PIPE_STAGE2 = pipe2 | |
| return pipe2 | |
| # ========================================================================== | |
| # ---------------- Mask helpers ---------------- | |
| def to_grayscale_mask(im: Image.Image) -> Image.Image: | |
| """ | |
| Convert any RGBA/RGB/L image to L mode. | |
| Output: white = region to remove/fill, black = keep. | |
| """ | |
| if im.mode == "RGBA": | |
| mask = im.split()[-1] # alpha as mask | |
| else: | |
| mask = im.convert("L") | |
| # Simple binarization & denoise | |
| mask = mask.point(lambda p: 255 if p > 16 else 0) | |
| return mask # Do not invert; white = mask region | |
| def dilate_mask(mask_l: Image.Image, px: int) -> Image.Image: | |
| """Dilate the white region by ~px pixels.""" | |
| if px <= 0: | |
| return mask_l | |
| arr = np.array(mask_l, dtype=np.uint8) | |
| kernel = np.ones((3, 3), np.uint8) | |
| iters = max(1, int(px // 2)) # heuristic | |
| dilated = cv2.dilate(arr, kernel, iterations=iters) | |
| return Image.fromarray(dilated, mode="L") | |
| def _mask_from_red(img: Image.Image, out_size: Tuple[int, int]) -> Image.Image: | |
| """ | |
| Extract "pure red strokes" as a binary mask (white=brush, black=others) from RGBA/RGB. | |
| Thresholds are lenient to tolerate compression/resampling. | |
| """ | |
| arr = np.array(img.convert("RGBA")) | |
| r, g, b, a = arr[..., 0], arr[..., 1], arr[..., 2], arr[..., 3] | |
| red_hit = (r >= 200) & (g <= 40) & (b <= 40) & (a > 0) | |
| mask = (red_hit.astype(np.uint8) * 255) | |
| m = Image.fromarray(mask, mode="L").resize(out_size, Image.NEAREST) | |
| return m | |
| def pick_mask( | |
| upload_mask: Optional[Image.Image], | |
| sketch_data: Optional[dict], | |
| base_image: Image.Image, | |
| dilate_px: int = 0, | |
| ) -> Optional[Image.Image]: | |
| """ | |
| Selection rules: | |
| 1) If a mask is uploaded: use it directly (white=mask) | |
| 2) Else from ImageEditor output, only red strokes are recognized as mask: | |
| - Try sketch_data['mask'] first (some versions provide it) | |
| - Else merge red strokes from sketch_data['layers'][*]['image'] | |
| - If still none, try sketch_data['composite'] for red strokes | |
| """ | |
| # 1) Uploaded mask has highest priority | |
| if isinstance(upload_mask, Image.Image): | |
| m = to_grayscale_mask(upload_mask).resize(base_image.size, Image.NEAREST) | |
| return dilate_mask(m, dilate_px) if dilate_px > 0 else m | |
| # 2) Hand-drawn (ImageEditor) | |
| if isinstance(sketch_data, dict): | |
| # 2a) explicit mask (still supported) | |
| m = sketch_data.get("mask") | |
| if isinstance(m, Image.Image): | |
| m = to_grayscale_mask(m).resize(base_image.size, Image.NEAREST) | |
| return dilate_mask(m, dilate_px) if dilate_px > 0 else m | |
| # 2b) merge red strokes from layers | |
| layers = sketch_data.get("layers") | |
| acc = None | |
| if isinstance(layers, list) and layers: | |
| acc = Image.new("L", base_image.size, 0) | |
| for lyr in layers: | |
| if not isinstance(lyr, dict): | |
| continue | |
| li = lyr.get("image") or lyr.get("mask") | |
| if isinstance(li, Image.Image): | |
| m_layer = _mask_from_red(li, base_image.size) | |
| acc = ImageOps.lighter(acc, m_layer) # union | |
| if acc.getbbox() is not None: | |
| return dilate_mask(acc, dilate_px) if dilate_px > 0 else acc | |
| # 2c) finally, search composite for red strokes | |
| comp = sketch_data.get("composite") | |
| if isinstance(comp, Image.Image): | |
| m_comp = _mask_from_red(comp, base_image.size) | |
| if m_comp.getbbox() is not None: | |
| return dilate_mask(m_comp, dilate_px) if dilate_px > 0 else m_comp | |
| # 3) No valid mask | |
| return None | |
| def _round_mult64(x: float, mode: str = "nearest") -> int: | |
| """ | |
| Align x to a multiple of 64: | |
| - mode="ceil" round up | |
| - mode="floor" round down | |
| - mode="nearest" nearest multiple | |
| """ | |
| if mode == "ceil": | |
| return int((x + 63) // 64) * 64 | |
| elif mode == "floor": | |
| return int(x // 64) * 64 | |
| else: # nearest | |
| return int((x + 32) // 64) * 64 | |
| def prepare_size_for_flux(img: Image.Image, target_max: int = 1024) -> tuple[int, int]: | |
| """ | |
| Steps: | |
| 1) Round w,h up to multiples of 64 (avoid too-small sizes) | |
| 2) Fix the long side to target_max (default 1024) | |
| 3) Scale the short side proportionally and align to a multiple of 64 (>= 64) | |
| """ | |
| w, h = img.size | |
| w1 = max(64, _round_mult64(w, mode="ceil")) | |
| h1 = max(64, _round_mult64(h, mode="ceil")) | |
| if w1 >= h1: | |
| out_w = target_max | |
| scaled_h = h1 * (target_max / w1) | |
| out_h = max(64, _round_mult64(scaled_h, mode="nearest")) | |
| else: | |
| out_h = target_max | |
| scaled_w = w1 * (target_max / h1) | |
| out_w = max(64, _round_mult64(scaled_w, mode="nearest")) | |
| return int(out_w), int(out_h) | |
| # ---------------- Preview depth for canvas (colored) ---------------- | |
| def preview_depth(image: Optional[Image.Image], encoder: str, max_res: int, input_size: int, fp32: bool): | |
| if image is None: | |
| return None | |
| dm = get_model(encoder) | |
| d_rgb = dm.infer(image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False) | |
| return d_rgb | |
| def prepare_canvas(image, depth_img, source): | |
| base = depth_img if source == "depth" else image | |
| if base is None: | |
| raise gr.Error('Please upload an image (and wait for the depth preview), then click "Prepare canvas".') | |
| return gr.update(value=base) | |
| # ---------------- Stage-1: depth(color) -> fill ---------------- | |
| def run_depth_and_fill( | |
| image: Image.Image, | |
| mask_upload: Optional[Image.Image], | |
| sketch: Optional[dict], | |
| prompt: str, | |
| encoder: str, | |
| max_res: int, | |
| input_size: int, | |
| fp32: bool, | |
| max_side: int, | |
| mask_dilate_px: int, | |
| guidance_scale: float, | |
| steps: int, | |
| seed: Optional[int], | |
| ) -> Tuple[Image.Image, Image.Image]: | |
| if image is None: | |
| raise gr.Error("Please upload an image first.") | |
| # 1) produce a colored depth map (RGB) | |
| depth_model = get_model(encoder) | |
| depth_rgb: Image.Image = depth_model.infer( | |
| image=image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=False | |
| ).convert("RGB") | |
| print(f"[DEBUG] Depth RGB: mode={depth_rgb.mode}, size={depth_rgb.size}") | |
| # 2) extract mask (uploaded > drawn) | |
| mask_l = pick_mask(mask_upload, sketch, image, dilate_px=mask_dilate_px) | |
| if (mask_l is None) or (mask_l.getbbox() is None): | |
| raise gr.Error("No valid mask detected: please draw with the red brush or upload a binary mask.") | |
| print(f"[DEBUG] Mask: mode={mask_l.mode}, size={mask_l.size}, bbox={mask_l.getbbox()}") | |
| # 3) decide output size | |
| width, height = prepare_size_for_flux(depth_rgb, target_max=max_side) | |
| orig_w, orig_h = image.size | |
| print(f"[DEBUG] FLUX size: {width}x{height}, original: {orig_w}x{orig_h}") | |
| # 4) run FLUX pipeline (key: use depth_rgb as both image and depth input) | |
| pipe = get_pipe() | |
| generator = ( | |
| torch.Generator("cpu").manual_seed(int(seed)) | |
| if (seed is not None and seed >= 0) | |
| else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1)) | |
| ) | |
| result = pipe( | |
| prompt=prompt, | |
| image=depth_rgb, # use the colored depth map instead of original image | |
| mask_image=mask_l, | |
| width=width, | |
| height=height, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| max_sequence_length=512, | |
| generator=generator, | |
| depth=depth_rgb, # feed depth (colored) | |
| ).images[0] | |
| final_result = result.resize((orig_w, orig_h), Image.BICUBIC) | |
| # return result and mask preview | |
| mask_preview = mask_l.resize((orig_w, orig_h), Image.NEAREST).convert("RGB") | |
| return final_result, mask_preview | |
| def _to_pil_rgb(img_like) -> Image.Image: | |
| """Normalize input to PIL RGB. Supports PIL/L/RGBA/np.array.""" | |
| if isinstance(img_like, Image.Image): | |
| return img_like.convert("RGB") | |
| try: | |
| arr = np.array(img_like) | |
| if arr.ndim == 2: | |
| arr = np.stack([arr, arr, arr], axis=-1) | |
| return Image.fromarray(arr.astype(np.uint8), mode="RGB") | |
| except Exception: | |
| raise gr.Error("Stage-2: `depth` / `depth_image` is not a valid image object.") | |
| # ---------------- Stage-2: REQUIRED refine/render ---------------- | |
| def run_stage2_refine( | |
| image: Image.Image, # original image (RGB) | |
| stage1_out: Image.Image, # output from Stage-1 | |
| depth_img_from_stage1_input: Image.Image, # Stage-1 depth preview (from UI) | |
| mask_upload: Optional[Image.Image], | |
| sketch: Optional[dict], | |
| prompt: str, | |
| encoder: str, | |
| max_res: int, | |
| input_size: int, | |
| fp32: bool, | |
| max_side: int, | |
| guidance_scale: float, | |
| steps: int, | |
| seed: Optional[int], | |
| ) -> Image.Image: | |
| if image is None or stage1_out is None: | |
| raise gr.Error("Please complete Stage-1 first (needs original image and Stage-1 output).") | |
| # Allow refine without mask (use all-black) | |
| mask_l = pick_mask(mask_upload, sketch, image, dilate_px=0) | |
| if (mask_l is None) or (mask_l.getbbox() is None): | |
| mask_l = Image.new("L", image.size, 0) | |
| # Unify sizes | |
| width, height = prepare_size_for_flux(image, target_max=max_side) | |
| orig_w, orig_h = image.size | |
| pipe2 = get_pipe_stage2() | |
| g2 = ( | |
| torch.Generator("cpu").manual_seed(int(seed)) | |
| if (seed is not None and seed >= 0) | |
| else torch.Generator("cpu").manual_seed(random.randint(0, 2**31 - 1)) | |
| ) | |
| depth_pil = _to_pil_rgb(stage1_out) # for `depth` | |
| depth_image_pil = _to_pil_rgb(depth_img_from_stage1_input) # for `depth_image` | |
| image_rgb = _to_pil_rgb(image) | |
| # Resize to (width, height) | |
| depth_pil = depth_pil.resize((width, height), Image.BICUBIC) | |
| depth_image_pil = depth_image_pil.resize((width, height), Image.BICUBIC) | |
| # Mapping: | |
| # image = original RGB | |
| # depth = Stage-1 output (updated geometry) | |
| # depth_image = Stage-1 input depth (UI depth preview) | |
| out2 = pipe2( | |
| prompt=prompt, | |
| image=image, # original image | |
| mask_image=mask_l, | |
| width=width, | |
| height=height, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(steps), | |
| max_sequence_length=512, | |
| generator=g2, | |
| depth=depth_pil, | |
| depth_image=depth_image_pil, | |
| ).images[0] | |
| out2 = out2.resize((orig_w * 3, orig_h), Image.BICUBIC) # keep your 3× showcase layout | |
| return out2 | |
| # ---------------- UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # GeoRemover · Depth-Guided Object Removal (Two-Stage, Stage-2 REQUIRED) | |
| **Pipeline overview** | |
| 1) Compute a **colored depth map** from your input image. | |
| 2) You create a **removal mask** (red brush or upload). | |
| 3) **Stage-1** runs FLUX Fill with depth guidance to get a first pass. | |
| 4) **Stage-2 (REQUIRED)** renders the final result from depth → image using Stage-1 output and the original depth. | |
| > ⚠️ **Stage-2 is required.** Always click **Run Stage-2 (Render)** *after* Stage-1 finishes. Stage-1 alone is not the final output. | |
| --- | |
| ### Quick start | |
| 1. **Upload image** (left). Wait for **Depth preview (colored)** (right). | |
| 2. In **Draw mask**, pick **Draw on: _image_** or **_depth_**, then click **Prepare canvas**. | |
| 3. Paint the region to remove using the **red brush** (**red = remove**). | |
| 4. Optionally adjust **Mask dilation** for thin edges. | |
| 5. Enter a concise **Prompt** describing the fill content. | |
| 6. Click **Run** → produces **Stage-1** (first pass). | |
| 7. Click **Run Stage-2 (Render)** → produces the **final** result. | |
| --- | |
| ### Mask rules & tips | |
| - Only **red strokes** are treated as mask (**white = remove, black = keep** internally). | |
| - Paint **slightly larger** than the object boundary to avoid seams/halos. | |
| - If you have a binary mask already, use **Upload mask**. | |
| - **Mask dilation (px)** expands the mask to cover thin borders. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input image | |
| img = gr.Image( | |
| label="Upload image", | |
| type="pil", | |
| ) | |
| # Mask: upload or draw | |
| with gr.Tab("Upload mask"): | |
| mask_upload = gr.Image( | |
| label="Mask (optional)", | |
| type="pil", | |
| ) | |
| with gr.Tab("Draw mask"): | |
| draw_source = gr.Radio( | |
| ["image", "depth"], | |
| value="image", | |
| label="Draw on", | |
| ) | |
| prepare_btn = gr.Button("Prepare canvas", variant="secondary") | |
| gr.Markdown( | |
| """ | |
| **Canvas usage** | |
| - Click **Prepare canvas** after selecting *image* or *depth*. | |
| - Use the **red brush** only—red strokes are extracted as the removal mask. | |
| - Switch tabs anytime if you prefer uploading a ready-made mask. | |
| """ | |
| ) | |
| sketch = gr.ImageEditor( | |
| label="Sketch mask (red = remove)", | |
| type="pil", | |
| brush=gr.Brush(colors=["#FF0000"], default_size=24), | |
| ) | |
| # Prompt | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A beautiful scene", | |
| placeholder="don't change it", | |
| ) | |
| # Tunables | |
| with gr.Accordion("Advanced (Depth & FLUX)", open=False): | |
| encoder = gr.Dropdown( | |
| ["vits", "vitl"], | |
| value="vitl", | |
| label="Depth encoder", | |
| ) | |
| max_res = gr.Slider( | |
| 512, 2048, value=1280, step=64, | |
| label="Depth: max_res", | |
| ) | |
| input_size = gr.Slider( | |
| 256, 1024, value=518, step=2, | |
| label="Depth: input_size", | |
| ) | |
| fp32 = gr.Checkbox( | |
| False, | |
| label="Depth: use FP32 (default FP16)", | |
| ) | |
| max_side = gr.Slider( | |
| 512, 1536, value=1024, step=64, | |
| label="FLUX: max side (px)", | |
| ) | |
| mask_dilate_px = gr.Slider( | |
| 0, 128, value=0, step=1, | |
| label="Mask dilation (px)", | |
| ) | |
| guidance_scale = gr.Slider( | |
| 0, 50, value=30, step=0.5, | |
| label="FLUX: guidance_scale", | |
| ) | |
| steps = gr.Slider( | |
| 10, 75, value=50, step=1, | |
| label="FLUX: steps", | |
| ) | |
| seed = gr.Number( | |
| value=0, precision=0, | |
| label="Seed (>=0 = fixed; empty = random)", | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| # Stage-2 is REQUIRED: keep disabled until Stage-1 finishes | |
| run_btn_stage2 = gr.Button("Run Stage-2 (Render)", variant="secondary", interactive=False) | |
| with gr.Column(scale=1): | |
| depth_preview = gr.Image( | |
| label="Depth preview (colored)", | |
| interactive=False, | |
| ) | |
| mask_preview = gr.Image( | |
| label="Mask preview (areas to remove)", | |
| interactive=False, | |
| ) | |
| out = gr.Image( | |
| label="Output (Stage-1 first pass)", | |
| ) | |
| out_stage2 = gr.Image( | |
| label="Final Output (Stage-2)", | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Why Stage-2 is required | |
| Stage-1 provides a depth-guided fill that is *not final*. **Stage-2 renders** the definitive image by leveraging: | |
| - **Stage-1 output** as updated geometry hints, and | |
| - **Original colored depth** as `depth_image` guidance. | |
| Skipping Stage-2 will leave the process incomplete. | |
| ### Troubleshooting | |
| - **“No valid mask detected”**: Either upload a binary mask (white=remove) **or** draw with **red brush** after clicking **Prepare canvas**. | |
| - **Seams/halos**: Increase **Mask dilation (px)** (e.g., 8–16) and re-run both stages. | |
| - **Prompt not followed**: Lower **guidance_scale** (e.g., 18–24) and make the prompt more concrete. | |
| - **Depth looks noisy**: Use **vitl**, increase **Depth: max_res**, or enable **FP32**. | |
| """ | |
| ) | |
| # ===== Helpers to toggle Stage-2 button ===== | |
| def _enable_button(): | |
| return gr.update(interactive=True) | |
| # Auto depth preview on image change | |
| img.change( | |
| fn=preview_depth, | |
| inputs=[img, encoder, max_res, input_size, fp32], | |
| outputs=[depth_preview], | |
| ) | |
| # Prepare canvas for drawing on image or depth | |
| prepare_btn.click( | |
| fn=prepare_canvas, | |
| inputs=[img, depth_preview, draw_source], | |
| outputs=[sketch], | |
| ) | |
| # Stage-1 | |
| run_btn.click( | |
| fn=run_depth_and_fill, | |
| inputs=[img, mask_upload, sketch, prompt, encoder, max_res, input_size, fp32, | |
| max_side, mask_dilate_px, guidance_scale, steps, seed], | |
| outputs=[out, mask_preview], | |
| api_name="run", | |
| ).then( # Enable Stage-2 only after Stage-1 completes | |
| fn=_enable_button, | |
| inputs=[], | |
| outputs=[run_btn_stage2], | |
| ) | |
| # Stage-2 (REQUIRED; unlocked after Stage-1) | |
| run_btn_stage2.click( | |
| fn=run_stage2_refine, | |
| inputs=[img, out, depth_preview, | |
| mask_upload, sketch, prompt, encoder, max_res, input_size, fp32, | |
| max_side, guidance_scale, steps, seed], | |
| outputs=[out_stage2], | |
| api_name="run_stage2", | |
| ) | |
| if __name__ == "__main__": | |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |