import os os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import spaces # MUST come before torch / any CUDA-touching import import torch import gradio as gr import re from PIL import Image, ImageDraw from transformers import AutoModelForImageTextToText, AutoProcessor MODEL_ID = "inclusionAI/VISTA-9B" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, ).to("cuda").eval() COORD_SCALE = 1000 # VISTA outputs normalized coordinates in [0, 1000] frame def parse_coordinate(response_text: str): """Extract (x, y) normalized coordinates from the model's response. The model outputs in the format [x,y] where x, y are in [0, 1000]. """ match = re.search(r'\[(\d+)\s*,\s*(\d+)\]', response_text) if match: x = int(match.group(1)) y = int(match.group(2)) return x, y return None def draw_marker(image: Image.Image, x: int, y: int) -> Image.Image: """Draw a crosshair marker and circle at the predicted pixel coordinate.""" annotated = image.copy() draw = ImageDraw.Draw(annotated) radius = max(10, min(image.width, image.height) // 50) # Draw filled circle with outline draw.ellipse( [x - radius, y - radius, x + radius, y + radius], outline="red", width=4, ) # Draw crosshair lines line_len = radius + 10 draw.line([(x - line_len, y), (x + line_len, y)], fill="red", width=3) draw.line([(x, y - line_len), (x, y + line_len)], fill="red", width=3) return annotated @spaces.GPU(duration=120) def predict(image: Image.Image, instruction: str): """Predict the click coordinate for a target element on a GUI screenshot. Args: image: A GUI screenshot (PNG/JPG). instruction: Natural-language description of the target element to click. Returns: A tuple of (annotated image, coordinate string, raw model output). """ if image is None: return None, "Please upload an image.", "" if not instruction.strip(): return None, "Please enter an instruction.", "" image = image.convert("RGB") w, h = image.size prompt = ( "Output the center point of the position corresponding to the instruction: " f"{instruction}. The output should just be the coordinates of a point, " "in the format [x,y]." ) messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[text], images=[image], padding=True, return_tensors="pt", ).to("cuda") with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=32, do_sample=False, ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, )[0].strip() coord = parse_coordinate(output_text) if coord is not None: norm_x, norm_y = coord # Scale from [0, 1000] to actual image dimensions px = int(norm_x / COORD_SCALE * w) py = int(norm_y / COORD_SCALE * h) # Clamp to image bounds px = max(0, min(px, w - 1)) py = max(0, min(py, h - 1)) annotated = draw_marker(image, px, py) coord_str = f"({px}, {py}) [normalized: [{norm_x},{norm_y}]]" return annotated, coord_str, output_text else: return image, "Could not parse coordinates from model output.", output_text CSS = """ #col-container { max-width: 1100px; margin: 0 auto; } .dark .gradio-container { color: var(--body-text-color); } """ with gr.Blocks() as demo: gr.Markdown( "# 🎯 VISTA: GUI Grounding\n" "Upload a GUI screenshot and describe the element you want to click. " "VISTA-9B predicts the click coordinate and marks it on the image.\n\n" "Based on [VISTA: View-Consistent Self-Verified Training for GUI Grounding](https://arxiv.org/abs/2606.14579) | " "[Model Card](https://huggingface.co/inclusionAI/VISTA-9B) | [GitHub](https://github.com/ZJUSCL/VISTA)" ) with gr.Row(elem_id="col-container"): with gr.Column(scale=1): input_img = gr.Image(label="GUI Screenshot", type="pil") instruction = gr.Textbox( label="Instruction", placeholder="e.g. click the search button", lines=2, ) run_btn = gr.Button("Predict Coordinate", variant="primary") with gr.Column(scale=1): output_img = gr.Image(label="Predicted Click Location") coord_text = gr.Textbox(label="Predicted Coordinate (x, y)", interactive=False) with gr.Accordion("Raw Model Output", open=False): raw_output = gr.Textbox(label="Model Response", lines=4, interactive=False) gr.Examples( examples=[ ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "click the search box"], ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"], ], inputs=[input_img, instruction], outputs=[output_img, coord_text, raw_output], fn=predict, cache_examples=True, cache_mode="lazy", ) run_btn.click( fn=predict, inputs=[input_img, instruction], outputs=[output_img, coord_text, raw_output], ) if __name__ == "__main__": demo.launch(mcp_server=True, theme=gr.themes.Citrus(), css=CSS)