Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import uuid | |
| import json | |
| import time | |
| import shutil | |
| import numpy as np | |
| import random | |
| import tempfile | |
| import zipfile | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import QwenImageLayeredPipeline | |
| from pptx import Presentation | |
| LOG_DIR = "/tmp/local" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Optional HF login (works in Spaces if you set HF token as secret env var "hf") | |
| from huggingface_hub import login, HfApi, hf_hub_download | |
| login(token=os.environ.get("hf")) | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipeline = QwenImageLayeredPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Layered", torch_dtype=dtype | |
| ).to(device) | |
| # ---------------------------- | |
| # Dataset repo persistence (no /data needed) | |
| # ---------------------------- | |
| HF_TOKEN = os.environ.get("hf") # secret | |
| HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO") # e.g. "hexware/qwen-layered-sessions" | |
| _hf_api: Optional[HfApi] = None | |
| _persist_enabled = False | |
| def _init_dataset_repo() -> Tuple[bool, str]: | |
| """ | |
| Returns (enabled, message) | |
| """ | |
| global _hf_api, _persist_enabled | |
| if not HF_TOKEN: | |
| _persist_enabled = False | |
| return False, "Persistence: disabled (no secret env var 'hf')." | |
| if not HF_DATASET_REPO: | |
| _persist_enabled = False | |
| return False, "Persistence: disabled (set env var HF_DATASET_REPO to enable)." | |
| try: | |
| _hf_api = HfApi(token=HF_TOKEN) | |
| # Create dataset repo if missing (private). If exists, this is no-op. | |
| # NOTE: create_repo is available via HfApi in most versions. | |
| _hf_api.create_repo( | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| private=True, | |
| exist_ok=True, | |
| ) | |
| _persist_enabled = True | |
| return True, f"Persistence: enabled (dataset repo: {HF_DATASET_REPO})." | |
| except Exception as e: | |
| _persist_enabled = False | |
| return False, f"Persistence: failed to init dataset repo: {type(e).__name__}: {e}" | |
| _enabled, _enabled_msg = _init_dataset_repo() | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def ensure_dirname(path: str): | |
| if path and not os.path.exists(path): | |
| os.makedirs(path, exist_ok=True) | |
| def random_str(length=8): | |
| return uuid.uuid4().hex[:length] | |
| def _now_ts() -> float: | |
| return time.time() | |
| def _clamp_int(x, default: int, lo: int, hi: int) -> int: | |
| try: | |
| v = int(x) | |
| except Exception: | |
| v = default | |
| return max(lo, min(hi, v)) | |
| def _pil_rgba(input_image) -> Image.Image: | |
| if isinstance(input_image, list): | |
| input_image = input_image[0] | |
| if isinstance(input_image, str): | |
| pil_image = Image.open(input_image).convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, Image.Image): | |
| pil_image = input_image.convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, np.ndarray): | |
| pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA") | |
| else: | |
| raise ValueError(f"Unsupported input_image type: {type(input_image)}") | |
| return pil_image | |
| def imagelist_to_pptx(img_files: List[str]) -> str: | |
| with Image.open(img_files[0]) as img: | |
| img_width_px, img_height_px = img.size | |
| def px_to_emu(px, dpi=96): | |
| inch = px / dpi | |
| emu = inch * 914400 | |
| return int(emu) | |
| prs = Presentation() | |
| prs.slide_width = px_to_emu(img_width_px) | |
| prs.slide_height = px_to_emu(img_height_px) | |
| slide = prs.slides.add_slide(prs.slide_layouts[6]) | |
| left = top = 0 | |
| for img_path in img_files: | |
| slide.shapes.add_picture( | |
| img_path, | |
| left, | |
| top, | |
| width=px_to_emu(img_width_px), | |
| height=px_to_emu(img_height_px), | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as tmp: | |
| prs.save(tmp.name) | |
| return tmp.name | |
| def export_node_layers(layers: List[Image.Image]) -> Tuple[str, str]: | |
| """ | |
| Returns (pptx_path, zip_path) | |
| """ | |
| temp_files: List[str] = [] | |
| for img in layers: | |
| tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| img.save(tmp.name) | |
| temp_files.append(tmp.name) | |
| pptx_path = imagelist_to_pptx(temp_files) | |
| with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip: | |
| with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for i, img_path in enumerate(temp_files): | |
| zipf.write(img_path, f"layer_{i+1}.png") | |
| zip_path = tmpzip.name | |
| return pptx_path, zip_path | |
| # ---------------------------- | |
| # ZeroGPU duration | |
| # ---------------------------- | |
| def get_duration( | |
| input_image, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| layer=4, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| ): | |
| return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500) | |
| # ---------------------------- | |
| # GPU pipeline runners | |
| # ---------------------------- | |
| def gpu_run_pipeline( | |
| input_pil_image: Image.Image, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| layer=4, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| ) -> List[Image.Image]: | |
| # Seed | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Normalize resolution input | |
| resolution = _clamp_int(resolution, default=640, lo=640, hi=1024) | |
| if resolution not in (640, 1024): | |
| resolution = 640 | |
| gen_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = { | |
| "image": input_pil_image, | |
| "generator": torch.Generator(device=gen_device).manual_seed(seed), | |
| "true_cfg_scale": true_guidance_scale, | |
| "prompt": prompt, | |
| "negative_prompt": neg_prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "num_images_per_prompt": 1, | |
| "layers": layer, | |
| "resolution": resolution, | |
| "cfg_normalize": cfg_norm, | |
| "use_en_prompt": use_en_prompt, | |
| } | |
| with torch.inference_mode(): | |
| out = pipeline(**inputs) | |
| # out.images[0] => list of PIL images | |
| return out.images[0] | |
| # ---------------------------- | |
| # Session / History model | |
| # ---------------------------- | |
| def new_state() -> Dict[str, Any]: | |
| sid = uuid.uuid4().hex | |
| return { | |
| "session_id": sid, | |
| "nodes": {}, # node_id -> node dict | |
| "root_id": None, | |
| "current_id": None, | |
| "selected_layer_idx": 0, | |
| "last_refined_id": None, | |
| } | |
| def _node_label(node: Dict[str, Any]) -> str: | |
| name = node.get("name") or node["id"][:8] | |
| return f"{name} ({node['id'][:8]})" | |
| def _build_history_choices(st: Dict[str, Any]) -> List[Tuple[str, str]]: | |
| # returns list of (label, value=node_id) | |
| out = [] | |
| for nid, node in st["nodes"].items(): | |
| out.append((_node_label(node), nid)) | |
| # stable order by created | |
| out.sort(key=lambda x: st["nodes"][x[1]].get("created_at", 0.0)) | |
| return out | |
| def _get_node(st: Dict[str, Any], node_id: Optional[str]) -> Optional[Dict[str, Any]]: | |
| if not node_id: | |
| return None | |
| return st["nodes"].get(node_id) | |
| def _current_node(st: Dict[str, Any]) -> Optional[Dict[str, Any]]: | |
| return _get_node(st, st.get("current_id")) | |
| def _chips_text(st: Dict[str, Any], node_id: Optional[str]) -> str: | |
| node = _get_node(st, node_id) | |
| if not node: | |
| return "" | |
| chips = [] | |
| if node_id == st.get("root_id"): | |
| chips.append("[root]") | |
| if node.get("parent_id"): | |
| chips.append("[parent]") | |
| children = node.get("children_ids") or [] | |
| if children: | |
| chips.append(f"[children:{len(children)}]") | |
| return " ".join(chips) | |
| def _make_node( | |
| st: Dict[str, Any], | |
| layers: List[Image.Image], | |
| parent_id: Optional[str], | |
| name: Optional[str] = None, | |
| ) -> str: | |
| nid = uuid.uuid4().hex | |
| node = { | |
| "id": nid, | |
| "name": name or ("root" if parent_id is None else "refine"), | |
| "parent_id": parent_id, | |
| "children_ids": [], | |
| "created_at": _now_ts(), | |
| "layers": layers, | |
| } | |
| st["nodes"][nid] = node | |
| if parent_id: | |
| parent = st["nodes"].get(parent_id) | |
| if parent is not None: | |
| parent.setdefault("children_ids", []).append(nid) | |
| return nid | |
| def _set_current(st: Dict[str, Any], node_id: str): | |
| st["current_id"] = node_id | |
| st["selected_layer_idx"] = 0 | |
| # ---------------------------- | |
| # Persistence: save/load whole session as one zip in dataset repo | |
| # ---------------------------- | |
| def _serialize_session_to_zip(st: Dict[str, Any]) -> str: | |
| """ | |
| Create a zip file with: | |
| session.json | |
| nodes/<node_id>/layer_1.png ... | |
| Returns local zip path. | |
| """ | |
| tmpdir = tempfile.mkdtemp(prefix="sess_") | |
| try: | |
| sess_meta = { | |
| "session_id": st["session_id"], | |
| "root_id": st["root_id"], | |
| "current_id": st["current_id"], | |
| "selected_layer_idx": st.get("selected_layer_idx", 0), | |
| "last_refined_id": st.get("last_refined_id"), | |
| "nodes": {}, | |
| } | |
| for nid, node in st["nodes"].items(): | |
| node_dir = os.path.join(tmpdir, "nodes", nid) | |
| os.makedirs(node_dir, exist_ok=True) | |
| layers: List[Image.Image] = node.get("layers") or [] | |
| for i, img in enumerate(layers): | |
| img_path = os.path.join(node_dir, f"layer_{i+1}.png") | |
| img.save(img_path) | |
| sess_meta["nodes"][nid] = { | |
| "id": nid, | |
| "name": node.get("name"), | |
| "parent_id": node.get("parent_id"), | |
| "children_ids": node.get("children_ids") or [], | |
| "created_at": node.get("created_at", 0.0), | |
| "layer_count": len(layers), | |
| } | |
| meta_path = os.path.join(tmpdir, "session.json") | |
| with open(meta_path, "w", encoding="utf-8") as f: | |
| json.dump(sess_meta, f, ensure_ascii=False, indent=2) | |
| out_zip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name | |
| with zipfile.ZipFile(out_zip, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(tmpdir): | |
| for fn in files: | |
| abs_path = os.path.join(root, fn) | |
| rel_path = os.path.relpath(abs_path, tmpdir) | |
| zf.write(abs_path, rel_path) | |
| return out_zip | |
| finally: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def _deserialize_session_from_zip(zip_path: str) -> Dict[str, Any]: | |
| tmpdir = tempfile.mkdtemp(prefix="sess_load_") | |
| try: | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(tmpdir) | |
| meta_path = os.path.join(tmpdir, "session.json") | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| st = new_state() | |
| st["session_id"] = meta["session_id"] | |
| st["root_id"] = meta.get("root_id") | |
| st["current_id"] = meta.get("current_id") | |
| st["selected_layer_idx"] = meta.get("selected_layer_idx", 0) | |
| st["last_refined_id"] = meta.get("last_refined_id") | |
| nodes_meta: Dict[str, Any] = meta.get("nodes", {}) | |
| # First pass: create node shells | |
| for nid, nm in nodes_meta.items(): | |
| st["nodes"][nid] = { | |
| "id": nid, | |
| "name": nm.get("name"), | |
| "parent_id": nm.get("parent_id"), | |
| "children_ids": nm.get("children_ids") or [], | |
| "created_at": nm.get("created_at", 0.0), | |
| "layers": [], | |
| } | |
| # Second pass: load layers images | |
| for nid, nm in nodes_meta.items(): | |
| layer_count = int(nm.get("layer_count", 0)) | |
| node_dir = os.path.join(tmpdir, "nodes", nid) | |
| layers: List[Image.Image] = [] | |
| for i in range(layer_count): | |
| p = os.path.join(node_dir, f"layer_{i+1}.png") | |
| if os.path.exists(p): | |
| layers.append(Image.open(p).convert("RGBA")) | |
| st["nodes"][nid]["layers"] = layers | |
| return st | |
| finally: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def save_session_to_hub(st: Dict[str, Any]) -> Tuple[str, str]: | |
| """ | |
| Returns (status_text, session_id) | |
| """ | |
| if not _persist_enabled or _hf_api is None: | |
| return "Save: disabled (set HF_DATASET_REPO and secret hf write token).", st.get("session_id", "") | |
| try: | |
| zip_path = _serialize_session_to_zip(st) | |
| path_in_repo = f"sessions/{st['session_id']}.zip" | |
| _hf_api.upload_file( | |
| path_or_fileobj=zip_path, | |
| path_in_repo=path_in_repo, | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| commit_message=f"Save session {st['session_id']}", | |
| ) | |
| return f"Saved to dataset repo: {path_in_repo}", st["session_id"] | |
| except Exception as e: | |
| return f"Save failed: {type(e).__name__}: {e}", st.get("session_id", "") | |
| finally: | |
| try: | |
| if "zip_path" in locals() and os.path.exists(zip_path): | |
| os.remove(zip_path) | |
| except Exception: | |
| pass | |
| def load_session_from_hub(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]: | |
| if not _persist_enabled: | |
| return None, "Load: disabled (set HF_DATASET_REPO and secret hf write token)." | |
| session_id = (session_id or "").strip() | |
| if not session_id: | |
| return None, "Load: please enter a Session ID." | |
| try: | |
| filename = f"sessions/{session_id}.zip" | |
| local_zip = hf_hub_download( | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| filename=filename, | |
| token=HF_TOKEN, | |
| ) | |
| st = _deserialize_session_from_zip(local_zip) | |
| return st, f"Loaded session: {session_id}" | |
| except Exception as e: | |
| return None, f"Load failed: {type(e).__name__}: {e}" | |
| # ---------------------------- | |
| # UI Callbacks | |
| # ---------------------------- | |
| def ui_boot() -> Tuple[str, Dict[str, Any]]: | |
| ensure_dirname(LOG_DIR) | |
| st = new_state() | |
| return _enabled_msg, st | |
| def on_new_session(st: Dict[str, Any]) -> Tuple[Dict[str, Any], str, gr.Dropdown, List[Image.Image], List[Image.Image], str, str, str, Optional[str], Optional[str]]: | |
| st = new_state() | |
| return ( | |
| st, | |
| st["session_id"], | |
| gr.Dropdown(choices=[], value=None), | |
| [], | |
| [], | |
| "", | |
| "", | |
| "", | |
| None, | |
| None, | |
| ) | |
| def _render_from_state(st: Dict[str, Any]) -> Tuple[ | |
| gr.Dropdown, | |
| List[Image.Image], | |
| List[Image.Image], | |
| gr.Number, | |
| str | |
| ]: | |
| choices = _build_history_choices(st) | |
| current = _current_node(st) | |
| layers = current["layers"] if current else [] | |
| idx = st.get("selected_layer_idx", 0) | |
| if layers: | |
| idx = max(0, min(idx, len(layers) - 1)) | |
| st["selected_layer_idx"] = idx | |
| chips = _chips_text(st, st.get("current_id")) | |
| return ( | |
| gr.Dropdown(choices=choices, value=st.get("current_id")), | |
| layers, | |
| layers, # mini gallery mirrors current layers | |
| idx, | |
| chips, | |
| ) | |
| def on_history_select(node_id: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]: | |
| if node_id and node_id in st["nodes"]: | |
| st["current_id"] = node_id | |
| st["selected_layer_idx"] = 0 | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return st, dd, layers, mini, idx, chips | |
| def on_layer_gallery_select(evt: gr.SelectData, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Number]: | |
| # evt.index is int for Gallery | |
| idx = int(evt.index) if evt and evt.index is not None else 0 | |
| current = _current_node(st) | |
| if current: | |
| layers = current.get("layers") or [] | |
| if layers: | |
| idx = max(0, min(idx, len(layers) - 1)) | |
| else: | |
| idx = 0 | |
| else: | |
| idx = 0 | |
| st["selected_layer_idx"] = idx | |
| return st, idx | |
| def on_back_to_parent(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]: | |
| cur = _current_node(st) | |
| if cur and cur.get("parent_id"): | |
| st["current_id"] = cur["parent_id"] | |
| st["selected_layer_idx"] = 0 | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return st, dd, layers, mini, idx, chips | |
| def on_duplicate_node(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]: | |
| cur = _current_node(st) | |
| if cur: | |
| # Duplicate current node as sibling (same parent) | |
| layers = cur.get("layers") or [] | |
| parent_id = cur.get("parent_id") | |
| name = (cur.get("name") or "node") + " copy" | |
| new_id = _make_node(st, layers=layers, parent_id=parent_id, name=name) | |
| _set_current(st, new_id) | |
| if st.get("root_id") is None and parent_id is None: | |
| st["root_id"] = new_id | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return st, dd, layers, mini, idx, chips | |
| def on_rename_node(new_name: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown]: | |
| cur = _current_node(st) | |
| if cur: | |
| nn = (new_name or "").strip() | |
| if nn: | |
| cur["name"] = nn | |
| dd, _, _, _, _ = _render_from_state(st) | |
| return st, dd | |
| def on_export_selected(st: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: | |
| cur = _current_node(st) | |
| if not cur: | |
| return None, None | |
| pptx_path, zip_path = export_node_layers(cur.get("layers") or []) | |
| return pptx_path, zip_path | |
| def on_save_session(st: Dict[str, Any]) -> Tuple[str, str]: | |
| status, sid = save_session_to_hub(st) | |
| return status, sid | |
| def on_load_session(session_id: str, st: Dict[str, Any]) -> Tuple[ | |
| Dict[str, Any], | |
| str, | |
| gr.Dropdown, | |
| List[Image.Image], | |
| List[Image.Image], | |
| gr.Number, | |
| str | |
| ]: | |
| loaded, msg = load_session_from_hub(session_id) | |
| if loaded is None: | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return st, msg, dd, layers, mini, idx, chips | |
| st = loaded | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return st, msg, dd, layers, mini, idx, chips | |
| # GPU click handlers | |
| def on_decompose_click( | |
| input_image, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| layer=4, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| st: Optional[Dict[str, Any]] = None, | |
| ): | |
| if st is None: | |
| st = new_state() | |
| pil_image = _pil_rgba(input_image) | |
| layers_out = gpu_run_pipeline( | |
| pil_image, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=layer, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| # Reset session tree on Decompose (new root) | |
| sid = st.get("session_id") or uuid.uuid4().hex | |
| st = new_state() | |
| st["session_id"] = sid | |
| root_id = _make_node(st, layers=layers_out, parent_id=None, name="root") | |
| st["root_id"] = root_id | |
| _set_current(st, root_id) | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return ( | |
| st, | |
| dd, | |
| layers, | |
| mini, | |
| idx, | |
| chips, | |
| gr.update(open=False), # refined accordion closed | |
| [], # refined gallery cleared | |
| None, | |
| None, | |
| ) | |
| def on_refine_click( | |
| sub_layers_count: int, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| st: Optional[Dict[str, Any]] = None, | |
| ): | |
| if st is None: | |
| st = new_state() | |
| cur = _current_node(st) | |
| if not cur: | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return ( | |
| st, | |
| dd, | |
| layers, | |
| mini, | |
| idx, | |
| chips, | |
| "Refine: no current node.", | |
| gr.update(open=False), | |
| [], | |
| None, | |
| None, | |
| ) | |
| layers_list: List[Image.Image] = cur.get("layers") or [] | |
| if not layers_list: | |
| dd, layers, mini, idx, chips = _render_from_state(st) | |
| return ( | |
| st, | |
| dd, | |
| layers, | |
| mini, | |
| idx, | |
| chips, | |
| "Refine: current node has no layers.", | |
| gr.update(open=False), | |
| [], | |
| None, | |
| None, | |
| ) | |
| idx = int(st.get("selected_layer_idx", 0)) | |
| idx = max(0, min(idx, len(layers_list) - 1)) | |
| selected_layer = layers_list[idx].convert("RGBA") | |
| # Run pipeline again on selected layer, producing sub-layers | |
| sub_layers_count = _clamp_int(sub_layers_count, default=3, lo=2, hi=10) | |
| sub_layers = gpu_run_pipeline( | |
| selected_layer, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=sub_layers_count, # <-- only change: layers = sub_layers_count | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| new_id = _make_node(st, layers=sub_layers, parent_id=cur["id"], name=f"refine L{idx+1}") | |
| _set_current(st, new_id) | |
| st["last_refined_id"] = new_id | |
| dd, layers, mini, idx2, chips = _render_from_state(st) | |
| # Export files for current node on-demand (not automatic) | |
| return ( | |
| st, | |
| dd, | |
| layers, | |
| mini, | |
| idx2, | |
| chips, | |
| f"Refined: created node {_node_label(st['nodes'][new_id])}", | |
| gr.update(open=True), # open refined accordion | |
| sub_layers, # show refined layers | |
| None, | |
| None, | |
| ) | |
| # ---------------------------- | |
| # App UI | |
| # ---------------------------- | |
| ensure_dirname(LOG_DIR) | |
| examples = [ | |
| "assets/test_images/1.png", | |
| "assets/test_images/2.png", | |
| "assets/test_images/3.png", | |
| "assets/test_images/4.png", | |
| "assets/test_images/5.png", | |
| "assets/test_images/6.png", | |
| "assets/test_images/7.png", | |
| "assets/test_images/8.png", | |
| "assets/test_images/9.png", | |
| "assets/test_images/10.png", | |
| "assets/test_images/11.png", | |
| "assets/test_images/12.png", | |
| "assets/test_images/13.png", | |
| ] | |
| with gr.Blocks() as demo: | |
| st = gr.State(value=new_state()) | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML( | |
| '<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" ' | |
| 'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">' | |
| ) | |
| persist_status = gr.Markdown(_enabled_msg) | |
| with gr.Row(): | |
| new_session_btn = gr.Button("New session", variant="secondary") | |
| session_id_box = gr.Textbox(label="Session ID", value="", interactive=False) | |
| save_btn = gr.Button("Save session to Dataset repo", variant="primary") | |
| save_status = gr.Textbox(label="Save/Load status", value="", interactive=False) | |
| with gr.Row(): | |
| load_session_id = gr.Textbox(label="Load Session ID", value="", placeholder="paste Session ID here") | |
| load_btn = gr.Button("Load", variant="secondary") | |
| gr.Markdown( | |
| """ | |
| The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded. | |
| It is not designed to control the semantic content of individual layers explicitly. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Input Image", image_mode="RGBA") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt (Optional)", | |
| placeholder="Please enter the prompt to descibe the image. (Optional)", | |
| value="", | |
| lines=2, | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="Please enter the negative prompt", | |
| value=" ", | |
| lines=2, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| true_guidance_scale = gr.Slider( | |
| label="True guidance scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=4.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| ) | |
| layer = gr.Slider( | |
| label="Layers", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=7, | |
| ) | |
| resolution = gr.Radio( | |
| label="Processing resolution", | |
| choices=[640, 1024], | |
| value=640, | |
| ) | |
| cfg_norm = gr.Checkbox( | |
| label="Whether enable CFG normalization", value=True | |
| ) | |
| use_en_prompt = gr.Checkbox( | |
| label="Automatic caption language if no prompt provided, True for EN, False for ZH", | |
| value=True, | |
| ) | |
| gpu_duration = gr.Textbox( | |
| label="GPU duration override (seconds, 20..1500)", | |
| value="1000", | |
| lines=1, | |
| placeholder="e.g. 60, 120, 300, 1000, 1500", | |
| ) | |
| run_button = gr.Button("Decompose!", variant="primary") | |
| gr.Markdown("### History") | |
| history_dd = gr.Dropdown( | |
| label="Nodes", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| ) | |
| chips_md = gr.Markdown("") | |
| with gr.Row(): | |
| back_btn = gr.Button("← back to parent") | |
| dup_btn = gr.Button("Duplicate node (branch)") | |
| with gr.Row(): | |
| rename_inp = gr.Textbox(label="Branch name", value="", placeholder="rename current node") | |
| rename_btn = gr.Button("Rename") | |
| with gr.Row(): | |
| export_btn = gr.Button("Export selected node") | |
| export_file = gr.File(label="Download PPTX") | |
| export_zip_file = gr.File(label="Download ZIP") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Layers (current node)") | |
| gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png") | |
| with gr.Accordion("Layer picker (mini, click like Photoshop)", open=True): | |
| mini_gallery = gr.Gallery(label="Pick layer to refine", columns=7, rows=1, format="png") | |
| selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, interactive=False) | |
| with gr.Accordion("Refine selected layer", open=True): | |
| refine_info = gr.Textbox(label="Refine status", value="", interactive=False) | |
| with gr.Row(): | |
| sub_layers_count = gr.Slider( | |
| label="Sub-layers (refine)", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=3, | |
| ) | |
| refine_btn = gr.Button("Refine selected layer", variant="primary") | |
| refined_acc = gr.Accordion("Refined layers (latest)", open=False) | |
| with refined_acc: | |
| refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png") | |
| # Examples run Decompose | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image], | |
| outputs=[gallery, export_file, export_zip_file], | |
| fn=lambda img: ([], None, None), | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| # Boot / init | |
| demo.load( | |
| fn=ui_boot, | |
| inputs=[], | |
| outputs=[persist_status, st], | |
| ).then( | |
| fn=lambda st: st.get("session_id", ""), | |
| inputs=[st], | |
| outputs=[session_id_box], | |
| ) | |
| # New session | |
| new_session_btn.click( | |
| fn=on_new_session, | |
| inputs=[st], | |
| outputs=[st, session_id_box, history_dd, gallery, mini_gallery, chips_md, refine_info, save_status, export_file, export_zip_file], | |
| ) | |
| # Decompose | |
| run_button.click( | |
| fn=on_decompose_click, | |
| inputs=[ | |
| input_image, | |
| seed, | |
| randomize_seed, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| layer, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| st, | |
| ], | |
| outputs=[ | |
| st, | |
| history_dd, | |
| gallery, | |
| mini_gallery, | |
| selected_layer_idx, | |
| chips_md, | |
| refined_acc, | |
| refined_gallery, | |
| export_file, | |
| export_zip_file, | |
| ], | |
| ).then( | |
| fn=lambda st: st.get("session_id", ""), | |
| inputs=[st], | |
| outputs=[session_id_box], | |
| ) | |
| # History selection | |
| history_dd.change( | |
| fn=on_history_select, | |
| inputs=[history_dd, st], | |
| outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md], | |
| ) | |
| # Mini gallery click -> choose layer index | |
| mini_gallery.select( | |
| fn=on_layer_gallery_select, | |
| inputs=[st], | |
| outputs=[st, selected_layer_idx], | |
| ) | |
| # Back to parent | |
| back_btn.click( | |
| fn=on_back_to_parent, | |
| inputs=[st], | |
| outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md], | |
| ) | |
| # Duplicate node | |
| dup_btn.click( | |
| fn=on_duplicate_node, | |
| inputs=[st], | |
| outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md], | |
| ) | |
| # Rename node | |
| rename_btn.click( | |
| fn=on_rename_node, | |
| inputs=[rename_inp, st], | |
| outputs=[st, history_dd], | |
| ).then( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=[rename_inp], | |
| ) | |
| # Refine selected layer | |
| refine_btn.click( | |
| fn=on_refine_click, | |
| inputs=[ | |
| sub_layers_count, | |
| seed, | |
| randomize_seed, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| st, | |
| ], | |
| outputs=[ | |
| st, | |
| history_dd, | |
| gallery, | |
| mini_gallery, | |
| selected_layer_idx, | |
| chips_md, | |
| refine_info, | |
| refined_acc, | |
| refined_gallery, | |
| export_file, | |
| export_zip_file, | |
| ], | |
| ) | |
| # Export current node | |
| export_btn.click( | |
| fn=on_export_selected, | |
| inputs=[st], | |
| outputs=[export_file, export_zip_file], | |
| ) | |
| # Save session | |
| save_btn.click( | |
| fn=on_save_session, | |
| inputs=[st], | |
| outputs=[save_status, session_id_box], | |
| ) | |
| # Load session | |
| load_btn.click( | |
| fn=on_load_session, | |
| inputs=[load_session_id, st], | |
| outputs=[st, save_status, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md], | |
| ).then( | |
| fn=lambda st: st.get("session_id", ""), | |
| inputs=[st], | |
| outputs=[session_id_box], | |
| ).then( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=[load_session_id], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |