"""Model loading and inference for LEGATO OMR.""" import os import spaces import torch from legato.models import LegatoModel from transformers import AutoProcessor, GenerationConfig from image_utils import pad_to_portrait_letter hf_token = os.getenv("HF_TOKEN") MODEL_ID = "guangyangmusic/legato" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(MODEL_ID, token=hf_token) model = LegatoModel.from_pretrained(MODEL_ID, token=hf_token, trust_remote_code=True).to(device) if device == "cuda": model = model.half() gen_config = GenerationConfig(max_length=2048, num_beams=10, repetition_penalty=1.1) @spaces.GPU def inference(image): if not image: return "" image = pad_to_portrait_letter(image) inputs = processor(images=[image], truncation=True, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, generation_config=gen_config, use_model_defaults=False ) return processor.batch_decode(outputs, skip_special_tokens=True)[0].replace( "<|text|>", "text" )