guangyangmusic commited on
Commit
aa151d2
·
1 Parent(s): cd17632

chore: reorganize codebase

Browse files
Files changed (5) hide show
  1. abc_utils.py +68 -0
  2. app.py +10 -126
  3. config.py +19 -0
  4. image_utils.py +16 -0
  5. inference.py +35 -0
abc_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ABC notation utilities: MusicXML conversion and HTML visualization."""
2
+ import html as html_module
3
+ import json
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+
8
+ from config import ABC2XML_PATH, APP_DIR
9
+
10
+ def abc_to_musicxml_file(abc: str):
11
+ """Convert ABC to MusicXML using abc2xml.py; return file path for download or None."""
12
+ if not (abc or "").strip():
13
+ return None
14
+ try:
15
+ result = subprocess.run(
16
+ [os.environ.get("PYTHON", "python"), ABC2XML_PATH, "-"],
17
+ input=(abc or "").strip().encode("utf-8"),
18
+ capture_output=True,
19
+ cwd=APP_DIR,
20
+ timeout=30,
21
+ )
22
+ if result.returncode != 0 or not result.stdout:
23
+ return None
24
+ xml_bytes = result.stdout
25
+ if isinstance(xml_bytes, bytes):
26
+ xml_str = xml_bytes.decode("utf-8", errors="replace")
27
+ else:
28
+ xml_str = xml_bytes
29
+ tmpdir = tempfile.mkdtemp(prefix="musicxml_")
30
+ score_path = os.path.join(tmpdir, "score.musicxml")
31
+ try:
32
+ with open(score_path, "w", encoding="utf-8") as f:
33
+ f.write(xml_str)
34
+ return score_path
35
+ except Exception:
36
+ try:
37
+ os.unlink(score_path)
38
+ except Exception:
39
+ pass
40
+ try:
41
+ os.rmdir(tmpdir)
42
+ except Exception:
43
+ pass
44
+ return None
45
+ except Exception:
46
+ return None
47
+
48
+
49
+ def abc_viz_html(abc: str) -> str:
50
+ """Generate HTML with ABCJS for rendering ABC notation in Gradio."""
51
+ viz_abc = abc or ""
52
+ data_attr = html_module.escape(json.dumps(viz_abc), quote=True)
53
+ # Gradio strips <script> in gr.HTML; use iframe srcdoc so ABCJS runs inside the frame.
54
+ inner = (
55
+ '<!DOCTYPE html><html><head><meta charset="utf-8">'
56
+ '<style>body{overflow:auto;margin:0;} #abc-viz{width:100%;}</style></head><body>'
57
+ '<div id="abc-viz" data-abc="' + data_attr + '"></div>'
58
+ '<script src="https://cdnjs.cloudflare.com/ajax/libs/abcjs/6.4.0/abcjs-basic-min.js"><\x2fscript>'
59
+ '<script>'
60
+ '(function(){ var el=document.getElementById("abc-viz"); if(!el) return; '
61
+ 'var run=function(){ try { var abc=JSON.parse(el.getAttribute("data-abc")); '
62
+ 'if(typeof ABCJS!=="undefined"&&abc) ABCJS.renderAbc("abc-viz",abc,{responsive:"resize"}); } catch(e){ el.innerHTML="<span>Invalid ABC</span>"; } }; '
63
+ 'if(typeof ABCJS!=="undefined") run(); else { var s=document.createElement("script"); '
64
+ 's.src="https://cdnjs.cloudflare.com/ajax/libs/abcjs/6.4.0/abcjs-basic-min.js"; s.onload=run; document.head.appendChild(s); } })();'
65
+ '<\x2fscript></body></html>'
66
+ )
67
+ srcdoc_escaped = inner.replace("&", "&amp;").replace('"', "&quot;")
68
+ return '<iframe sandbox="allow-scripts" title="ABC notation" style="width:100%;height:60vh;max-height:400px;display:block;" srcdoc="' + srcdoc_escaped + '"></iframe>'
app.py CHANGED
@@ -1,125 +1,9 @@
1
- import spaces
2
  import gradio as gr
3
- from legato.models import *
4
- from transformers import AutoProcessor, GenerationConfig
5
- import torch
6
- import os
7
- import html as html_module
8
- import json
9
- import subprocess
10
- import tempfile
11
- from PIL import Image
12
-
13
- _APP_DIR = os.path.dirname(os.path.abspath(__file__))
14
- _ABC2XML = os.path.join(_APP_DIR, "abc2xml.py")
15
-
16
- BIBTEX = """@misc{yang2025legatolargescaleendtoendgeneralizable,
17
- title={LEGATO: Large-scale End-to-end Generalizable Approach to Typeset OMR},
18
- author={Guang Yang and Victoria Ebert and Nazif Tamer and Brian Siyuan Zheng and Luiza Pozzobon and Noah A. Smith},
19
- year={2025},
20
- eprint={2506.19065},
21
- archivePrefix={arXiv},
22
- primaryClass={cs.CV},
23
- url={https://arxiv.org/abs/2506.19065},
24
- }"""
25
- # Portrait letter aspect: 8.5" × 11" → width/height
26
- LETTER_ASPECT = 8.5 / 11
27
-
28
- def _pad_to_portrait_letter(pil_image: Image.Image) -> Image.Image:
29
- """If aspect ratio is narrower than letter, pad at the bottom to match letter aspect."""
30
- w, h = pil_image.size
31
- if w / h < LETTER_ASPECT:
32
- return pil_image
33
- new_h = int(round(w / LETTER_ASPECT))
34
- canvas = Image.new("RGB", (w, new_h), (255, 255, 255))
35
- if pil_image.mode != "RGB":
36
- pil_image = pil_image.convert("RGB")
37
- canvas.paste(pil_image, (0, 0))
38
- return canvas
39
-
40
- hf_token = os.getenv("HF_TOKEN")
41
-
42
- model_id = "guangyangmusic/legato"
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
-
45
- processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
46
- model = LegatoModel.from_pretrained(model_id, token=hf_token, trust_remote_code=True).to(device)
47
-
48
- if device == "cuda":
49
- model = model.half()
50
-
51
- gen_config = GenerationConfig(max_length=2048, num_beams=10, repetition_penalty=1.1)
52
-
53
-
54
- def _abc_to_musicxml_file(abc: str):
55
- """Convert ABC to MusicXML using abc2xml.py -; return file path for download or None."""
56
- if not (abc or "").strip():
57
- return None
58
- try:
59
- result = subprocess.run(
60
- [os.environ.get("PYTHON", "python"), _ABC2XML, "-"],
61
- input=(abc or "").strip().encode("utf-8"),
62
- capture_output=True,
63
- cwd=_APP_DIR,
64
- timeout=30,
65
- )
66
- if result.returncode != 0 or not result.stdout:
67
- return None
68
- xml_bytes = result.stdout
69
- if isinstance(xml_bytes, bytes):
70
- xml_str = xml_bytes.decode("utf-8", errors="replace")
71
- else:
72
- xml_str = xml_bytes
73
- tmpdir = tempfile.mkdtemp(prefix="musicxml_")
74
- score_path = os.path.join(tmpdir, "score.musicxml")
75
- try:
76
- with open(score_path, "w", encoding="utf-8") as f:
77
- f.write(xml_str)
78
- return score_path
79
- except Exception:
80
- try:
81
- os.unlink(score_path)
82
- except Exception:
83
- pass
84
- try:
85
- os.rmdir(tmpdir)
86
- except Exception:
87
- pass
88
- return None
89
- except Exception:
90
- return None
91
-
92
-
93
- def _abc_viz_html(abc: str) -> str:
94
- viz_abc = abc or ""
95
- data_attr = html_module.escape(json.dumps(viz_abc), quote=True)
96
- # Gradio strips <script> in gr.HTML; use iframe srcdoc so ABCJS runs inside the frame.
97
- inner = (
98
- '<!DOCTYPE html><html><head><meta charset="utf-8">'
99
- '<style>body{overflow:auto;margin:0;} #abc-viz{width:100%;}</style></head><body>'
100
- '<div id="abc-viz" data-abc="' + data_attr + '"></div>'
101
- '<script src="https://cdnjs.cloudflare.com/ajax/libs/abcjs/6.4.0/abcjs-basic-min.js"><\x2fscript>'
102
- '<script>'
103
- '(function(){ var el=document.getElementById("abc-viz"); if(!el) return; '
104
- 'var run=function(){ try { var abc=JSON.parse(el.getAttribute("data-abc")); '
105
- 'if(typeof ABCJS!=="undefined"&&abc) ABCJS.renderAbc("abc-viz",abc,{responsive:"resize"}); } catch(e){ el.innerHTML="<span>Invalid ABC</span>"; } }; '
106
- 'if(typeof ABCJS!=="undefined") run(); else { var s=document.createElement("script"); '
107
- 's.src="https://cdnjs.cloudflare.com/ajax/libs/abcjs/6.4.0/abcjs-basic-min.js"; s.onload=run; document.head.appendChild(s); } })();'
108
- '<\x2fscript></body></html>'
109
- )
110
- srcdoc_escaped = inner.replace("&", "&amp;").replace('"', "&quot;")
111
- return '<iframe sandbox="allow-scripts" title="ABC notation" style="width:100%;height:60vh;max-height:400px;display:block;" srcdoc="' + srcdoc_escaped + '"></iframe>'
112
-
113
-
114
- @spaces.GPU
115
- def inference(image):
116
- if not image: return ""
117
- image = _pad_to_portrait_letter(image)
118
- inputs = processor(images=[image], truncation=True, return_tensors='pt').to(device)
119
- with torch.no_grad():
120
- outputs = model.generate(**inputs, generation_config=gen_config, use_model_defaults=False)
121
- return processor.batch_decode(outputs, skip_special_tokens=True)[0].replace("<|text|>", "text")
122
 
 
 
 
123
 
124
  with gr.Blocks(theme=gr.themes.Soft(), title="LEGATO OMR Demo") as demo:
125
  gr.Markdown("""
@@ -152,25 +36,25 @@ with gr.Blocks(theme=gr.themes.Soft(), title="LEGATO OMR Demo") as demo:
152
  with gr.Row():
153
  out = gr.Textbox(label="📝 ABC transcription", lines=10, buttons=["copy"])
154
  with gr.Accordion("🎵 Rendered ABC notation", open=True):
155
- html_viz = gr.HTML(label=None, value=_abc_viz_html(""))
156
  with gr.Row():
157
  btn = gr.Button("▶️ Run LEGATO")
158
  dl_musicxml = gr.DownloadButton("⬇️ Download MusicXML", variant="secondary")
159
  btn.click(inference, inp, [out])
160
- out.change(lambda x: _abc_viz_html(x or ""), inputs=[out], outputs=[html_viz])
161
  dl_musicxml.click(
162
- _abc_to_musicxml_file,
163
  inputs=[out],
164
  outputs=[dl_musicxml],
165
  )
166
 
167
  gr.Markdown("---")
168
  gr.Textbox(
169
- value=BIBTEX,
170
  label="Citation (BibTeX)",
171
- lines=9,
172
  interactive=False,
173
  buttons=["copy"],
174
  )
175
 
176
- demo.launch()
 
1
+ import spaces # Must be before any CUDA/torch imports
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import abc_utils
5
+ import config
6
+ from inference import inference
7
 
8
  with gr.Blocks(theme=gr.themes.Soft(), title="LEGATO OMR Demo") as demo:
9
  gr.Markdown("""
 
36
  with gr.Row():
37
  out = gr.Textbox(label="📝 ABC transcription", lines=10, buttons=["copy"])
38
  with gr.Accordion("🎵 Rendered ABC notation", open=True):
39
+ html_viz = gr.HTML(label=None, value=abc_utils.abc_viz_html(""))
40
  with gr.Row():
41
  btn = gr.Button("▶️ Run LEGATO")
42
  dl_musicxml = gr.DownloadButton("⬇️ Download MusicXML", variant="secondary")
43
  btn.click(inference, inp, [out])
44
+ out.change(lambda x: abc_utils.abc_viz_html(x or ""), inputs=[out], outputs=[html_viz])
45
  dl_musicxml.click(
46
+ abc_utils.abc_to_musicxml_file,
47
  inputs=[out],
48
  outputs=[dl_musicxml],
49
  )
50
 
51
  gr.Markdown("---")
52
  gr.Textbox(
53
+ value=config.BIBTEX,
54
  label="Citation (BibTeX)",
55
+ lines=8,
56
  interactive=False,
57
  buttons=["copy"],
58
  )
59
 
60
+ demo.launch()
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared configuration and constants for the LEGATO OMR app."""
2
+ import os
3
+ import torch
4
+
5
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
6
+ ABC2XML_PATH = os.path.join(APP_DIR, "abc2xml.py")
7
+
8
+ BIBTEX = """@misc{yang2025legatolargescaleendtoendgeneralizable,
9
+ title={LEGATO: Large-scale End-to-end Generalizable Approach to Typeset OMR},
10
+ author={Guang Yang and Victoria Ebert and Nazif Tamer and Brian Siyuan Zheng and Luiza Pozzobon and Noah A. Smith},
11
+ year={2025},
12
+ eprint={2506.19065},
13
+ archivePrefix={arXiv},
14
+ primaryClass={cs.CV},
15
+ url={https://arxiv.org/abs/2506.19065},
16
+ }"""
17
+
18
+ # Portrait letter aspect: 8.5" × 11" → width/height
19
+ LETTER_ASPECT = 8.5 / 11
image_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image preprocessing utilities for LEGATO OMR."""
2
+ from PIL import Image
3
+ from config import LETTER_ASPECT
4
+
5
+
6
+ def pad_to_portrait_letter(pil_image: Image.Image) -> Image.Image:
7
+ """If aspect ratio is wider than letter, pad at the bottom to match letter aspect."""
8
+ w, h = pil_image.size
9
+ if w / h < LETTER_ASPECT:
10
+ return pil_image
11
+ new_h = int(round(w / LETTER_ASPECT))
12
+ canvas = Image.new("RGB", (w, new_h), (255, 255, 255))
13
+ if pil_image.mode != "RGB":
14
+ pil_image = pil_image.convert("RGB")
15
+ canvas.paste(pil_image, (0, 0))
16
+ return canvas
inference.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading and inference for LEGATO OMR."""
2
+ import os
3
+ import spaces
4
+ import torch
5
+ from legato.models import LegatoModel
6
+ from transformers import AutoProcessor, GenerationConfig
7
+
8
+ from image_utils import pad_to_portrait_letter
9
+
10
+ hf_token = os.getenv("HF_TOKEN")
11
+ MODEL_ID = "guangyangmusic/legato"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ processor = AutoProcessor.from_pretrained(MODEL_ID, token=hf_token)
15
+ model = LegatoModel.from_pretrained(MODEL_ID, token=hf_token, trust_remote_code=True).to(device)
16
+
17
+ if device == "cuda":
18
+ model = model.half()
19
+
20
+ gen_config = GenerationConfig(max_length=2048, num_beams=10, repetition_penalty=1.1)
21
+
22
+
23
+ @spaces.GPU
24
+ def inference(image):
25
+ if not image:
26
+ return ""
27
+ image = pad_to_portrait_letter(image)
28
+ inputs = processor(images=[image], truncation=True, return_tensors="pt").to(device)
29
+ with torch.no_grad():
30
+ outputs = model.generate(
31
+ **inputs, generation_config=gen_config, use_model_defaults=False
32
+ )
33
+ return processor.batch_decode(outputs, skip_special_tokens=True)[0].replace(
34
+ "<|text|>", "text"
35
+ )