hexware's picture
Update app.py
230a765 verified
import os
import uuid
import json
import time
import shutil
import numpy as np
import random
import tempfile
import zipfile
from typing import Any, Dict, List, Optional, Tuple
import spaces
import torch
import gradio as gr
from PIL import Image
from diffusers import QwenImageLayeredPipeline
from pptx import Presentation
LOG_DIR = "/tmp/local"
MAX_SEED = np.iinfo(np.int32).max
# Optional HF login (works in Spaces if you set HF token as secret env var "hf")
from huggingface_hub import login, HfApi, hf_hub_download
login(token=os.environ.get("hf"))
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = QwenImageLayeredPipeline.from_pretrained(
"Qwen/Qwen-Image-Layered", torch_dtype=dtype
).to(device)
# ----------------------------
# Dataset repo persistence (no /data needed)
# ----------------------------
HF_TOKEN = os.environ.get("hf") # secret
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO") # e.g. "hexware/qwen-layered-sessions"
_hf_api: Optional[HfApi] = None
_persist_enabled = False
def _init_dataset_repo() -> Tuple[bool, str]:
"""
Returns (enabled, message)
"""
global _hf_api, _persist_enabled
if not HF_TOKEN:
_persist_enabled = False
return False, "Persistence: disabled (no secret env var 'hf')."
if not HF_DATASET_REPO:
_persist_enabled = False
return False, "Persistence: disabled (set env var HF_DATASET_REPO to enable)."
try:
_hf_api = HfApi(token=HF_TOKEN)
# Create dataset repo if missing (private). If exists, this is no-op.
# NOTE: create_repo is available via HfApi in most versions.
_hf_api.create_repo(
repo_id=HF_DATASET_REPO,
repo_type="dataset",
private=True,
exist_ok=True,
)
_persist_enabled = True
return True, f"Persistence: enabled (dataset repo: {HF_DATASET_REPO})."
except Exception as e:
_persist_enabled = False
return False, f"Persistence: failed to init dataset repo: {type(e).__name__}: {e}"
_enabled, _enabled_msg = _init_dataset_repo()
# ----------------------------
# Helpers
# ----------------------------
def ensure_dirname(path: str):
if path and not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def random_str(length=8):
return uuid.uuid4().hex[:length]
def _now_ts() -> float:
return time.time()
def _clamp_int(x, default: int, lo: int, hi: int) -> int:
try:
v = int(x)
except Exception:
v = default
return max(lo, min(hi, v))
def _pil_rgba(input_image) -> Image.Image:
if isinstance(input_image, list):
input_image = input_image[0]
if isinstance(input_image, str):
pil_image = Image.open(input_image).convert("RGB").convert("RGBA")
elif isinstance(input_image, Image.Image):
pil_image = input_image.convert("RGB").convert("RGBA")
elif isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
else:
raise ValueError(f"Unsupported input_image type: {type(input_image)}")
return pil_image
def imagelist_to_pptx(img_files: List[str]) -> str:
with Image.open(img_files[0]) as img:
img_width_px, img_height_px = img.size
def px_to_emu(px, dpi=96):
inch = px / dpi
emu = inch * 914400
return int(emu)
prs = Presentation()
prs.slide_width = px_to_emu(img_width_px)
prs.slide_height = px_to_emu(img_height_px)
slide = prs.slides.add_slide(prs.slide_layouts[6])
left = top = 0
for img_path in img_files:
slide.shapes.add_picture(
img_path,
left,
top,
width=px_to_emu(img_width_px),
height=px_to_emu(img_height_px),
)
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as tmp:
prs.save(tmp.name)
return tmp.name
def export_node_layers(layers: List[Image.Image]) -> Tuple[str, str]:
"""
Returns (pptx_path, zip_path)
"""
temp_files: List[str] = []
for img in layers:
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp.name)
temp_files.append(tmp.name)
pptx_path = imagelist_to_pptx(temp_files)
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
for i, img_path in enumerate(temp_files):
zipf.write(img_path, f"layer_{i+1}.png")
zip_path = tmpzip.name
return pptx_path, zip_path
# ----------------------------
# ZeroGPU duration
# ----------------------------
def get_duration(
input_image,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
layer=4,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
):
return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
# ----------------------------
# GPU pipeline runners
# ----------------------------
@spaces.GPU(duration=get_duration)
def gpu_run_pipeline(
input_pil_image: Image.Image,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
layer=4,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
) -> List[Image.Image]:
# Seed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Normalize resolution input
resolution = _clamp_int(resolution, default=640, lo=640, hi=1024)
if resolution not in (640, 1024):
resolution = 640
gen_device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = {
"image": input_pil_image,
"generator": torch.Generator(device=gen_device).manual_seed(seed),
"true_cfg_scale": true_guidance_scale,
"prompt": prompt,
"negative_prompt": neg_prompt,
"num_inference_steps": num_inference_steps,
"num_images_per_prompt": 1,
"layers": layer,
"resolution": resolution,
"cfg_normalize": cfg_norm,
"use_en_prompt": use_en_prompt,
}
with torch.inference_mode():
out = pipeline(**inputs)
# out.images[0] => list of PIL images
return out.images[0]
# ----------------------------
# Session / History model
# ----------------------------
def new_state() -> Dict[str, Any]:
sid = uuid.uuid4().hex
return {
"session_id": sid,
"nodes": {}, # node_id -> node dict
"root_id": None,
"current_id": None,
"selected_layer_idx": 0,
"last_refined_id": None,
}
def _node_label(node: Dict[str, Any]) -> str:
name = node.get("name") or node["id"][:8]
return f"{name} ({node['id'][:8]})"
def _build_history_choices(st: Dict[str, Any]) -> List[Tuple[str, str]]:
# returns list of (label, value=node_id)
out = []
for nid, node in st["nodes"].items():
out.append((_node_label(node), nid))
# stable order by created
out.sort(key=lambda x: st["nodes"][x[1]].get("created_at", 0.0))
return out
def _get_node(st: Dict[str, Any], node_id: Optional[str]) -> Optional[Dict[str, Any]]:
if not node_id:
return None
return st["nodes"].get(node_id)
def _current_node(st: Dict[str, Any]) -> Optional[Dict[str, Any]]:
return _get_node(st, st.get("current_id"))
def _chips_text(st: Dict[str, Any], node_id: Optional[str]) -> str:
node = _get_node(st, node_id)
if not node:
return ""
chips = []
if node_id == st.get("root_id"):
chips.append("[root]")
if node.get("parent_id"):
chips.append("[parent]")
children = node.get("children_ids") or []
if children:
chips.append(f"[children:{len(children)}]")
return " ".join(chips)
def _make_node(
st: Dict[str, Any],
layers: List[Image.Image],
parent_id: Optional[str],
name: Optional[str] = None,
) -> str:
nid = uuid.uuid4().hex
node = {
"id": nid,
"name": name or ("root" if parent_id is None else "refine"),
"parent_id": parent_id,
"children_ids": [],
"created_at": _now_ts(),
"layers": layers,
}
st["nodes"][nid] = node
if parent_id:
parent = st["nodes"].get(parent_id)
if parent is not None:
parent.setdefault("children_ids", []).append(nid)
return nid
def _set_current(st: Dict[str, Any], node_id: str):
st["current_id"] = node_id
st["selected_layer_idx"] = 0
# ----------------------------
# Persistence: save/load whole session as one zip in dataset repo
# ----------------------------
def _serialize_session_to_zip(st: Dict[str, Any]) -> str:
"""
Create a zip file with:
session.json
nodes/<node_id>/layer_1.png ...
Returns local zip path.
"""
tmpdir = tempfile.mkdtemp(prefix="sess_")
try:
sess_meta = {
"session_id": st["session_id"],
"root_id": st["root_id"],
"current_id": st["current_id"],
"selected_layer_idx": st.get("selected_layer_idx", 0),
"last_refined_id": st.get("last_refined_id"),
"nodes": {},
}
for nid, node in st["nodes"].items():
node_dir = os.path.join(tmpdir, "nodes", nid)
os.makedirs(node_dir, exist_ok=True)
layers: List[Image.Image] = node.get("layers") or []
for i, img in enumerate(layers):
img_path = os.path.join(node_dir, f"layer_{i+1}.png")
img.save(img_path)
sess_meta["nodes"][nid] = {
"id": nid,
"name": node.get("name"),
"parent_id": node.get("parent_id"),
"children_ids": node.get("children_ids") or [],
"created_at": node.get("created_at", 0.0),
"layer_count": len(layers),
}
meta_path = os.path.join(tmpdir, "session.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(sess_meta, f, ensure_ascii=False, indent=2)
out_zip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
with zipfile.ZipFile(out_zip, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(tmpdir):
for fn in files:
abs_path = os.path.join(root, fn)
rel_path = os.path.relpath(abs_path, tmpdir)
zf.write(abs_path, rel_path)
return out_zip
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _deserialize_session_from_zip(zip_path: str) -> Dict[str, Any]:
tmpdir = tempfile.mkdtemp(prefix="sess_load_")
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(tmpdir)
meta_path = os.path.join(tmpdir, "session.json")
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
st = new_state()
st["session_id"] = meta["session_id"]
st["root_id"] = meta.get("root_id")
st["current_id"] = meta.get("current_id")
st["selected_layer_idx"] = meta.get("selected_layer_idx", 0)
st["last_refined_id"] = meta.get("last_refined_id")
nodes_meta: Dict[str, Any] = meta.get("nodes", {})
# First pass: create node shells
for nid, nm in nodes_meta.items():
st["nodes"][nid] = {
"id": nid,
"name": nm.get("name"),
"parent_id": nm.get("parent_id"),
"children_ids": nm.get("children_ids") or [],
"created_at": nm.get("created_at", 0.0),
"layers": [],
}
# Second pass: load layers images
for nid, nm in nodes_meta.items():
layer_count = int(nm.get("layer_count", 0))
node_dir = os.path.join(tmpdir, "nodes", nid)
layers: List[Image.Image] = []
for i in range(layer_count):
p = os.path.join(node_dir, f"layer_{i+1}.png")
if os.path.exists(p):
layers.append(Image.open(p).convert("RGBA"))
st["nodes"][nid]["layers"] = layers
return st
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def save_session_to_hub(st: Dict[str, Any]) -> Tuple[str, str]:
"""
Returns (status_text, session_id)
"""
if not _persist_enabled or _hf_api is None:
return "Save: disabled (set HF_DATASET_REPO and secret hf write token).", st.get("session_id", "")
try:
zip_path = _serialize_session_to_zip(st)
path_in_repo = f"sessions/{st['session_id']}.zip"
_hf_api.upload_file(
path_or_fileobj=zip_path,
path_in_repo=path_in_repo,
repo_id=HF_DATASET_REPO,
repo_type="dataset",
commit_message=f"Save session {st['session_id']}",
)
return f"Saved to dataset repo: {path_in_repo}", st["session_id"]
except Exception as e:
return f"Save failed: {type(e).__name__}: {e}", st.get("session_id", "")
finally:
try:
if "zip_path" in locals() and os.path.exists(zip_path):
os.remove(zip_path)
except Exception:
pass
def load_session_from_hub(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
if not _persist_enabled:
return None, "Load: disabled (set HF_DATASET_REPO and secret hf write token)."
session_id = (session_id or "").strip()
if not session_id:
return None, "Load: please enter a Session ID."
try:
filename = f"sessions/{session_id}.zip"
local_zip = hf_hub_download(
repo_id=HF_DATASET_REPO,
repo_type="dataset",
filename=filename,
token=HF_TOKEN,
)
st = _deserialize_session_from_zip(local_zip)
return st, f"Loaded session: {session_id}"
except Exception as e:
return None, f"Load failed: {type(e).__name__}: {e}"
# ----------------------------
# UI Callbacks
# ----------------------------
def ui_boot() -> Tuple[str, Dict[str, Any]]:
ensure_dirname(LOG_DIR)
st = new_state()
return _enabled_msg, st
def on_new_session(st: Dict[str, Any]) -> Tuple[Dict[str, Any], str, gr.Dropdown, List[Image.Image], List[Image.Image], str, str, str, Optional[str], Optional[str]]:
st = new_state()
return (
st,
st["session_id"],
gr.Dropdown(choices=[], value=None),
[],
[],
"",
"",
"",
None,
None,
)
def _render_from_state(st: Dict[str, Any]) -> Tuple[
gr.Dropdown,
List[Image.Image],
List[Image.Image],
gr.Number,
str
]:
choices = _build_history_choices(st)
current = _current_node(st)
layers = current["layers"] if current else []
idx = st.get("selected_layer_idx", 0)
if layers:
idx = max(0, min(idx, len(layers) - 1))
st["selected_layer_idx"] = idx
chips = _chips_text(st, st.get("current_id"))
return (
gr.Dropdown(choices=choices, value=st.get("current_id")),
layers,
layers, # mini gallery mirrors current layers
idx,
chips,
)
def on_history_select(node_id: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
if node_id and node_id in st["nodes"]:
st["current_id"] = node_id
st["selected_layer_idx"] = 0
dd, layers, mini, idx, chips = _render_from_state(st)
return st, dd, layers, mini, idx, chips
def on_layer_gallery_select(evt: gr.SelectData, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Number]:
# evt.index is int for Gallery
idx = int(evt.index) if evt and evt.index is not None else 0
current = _current_node(st)
if current:
layers = current.get("layers") or []
if layers:
idx = max(0, min(idx, len(layers) - 1))
else:
idx = 0
else:
idx = 0
st["selected_layer_idx"] = idx
return st, idx
def on_back_to_parent(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
cur = _current_node(st)
if cur and cur.get("parent_id"):
st["current_id"] = cur["parent_id"]
st["selected_layer_idx"] = 0
dd, layers, mini, idx, chips = _render_from_state(st)
return st, dd, layers, mini, idx, chips
def on_duplicate_node(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
cur = _current_node(st)
if cur:
# Duplicate current node as sibling (same parent)
layers = cur.get("layers") or []
parent_id = cur.get("parent_id")
name = (cur.get("name") or "node") + " copy"
new_id = _make_node(st, layers=layers, parent_id=parent_id, name=name)
_set_current(st, new_id)
if st.get("root_id") is None and parent_id is None:
st["root_id"] = new_id
dd, layers, mini, idx, chips = _render_from_state(st)
return st, dd, layers, mini, idx, chips
def on_rename_node(new_name: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown]:
cur = _current_node(st)
if cur:
nn = (new_name or "").strip()
if nn:
cur["name"] = nn
dd, _, _, _, _ = _render_from_state(st)
return st, dd
def on_export_selected(st: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
cur = _current_node(st)
if not cur:
return None, None
pptx_path, zip_path = export_node_layers(cur.get("layers") or [])
return pptx_path, zip_path
def on_save_session(st: Dict[str, Any]) -> Tuple[str, str]:
status, sid = save_session_to_hub(st)
return status, sid
def on_load_session(session_id: str, st: Dict[str, Any]) -> Tuple[
Dict[str, Any],
str,
gr.Dropdown,
List[Image.Image],
List[Image.Image],
gr.Number,
str
]:
loaded, msg = load_session_from_hub(session_id)
if loaded is None:
dd, layers, mini, idx, chips = _render_from_state(st)
return st, msg, dd, layers, mini, idx, chips
st = loaded
dd, layers, mini, idx, chips = _render_from_state(st)
return st, msg, dd, layers, mini, idx, chips
# GPU click handlers
def on_decompose_click(
input_image,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
layer=4,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
st: Optional[Dict[str, Any]] = None,
):
if st is None:
st = new_state()
pil_image = _pil_rgba(input_image)
layers_out = gpu_run_pipeline(
pil_image,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=layer,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
# Reset session tree on Decompose (new root)
sid = st.get("session_id") or uuid.uuid4().hex
st = new_state()
st["session_id"] = sid
root_id = _make_node(st, layers=layers_out, parent_id=None, name="root")
st["root_id"] = root_id
_set_current(st, root_id)
dd, layers, mini, idx, chips = _render_from_state(st)
return (
st,
dd,
layers,
mini,
idx,
chips,
gr.update(open=False), # refined accordion closed
[], # refined gallery cleared
None,
None,
)
def on_refine_click(
sub_layers_count: int,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
st: Optional[Dict[str, Any]] = None,
):
if st is None:
st = new_state()
cur = _current_node(st)
if not cur:
dd, layers, mini, idx, chips = _render_from_state(st)
return (
st,
dd,
layers,
mini,
idx,
chips,
"Refine: no current node.",
gr.update(open=False),
[],
None,
None,
)
layers_list: List[Image.Image] = cur.get("layers") or []
if not layers_list:
dd, layers, mini, idx, chips = _render_from_state(st)
return (
st,
dd,
layers,
mini,
idx,
chips,
"Refine: current node has no layers.",
gr.update(open=False),
[],
None,
None,
)
idx = int(st.get("selected_layer_idx", 0))
idx = max(0, min(idx, len(layers_list) - 1))
selected_layer = layers_list[idx].convert("RGBA")
# Run pipeline again on selected layer, producing sub-layers
sub_layers_count = _clamp_int(sub_layers_count, default=3, lo=2, hi=10)
sub_layers = gpu_run_pipeline(
selected_layer,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=sub_layers_count, # <-- only change: layers = sub_layers_count
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
new_id = _make_node(st, layers=sub_layers, parent_id=cur["id"], name=f"refine L{idx+1}")
_set_current(st, new_id)
st["last_refined_id"] = new_id
dd, layers, mini, idx2, chips = _render_from_state(st)
# Export files for current node on-demand (not automatic)
return (
st,
dd,
layers,
mini,
idx2,
chips,
f"Refined: created node {_node_label(st['nodes'][new_id])}",
gr.update(open=True), # open refined accordion
sub_layers, # show refined layers
None,
None,
)
# ----------------------------
# App UI
# ----------------------------
ensure_dirname(LOG_DIR)
examples = [
"assets/test_images/1.png",
"assets/test_images/2.png",
"assets/test_images/3.png",
"assets/test_images/4.png",
"assets/test_images/5.png",
"assets/test_images/6.png",
"assets/test_images/7.png",
"assets/test_images/8.png",
"assets/test_images/9.png",
"assets/test_images/10.png",
"assets/test_images/11.png",
"assets/test_images/12.png",
"assets/test_images/13.png",
]
with gr.Blocks() as demo:
st = gr.State(value=new_state())
with gr.Column(elem_id="col-container"):
gr.HTML(
'<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" '
'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">'
)
persist_status = gr.Markdown(_enabled_msg)
with gr.Row():
new_session_btn = gr.Button("New session", variant="secondary")
session_id_box = gr.Textbox(label="Session ID", value="", interactive=False)
save_btn = gr.Button("Save session to Dataset repo", variant="primary")
save_status = gr.Textbox(label="Save/Load status", value="", interactive=False)
with gr.Row():
load_session_id = gr.Textbox(label="Load Session ID", value="", placeholder="paste Session ID here")
load_btn = gr.Button("Load", variant="secondary")
gr.Markdown(
"""
The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded.
It is not designed to control the semantic content of individual layers explicitly.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", image_mode="RGBA")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Textbox(
label="Prompt (Optional)",
placeholder="Please enter the prompt to descibe the image. (Optional)",
value="",
lines=2,
)
neg_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Please enter the negative prompt",
value=" ",
lines=2,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
true_guidance_scale = gr.Slider(
label="True guidance scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=4.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
layer = gr.Slider(
label="Layers",
minimum=2,
maximum=10,
step=1,
value=7,
)
resolution = gr.Radio(
label="Processing resolution",
choices=[640, 1024],
value=640,
)
cfg_norm = gr.Checkbox(
label="Whether enable CFG normalization", value=True
)
use_en_prompt = gr.Checkbox(
label="Automatic caption language if no prompt provided, True for EN, False for ZH",
value=True,
)
gpu_duration = gr.Textbox(
label="GPU duration override (seconds, 20..1500)",
value="1000",
lines=1,
placeholder="e.g. 60, 120, 300, 1000, 1500",
)
run_button = gr.Button("Decompose!", variant="primary")
gr.Markdown("### History")
history_dd = gr.Dropdown(
label="Nodes",
choices=[],
value=None,
interactive=True,
)
chips_md = gr.Markdown("")
with gr.Row():
back_btn = gr.Button("← back to parent")
dup_btn = gr.Button("Duplicate node (branch)")
with gr.Row():
rename_inp = gr.Textbox(label="Branch name", value="", placeholder="rename current node")
rename_btn = gr.Button("Rename")
with gr.Row():
export_btn = gr.Button("Export selected node")
export_file = gr.File(label="Download PPTX")
export_zip_file = gr.File(label="Download ZIP")
with gr.Column(scale=2):
gr.Markdown("### Layers (current node)")
gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
with gr.Accordion("Layer picker (mini, click like Photoshop)", open=True):
mini_gallery = gr.Gallery(label="Pick layer to refine", columns=7, rows=1, format="png")
selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, interactive=False)
with gr.Accordion("Refine selected layer", open=True):
refine_info = gr.Textbox(label="Refine status", value="", interactive=False)
with gr.Row():
sub_layers_count = gr.Slider(
label="Sub-layers (refine)",
minimum=2,
maximum=10,
step=1,
value=3,
)
refine_btn = gr.Button("Refine selected layer", variant="primary")
refined_acc = gr.Accordion("Refined layers (latest)", open=False)
with refined_acc:
refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png")
# Examples run Decompose
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[gallery, export_file, export_zip_file],
fn=lambda img: ([], None, None),
cache_examples=False,
run_on_click=False,
)
# Boot / init
demo.load(
fn=ui_boot,
inputs=[],
outputs=[persist_status, st],
).then(
fn=lambda st: st.get("session_id", ""),
inputs=[st],
outputs=[session_id_box],
)
# New session
new_session_btn.click(
fn=on_new_session,
inputs=[st],
outputs=[st, session_id_box, history_dd, gallery, mini_gallery, chips_md, refine_info, save_status, export_file, export_zip_file],
)
# Decompose
run_button.click(
fn=on_decompose_click,
inputs=[
input_image,
seed,
randomize_seed,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
layer,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
st,
],
outputs=[
st,
history_dd,
gallery,
mini_gallery,
selected_layer_idx,
chips_md,
refined_acc,
refined_gallery,
export_file,
export_zip_file,
],
).then(
fn=lambda st: st.get("session_id", ""),
inputs=[st],
outputs=[session_id_box],
)
# History selection
history_dd.change(
fn=on_history_select,
inputs=[history_dd, st],
outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
)
# Mini gallery click -> choose layer index
mini_gallery.select(
fn=on_layer_gallery_select,
inputs=[st],
outputs=[st, selected_layer_idx],
)
# Back to parent
back_btn.click(
fn=on_back_to_parent,
inputs=[st],
outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
)
# Duplicate node
dup_btn.click(
fn=on_duplicate_node,
inputs=[st],
outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
)
# Rename node
rename_btn.click(
fn=on_rename_node,
inputs=[rename_inp, st],
outputs=[st, history_dd],
).then(
fn=lambda: "",
inputs=[],
outputs=[rename_inp],
)
# Refine selected layer
refine_btn.click(
fn=on_refine_click,
inputs=[
sub_layers_count,
seed,
randomize_seed,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
st,
],
outputs=[
st,
history_dd,
gallery,
mini_gallery,
selected_layer_idx,
chips_md,
refine_info,
refined_acc,
refined_gallery,
export_file,
export_zip_file,
],
)
# Export current node
export_btn.click(
fn=on_export_selected,
inputs=[st],
outputs=[export_file, export_zip_file],
)
# Save session
save_btn.click(
fn=on_save_session,
inputs=[st],
outputs=[save_status, session_id_box],
)
# Load session
load_btn.click(
fn=on_load_session,
inputs=[load_session_id, st],
outputs=[st, save_status, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
).then(
fn=lambda st: st.get("session_id", ""),
inputs=[st],
outputs=[session_id_box],
).then(
fn=lambda: "",
inputs=[],
outputs=[load_session_id],
)
if __name__ == "__main__":
demo.launch()