EJ-L commited on
Commit
e0dfa99
·
1 Parent(s): d989c20

add app file

Browse files
Files changed (2) hide show
  1. app.py +170 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+
4
+ Spaces:
5
+ maxiw
6
+ /
7
+ OS-ATLAS
8
+
9
+ like
10
+ 25
11
+ App
12
+ Files
13
+ Community
14
+ OS-ATLAS
15
+ /
16
+ app.py
17
+
18
+ maxiw's picture
19
+ maxiw
20
+ Update app.py
21
+ d1e80e2
22
+ verified
23
+ raw
24
+
25
+ Copy download link
26
+ history
27
+ blame
28
+ contribute
29
+ delete
30
+
31
+ 4.81 kB
32
+ import gradio as gr
33
+ import spaces
34
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
35
+ from qwen_vl_utils import process_vision_info
36
+ import torch
37
+ import base64
38
+ from PIL import Image, ImageDraw
39
+ from io import BytesIO
40
+ import re
41
+
42
+
43
+ models = {
44
+ "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained("OS-Copilot/OS-Atlas-Base-7B", torch_dtype="auto", device_map="auto"),
45
+ }
46
+
47
+ processors = {
48
+ "OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B")
49
+ }
50
+
51
+
52
+ def image_to_base64(image):
53
+ buffered = BytesIO()
54
+ image.save(buffered, format="PNG")
55
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
56
+ return img_str
57
+
58
+
59
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
60
+ draw = ImageDraw.Draw(image)
61
+ for box in bounding_boxes:
62
+ xmin, ymin, xmax, ymax = box
63
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
64
+ return image
65
+
66
+
67
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
68
+ x_scale = original_width / scaled_width
69
+ y_scale = original_height / scaled_height
70
+ rescaled_boxes = []
71
+ for box in bounding_boxes:
72
+ xmin, ymin, xmax, ymax = box
73
+ rescaled_box = [
74
+ xmin * x_scale,
75
+ ymin * y_scale,
76
+ xmax * x_scale,
77
+ ymax * y_scale
78
+ ]
79
+ rescaled_boxes.append(rescaled_box)
80
+ return rescaled_boxes
81
+
82
+
83
+ @spaces.GPU
84
+ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
85
+ model = models[model_id].eval()
86
+ processor = processors[model_id]
87
+ prompt = f"In this UI screenshot, what is the position of the element corresponding to the command \"{text_input}\" (with bbox)?"
88
+ messages = [
89
+ {
90
+ "role": "user",
91
+ "content": [
92
+ {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
93
+ {"type": "text", "text": prompt},
94
+ ],
95
+ }
96
+ ]
97
+
98
+ text = processor.apply_chat_template(
99
+ messages, tokenize=False, add_generation_prompt=True
100
+ )
101
+ image_inputs, video_inputs = process_vision_info(messages)
102
+ inputs = processor(
103
+ text=[text],
104
+ images=image_inputs,
105
+ videos=video_inputs,
106
+ padding=True,
107
+ return_tensors="pt",
108
+ )
109
+ inputs = inputs.to("cuda")
110
+
111
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
112
+ generated_ids_trimmed = [
113
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
114
+ ]
115
+ output_text = processor.batch_decode(
116
+ generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
117
+ )
118
+ print(output_text)
119
+ text = output_text[0]
120
+
121
+ object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
122
+ box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
123
+
124
+ object_ref = re.search(object_ref_pattern, text).group(1)
125
+ box_content = re.search(box_pattern, text).group(1)
126
+
127
+ boxes = [tuple(map(int, pair.strip("()").split(','))) for pair in box_content.split("),(")]
128
+ boxes = [[boxes[0][0], boxes[0][1], boxes[1][0], boxes[1][1]]]
129
+
130
+ scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height)
131
+ return object_ref, scaled_boxes, draw_bounding_boxes(image, scaled_boxes)
132
+
133
+ css = """
134
+ #output {
135
+ height: 500px;
136
+ overflow: auto;
137
+ border: 1px solid #ccc;
138
+ }
139
+ """
140
+ with gr.Blocks(css=css) as demo:
141
+ gr.Markdown(
142
+ """
143
+ # Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents
144
+ """)
145
+ with gr.Row():
146
+ with gr.Column():
147
+ input_img = gr.Image(label="Input Image", type="pil")
148
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="OS-Copilot/OS-Atlas-Base-7B")
149
+ text_input = gr.Textbox(label="User Prompt")
150
+ submit_btn = gr.Button(value="Submit")
151
+ with gr.Column():
152
+ model_output_text = gr.Textbox(label="Model Output Text")
153
+ model_output_box = gr.Textbox(label="Model Output Box")
154
+ annotated_image = gr.Image(label="Annotated Image")
155
+
156
+ gr.Examples(
157
+ examples=[
158
+ ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "select search textfield"],
159
+ ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"],
160
+ ],
161
+ inputs=[input_img, text_input],
162
+ outputs=[model_output_text, model_output_box, annotated_image],
163
+ fn=run_example,
164
+ cache_examples=True,
165
+ label="Try examples"
166
+ )
167
+
168
+ submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, model_output_box, annotated_image])
169
+
170
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.24.4
2
+ Pillow==10.3.0
3
+ Requests==2.31.0
4
+ torch
5
+ torchvision
6
+ transformers
7
+ accelerate==0.30.0
8
+ qwen-vl-utils