#!/usr/bin/env python import os import random from typing import Optional, Tuple, Union, List import numpy as np import PIL.Image import gradio as gr import torch import spaces # 👈 ZeroGPU support from model import PartEditSDXLModel, PART_TOKENS from datasets import load_dataset import base64 from io import BytesIO import tempfile import uuid import pathlib # ---- Robust HF cache setup for Spaces (persistent or not) ---- # def _pick_hf_cache_base() -> str: # data = "/data" # if os.path.isdir(data) and os.access(data, os.W_OK): # return "/data/.huggingface" # persistent disk on Spaces # return "/tmp/hf_cache" # ephemeral but writable everywhere # HF_BASE = os.environ.get("HF_HOME") or _pick_hf_cache_base() # HF_BASE = str(pathlib.Path(HF_BASE).absolute()) # os.environ.setdefault("HF_HOME", HF_BASE) # shared root (recommended) # os.environ.setdefault("HF_HUB_CACHE", f"{HF_BASE}/hub") # hub cache (models/datasets/spaces) # os.environ.setdefault("HF_DATASETS_CACHE", f"{HF_BASE}/datasets") # datasets Arrow cache # os.environ.setdefault("TRANSFORMERS_CACHE", f"{HF_BASE}/hub") # transformers cache # for k in ("HF_HUB_CACHE", "HF_DATASETS_CACHE", "TRANSFORMERS_CACHE"): # os.makedirs(os.environ[k], exist_ok=True) # --------------------------------------------------------------- MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = os.environ.get("CACHE_EXAMPLES") == "1" AVAILABLE_TOKENS = list(PART_TOKENS.keys()) # Download examples directly from the huggingface PartEdit-Bench # Login using e.g. `huggingface-cli login` or `hf login` if needed. bench = load_dataset("Aleksandar/PartEdit-Bench", revision="v1.1", split="synth", cache_dir="~/.cache/huggingface/hub") use_examples = None # all with None logo = "assets/partedit.png" loaded_logo = PIL.Image.open(logo).convert("RGB") # base encoded logo logo_encoded = None with open(logo, "rb") as f: logo_encoded = base64.b64encode(f.read()).decode() def _save_image_for_download(edited: Union[PIL.Image.Image, np.ndarray, str, List]) -> str: item = edited[0] if isinstance(edited, list) else edited # pick first if isinstance(item, str): return item # already a path if isinstance(item, np.ndarray): item = PIL.Image.fromarray(item) assert isinstance(item, PIL.Image.Image), "Edited output must be PIL, ndarray, str path, or list of these." out_path = os.path.join(tempfile.gettempdir(), f"partedit_{uuid.uuid4().hex}.png") item.save(out_path) return out_path def get_example(idx, bench): # [prompt_original, subject, token_cls, edit, "", 50, 7.5, seed, 50] example = bench[idx] return [ example["prompt_original"], example["subject"], example["token_cls"], example["edit"], "", 50, 7.5, example["seed"], 50, ] examples = [get_example(idx, bench) for idx in (use_examples if use_examples is not None else range(len(bench)))] first_ex = examples[0] if len(examples) else ["", "", AVAILABLE_TOKENS[0], "", "", 50, 7.5, 0, 50] title = f"""
Official demo for the PartEdit paper.
It simultaneously predicts the part-localization mask and edits the original trajectory. Supports Hugging Face ZeroGPU and one-click Duplicate for private use.
Running on CPU 🥶 This demo does not work on CPU. On ZeroGPU Spaces, a GPU will be requested when you click Apply Edit.
" def running_in_hf_space() -> bool: # Common env vars present on Hugging Face Spaces return ( os.getenv("SYSTEM") == "spaces" or any(os.getenv(k) for k in ( "SPACE_ID", "HF_SPACE_ID", "SPACE_REPO_ID", "SPACE_REPO_NAME", "SPACE_AUTHOR_NAME", "SPACE_TITLE" )) ) if __name__ == "__main__": model = PartEditSDXLModel() with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) # Always show Duplicate button on Spaces gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", variant="huggingface", size="lg", visible=running_in_hf_space(), ) # Single tab: PartEdit only with gr.Tabs(): with gr.Tab(label="PartEdit", id="edit"): edit_demo(model) demo.queue(max_size=20).launch(server_name="0.0.0.0")