|
|
import os |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
import gradio_client.utils as gcu |
|
|
|
|
|
orig_json_schema_to_python_type = gcu._json_schema_to_python_type |
|
|
|
|
|
def _safe_json_schema_to_python_type(schema, defs): |
|
|
|
|
|
if isinstance(schema, bool): |
|
|
|
|
|
return "Any" if schema else "Never" |
|
|
return orig_json_schema_to_python_type(schema, defs) |
|
|
|
|
|
gcu._json_schema_to_python_type = _safe_json_schema_to_python_type |
|
|
|
|
|
|
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import torch |
|
|
import base64 |
|
|
from PIL import Image, ImageDraw |
|
|
from io import BytesIO |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" |
|
|
|
|
|
|
|
|
PORT = int(os.getenv("PORT", "7860")) |
|
|
|
|
|
|
|
|
|
|
|
models = { |
|
|
"OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
"OS-Copilot/OS-Atlas-Base-7B", |
|
|
dtype="auto", |
|
|
device_map=None |
|
|
).to(device) |
|
|
} |
|
|
|
|
|
processors = { |
|
|
"OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B") |
|
|
} |
|
|
|
|
|
|
|
|
def image_to_base64(image: Image.Image) -> str: |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def draw_bounding_boxes(image: Image.Image, bounding_boxes, outline_color="red", line_width=2): |
|
|
draw = ImageDraw.Draw(image) |
|
|
for box in bounding_boxes or []: |
|
|
xmin, ymin, xmax, ymax = box |
|
|
draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width) |
|
|
return image |
|
|
|
|
|
def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000): |
|
|
if not bounding_boxes: |
|
|
return [] |
|
|
x_scale = original_width / scaled_width |
|
|
y_scale = original_height / scaled_height |
|
|
return [ |
|
|
[xmin * x_scale, ymin * y_scale, xmax * x_scale, ymax * y_scale] |
|
|
for (xmin, ymin, xmax, ymax) in bounding_boxes |
|
|
] |
|
|
|
|
|
|
|
|
def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"): |
|
|
|
|
|
if image is None or (text_input is None or str(text_input).strip() == ""): |
|
|
return "", [], image |
|
|
|
|
|
model = models[model_id].eval() |
|
|
processor = processors[model_id] |
|
|
|
|
|
prompt = f'In this UI screenshot, what is the position of the element corresponding to the command "{text_input}" (with bbox)?' |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"}, |
|
|
{"type": "text", "text": prompt}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate(**inputs, max_new_tokens=128) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)] |
|
|
output_texts = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False |
|
|
) |
|
|
text = output_texts[0] if output_texts else "" |
|
|
|
|
|
|
|
|
object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>" |
|
|
box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>" |
|
|
|
|
|
object_match = re.search(object_ref_pattern, text or "") |
|
|
box_match = re.search(box_pattern, text or "") |
|
|
|
|
|
object_ref = object_match.group(1).strip() if object_match else "" |
|
|
box_content = box_match.group(1).strip() if box_match else "" |
|
|
|
|
|
boxes = [] |
|
|
if box_content: |
|
|
try: |
|
|
|
|
|
parts = [p.strip() for p in box_content.split("),(")] |
|
|
parts[0] = parts[0].lstrip("(") |
|
|
parts[-1] = parts[-1].rstrip(")") |
|
|
coords = [tuple(map(int, p.split(","))) for p in parts] |
|
|
if len(coords) >= 2: |
|
|
(x1, y1), (x2, y2) = coords[0], coords[1] |
|
|
boxes = [[x1, y1, x2, y2]] |
|
|
except Exception: |
|
|
boxes = [] |
|
|
|
|
|
scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height) if boxes else [] |
|
|
annotated = draw_bounding_boxes(image.copy(), scaled_boxes) if scaled_boxes else image |
|
|
|
|
|
return object_ref, scaled_boxes, annotated |
|
|
|
|
|
|
|
|
css = """ |
|
|
#output { |
|
|
height: 500px; |
|
|
overflow: auto; |
|
|
border: 1px solid #ccc; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML(f"<style>{css}</style>") |
|
|
gr.Markdown("# Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_img = gr.Image(label="Input Image", type="pil") |
|
|
model_selector = gr.Dropdown( |
|
|
choices=list(models.keys()), |
|
|
label="Model", |
|
|
value="OS-Copilot/OS-Atlas-Base-7B" |
|
|
) |
|
|
text_input = gr.Textbox(label="User Prompt") |
|
|
submit_btn = gr.Button(value="Submit") |
|
|
with gr.Column(): |
|
|
model_output_text = gr.Textbox(label="Model Output Text") |
|
|
model_output_box = gr.Textbox(label="Model Output Box") |
|
|
annotated_image = gr.Image(label="Annotated Image") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "select search textfield"], |
|
|
["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"], |
|
|
], |
|
|
inputs=[input_img, text_input], |
|
|
|
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
run_example, |
|
|
[input_img, text_input, model_selector], |
|
|
[model_output_text, model_output_box, annotated_image], |
|
|
) |
|
|
|
|
|
|
|
|
from fastapi import Request |
|
|
from starlette.responses import PlainTextResponse |
|
|
|
|
|
app = demo.app |
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def _catch_all_exceptions(request: Request, exc: Exception): |
|
|
|
|
|
return PlainTextResponse("Internal Server Error", status_code=500) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=PORT, |
|
|
show_error=False, |
|
|
debug=False |
|
|
) |