File size: 3,318 Bytes
0e84104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f87dfb8
 
 
 
 
0e84104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13989c9
 
9b026cf
13989c9
0e84104
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import spaces  # must be first!
import sys
import os
import torch
from PIL import Image
import gradio as gr
from glob import glob
from contextlib import nullcontext
from pipeline import Lotus2Pipeline
from diffusers import (
    FlowMatchEulerDiscreteScheduler,
    FluxTransformer2DModel,
)
from infer import (
    load_lora_and_lcm_weights,
    process_single_image
)

from huggingface_hub import login
import os

login(token=os.getenv("HF_TOKEN"))

pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16
task = None

@spaces.GPU
def load_pipeline():
    global pipeline, device, weight_dtype, task
    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
    )
    transformer.requires_grad_(False)
    transformer.to(device=device, dtype=weight_dtype)
    transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task)
    pipeline = Lotus2Pipeline.from_pretrained(
        'black-forest-labs/FLUX.1-dev',
        scheduler=noise_scheduler,
        transformer=transformer,
        revision=None,
        variant=None,
        torch_dtype=weight_dtype,
    )
    pipeline.local_continuity_module = local_continuity_module
    pipeline = pipeline.to(device)

@spaces.GPU
def fn(image_path):
    global pipeline, device, task
    pipeline.set_progress_bar_config(disable=True)
    with nullcontext():
        _, output_vis, _ = process_single_image(
            image_path, pipeline, 
            task_name=task,
            device=device,
            num_inference_steps=10,
            process_res=1024
        )
    return [Image.open(image_path), output_vis]

def build_demo():
    global task
    inputs = [
        gr.Image(label="Image", type="filepath")
    ]
    outputs = [
        gr.ImageSlider(
            label=f"{task.title()}",
            type="pil",
            slider_position=20,
        )
    ]
    examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg")
    demo = gr.Interface(
        fn=fn,
        title="Lotus-2: Advancing Geometric Dense Prediction with Powerful Image Generative Model",
        description=f"""
            <strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/Lotus-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful! 😊</strong>
            <br>
            <strong>Current Task: </strong><strong style="color: red;">{task.title()}</strong>
        """,
        inputs=inputs,
        outputs=outputs,
        examples=examples,
        examples_per_page=10
    )
    return demo

def main(task_name):
    global task
    task = task_name
    load_pipeline()
    demo = build_demo()
    demo.launch(
        # server_name="0.0.0.0",
        # server_port=6382,
    )

if __name__ == "__main__":
    task_name = "normal"
    if not task_name in ['depth', 'normal']:
        raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.")
    main(task_name)