smcleod commited on
Commit
6a1108c
·
verified ·
1 Parent(s): b688ed5

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ encoder.onnx_data filter=lfs diff=lfs merge=lfs -text
37
+ decode_step_int8.onnx_data filter=lfs diff=lfs merge=lfs -text
38
+ prompt_encode.onnx_data filter=lfs diff=lfs merge=lfs -text
39
+ encoder_int8.onnx_data filter=lfs diff=lfs merge=lfs -text
40
+ prompt_encode_int8.onnx_data filter=lfs diff=lfs merge=lfs -text
41
+ decode_step.onnx_data filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IBM Granite Speech 4.1 2b - ONNX export
2
+
3
+ ONNX export of [`ibm-granite/granite-speech-4.1-2b`](https://huggingface.co/ibm-granite/granite-speech-4.1-2b) produced by Sam McLeod
4
+ (<https://smcleod.net>). Repository: `smcleod/ibm-granite-speech-4.1-2b-onnx`. Both FP32 and INT8 weight-only
5
+ graphs are included. The graphs target opset 20, IR 10, `ai.onnx` operators
6
+ only - no `com.microsoft` ops - so they load under the `ort` 2.0-rc.x Rust
7
+ crate as well as standard `onnxruntime` 1.17 - 1.25.
8
+
9
+ > **Additional precision tiers in progress.** A statically-calibrated INT8 variant (better quality vs the dynamic INT8 already in this repo) and a half-precision encoder are in active development. The repo will be updated when those graphs pass the multi-clip parity gate.
10
+
11
+ Three graphs cooperate: `encoder.onnx` projects mel features to audio embeddings; `prompt_encode.onnx` runs the LLM forward over the full prompt (text tokens + projected audio embeds) and returns the first-token logits plus a 40-layer KV cache; `decode_step.onnx` consumes one token at a time plus the past KV cache and emits the next logits.
12
+
13
+ The audio placeholder token id is `100352`. Replace those positions in the prompt with the projector outputs from `encoder.onnx` before running `prompt_encode.onnx`.
14
+
15
+ ## Files
16
+
17
+ - `encoder.onnx` + `encoder.onnx_data` (FP32) and `encoder_int8.onnx` + `encoder_int8.onnx_data` (INT8 weight-only quantisation)
18
+ - `prompt_encode.onnx` + `prompt_encode.onnx_data` (FP32) and `prompt_encode_int8.onnx` + `prompt_encode_int8.onnx_data` (INT8 weight-only quantisation)
19
+ - `decode_step.onnx` + `decode_step.onnx_data` (FP32) and `decode_step_int8.onnx` + `decode_step_int8.onnx_data` (INT8 weight-only quantisation)
20
+ - Tokeniser / processor: `tokenizer.json`, `tokenizer_config.json`, `processor_config.json`, `chat_template.jinja`, `special_tokens_map.json`, `preprocessor_config.json`
21
+ - Export scripts: `export_speech_2b_ar.py`, `quantise.py`
22
+ - `granite_export_metadata.json` (graph IO, parity numbers, toolchain)
23
+ - `LICENSE` (Apache 2.0)
24
+
25
+ ## Parity
26
+
27
+ Parity is taken against the upstream PyTorch reference on a single LibriSpeech
28
+ clip (`10226_10111_000000.wav`, 8.43 seconds, 844 mel frames). FP32 graphs
29
+ match the reference within numeric tolerance; INT8 graphs are validated in
30
+ argmax-only mode (logit values shift but token argmax is preserved, so the
31
+ decoded transcript is unchanged).
32
+
33
+ Encoder (numeric output, no argmax decoding):
34
+
35
+ | precision | max-abs-err | mean-abs-err | p99-abs-err |
36
+ | --- | --- | --- | --- |
37
+ | FP32 | 4.48e-06 | 1.24e-07 | 6.46e-07 |
38
+ | INT8 | 0.169 | 0.0109 | 0.0447 |
39
+
40
+ LLM stages (argmax decoding; INT8 logit max-abs delta is large but argmax is preserved):
41
+
42
+ | graph | precision | max-abs-err | argmax mismatches | transcript match |
43
+ | --- | --- | --- | --- | --- |
44
+ | prompt_encode | FP32 | 0.000364 | 0/190 | Y |
45
+ | prompt_encode | INT8 | 10.1 | 58/190 | Y |
46
+ | decode_step | FP32 | n/a | 0/51 | Y |
47
+ | decode_step | INT8 | 5.76 | 0/51 | Y |
48
+
49
+ ### Multi-clip transcript parity
50
+
51
+ Three additional 16 kHz mono clips covering longer utterances (39 to 94 seconds), single and two-speaker conversational content. Word error rate (WER) and Levenshtein edit distance computed against the upstream PyTorch reference. Numbers measured end-to-end through the full ONNX pipeline (no PyTorch encoder fallback).
52
+
53
+ | Clip | Duration | FP32 byte-exact vs PT | INT8 byte-exact vs PT | INT8 WER vs PT | INT8 vs FP32 Lev |
54
+ | --- | ---: | :---: | :---: | ---: | ---: |
55
+ | is-it-more-wood | 46.9 s | Y | N | 1.4% | 2 |
56
+ | two-speakers-1 | 93.8 s | Y | N | 1.0% | 12 |
57
+ | two-speakers-2 | 38.8 s | Y | N | 23.5% | 26 |
58
+
59
+ Raw multi-clip data including full transcripts: see `granite_export_metadata.json` `multi_clip_parity` block.
60
+
61
+ Reference transcript:
62
+
63
+ > After his nap, Timothy lazily stretched, first one gray velvet foot, then another, strolled indolently to his plate, turning over the food, carefully selecting choice bits, nosing out that which he scorned upon the clean hearth
64
+
65
+ Both FP32 and INT8 paths reproduce this transcript exactly on the test clip.
66
+
67
+ ## Toolchain
68
+
69
+ - transformers 5.8.0
70
+ - torch 2.11.0
71
+ - onnx 1.21.0
72
+ - onnxruntime 1.25.1
73
+ - exporter: torch.onnx.export TorchScript path (dynamo=False)
74
+ - opset: 20 (`ai.onnx` only)
75
+ - IR version: 10
76
+ - external data layout: single `<stem>.onnx_data` sidecar per graph
77
+
78
+ ## Compatibility
79
+
80
+ Targeted at the [`ort`](https://crates.io/crates/ort) 2.0-rc.x Rust crate.
81
+ Compatible with `onnxruntime` Python 1.17 through 1.25. No `com.microsoft`
82
+ ops are used. Graphs were emitted via the TorchScript path
83
+ (`torch.onnx.export(..., dynamo=False)`); the dynamo exporter was deliberately
84
+ avoided because it injects `aten::*` ops `ort` does not understand.
85
+
86
+ ## Reproducing the export
87
+
88
+ The included scripts and `quantise.py` regenerate every artefact in this
89
+ bundle. From a checkout of <https://github.com/sammcj/granite-speech-4.1-onnx>:
90
+
91
+ ```bash
92
+ python export_speech_2b_ar.py \
93
+ --model-dir <path-to-ibm-granite/granite-speech-4.1-2b> \
94
+ --out-dir exports/granite-speech-4.1-2b
95
+ python quantise.py --input exports/granite-speech-4.1-2b/encoder.onnx --output exports/granite-speech-4.1-2b/encoder_int8.onnx
96
+ python quantise.py --input exports/granite-speech-4.1-2b/prompt_encode.onnx --output exports/granite-speech-4.1-2b/prompt_encode_int8.onnx
97
+ python quantise.py --input exports/granite-speech-4.1-2b/decode_step.onnx --output exports/granite-speech-4.1-2b/decode_step_int8.onnx
98
+ ```
99
+
100
+ Sandboxed environments may need:
101
+
102
+ ```bash
103
+ HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules <command above>
104
+ ```
105
+
106
+ ## Licence
107
+
108
+ Apache 2.0 for both the upstream IBM model and this ONNX export. See
109
+ [`LICENSE`](LICENSE) for the full text.
chat_template.jinja ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }}
2
+ ASSISTANT:{% elif message['role'] == 'assistant' %}{{ message['content'] }}{% endif %}{% endfor %}
decode_step.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bafe9470dc446bafb9499566ca22328251d352b222f290e039a3bbc54aa1baf7
3
+ size 1849786
decode_step.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87d924ecd71746694f43e653c9366827a9444ab9407e976f5cd9cc9dbde97608
3
+ size 6527008768
decode_step_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6ddf694deba562408d2ed39d9a1744c8015c5610d5e0afce63908293a1eac45
3
+ size 6426226
decode_step_int8.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70652d6a31cbae2d57c7e8cefb665f6c1ee503e495d191b951fff09ddb7f8608
3
+ size 1632249856
encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efe873c6d19468eda93d1751ba14615508e763312cac6112029914acec0f33a9
3
+ size 912937
encoder.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a39db4121859fade3cc6fef05dc9f7abc0af068d389ef8af7b57997eca0d2f43
3
+ size 1903334768
encoder_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1167991ae3b81a0aa0a0c35f0bed619dc8c3d52b5da72de0564fb708b0547070
3
+ size 2608070
encoder_int8.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd1f606ff9b9145636849b7f2dbbc972e3f4964852cdbf950254da4fc45132d5
3
+ size 787117424
export_speech_2b_ar.py ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Sam McLeod
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Export Granite Speech 4.1 2b (autoregressive variants) to three ONNX graphs.
15
+
16
+ Covers both `granite-speech-4.1-2b` (base) and `granite-speech-4.1-2b-plus`.
17
+ The two share architecture - Conformer encoder + Blip2 Q-Former projector +
18
+ Granite-4.0 1B causal LM with `logits_scaling=8` - and only differ in weights
19
+ and chat template. Pass `--model-dir` and `--baseline` to select the variant.
20
+ The NAR variant has a different topology and is exported by the
21
+ `export_nar_*.py` scripts instead.
22
+
23
+ Produces, under the configured `--out-dir`:
24
+ - encoder.onnx : Conformer CTC encoder + Blip2 Q-Former projector.
25
+ Input: input_features float32 [B, T, 160]
26
+ Output: audio_embeds float32 [B, T_audio, 2048]
27
+ audio_embed_sizes int64 [B] (per-sample valid lengths)
28
+ - prompt_encode.onnx : LLM prefill over a fully spliced inputs_embeds.
29
+ Inputs : inputs_embeds float32 [B, N, 2048]
30
+ position_ids int64 [B, N]
31
+ attention_mask float32 [B, 1, N, N] (additive)
32
+ Outputs: logits float32 [B, N, V] (divided by 8)
33
+ present.<L>.{key,value} for L in 0..39
34
+ - decode_step.onnx : Single-token decode with KV cache.
35
+ Inputs : inputs_embeds float32 [B, 1, 2048]
36
+ position_ids int64 [B, 1]
37
+ attention_mask float32 [B, 1, 1, T_total] (additive)
38
+ past_key_values.<L>.{key,value} for L in 0..39
39
+ Outputs: logits float32 [B, 1, V] (divided by 8)
40
+ present.<L>.{key,value}
41
+
42
+ The base/plus projector is `Blip2QFormerModel`, not the NAR custom projector.
43
+ Q-Former self-attention is plain matmul-softmax already (Bert-style); only the
44
+ Conformer encoder's SDPA + `if remainder > 0` guard need rewriting for clean
45
+ tracing.
46
+
47
+ Both LLM graphs apply `logits / config.text_config.logits_scaling` (=8). This
48
+ matches `GraniteForCausalLM.forward`, which the reference autoregressive path
49
+ goes through. Without it, ONNX logits are 8x the PyTorch reference even though
50
+ argmax is preserved, which trips strict numeric parity bars.
51
+
52
+ Usage:
53
+ # Base 2b (defaults):
54
+ HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules \\
55
+ uv run python src/export_speech_2b_ar.py
56
+
57
+ # Plus 2b:
58
+ HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules \\
59
+ uv run python src/export_speech_2b_ar.py \\
60
+ --model-dir models/granite-speech-4.1-2b-plus \\
61
+ --baseline test_data/baselines/plus.json \\
62
+ --out-dir exports/granite-speech-4.1-2b-plus
63
+
64
+ # Just one stage:
65
+ uv run python src/export_speech_2b_ar.py --stages encoder
66
+ uv run python src/export_speech_2b_ar.py --stages prompt,decode --skip-export
67
+ """
68
+
69
+ from __future__ import annotations
70
+
71
+ import argparse
72
+ import json
73
+ import os
74
+ import tempfile
75
+ import time
76
+ from pathlib import Path
77
+ from typing import Any
78
+
79
+ import numpy as np
80
+ import soundfile as sf
81
+ import torch
82
+ import torch.nn as nn
83
+ import torch.nn.functional as F
84
+
85
+
86
+ # Resolve roots so the script works whether it lives at <repo>/src/<name>.py
87
+ # (project layout) or <bundle>/<name>.py (HF bundle layout). Defaults exist for
88
+ # the project layout; bundle users should pass explicit --audio / --baseline /
89
+ # --model-dir / --out-dir.
90
+ SCRIPT_DIR = Path(__file__).resolve().parent
91
+ REPO_ROOT = SCRIPT_DIR.parent if SCRIPT_DIR.name == "src" else SCRIPT_DIR
92
+ DEFAULT_AUDIO = REPO_ROOT / "test_data" / "10226_10111_000000.wav"
93
+ DEFAULT_BASELINE = REPO_ROOT / "test_data" / "baselines" / "base.json"
94
+ DEFAULT_MODEL_DIR = REPO_ROOT / "models" / "granite-speech-4.1-2b"
95
+ DEFAULT_OUT_DIR = REPO_ROOT / "exports" / "granite-speech-4.1-2b"
96
+
97
+ USER_PROMPT_TRANSCRIBE = (
98
+ "<|audio|>transcribe the speech with proper punctuation and capitalization."
99
+ )
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Utilities.
104
+ # ---------------------------------------------------------------------------
105
+
106
+
107
+ def load_audio(path: Path) -> np.ndarray:
108
+ waveform, sr = sf.read(str(path), dtype="float32")
109
+ if waveform.ndim > 1:
110
+ waveform = waveform.mean(axis=1)
111
+ assert sr == 16000, f"expected 16 kHz, got {sr}"
112
+ return waveform
113
+
114
+
115
+ def tensor_stats(t: torch.Tensor | np.ndarray | None) -> dict[str, Any] | None:
116
+ if t is None:
117
+ return None
118
+ if isinstance(t, torch.Tensor):
119
+ x = t.detach().float().cpu().numpy()
120
+ dtype_str = str(t.dtype).replace("torch.", "")
121
+ else:
122
+ x = np.asarray(t).astype(np.float32, copy=False)
123
+ dtype_str = str(t.dtype)
124
+ flat = x.flatten()
125
+ return {
126
+ "shape": list(x.shape),
127
+ "dtype": dtype_str,
128
+ "mean": float(flat.mean()) if flat.size else None,
129
+ "std": float(flat.std()) if flat.size else None,
130
+ "min": float(flat.min()) if flat.size else None,
131
+ "max": float(flat.max()) if flat.size else None,
132
+ "first10": [float(v) for v in flat[:10]],
133
+ }
134
+
135
+
136
+ def _resave_single_sidecar(scratch_path: Path, out_path: Path, ir_version: int) -> None:
137
+ """Stage 2 of every export: re-save with one external-data sidecar in the
138
+ final location so we end up with exactly two artefacts on disk."""
139
+ import onnx
140
+
141
+ print(" stage-2: re-saving with single .onnx_data sidecar + ir bump")
142
+ model_proto = onnx.load(str(scratch_path), load_external_data=True)
143
+ if model_proto.ir_version < ir_version:
144
+ model_proto.ir_version = ir_version
145
+
146
+ for tensor in model_proto.graph.initializer:
147
+ tensor.ClearField("data_location")
148
+ tensor.ClearField("external_data")
149
+
150
+ sidecar_name = out_path.name + "_data"
151
+ if (out_path.parent / sidecar_name).exists():
152
+ (out_path.parent / sidecar_name).unlink()
153
+ if out_path.exists():
154
+ out_path.unlink()
155
+
156
+ onnx.save_model(
157
+ model_proto,
158
+ str(out_path),
159
+ save_as_external_data=True,
160
+ all_tensors_to_one_file=True,
161
+ location=sidecar_name,
162
+ size_threshold=1024,
163
+ convert_attribute=False,
164
+ )
165
+ onnx.checker.check_model(str(out_path), full_check=False)
166
+ domains = sorted({n.domain for n in model_proto.graph.node})
167
+ print(f" saved {out_path} (+ {sidecar_name}) node-domains={domains}")
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Model loading (mirrors capture_baselines.py::capture_base_or_plus).
172
+ # ---------------------------------------------------------------------------
173
+
174
+
175
+ def load_base_model(model_dir: Path) -> tuple[nn.Module, Any]:
176
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
177
+
178
+ print(f" loading processor from {model_dir}")
179
+ processor = AutoProcessor.from_pretrained(str(model_dir))
180
+
181
+ print(f" loading model from {model_dir} (eager, fp32)")
182
+ t0 = time.time()
183
+ # Blip2QFormerModel does not support SDPA in transformers 5.8; eager is mandatory.
184
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
185
+ str(model_dir), torch_dtype=torch.float32, attn_implementation="eager",
186
+ )
187
+ model.eval()
188
+ # The nested text_config / encoder_config / projector_config can carry
189
+ # `dtype: bfloat16`; force fp32 across the whole module tree.
190
+ model = model.to(torch.float32)
191
+ print(f" loaded in {time.time() - t0:.1f}s")
192
+ return model, processor
193
+
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # Trace-friendly monkey-patches for the Conformer encoder.
197
+ # ---------------------------------------------------------------------------
198
+
199
+
200
+ def patch_conformer_for_tracing(model: nn.Module) -> None:
201
+ """Rewrite the in-tree GraniteSpeechConformerAttention.forward so it traces:
202
+ - SDPA -> plain matmul/softmax.
203
+ - `if remainder > 0` guard -> always-pad by `(-num_features) % context_size`.
204
+
205
+ Blip2QFormerMultiHeadAttention is already plain matmul-softmax (Bert-style),
206
+ so no rewrite is needed for the projector's self-attention path. The
207
+ projector's outer reshape/pad math is handled separately by
208
+ patch_projector_for_tracing because it bakes T_audio into the graph if
209
+ left as upstream's `math.ceil(seq_len / window_size)` pattern.
210
+ """
211
+ encoder = model.encoder
212
+ attn0 = encoder.layers[0].attn
213
+ attn_cls = type(attn0)
214
+
215
+ def attn_forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor:
216
+ hidden_states = self.pre_norm(hidden_states)
217
+ bsz, num_features, _ = hidden_states.shape
218
+ # Always-pad: pad amount may be zero. Use modulo so the graph is valid
219
+ # for any T at runtime.
220
+ pad_amount = (-num_features) % self.context_size
221
+ num_blocks = (num_features + self.context_size - 1) // self.context_size
222
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_amount))
223
+
224
+ query_states = self.to_q(hidden_states)
225
+ key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
226
+
227
+ query_states = query_states.reshape(
228
+ bsz, num_blocks, self.context_size, self.num_heads, -1
229
+ ).transpose(2, 3)
230
+ key_states = key_states.reshape(
231
+ bsz, num_blocks, self.context_size, self.num_heads, -1
232
+ ).transpose(2, 3)
233
+ value_states = value_states.reshape(
234
+ bsz, num_blocks, self.context_size, self.num_heads, -1
235
+ ).transpose(2, 3)
236
+
237
+ # Shaw's relative positional embedding.
238
+ rel_pos_emb = self.rel_pos_emb(attention_dists)
239
+ # query_states: [B, M, H, C, D]; rel_pos_emb: [C, R, D]
240
+ # Output: [B, M, H, C, R]
241
+ pos_attn = torch.einsum(
242
+ "b m h c d, c r d -> b m h c r", query_states, rel_pos_emb
243
+ ) * self.scale
244
+
245
+ # Plain matmul attention with the additive `pos_attn` bias inside the
246
+ # softmax (matches the MATH SDPA backend numerically).
247
+ attn_logits = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scale
248
+ attn_logits = attn_logits + pos_attn
249
+ attn_weights = torch.softmax(attn_logits, dim=-1)
250
+ out = torch.matmul(attn_weights, value_states) # [B, M, H, C, D]
251
+
252
+ out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
253
+ out = self.to_out(out[:, :num_features, :])
254
+ return self.dropout(out)
255
+
256
+ attn_cls.forward = attn_forward
257
+
258
+
259
+ def patch_projector_for_tracing(model: nn.Module) -> None:
260
+ """Rewrite GraniteSpeechEncoderProjector.forward so the output time
261
+ dimension (T_audio = nblocks * num_queries) stays dynamic in the exported
262
+ graph.
263
+
264
+ The upstream forward bakes T_audio because:
265
+ 1. `seq_len = hidden_states.size(1)` is a Python int under TorchScript trace
266
+ 2. `math.ceil(seq_len / self.window_size)` is Python int math, baked
267
+ 3. The intermediate `.view(batch * nblocks, window_size, dim)` and final
268
+ `.view(batch, nblocks * window_size // downsample_rate, -1)` both
269
+ emit Reshape ops with a constant shape vector
270
+
271
+ The rewrite uses `torch._shape_as_tensor` for dynamic shape access, an
272
+ over-pad-then-tensor-slice idiom for the F.pad step, and `-1` for the
273
+ intermediate batch*nblocks dim and the final T_audio dim. Batch is still
274
+ baked at trace value (1) because reshape's target shape is a constant
275
+ vector and we don't support multi-batch inference; T_audio is the
276
+ audio-length-dependent dim that needs to be dynamic.
277
+ """
278
+ projector = model.projector
279
+ projector_cls = type(projector)
280
+
281
+ def projector_forward_traceable(self, hidden_states: torch.Tensor) -> torch.Tensor:
282
+ batch_size = hidden_states.shape[0] # static (B=1 at trace)
283
+ dim = hidden_states.shape[2] # static encoder hidden_dim
284
+ window_size = self.window_size
285
+
286
+ # Dynamic seq_len via Shape op (emitted by torch._shape_as_tensor):
287
+ shape_t = torch._shape_as_tensor(hidden_states)
288
+ seq_len_t = shape_t[1] # 0-d int64 Tensor
289
+
290
+ # nblocks * window_size = the padded length we want.
291
+ nblocks_t = (seq_len_t + window_size - 1) // window_size
292
+ final_len_t = nblocks_t * window_size # 0-d Tensor
293
+
294
+ # Statically pad by (window_size - 1), the maximum pad ever needed,
295
+ # then dynamically slice down to final_len_t. Avoids needing F.pad
296
+ # with a tensor pad amount (which doesn't trace cleanly).
297
+ hidden_states = nn.functional.pad(
298
+ hidden_states, (0, 0, 0, window_size - 1), "constant", 0.0
299
+ )
300
+ hidden_states = hidden_states[:, :final_len_t, :]
301
+
302
+ # [B, nblocks*window_size, dim] -> [B*nblocks, window_size, dim].
303
+ # `-1` lets ONNX infer batch*nblocks from numel at runtime.
304
+ hidden_states = hidden_states.reshape(-1, window_size, dim)
305
+
306
+ # Build an explicit all-ones encoder_attention_mask. Without this, the
307
+ # QFormer auto-creates one via `torch.ones(encoder_hidden_states.size())`
308
+ # which under tracing bakes batch*nblocks at the trace input's value.
309
+ # `torch.ones_like` on a slice that drops the hidden dim keeps the
310
+ # mask shape dynamic ([batch*nblocks, window_size]).
311
+ encoder_attention_mask = torch.ones_like(hidden_states[..., 0])
312
+
313
+ query_output = self.qformer(
314
+ query_embeds=self.query,
315
+ encoder_hidden_states=hidden_states,
316
+ encoder_attention_mask=encoder_attention_mask,
317
+ return_dict=True,
318
+ )
319
+ # qf_out: [B*nblocks, num_queries, qf_hidden]
320
+ qf_out = query_output.last_hidden_state
321
+ qf_hidden = qf_out.shape[-1] # static qformer hidden
322
+
323
+ # [B*nblocks, num_queries, qf_hidden] -> [B, T_audio, qf_hidden].
324
+ # B is baked at trace (1); T_audio (= nblocks*num_queries) is inferred.
325
+ qf_out = qf_out.reshape(batch_size, -1, qf_hidden)
326
+
327
+ return self.linear(qf_out)
328
+
329
+ projector_cls.forward = projector_forward_traceable
330
+
331
+
332
+ # ---------------------------------------------------------------------------
333
+ # Encoder + projector wrapper.
334
+ # ---------------------------------------------------------------------------
335
+
336
+
337
+ class EncoderProjectorWrapper(nn.Module):
338
+ """Wrap encoder + projector into one ONNX graph.
339
+
340
+ Inputs:
341
+ input_features: float32 [B, T, 160]
342
+
343
+ Outputs:
344
+ audio_embeds: float32 [B, T_audio, 2048] = projector(encoder(input_features))
345
+ audio_embed_sizes: int64 [B] - count of valid audio tokens per sample,
346
+ replicating the feature_extractor's projection-length math
347
+ on the static input shape.
348
+
349
+ Notes:
350
+ - The Conformer encoder itself does not consume an attention mask; the
351
+ feature extractor supplies a Python-int per-sample length, which is what
352
+ `audio_embed_sizes` reproduces here from the static input shape T.
353
+ Downstream Rust glue should compute the same size from the raw audio
354
+ length and slice `audio_embeds[:, :size, :]` for the splice.
355
+ - The projector output size is `nblocks * (window_size / downsample_rate)`.
356
+ With T=844 (the reference clip), this gives `ceil(844/15) * 3 = 171`,
357
+ which matches the captured PyTorch reference.
358
+ """
359
+
360
+ def __init__(self, encoder: nn.Module, projector: nn.Module, window_size: int, downsample_rate: int):
361
+ super().__init__()
362
+ self.encoder = encoder
363
+ self.projector = projector
364
+ self.window_size = int(window_size)
365
+ self.downsample_rate = int(downsample_rate)
366
+
367
+ def forward(self, input_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
368
+ enc_out = self.encoder(input_features, return_dict=True)
369
+ audio_embeds = self.projector(enc_out.last_hidden_state)
370
+ # Compute audio_embed_sizes dynamically so the value tracks T at runtime.
371
+ # `torch._shape_as_tensor` emits an ONNX Shape op so seq_len_t is a 0-d
372
+ # int64 Tensor rather than a baked Python int. The result is an int64
373
+ # tensor of shape [B] (B baked at 1, the only mode we trace; T_audio
374
+ # tracks runtime input length).
375
+ shape_t = torch._shape_as_tensor(input_features)
376
+ seq_len_t = shape_t[1]
377
+ num_queries = self.window_size // self.downsample_rate
378
+ nblocks_t = (seq_len_t + self.window_size - 1) // self.window_size
379
+ size_per_t = nblocks_t * num_queries # 0-d int64 Tensor
380
+ audio_embed_sizes = size_per_t.unsqueeze(0) # [1] tensor
381
+ return audio_embeds, audio_embed_sizes
382
+
383
+
384
+ # ---------------------------------------------------------------------------
385
+ # LLM wrappers (prompt_encode + decode_step). Adapted from
386
+ # src/export_granite_llm_kv.py to take inputs_embeds instead of input_ids.
387
+ # ---------------------------------------------------------------------------
388
+
389
+
390
+ def _build_causal_mask_4d(
391
+ attention_mask_2d: torch.Tensor,
392
+ T_past: int,
393
+ dtype: torch.dtype,
394
+ ) -> torch.Tensor:
395
+ """Build a 4-D additive attention mask `[B, 1, T_q, T_k]` from a 2-D padding
396
+ mask `[B, T_k]`. Padding columns are -inf, and the trailing T_q query rows
397
+ have an upper-triangular causal mask added.
398
+
399
+ The Granite eager-mask path early-exits when handed a 4-D mask, so this
400
+ short-circuits the v5 mask-helper crash under TorchScript trace.
401
+ """
402
+ B, T_k = attention_mask_2d.shape
403
+ T_q = T_k - T_past
404
+ neg_inf = torch.finfo(dtype).min
405
+
406
+ pad = (attention_mask_2d == 0).to(dtype) * neg_inf # [B, T_k]
407
+ pad = pad.view(B, 1, 1, T_k).expand(B, 1, T_q, T_k)
408
+
409
+ q_idx = torch.arange(T_q, device=attention_mask_2d.device).view(1, 1, T_q, 1)
410
+ k_idx = torch.arange(T_k, device=attention_mask_2d.device).view(1, 1, 1, T_k)
411
+ allowed = k_idx <= (q_idx + T_past)
412
+ causal = torch.where(
413
+ allowed,
414
+ torch.zeros((), dtype=dtype, device=attention_mask_2d.device),
415
+ torch.full((), neg_inf, dtype=dtype, device=attention_mask_2d.device),
416
+ )
417
+ return pad + causal
418
+
419
+
420
+ class PromptEncodeWrapper(nn.Module):
421
+ """Prefill graph; consumes pre-spliced inputs_embeds.
422
+
423
+ Forward signature (positional):
424
+ inputs_embeds: float32 [B, N, H]
425
+ position_ids: int64 [B, N]
426
+ attention_mask: float32 [B, 1, N, N] additive 4-D causal+padding mask
427
+
428
+ Outputs:
429
+ logits: float32 [B, N, V] (divided by logits_scaling)
430
+ present.<L>.key, present.<L>.value for L in 0..n_layers-1
431
+ """
432
+
433
+ def __init__(
434
+ self, llm_model: nn.Module, lm_head: nn.Module, num_layers: int, logits_scaling: float
435
+ ) -> None:
436
+ super().__init__()
437
+ self.llm_model = llm_model
438
+ self.lm_head = lm_head
439
+ self.num_layers = num_layers
440
+ self.logits_scaling = logits_scaling
441
+
442
+ def forward(
443
+ self,
444
+ inputs_embeds: torch.Tensor,
445
+ position_ids: torch.Tensor,
446
+ attention_mask: torch.Tensor,
447
+ ) -> tuple[torch.Tensor, ...]:
448
+ from transformers import DynamicCache
449
+
450
+ cache = DynamicCache()
451
+ out = self.llm_model(
452
+ inputs_embeds=inputs_embeds,
453
+ attention_mask=attention_mask,
454
+ position_ids=position_ids,
455
+ use_cache=True,
456
+ past_key_values=cache,
457
+ )
458
+ logits = self.lm_head(out.last_hidden_state) / self.logits_scaling
459
+ present = out.past_key_values
460
+ flat: list[torch.Tensor] = [logits]
461
+ for layer in present.layers:
462
+ flat.append(layer.keys)
463
+ flat.append(layer.values)
464
+ return tuple(flat)
465
+
466
+
467
+ class DecodeStepWrapper(nn.Module):
468
+ """Single-token decode graph.
469
+
470
+ Forward signature (positional):
471
+ inputs_embeds: float32 [B, 1, H]
472
+ position_ids: int64 [B, 1]
473
+ attention_mask: float32 [B, 1, 1, T_total] additive 4-D mask
474
+ past_kv_flat: 2*n_layers tensors, each float32
475
+ [B, num_kv_heads, T_past, head_dim], in the order
476
+ (past.0.key, past.0.value, past.1.key, ..., past.<L-1>.value)
477
+ """
478
+
479
+ def __init__(
480
+ self, llm_model: nn.Module, lm_head: nn.Module, num_layers: int, logits_scaling: float
481
+ ) -> None:
482
+ super().__init__()
483
+ self.llm_model = llm_model
484
+ self.lm_head = lm_head
485
+ self.num_layers = num_layers
486
+ self.logits_scaling = logits_scaling
487
+
488
+ def forward(
489
+ self,
490
+ inputs_embeds: torch.Tensor,
491
+ position_ids: torch.Tensor,
492
+ attention_mask: torch.Tensor,
493
+ *past_kv_flat: torch.Tensor,
494
+ ) -> tuple[torch.Tensor, ...]:
495
+ from transformers import DynamicCache
496
+
497
+ if len(past_kv_flat) != 2 * self.num_layers:
498
+ raise ValueError(
499
+ f"expected {2 * self.num_layers} past_kv tensors, got {len(past_kv_flat)}"
500
+ )
501
+ layer_pairs = [
502
+ (past_kv_flat[2 * i], past_kv_flat[2 * i + 1])
503
+ for i in range(self.num_layers)
504
+ ]
505
+ cache = DynamicCache(ddp_cache_data=layer_pairs)
506
+
507
+ out = self.llm_model(
508
+ inputs_embeds=inputs_embeds,
509
+ attention_mask=attention_mask,
510
+ position_ids=position_ids,
511
+ use_cache=True,
512
+ past_key_values=cache,
513
+ )
514
+ logits = self.lm_head(out.last_hidden_state) / self.logits_scaling
515
+ present = out.past_key_values
516
+ flat: list[torch.Tensor] = [logits]
517
+ for layer in present.layers:
518
+ flat.append(layer.keys)
519
+ flat.append(layer.values)
520
+ return tuple(flat)
521
+
522
+
523
+ # ---------------------------------------------------------------------------
524
+ # Export functions.
525
+ # ---------------------------------------------------------------------------
526
+
527
+
528
+ def export_encoder(
529
+ wrapper: EncoderProjectorWrapper,
530
+ sample_input_features: torch.Tensor,
531
+ out_path: Path,
532
+ opset: int = 20,
533
+ ir_version: int = 10,
534
+ ) -> None:
535
+ out_path.parent.mkdir(parents=True, exist_ok=True)
536
+ print(f" exporting encoder to {out_path} (opset={opset}, ir_version={ir_version})")
537
+
538
+ dynamic_axes = {
539
+ "input_features": {0: "B", 1: "T"},
540
+ "audio_embeds": {0: "B", 1: "T_audio"},
541
+ "audio_embed_sizes": {0: "B"},
542
+ }
543
+
544
+ with tempfile.TemporaryDirectory(prefix="speech2b_ar_encoder_onnx_") as scratch_dir:
545
+ scratch_path = Path(scratch_dir) / "encoder.onnx"
546
+ t0 = time.time()
547
+ torch.onnx.export(
548
+ wrapper,
549
+ (sample_input_features,),
550
+ str(scratch_path),
551
+ input_names=["input_features"],
552
+ output_names=["audio_embeds", "audio_embed_sizes"],
553
+ dynamic_axes=dynamic_axes,
554
+ opset_version=opset,
555
+ do_constant_folding=True,
556
+ export_params=True,
557
+ dynamo=False,
558
+ )
559
+ print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s")
560
+ _resave_single_sidecar(scratch_path, out_path, ir_version)
561
+
562
+
563
+ def export_prompt_encode(
564
+ wrapper: PromptEncodeWrapper,
565
+ sample_inputs_embeds: torch.Tensor,
566
+ sample_position_ids: torch.Tensor,
567
+ sample_attention_mask: torch.Tensor,
568
+ out_path: Path,
569
+ num_layers: int,
570
+ opset: int = 20,
571
+ ir_version: int = 10,
572
+ ) -> None:
573
+ out_path.parent.mkdir(parents=True, exist_ok=True)
574
+ print(f" exporting prompt_encode to {out_path} (opset={opset}, ir_version={ir_version})")
575
+
576
+ output_names: list[str] = ["logits"]
577
+ for i in range(num_layers):
578
+ output_names.append(f"present.{i}.key")
579
+ output_names.append(f"present.{i}.value")
580
+
581
+ dynamic_axes: dict[str, dict[int, str]] = {
582
+ "inputs_embeds": {0: "B", 1: "N"},
583
+ "position_ids": {0: "B", 1: "N"},
584
+ "attention_mask": {0: "B", 2: "N", 3: "N"},
585
+ "logits": {0: "B", 1: "N"},
586
+ }
587
+ for i in range(num_layers):
588
+ dynamic_axes[f"present.{i}.key"] = {0: "B", 2: "N"}
589
+ dynamic_axes[f"present.{i}.value"] = {0: "B", 2: "N"}
590
+
591
+ with tempfile.TemporaryDirectory(prefix="speech2b_ar_prompt_onnx_") as scratch_dir:
592
+ scratch_path = Path(scratch_dir) / "prompt_encode.onnx"
593
+ t0 = time.time()
594
+ torch.onnx.export(
595
+ wrapper,
596
+ (sample_inputs_embeds, sample_position_ids, sample_attention_mask),
597
+ str(scratch_path),
598
+ input_names=["inputs_embeds", "position_ids", "attention_mask"],
599
+ output_names=output_names,
600
+ dynamic_axes=dynamic_axes,
601
+ opset_version=opset,
602
+ do_constant_folding=True,
603
+ export_params=True,
604
+ dynamo=False,
605
+ )
606
+ print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s")
607
+ _resave_single_sidecar(scratch_path, out_path, ir_version)
608
+
609
+
610
+ def export_decode_step(
611
+ wrapper: DecodeStepWrapper,
612
+ sample_inputs_embeds: torch.Tensor,
613
+ sample_position_ids: torch.Tensor,
614
+ sample_attention_mask: torch.Tensor,
615
+ sample_past_kv_flat: tuple[torch.Tensor, ...],
616
+ out_path: Path,
617
+ num_layers: int,
618
+ opset: int = 20,
619
+ ir_version: int = 10,
620
+ ) -> None:
621
+ out_path.parent.mkdir(parents=True, exist_ok=True)
622
+ print(f" exporting decode_step to {out_path} (opset={opset}, ir_version={ir_version})")
623
+
624
+ input_names: list[str] = ["inputs_embeds", "position_ids", "attention_mask"]
625
+ for i in range(num_layers):
626
+ input_names.append(f"past_key_values.{i}.key")
627
+ input_names.append(f"past_key_values.{i}.value")
628
+
629
+ output_names: list[str] = ["logits"]
630
+ for i in range(num_layers):
631
+ output_names.append(f"present.{i}.key")
632
+ output_names.append(f"present.{i}.value")
633
+
634
+ dynamic_axes: dict[str, dict[int, str]] = {
635
+ "inputs_embeds": {0: "B"},
636
+ "position_ids": {0: "B"},
637
+ "attention_mask": {0: "B", 3: "T_total"},
638
+ "logits": {0: "B"},
639
+ }
640
+ for i in range(num_layers):
641
+ dynamic_axes[f"past_key_values.{i}.key"] = {0: "B", 2: "T_past"}
642
+ dynamic_axes[f"past_key_values.{i}.value"] = {0: "B", 2: "T_past"}
643
+ dynamic_axes[f"present.{i}.key"] = {0: "B", 2: "T_total"}
644
+ dynamic_axes[f"present.{i}.value"] = {0: "B", 2: "T_total"}
645
+
646
+ with tempfile.TemporaryDirectory(prefix="speech2b_ar_decode_onnx_") as scratch_dir:
647
+ scratch_path = Path(scratch_dir) / "decode_step.onnx"
648
+ args = (sample_inputs_embeds, sample_position_ids, sample_attention_mask, *sample_past_kv_flat)
649
+ t0 = time.time()
650
+ torch.onnx.export(
651
+ wrapper,
652
+ args,
653
+ str(scratch_path),
654
+ input_names=input_names,
655
+ output_names=output_names,
656
+ dynamic_axes=dynamic_axes,
657
+ opset_version=opset,
658
+ do_constant_folding=True,
659
+ export_params=True,
660
+ dynamo=False,
661
+ )
662
+ print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s")
663
+ _resave_single_sidecar(scratch_path, out_path, ir_version)
664
+
665
+
666
+ # ---------------------------------------------------------------------------
667
+ # Parity helpers.
668
+ # ---------------------------------------------------------------------------
669
+
670
+
671
+ def encoder_parity(
672
+ wrapper: EncoderProjectorWrapper,
673
+ processor: Any,
674
+ waveform: np.ndarray,
675
+ onnx_path: Path,
676
+ abs_tol: float,
677
+ argmax_only: bool = False,
678
+ ) -> dict[str, Any]:
679
+ import onnxruntime as ort
680
+
681
+ print("\n=== encoder parity ===")
682
+ inputs = processor(USER_PROMPT_TRANSCRIBE, [waveform], sampling_rate=16000, return_tensors="pt")
683
+ input_features = inputs["input_features"].to(torch.float32)
684
+ print(f" input_features: {tuple(input_features.shape)}")
685
+
686
+ print(" PyTorch wrapper forward")
687
+ t0 = time.time()
688
+ with torch.inference_mode():
689
+ audio_pt, sizes_pt = wrapper(input_features)
690
+ print(f" pt: {time.time() - t0:.2f}s audio_embeds={tuple(audio_pt.shape)}")
691
+
692
+ print(f" ONNX inference: {onnx_path}")
693
+ sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
694
+ t0 = time.time()
695
+ audio_ort, sizes_ort = sess.run(
696
+ ["audio_embeds", "audio_embed_sizes"],
697
+ {"input_features": input_features.numpy().astype(np.float32)},
698
+ )
699
+ print(f" ort: {time.time() - t0:.2f}s audio_embeds={tuple(audio_ort.shape)}")
700
+
701
+ pt_np = audio_pt.detach().float().cpu().numpy()
702
+ abs_err = np.abs(pt_np - audio_ort)
703
+ max_err = float(abs_err.max())
704
+ mean_err = float(abs_err.mean())
705
+ p99 = float(np.percentile(abs_err, 99))
706
+
707
+ sizes_pt_np = sizes_pt.detach().cpu().numpy().astype(np.int64)
708
+ sizes_ok = bool(np.array_equal(sizes_pt_np, sizes_ort.astype(np.int64)))
709
+
710
+ if argmax_only:
711
+ # The encoder's audio_embeds feed into the LLM, where the actual ship
712
+ # gate (transcript byte-exact, argmax stable) lives. The continuous
713
+ # audio_embeds delta is informational only in INT8 mode.
714
+ ok = sizes_ok
715
+ else:
716
+ ok = max_err <= abs_tol and sizes_ok
717
+
718
+ print(f" max_abs_err={max_err:.3e} mean={mean_err:.3e} p99={p99:.3e}")
719
+ print(f" audio_embed_sizes pt={sizes_pt_np.tolist()} ort={sizes_ort.tolist()} match={sizes_ok}")
720
+ print(f" encoder parity: {'PASS' if ok else 'FAIL'}{' (argmax-only)' if argmax_only else ''}")
721
+
722
+ return {
723
+ "ok": ok,
724
+ "abs_tol": abs_tol,
725
+ "argmax_only": argmax_only,
726
+ "max_abs_err": max_err,
727
+ "mean_abs_err": mean_err,
728
+ "p99_abs_err": p99,
729
+ "audio_embeds_shape_pt": list(pt_np.shape),
730
+ "audio_embeds_shape_ort": list(audio_ort.shape),
731
+ "audio_embed_sizes_pt": sizes_pt_np.tolist(),
732
+ "audio_embed_sizes_ort": sizes_ort.tolist(),
733
+ "audio_embed_sizes_match": sizes_ok,
734
+ "audio_embeds_stats_pt": tensor_stats(audio_pt),
735
+ "audio_embeds_stats_ort": tensor_stats(audio_ort),
736
+ }
737
+
738
+
739
+ def build_inputs_embeds(model: nn.Module, processor: Any, waveform: np.ndarray) -> tuple[torch.Tensor, torch.Tensor, dict]:
740
+ """Build the post-splice `inputs_embeds [1, N, 2048]` and `position_ids` for
741
+ parity, exactly mirroring the PyTorch path:
742
+ 1. Render the chat prompt with `<|audio|>` -> repeated audio token.
743
+ 2. Run encoder + projector to get audio embeds.
744
+ 3. masked_scatter audio embeds into the text embeddings at audio-token positions.
745
+ """
746
+ chat = [{"role": "user", "content": USER_PROMPT_TRANSCRIBE}]
747
+ rendered = processor.tokenizer.apply_chat_template(
748
+ chat, tokenize=False, add_generation_prompt=True
749
+ )
750
+ inputs = processor(rendered, [waveform], sampling_rate=16000, return_tensors="pt")
751
+ input_ids = inputs["input_ids"].to(torch.long)
752
+ input_features = inputs["input_features"].to(torch.float32)
753
+ input_features_mask = inputs["input_features_mask"]
754
+ print(f" prompt token ids shape={tuple(input_ids.shape)}")
755
+ print(f" input_features shape={tuple(input_features.shape)} input_features_mask shape={tuple(input_features_mask.shape)}")
756
+
757
+ with torch.inference_mode():
758
+ audio_outputs = model.get_audio_features(input_features, return_dict=True)
759
+ audio_embeds = audio_outputs.pooler_output
760
+ # The reference uses model.dtype; we forced fp32 at load.
761
+ inputs_embeds = model.get_merged_audio_embeddings(
762
+ input_ids=input_ids,
763
+ audio_features=audio_embeds,
764
+ input_features_mask=input_features_mask,
765
+ )
766
+ inputs_embeds = inputs_embeds.to(torch.float32)
767
+ N = inputs_embeds.shape[1]
768
+ position_ids = torch.arange(N, dtype=torch.long).unsqueeze(0).expand(1, N).contiguous()
769
+
770
+ info = {
771
+ "input_ids_shape": list(input_ids.shape),
772
+ "audio_embeds_shape": list(audio_embeds.shape),
773
+ "input_features_mask_shape": list(input_features_mask.shape),
774
+ "inputs_embeds_shape": list(inputs_embeds.shape),
775
+ "n_audio_tokens": int((input_ids == model.config.audio_token_id).sum().item()),
776
+ "input_features_mask_sum": int(input_features_mask.sum().item()),
777
+ }
778
+ print(f" inputs_embeds shape={tuple(inputs_embeds.shape)} audio_tokens={info['n_audio_tokens']}")
779
+ return inputs_embeds, position_ids, info
780
+
781
+
782
+ def llm_parity_e2e(
783
+ model: nn.Module,
784
+ processor: Any,
785
+ waveform: np.ndarray,
786
+ prompt_onnx: Path,
787
+ decode_onnx: Path,
788
+ baseline_json: Path,
789
+ max_new_tokens: int,
790
+ abs_tol: float,
791
+ argmax_only: bool = False,
792
+ ) -> dict[str, Any]:
793
+ """Greedy-decode end-to-end through the ONNX graphs and compare against
794
+ the captured PyTorch baseline transcript token-for-token.
795
+ """
796
+ import onnxruntime as ort
797
+
798
+ print("\n=== prompt_encode + decode_step end-to-end parity ===")
799
+ inputs_embeds, position_ids, embed_info = build_inputs_embeds(model, processor, waveform)
800
+ N = inputs_embeds.shape[1]
801
+
802
+ # Build the 4-D additive causal+pad mask Python-side.
803
+ attn_2d = torch.ones((1, N), dtype=torch.long)
804
+ attn_4d_prompt = _build_causal_mask_4d(attn_2d, T_past=0, dtype=torch.float32)
805
+
806
+ # Reference: PyTorch logits at the prompt's last position; expected first
807
+ # generated token == baseline new_token_ids[0].
808
+ print(" loading PyTorch reference path (lm_head / logits_scaling)")
809
+ with torch.inference_mode():
810
+ out = model.language_model(
811
+ inputs_embeds=inputs_embeds,
812
+ attention_mask=attn_4d_prompt,
813
+ position_ids=position_ids,
814
+ use_cache=True,
815
+ past_key_values=None,
816
+ )
817
+ # The language_model is a GraniteForCausalLM; out.logits is already
818
+ # divided by logits_scaling. Use it as the strict-parity reference.
819
+ pt_logits = out.logits.detach().float().cpu().numpy()
820
+ pt_past = out.past_key_values
821
+
822
+ print(f" pt prompt logits shape={pt_logits.shape} argmax_last={int(pt_logits[0, -1].argmax())}")
823
+
824
+ # ---- ONNX: prompt_encode ----
825
+ print(f" loading ONNX sessions")
826
+ so = ort.SessionOptions()
827
+ sess_prompt = ort.InferenceSession(str(prompt_onnx), so, providers=["CPUExecutionProvider"])
828
+ sess_decode = ort.InferenceSession(str(decode_onnx), so, providers=["CPUExecutionProvider"])
829
+
830
+ num_layers = len(model.language_model.model.layers)
831
+ feeds_prompt = {
832
+ "inputs_embeds": inputs_embeds.numpy().astype(np.float32),
833
+ "position_ids": position_ids.numpy().astype(np.int64),
834
+ "attention_mask": attn_4d_prompt.numpy().astype(np.float32),
835
+ }
836
+ print(" running prompt_encode.onnx")
837
+ t0 = time.time()
838
+ prompt_outs = sess_prompt.run(None, feeds_prompt)
839
+ print(f" forward: {time.time() - t0:.2f}s")
840
+ prompt_logits = prompt_outs[0]
841
+ past_kv_flat = list(prompt_outs[1:])
842
+ assert len(past_kv_flat) == 2 * num_layers
843
+
844
+ # Compare prompt-stage logits.
845
+ prompt_diff = np.abs(prompt_logits - pt_logits)
846
+ prompt_max_err = float(prompt_diff.max())
847
+ prompt_mean_err = float(prompt_diff.mean())
848
+ pt_argmax = pt_logits.argmax(-1)
849
+ ort_argmax = prompt_logits.argmax(-1)
850
+ prompt_argmax_mismatches = int((pt_argmax != ort_argmax).sum())
851
+ print(f" prompt logits max_abs_err={prompt_max_err:.3e} mean={prompt_mean_err:.3e} "
852
+ f"argmax_mismatches={prompt_argmax_mismatches}/{pt_argmax.size}")
853
+
854
+ # First generated token: argmax at the last prompt position (this is what
855
+ # GenerationMixin's greedy path does).
856
+ embed_tokens = model.language_model.model.embed_tokens
857
+ eos_id = int(model.config.text_config.eos_token_id)
858
+
859
+ onnx_new_tokens: list[int] = [int(prompt_logits[0, -1].argmax())]
860
+ onnx_step_logits: list[np.ndarray] = [prompt_logits[0, -1].astype(np.float32)]
861
+
862
+ print(f" greedy-decoding up to {max_new_tokens} new tokens through decode_step.onnx")
863
+ t0 = time.time()
864
+ for step in range(1, max_new_tokens):
865
+ prev_tok = onnx_new_tokens[-1]
866
+ if prev_tok == eos_id:
867
+ break
868
+ T_past = N + step - 1
869
+ T_total = T_past + 1
870
+
871
+ # Build the next inputs_embeds via the model's embed_tokens.
872
+ prev_id_tensor = torch.tensor([[prev_tok]], dtype=torch.long)
873
+ with torch.inference_mode():
874
+ next_embed = embed_tokens(prev_id_tensor).to(torch.float32)
875
+
876
+ # 4-D additive mask of zeros for unmasked positions; padding is irrelevant
877
+ # because attention_mask_2d is all-ones throughout the decode loop.
878
+ attn_2d_step = torch.ones((1, T_total), dtype=torch.long)
879
+ attn_4d_step = _build_causal_mask_4d(attn_2d_step, T_past=T_past, dtype=torch.float32)
880
+
881
+ feeds: dict[str, np.ndarray] = {
882
+ "inputs_embeds": next_embed.numpy().astype(np.float32),
883
+ "position_ids": np.array([[T_past]], dtype=np.int64),
884
+ "attention_mask": attn_4d_step.numpy().astype(np.float32),
885
+ }
886
+ for i in range(num_layers):
887
+ feeds[f"past_key_values.{i}.key"] = past_kv_flat[2 * i]
888
+ feeds[f"past_key_values.{i}.value"] = past_kv_flat[2 * i + 1]
889
+
890
+ outs = sess_decode.run(None, feeds)
891
+ step_logits = outs[0]
892
+ new_past = list(outs[1:])
893
+ assert len(new_past) == 2 * num_layers
894
+ past_kv_flat = new_past
895
+
896
+ nt = int(step_logits[0, 0].argmax())
897
+ onnx_step_logits.append(step_logits[0, 0].astype(np.float32))
898
+ onnx_new_tokens.append(nt)
899
+
900
+ print(f" {len(onnx_new_tokens) - 1} decode_step forwards: {time.time() - t0:.2f}s")
901
+ onnx_transcript = processor.tokenizer.decode(
902
+ [t for t in onnx_new_tokens if t != eos_id], skip_special_tokens=True
903
+ )
904
+ print(f" onnx new tokens: {onnx_new_tokens}")
905
+ print(f" onnx transcript: {onnx_transcript!r}")
906
+
907
+ baseline = json.loads(baseline_json.read_text())
908
+ baseline_tokens = baseline["new_token_ids"]
909
+ baseline_transcript = baseline["transcript"]
910
+ tokens_match = onnx_new_tokens == baseline_tokens
911
+ transcript_match = onnx_transcript == baseline_transcript
912
+ print(f" baseline transcript: {baseline_transcript!r}")
913
+ print(f" tokens match: {tokens_match} transcript match: {transcript_match}")
914
+
915
+ # Per-step parity vs PyTorch reference for the first 5 steps.
916
+ per_step_compare: list[dict[str, Any]] = []
917
+ pt_step_logits = None
918
+ if max_new_tokens >= 1:
919
+ # Recompute PyTorch reference logits per step via model.generate, to
920
+ # avoid having to maintain an alternate decode loop here.
921
+ with torch.inference_mode():
922
+ chat = [{"role": "user", "content": USER_PROMPT_TRANSCRIBE}]
923
+ rendered = processor.tokenizer.apply_chat_template(
924
+ chat, tokenize=False, add_generation_prompt=True
925
+ )
926
+ ref_inputs = processor(rendered, [waveform], sampling_rate=16000, return_tensors="pt")
927
+ gen = model.generate(
928
+ **ref_inputs,
929
+ max_new_tokens=max_new_tokens,
930
+ do_sample=False,
931
+ num_beams=1,
932
+ return_dict_in_generate=True,
933
+ output_scores=True,
934
+ )
935
+ pt_step_logits = [s[0].detach().float().cpu().numpy() for s in gen.scores]
936
+
937
+ n_compare = min(len(pt_step_logits), len(onnx_step_logits))
938
+ for i in range(n_compare):
939
+ ref = pt_step_logits[i].astype(np.float32)
940
+ ours = onnx_step_logits[i].astype(np.float32)
941
+ d = np.abs(ref - ours)
942
+ per_step_compare.append({
943
+ "step": i,
944
+ "ref_token": int(ref.argmax()),
945
+ "onnx_token": int(ours.argmax()),
946
+ "argmax_match": int(ref.argmax()) == int(ours.argmax()),
947
+ "max_abs_err": float(d.max()),
948
+ "mean_abs_err": float(d.mean()),
949
+ })
950
+
951
+ overall_max = max((s["max_abs_err"] for s in per_step_compare), default=0.0)
952
+ overall_argmax_mm = sum(0 if s["argmax_match"] else 1 for s in per_step_compare)
953
+
954
+ if argmax_only:
955
+ # INT8 ship gate: end-to-end transcript + decoded token IDs match the
956
+ # baseline exactly. Prompt-stage and per-step max-abs deltas vs FP32
957
+ # are recorded for reporting but not blocking.
958
+ ok = tokens_match and transcript_match
959
+ else:
960
+ ok = (
961
+ tokens_match
962
+ and transcript_match
963
+ and prompt_argmax_mismatches == 0
964
+ and overall_argmax_mm == 0
965
+ )
966
+
967
+ return {
968
+ "ok": ok,
969
+ "abs_tol": abs_tol,
970
+ "argmax_only": argmax_only,
971
+ "embed_info": embed_info,
972
+ "N_prompt": N,
973
+ "prompt_logits_max_abs_err": prompt_max_err,
974
+ "prompt_logits_mean_abs_err": prompt_mean_err,
975
+ "prompt_argmax_mismatches": prompt_argmax_mismatches,
976
+ "prompt_argmax_total": int(pt_argmax.size),
977
+ "onnx_new_tokens": onnx_new_tokens,
978
+ "baseline_new_tokens": baseline_tokens,
979
+ "tokens_match": tokens_match,
980
+ "onnx_transcript": onnx_transcript,
981
+ "baseline_transcript": baseline_transcript,
982
+ "transcript_match": transcript_match,
983
+ "per_step_compare": per_step_compare,
984
+ "overall_max_abs_err_step": overall_max,
985
+ "overall_argmax_mismatches_step": overall_argmax_mm,
986
+ }
987
+
988
+
989
+ # ---------------------------------------------------------------------------
990
+ # Main.
991
+ # ---------------------------------------------------------------------------
992
+
993
+
994
+ def main() -> None:
995
+ p = argparse.ArgumentParser()
996
+ p.add_argument("--audio", default=str(DEFAULT_AUDIO))
997
+ p.add_argument("--baseline", default=str(DEFAULT_BASELINE))
998
+ p.add_argument("--model-dir", default=str(DEFAULT_MODEL_DIR))
999
+ p.add_argument("--out-dir", default=str(DEFAULT_OUT_DIR))
1000
+ p.add_argument("--abs-tol", type=float, default=1e-3)
1001
+ p.add_argument(
1002
+ "--stages",
1003
+ default="encoder,prompt,decode",
1004
+ help="comma-separated subset of {encoder, prompt, decode}",
1005
+ )
1006
+ p.add_argument(
1007
+ "--skip-export", action="store_true", help="skip export, run parity on existing files"
1008
+ )
1009
+ p.add_argument("--max-new-tokens", type=int, default=80)
1010
+ p.add_argument(
1011
+ "--graph-suffix",
1012
+ default="",
1013
+ help="suffix appended to graph stems (e.g. '_int8') so parity runs against "
1014
+ "encoder<suffix>.onnx etc. Parity output goes to parity<suffix>.json. "
1015
+ "When set, --skip-export is implied.",
1016
+ )
1017
+ args = p.parse_args()
1018
+
1019
+ stages = {s.strip() for s in args.stages.split(",") if s.strip()}
1020
+ valid = {"encoder", "prompt", "decode"}
1021
+ bad = stages - valid
1022
+ if bad:
1023
+ raise SystemExit(f"unknown stage(s): {bad}; valid: {sorted(valid)}")
1024
+
1025
+ out_dir = Path(args.out_dir)
1026
+ suffix = args.graph_suffix
1027
+ if suffix and not args.skip_export:
1028
+ print(f" --graph-suffix={suffix!r} set; implying --skip-export")
1029
+ args.skip_export = True
1030
+ encoder_path = out_dir / f"encoder{suffix}.onnx"
1031
+ prompt_path = out_dir / f"prompt_encode{suffix}.onnx"
1032
+ decode_path = out_dir / f"decode_step{suffix}.onnx"
1033
+ parity_json = out_dir / f"parity{suffix}.json"
1034
+
1035
+ print(f"audio: {args.audio}")
1036
+ print(f"model_dir: {args.model_dir}")
1037
+ print(f"out_dir: {out_dir}")
1038
+ print(f"stages: {sorted(stages)}")
1039
+ waveform = load_audio(Path(args.audio))
1040
+ print(f" duration={waveform.shape[0] / 16000:.2f}s")
1041
+
1042
+ print("loading model...")
1043
+ model, processor = load_base_model(Path(args.model_dir))
1044
+
1045
+ print("patching conformer attention for tracing...")
1046
+ patch_conformer_for_tracing(model)
1047
+ print("patching projector for dynamic T_audio tracing...")
1048
+ patch_projector_for_tracing(model)
1049
+
1050
+ # Useful constants from the loaded config.
1051
+ text_cfg = model.config.text_config
1052
+ num_layers = int(text_cfg.num_hidden_layers)
1053
+ logits_scaling = float(text_cfg.logits_scaling)
1054
+ print(f" num_layers={num_layers} logits_scaling={logits_scaling}")
1055
+ print(f" audio_token_id={model.config.audio_token_id} hidden_size={text_cfg.hidden_size}")
1056
+
1057
+ # Sample inputs for tracing.
1058
+ sample_inputs = processor(
1059
+ USER_PROMPT_TRANSCRIBE, [waveform], sampling_rate=16000, return_tensors="pt"
1060
+ )
1061
+ sample_features = sample_inputs["input_features"].to(torch.float32)
1062
+
1063
+ parity_payload: dict[str, Any] = {
1064
+ "abs_tol": args.abs_tol,
1065
+ "stages_run": sorted(stages),
1066
+ "input_features_shape": list(sample_features.shape),
1067
+ }
1068
+
1069
+ # ----- Encoder export + parity -----
1070
+ if "encoder" in stages:
1071
+ wrapper = EncoderProjectorWrapper(
1072
+ encoder=model.encoder,
1073
+ projector=model.projector,
1074
+ window_size=int(model.config.window_size),
1075
+ downsample_rate=int(model.config.downsample_rate),
1076
+ ).eval()
1077
+
1078
+ if not args.skip_export:
1079
+ with torch.inference_mode():
1080
+ export_encoder(
1081
+ wrapper=wrapper,
1082
+ sample_input_features=sample_features,
1083
+ out_path=encoder_path,
1084
+ opset=20,
1085
+ ir_version=10,
1086
+ )
1087
+
1088
+ parity_payload["encoder"] = encoder_parity(
1089
+ wrapper=wrapper,
1090
+ processor=processor,
1091
+ waveform=waveform,
1092
+ onnx_path=encoder_path,
1093
+ abs_tol=args.abs_tol,
1094
+ argmax_only=bool(suffix),
1095
+ )
1096
+
1097
+ # ----- LLM (prompt + decode) export -----
1098
+ if {"prompt", "decode"} & stages and not args.skip_export:
1099
+ # Build a sample inputs_embeds + position_ids by running encoder + splice.
1100
+ print("\nbuilding sample inputs_embeds for LLM export trace...")
1101
+ sample_embeds, sample_pos_ids, _info = build_inputs_embeds(model, processor, waveform)
1102
+ N = sample_embeds.shape[1]
1103
+ sample_attn_4d = _build_causal_mask_4d(
1104
+ torch.ones((1, N), dtype=torch.long), T_past=0, dtype=torch.float32
1105
+ )
1106
+
1107
+ if "prompt" in stages:
1108
+ prompt_wrapper = PromptEncodeWrapper(
1109
+ llm_model=model.language_model.model,
1110
+ lm_head=model.language_model.lm_head,
1111
+ num_layers=num_layers,
1112
+ logits_scaling=logits_scaling,
1113
+ ).eval()
1114
+ with torch.inference_mode():
1115
+ export_prompt_encode(
1116
+ wrapper=prompt_wrapper,
1117
+ sample_inputs_embeds=sample_embeds,
1118
+ sample_position_ids=sample_pos_ids,
1119
+ sample_attention_mask=sample_attn_4d,
1120
+ out_path=prompt_path,
1121
+ num_layers=num_layers,
1122
+ opset=20,
1123
+ ir_version=10,
1124
+ )
1125
+
1126
+ if "decode" in stages:
1127
+ # We need a sample past_kv set for the decode_step trace; harvest by
1128
+ # running the prompt wrapper once.
1129
+ prompt_wrapper = PromptEncodeWrapper(
1130
+ llm_model=model.language_model.model,
1131
+ lm_head=model.language_model.lm_head,
1132
+ num_layers=num_layers,
1133
+ logits_scaling=logits_scaling,
1134
+ ).eval()
1135
+ with torch.inference_mode():
1136
+ p_outs = prompt_wrapper(sample_embeds, sample_pos_ids, sample_attn_4d)
1137
+ sample_past_kv_flat = tuple(t.detach().clone() for t in p_outs[1:])
1138
+ assert len(sample_past_kv_flat) == 2 * num_layers
1139
+
1140
+ embed_tokens = model.language_model.model.embed_tokens
1141
+ with torch.inference_mode():
1142
+ sample_step_embed = (
1143
+ embed_tokens(torch.tensor([[0]], dtype=torch.long)).to(torch.float32)
1144
+ )
1145
+ sample_step_pos = torch.tensor([[N]], dtype=torch.long)
1146
+ sample_step_attn_2d = torch.ones((1, N + 1), dtype=torch.long)
1147
+ sample_step_attn_4d = _build_causal_mask_4d(
1148
+ sample_step_attn_2d, T_past=N, dtype=torch.float32
1149
+ )
1150
+
1151
+ decode_wrapper = DecodeStepWrapper(
1152
+ llm_model=model.language_model.model,
1153
+ lm_head=model.language_model.lm_head,
1154
+ num_layers=num_layers,
1155
+ logits_scaling=logits_scaling,
1156
+ ).eval()
1157
+ with torch.inference_mode():
1158
+ export_decode_step(
1159
+ wrapper=decode_wrapper,
1160
+ sample_inputs_embeds=sample_step_embed,
1161
+ sample_position_ids=sample_step_pos,
1162
+ sample_attention_mask=sample_step_attn_4d,
1163
+ sample_past_kv_flat=sample_past_kv_flat,
1164
+ out_path=decode_path,
1165
+ num_layers=num_layers,
1166
+ opset=20,
1167
+ ir_version=10,
1168
+ )
1169
+
1170
+ # ----- end-to-end LLM parity -----
1171
+ if {"prompt", "decode"} <= stages and prompt_path.exists() and decode_path.exists():
1172
+ parity_payload["llm_e2e"] = llm_parity_e2e(
1173
+ model=model,
1174
+ processor=processor,
1175
+ waveform=waveform,
1176
+ prompt_onnx=prompt_path,
1177
+ decode_onnx=decode_path,
1178
+ baseline_json=Path(args.baseline),
1179
+ max_new_tokens=args.max_new_tokens,
1180
+ abs_tol=args.abs_tol,
1181
+ argmax_only=bool(suffix),
1182
+ )
1183
+
1184
+ # ----- Per-graph size + int8-vs-fp32 deltas (only when graph-suffix set) -----
1185
+ if suffix:
1186
+ parity_payload["graph_suffix"] = suffix
1187
+ parity_payload["graphs"] = {}
1188
+ for label, p in (
1189
+ ("encoder", encoder_path),
1190
+ ("prompt_encode", prompt_path),
1191
+ ("decode_step", decode_path),
1192
+ ):
1193
+ if not p.exists():
1194
+ continue
1195
+ data = p.with_name(p.name + "_data")
1196
+ entry = {
1197
+ "graph_path": str(p),
1198
+ "graph_size_bytes": int(p.stat().st_size),
1199
+ "sidecar_path": str(data) if data.exists() else None,
1200
+ "int8_size_bytes": int(data.stat().st_size) if data.exists() else None,
1201
+ }
1202
+ fp32 = p.with_name(p.name.replace(suffix, ""))
1203
+ fp32_data = fp32.with_name(fp32.name + "_data")
1204
+ if fp32.exists() and fp32_data.exists():
1205
+ entry["fp32_sidecar_path"] = str(fp32_data)
1206
+ entry["fp32_size_bytes"] = int(fp32_data.stat().st_size)
1207
+ if entry["int8_size_bytes"]:
1208
+ entry["size_ratio"] = entry["int8_size_bytes"] / entry["fp32_size_bytes"]
1209
+ parity_payload["graphs"][label] = entry
1210
+
1211
+ # ----- Write parity report -----
1212
+ parity_json.parent.mkdir(parents=True, exist_ok=True)
1213
+ parity_json.write_text(json.dumps(parity_payload, indent=2))
1214
+ print(f"\nwrote parity report -> {parity_json}")
1215
+
1216
+ # ----- Final summary -----
1217
+ failures = []
1218
+ print("\n--- summary ---")
1219
+ if "encoder" in parity_payload:
1220
+ e = parity_payload["encoder"]
1221
+ print(f" encoder: {'PASS' if e['ok'] else 'FAIL'} max_abs_err={e['max_abs_err']:.3e}")
1222
+ if not e["ok"]:
1223
+ failures.append("encoder")
1224
+ if "llm_e2e" in parity_payload:
1225
+ l = parity_payload["llm_e2e"]
1226
+ print(
1227
+ f" llm_e2e: {'PASS' if l['ok'] else 'FAIL'} "
1228
+ f"prompt_argmax_mm={l['prompt_argmax_mismatches']} "
1229
+ f"step_argmax_mm={l['overall_argmax_mismatches_step']} "
1230
+ f"transcript_match={l['transcript_match']}"
1231
+ )
1232
+ if not l["ok"]:
1233
+ failures.append("llm_e2e")
1234
+ if failures:
1235
+ raise SystemExit(f"failed: {failures}")
1236
+
1237
+
1238
+ if __name__ == "__main__":
1239
+ main()
granite_export_metadata.json ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "variant": "base",
3
+ "upstream": {
4
+ "repo": "ibm-granite/granite-speech-4.1-2b",
5
+ "url": "https://huggingface.co/ibm-granite/granite-speech-4.1-2b",
6
+ "license": "Apache-2.0"
7
+ },
8
+ "topology": "encoder + prompt_encode + decode_step (autoregressive)",
9
+ "graphs": [
10
+ {
11
+ "name": "encoder.onnx",
12
+ "sidecar": "encoder.onnx_data",
13
+ "precision": "fp32",
14
+ "size_bytes": 912937,
15
+ "sidecar_size_bytes": 1903334768,
16
+ "opset": 20,
17
+ "ir_version": 10,
18
+ "ai_onnx_only": true,
19
+ "inputs": [
20
+ {
21
+ "name": "input_features",
22
+ "shape": [
23
+ "B",
24
+ "T",
25
+ 160
26
+ ],
27
+ "dtype": "float32"
28
+ }
29
+ ],
30
+ "outputs": [
31
+ {
32
+ "name": "audio_embeds",
33
+ "shape": [
34
+ "B",
35
+ "T_audio",
36
+ 2048
37
+ ],
38
+ "dtype": "float32"
39
+ },
40
+ {
41
+ "name": "audio_embed_sizes",
42
+ "shape": [
43
+ "B"
44
+ ],
45
+ "dtype": "int64"
46
+ }
47
+ ]
48
+ },
49
+ {
50
+ "name": "encoder_int8.onnx",
51
+ "sidecar": "encoder_int8.onnx_data",
52
+ "precision": "int8-weights-only",
53
+ "size_bytes": 2608070,
54
+ "sidecar_size_bytes": 787117424,
55
+ "opset": 20,
56
+ "ir_version": 10,
57
+ "ai_onnx_only": true,
58
+ "inputs": [
59
+ {
60
+ "name": "input_features",
61
+ "shape": [
62
+ "B",
63
+ "T",
64
+ 160
65
+ ],
66
+ "dtype": "float32"
67
+ }
68
+ ],
69
+ "outputs": [
70
+ {
71
+ "name": "audio_embeds",
72
+ "shape": [
73
+ "B",
74
+ "T_audio",
75
+ 2048
76
+ ],
77
+ "dtype": "float32"
78
+ },
79
+ {
80
+ "name": "audio_embed_sizes",
81
+ "shape": [
82
+ "B"
83
+ ],
84
+ "dtype": "int64"
85
+ }
86
+ ]
87
+ },
88
+ {
89
+ "name": "prompt_encode.onnx",
90
+ "sidecar": "prompt_encode.onnx_data",
91
+ "precision": "fp32",
92
+ "size_bytes": 1844341,
93
+ "sidecar_size_bytes": 6527008768,
94
+ "opset": 20,
95
+ "ir_version": 10,
96
+ "ai_onnx_only": true,
97
+ "inputs": [
98
+ {
99
+ "name": "inputs_embeds",
100
+ "shape": [
101
+ "B",
102
+ "N",
103
+ 2048
104
+ ],
105
+ "dtype": "float32"
106
+ },
107
+ {
108
+ "name": "position_ids",
109
+ "shape": [
110
+ "B",
111
+ "N"
112
+ ],
113
+ "dtype": "int64"
114
+ },
115
+ {
116
+ "name": "attention_mask",
117
+ "shape": [
118
+ "B",
119
+ 1,
120
+ "N",
121
+ "N"
122
+ ],
123
+ "dtype": "float32"
124
+ }
125
+ ],
126
+ "outputs": [
127
+ {
128
+ "name": "logits",
129
+ "shape": [
130
+ "B",
131
+ "N",
132
+ 100353
133
+ ],
134
+ "dtype": "float32"
135
+ },
136
+ {
137
+ "name": "present.{i}.{key,value}",
138
+ "shape": [
139
+ "B",
140
+ 4,
141
+ "N",
142
+ 128
143
+ ],
144
+ "dtype": "float32",
145
+ "note": "40 layers x 2 (key, value) = 80 KV-cache outputs"
146
+ }
147
+ ]
148
+ },
149
+ {
150
+ "name": "prompt_encode_int8.onnx",
151
+ "sidecar": "prompt_encode_int8.onnx_data",
152
+ "precision": "int8-weights-only",
153
+ "size_bytes": 6419491,
154
+ "sidecar_size_bytes": 1632249856,
155
+ "opset": 20,
156
+ "ir_version": 10,
157
+ "ai_onnx_only": true,
158
+ "inputs": [
159
+ {
160
+ "name": "inputs_embeds",
161
+ "shape": [
162
+ "B",
163
+ "N",
164
+ 2048
165
+ ],
166
+ "dtype": "float32"
167
+ },
168
+ {
169
+ "name": "position_ids",
170
+ "shape": [
171
+ "B",
172
+ "N"
173
+ ],
174
+ "dtype": "int64"
175
+ },
176
+ {
177
+ "name": "attention_mask",
178
+ "shape": [
179
+ "B",
180
+ 1,
181
+ "N",
182
+ "N"
183
+ ],
184
+ "dtype": "float32"
185
+ }
186
+ ],
187
+ "outputs": [
188
+ {
189
+ "name": "logits",
190
+ "shape": [
191
+ "B",
192
+ "N",
193
+ 100353
194
+ ],
195
+ "dtype": "float32"
196
+ },
197
+ {
198
+ "name": "present.{i}.{key,value}",
199
+ "shape": [
200
+ "B",
201
+ 4,
202
+ "N",
203
+ 128
204
+ ],
205
+ "dtype": "float32",
206
+ "note": "40 layers x 2 (key, value) = 80 KV-cache outputs"
207
+ }
208
+ ]
209
+ },
210
+ {
211
+ "name": "decode_step.onnx",
212
+ "sidecar": "decode_step.onnx_data",
213
+ "precision": "fp32",
214
+ "size_bytes": 1849786,
215
+ "sidecar_size_bytes": 6527008768,
216
+ "opset": 20,
217
+ "ir_version": 10,
218
+ "ai_onnx_only": true,
219
+ "inputs": [
220
+ {
221
+ "name": "inputs_embeds",
222
+ "shape": [
223
+ "B",
224
+ 1,
225
+ 2048
226
+ ],
227
+ "dtype": "float32"
228
+ },
229
+ {
230
+ "name": "position_ids",
231
+ "shape": [
232
+ "B",
233
+ 1
234
+ ],
235
+ "dtype": "int64"
236
+ },
237
+ {
238
+ "name": "attention_mask",
239
+ "shape": [
240
+ "B",
241
+ 1,
242
+ 1,
243
+ "T_total"
244
+ ],
245
+ "dtype": "float32"
246
+ },
247
+ {
248
+ "name": "past_key_values.{i}.{key,value}",
249
+ "shape": [
250
+ "B",
251
+ 4,
252
+ "T_past",
253
+ 128
254
+ ],
255
+ "dtype": "float32",
256
+ "note": "40 layers x 2 = 80 KV-cache inputs"
257
+ }
258
+ ],
259
+ "outputs": [
260
+ {
261
+ "name": "logits",
262
+ "shape": [
263
+ "B",
264
+ 1,
265
+ 100353
266
+ ],
267
+ "dtype": "float32"
268
+ },
269
+ {
270
+ "name": "present.{i}.{key,value}",
271
+ "shape": [
272
+ "B",
273
+ 4,
274
+ "T_total",
275
+ 128
276
+ ],
277
+ "dtype": "float32",
278
+ "note": "40 layers x 2 = 80 KV-cache outputs"
279
+ }
280
+ ]
281
+ },
282
+ {
283
+ "name": "decode_step_int8.onnx",
284
+ "sidecar": "decode_step_int8.onnx_data",
285
+ "precision": "int8-weights-only",
286
+ "size_bytes": 6426226,
287
+ "sidecar_size_bytes": 1632249856,
288
+ "opset": 20,
289
+ "ir_version": 10,
290
+ "ai_onnx_only": true,
291
+ "inputs": [
292
+ {
293
+ "name": "inputs_embeds",
294
+ "shape": [
295
+ "B",
296
+ 1,
297
+ 2048
298
+ ],
299
+ "dtype": "float32"
300
+ },
301
+ {
302
+ "name": "position_ids",
303
+ "shape": [
304
+ "B",
305
+ 1
306
+ ],
307
+ "dtype": "int64"
308
+ },
309
+ {
310
+ "name": "attention_mask",
311
+ "shape": [
312
+ "B",
313
+ 1,
314
+ 1,
315
+ "T_total"
316
+ ],
317
+ "dtype": "float32"
318
+ },
319
+ {
320
+ "name": "past_key_values.{i}.{key,value}",
321
+ "shape": [
322
+ "B",
323
+ 4,
324
+ "T_past",
325
+ 128
326
+ ],
327
+ "dtype": "float32",
328
+ "note": "40 layers x 2 = 80 KV-cache inputs"
329
+ }
330
+ ],
331
+ "outputs": [
332
+ {
333
+ "name": "logits",
334
+ "shape": [
335
+ "B",
336
+ 1,
337
+ 100353
338
+ ],
339
+ "dtype": "float32"
340
+ },
341
+ {
342
+ "name": "present.{i}.{key,value}",
343
+ "shape": [
344
+ "B",
345
+ 4,
346
+ "T_total",
347
+ 128
348
+ ],
349
+ "dtype": "float32",
350
+ "note": "40 layers x 2 = 80 KV-cache outputs"
351
+ }
352
+ ]
353
+ }
354
+ ],
355
+ "parity": {
356
+ "fp32": {
357
+ "encoder": {
358
+ "argmax_only": false,
359
+ "max_abs_err": 4.481524229049683e-06,
360
+ "mean_abs_err": 1.243776637238625e-07,
361
+ "p99_abs_err": 6.463378667831421e-07,
362
+ "audio_embed_sizes_match": true,
363
+ "input_features_shape": [
364
+ 1,
365
+ 844,
366
+ 160
367
+ ],
368
+ "audio_embeds_shape": [
369
+ 1,
370
+ 171,
371
+ 2048
372
+ ]
373
+ },
374
+ "llm_e2e": {
375
+ "argmax_only": false,
376
+ "prompt_argmax_mismatches": 0,
377
+ "prompt_argmax_total": 190,
378
+ "prompt_logits_max_abs_err": 0.000364,
379
+ "decode_steps": 51,
380
+ "decode_argmax_mismatches": 0,
381
+ "decode_max_abs_err_step": null,
382
+ "tokens_match": true,
383
+ "transcript_match": true,
384
+ "source_note": "from task-11 record (parity.json was overwritten by a later encoder-only re-run; see dev-plan.md)"
385
+ }
386
+ },
387
+ "int8": {
388
+ "encoder": {
389
+ "argmax_only": true,
390
+ "max_abs_err": 0.16911625862121582,
391
+ "mean_abs_err": 0.010853650979697704,
392
+ "p99_abs_err": 0.044730618596076965
393
+ },
394
+ "llm_e2e": {
395
+ "argmax_only": true,
396
+ "prompt_argmax_mismatches": 58,
397
+ "prompt_argmax_total": 190,
398
+ "prompt_logits_max_abs_err": 10.136197090148926,
399
+ "prompt_logits_mean_abs_err": 0.778969943523407,
400
+ "decode_steps": 51,
401
+ "decode_argmax_mismatches": 0,
402
+ "decode_max_abs_err_step": 5.762608528137207,
403
+ "tokens_match": true,
404
+ "transcript_match": true
405
+ }
406
+ }
407
+ },
408
+ "multi_clip_parity": {
409
+ "rows": [
410
+ {
411
+ "name": "is-it-more-wood",
412
+ "duration_s": 46.9,
413
+ "fp32_byte_exact_vs_pt": true,
414
+ "int8_byte_exact_vs_pt": false,
415
+ "int8_wer_vs_pt": 0.0144,
416
+ "int8_vs_fp32_lev": 2,
417
+ "fp32_transcript": "Well, hello, Sam. Guess who? Yeah, it's Robert Clotworthy, the narrator of your favorite television show, \"The Curse of Oak Island.\" Yes, I'm the. Is it possible? Could it be? And what else do we say in Oak Island? A couple of words. They're not coming to me. Oh, yeah. More wood. But let's not forget. It is an island named after a tree. Well, here's the question. Why am I reaching out to you? Is it possible that I'm reaching out to you because it's your birthday? Could it be that Emma let the cat out of the bag? Well, the answer to those questions is yes. And she said, well, she contacted me. She said, Robert, you know, Sam is an amazing boyfriend. In fact, she used the word great. She said he is a great boyfriend.",
418
+ "int8_transcript": "Well, hello, Sam. Guess who? Yeah, it's Robert Clotworthy, the narrator of your favorite television show, \"The Curse of Oak Island.\" Yes, I'm the. Is it possible? Could it be? And what else do we say in Oak Island? A couple of words. They're not coming to me. Oh, yeah. More wood. But let's not forget. It is an island named after a tree. Well, here's the question. Why am I reaching out to you? Is it possible that I'm reaching out to you because it's your birthday? Could it be that Emma let the cat out of the bag? Well, the answer to those questions is yes. And she said, well, she contacted me. She said, Robert. You know, Sam is an amazing boyfriend. In fact, she used the word great. She said he is a great boyfriend.",
419
+ "pt_transcript": "Well, hello, Sam. Guess who? Yeah, it's Robert Clotworthy, the narrator of your favorite television show, \"The Curse of Oak Island.\" Yes, I'm the. Is it possible? Could it be? And what else do we say in Oak Island? A couple of words. They're not coming to me. Oh, yeah. More wood. But let's not forget. It is an island named after a tree. Well, here's the question. Why am I reaching out to you? Is it possible that I'm reaching out to you because it's your birthday? Could it be that Emma let the cat out of the bag? Well, the answer to those questions is yes. And she said, well, she contacted me. She said, Robert, you know, Sam is an amazing boyfriend. In fact, she used the word great. She said he is a great boyfriend."
420
+ },
421
+ {
422
+ "name": "two-speakers-1",
423
+ "duration_s": 93.8,
424
+ "fp32_byte_exact_vs_pt": true,
425
+ "int8_byte_exact_vs_pt": false,
426
+ "int8_wer_vs_pt": 0.0104,
427
+ "int8_vs_fp32_lev": 12,
428
+ "fp32_transcript": "Today it is a true honor to speak with Demis Asavis, who is the CEO of DeepMind. Demis, welcome to the podcast. Thanks for having me. First question, given your neuroscience background, how do you think about intelligence? Specifically, do you think it's like one higher level general reasoning circuit, or do you think it's thousands of independent subskills and heuristics? Well, it's interesting because intelligence is so broad and, you know, what we use it for is so sort of generally applicable. I think that suggests that, you know, there must be some sort of high-level common things in, you know, common kind of algorithmic themes, I think, around how the brain processes the world around us. So, of course, then there are specialized parts of the brain that do specific things, but I think there are probably some underlying principles that underpin all of that. Yeah. How do you make sense of the fact that in these LLMs, though, when you give them a lot of data in any specific domain, they tend to get asymmetrically better in that domain? Wouldn't we expect a sort of like general improvement across all the different areas? Well, I think you, first of all, I think you do actually sometimes get surprising improvement in other domains when you improve in a specific domain. So, for example, when these large models sort of improve at coding, that can actually improve their general reasoning. So there is some evidence of some transfer, although I think we would like a lot more evidence of that. But also, you know, that's how the human brain learns, too, is if we experience and practice a lot of things like chess or, you know, writing.",
429
+ "int8_transcript": "Today it is a true honor to speak with Demis Savas, who is the CEO of DeepMind. Demis, welcome to the podcast. Thanks for having me. First question, given your neuroscience background, how do you think about intelligence? Specifically, do you think it's like one higher level general reasoning circuit, or do you think it's thousands of independent subskills and heuristics? Well, it's interesting because intelligence is so broad and, you know, what we use it for is so sort of generally applicable. I think that suggests that, you know, there must be some sort of high-level common things in, you know, common kind of algorithmic themes, I think, around how the brain processes the world around us. So, of course, then there are specialized parts of the brain that do specific things, but I think there are probably some underlying principles that underpin all of that. Yeah. How do you make sense of the fact that in these LLMs, though, when you give them a lot of data in any specific domain, they tend to get asymmetrically better in that domain? Wouldn't we expect a sort of like general improvement across all the, all the different areas? Well, I think you, first of all, I think you do actually sometimes get surprising improvement in other domains when you improve in a specific domain. So, for example, when these large models sort of improve at coding, that can actually improve their general reasoning. So there is some evidence of some transfer, although I think we would like a lot more evidence of that. But also, you know, that's how the human brain learns, too, is if we experience and practice a lot of things like chess or, you know, writing.",
430
+ "pt_transcript": "Today it is a true honor to speak with Demis Asavis, who is the CEO of DeepMind. Demis, welcome to the podcast. Thanks for having me. First question, given your neuroscience background, how do you think about intelligence? Specifically, do you think it's like one higher level general reasoning circuit, or do you think it's thousands of independent subskills and heuristics? Well, it's interesting because intelligence is so broad and, you know, what we use it for is so sort of generally applicable. I think that suggests that, you know, there must be some sort of high-level common things in, you know, common kind of algorithmic themes, I think, around how the brain processes the world around us. So, of course, then there are specialized parts of the brain that do specific things, but I think there are probably some underlying principles that underpin all of that. Yeah. How do you make sense of the fact that in these LLMs, though, when you give them a lot of data in any specific domain, they tend to get asymmetrically better in that domain? Wouldn't we expect a sort of like general improvement across all the different areas? Well, I think you, first of all, I think you do actually sometimes get surprising improvement in other domains when you improve in a specific domain. So, for example, when these large models sort of improve at coding, that can actually improve their general reasoning. So there is some evidence of some transfer, although I think we would like a lot more evidence of that. But also, you know, that's how the human brain learns, too, is if we experience and practice a lot of things like chess or, you know, writing."
431
+ },
432
+ {
433
+ "name": "two-speakers-2",
434
+ "duration_s": 38.8,
435
+ "fp32_byte_exact_vs_pt": true,
436
+ "int8_byte_exact_vs_pt": false,
437
+ "int8_wer_vs_pt": 0.2347,
438
+ "int8_vs_fp32_lev": 26,
439
+ "fp32_transcript": "For the first time ever, we may have things more intelligent than us. You believe they can understand. Yes. You believe they are intelligent. Yes. You believe these systems have experiences of their own and can make decisions based on those experiences. In the same sense as people do, yes. Are they conscious? I think they probably don't have much self-awareness at present. So in that sense, I don't think they're conscious. Will they have self-awareness? Oh, yes. I think they will in time. And so human beings will be the second most intelligent beings on the planet.",
440
+ "int8_transcript": "for the first time ever we may have things more intelligent than us. You believe they can understand yes you believe they are intelligent yes you believe these systems have experiences of their own and can make decisions based on those experiences in the same sense as people do yes are they conscious I think they probably don't have much self-awareness at present so in that sense I don't think they're conscious. will they have self-awareness oh yes I think they will in time and so human beings will be the second most intelligent beings on the planet.",
441
+ "pt_transcript": "For the first time ever, we may have things more intelligent than us. You believe they can understand. Yes. You believe they are intelligent. Yes. You believe these systems have experiences of their own and can make decisions based on those experiences. In the same sense as people do, yes. Are they conscious? I think they probably don't have much self-awareness at present. So in that sense, I don't think they're conscious. Will they have self-awareness? Oh, yes. I think they will in time. And so human beings will be the second most intelligent beings on the planet."
442
+ }
443
+ ]
444
+ },
445
+ "toolchain": {
446
+ "transformers": "5.8.0",
447
+ "torch": "2.11.0",
448
+ "onnx": "1.21.0",
449
+ "onnxruntime": "1.25.1",
450
+ "exporter": "torch.onnx.export TorchScript path (dynamo=False)"
451
+ },
452
+ "ort_compatibility": "ort 2.0-rc.x (Rust crate); validated against onnxruntime 1.17 - 1.25",
453
+ "audio_token_id": 100352
454
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feature_extractor_type": "GraniteSpeechFeatureExtractor",
3
+ "melspec_kwargs": {
4
+ "hop_length": 160,
5
+ "n_fft": 512,
6
+ "n_mels": 80,
7
+ "sample_rate": 16000,
8
+ "win_length": 400
9
+ },
10
+ "processor_class": "GraniteSpeechProcessor",
11
+ "projector_downsample_rate": 5,
12
+ "projector_window_size": 15,
13
+ "sampling_rate": 16000
14
+ }
processor_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "audio_token": "<|audio|>",
3
+ "processor_class": "GraniteSpeechProcessor"
4
+ }
prompt_encode.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59adb1f8dca67a1e910a74adf94852cb4cd85fded0e0ca65d522d97779073b07
3
+ size 1844341
prompt_encode.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87d924ecd71746694f43e653c9366827a9444ab9407e976f5cd9cc9dbde97608
3
+ size 6527008768
prompt_encode_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3942d5a402302e5505dee6d4c79cd80f1487546656f70106f4384dbc8cd82982
3
+ size 6419491
prompt_encode_int8.onnx_data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70652d6a31cbae2d57c7e8cefb665f6c1ee503e495d191b951fff09ddb7f8608
3
+ size 1632249856
quantise.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Sam McLeod
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Dynamic INT8 (weights-only) quantiser for the Granite Speech 4.1 ONNX
15
+ exports.
16
+
17
+ Wraps `onnxruntime.quantization.quantize_dynamic` with the conventions used
18
+ by the Granite Speech ONNX bundles:
19
+
20
+ - Single external-data sidecar per graph (mirrors the FP32 export layout).
21
+ - Pure `ai.onnx` opset 20 / IR 10. The default operator set is restricted
22
+ to `MatMul` so the dynamic quantiser emits `MatMulInteger` (standard
23
+ `ai.onnx`) rather than the `com.microsoft.Attention` /
24
+ `com.microsoft.EmbedLayerNormalization` quantised variants. Override at
25
+ your own risk - those domain ops are forbidden by the parakeet-rs
26
+ consumer contract.
27
+ - `per_channel=True` and `weight_type=QInt8` by default (better accuracy
28
+ on the LLM weight tensors with no measurable speed cost on
29
+ arm64 / x86 CPU EP).
30
+
31
+ The script is self-contained (no project-internal imports) so it ships
32
+ inside each Hugging Face bundle alongside the export script.
33
+
34
+ Usage:
35
+ python quantise.py --input PATH --output PATH \\
36
+ [--per-channel | --no-per-channel] \\
37
+ [--reduce-range] \\
38
+ [--weight-type qint8|quint8] \\
39
+ [--op-types MatMul,Gemm] \\
40
+ [--exclude-pattern REGEX] \\
41
+ [--exclude-nodes NODE1,NODE2]
42
+
43
+ Examples:
44
+ # Quantise the NAR editor with defaults.
45
+ python quantise.py \\
46
+ --input exports/granite-speech-4.1-2b-nar/editor.onnx \\
47
+ --output exports/granite-speech-4.1-2b-nar/editor_int8.onnx
48
+
49
+ # Skip the lm_head MatMul if it hurts parity.
50
+ python quantise.py \\
51
+ --input exports/granite-speech-4.1-2b-nar/editor.onnx \\
52
+ --output exports/granite-speech-4.1-2b-nar/editor_int8.onnx \\
53
+ --exclude-nodes /lm_head/MatMul
54
+ """
55
+
56
+ from __future__ import annotations
57
+
58
+ import argparse
59
+ import re
60
+ import sys
61
+ import tempfile
62
+ import time
63
+ from pathlib import Path
64
+
65
+ import onnx
66
+ from onnxruntime.quantization import QuantType, quantize_dynamic
67
+
68
+
69
+ WEIGHT_TYPE_MAP = {
70
+ "qint8": QuantType.QInt8,
71
+ "quint8": QuantType.QUInt8,
72
+ }
73
+
74
+
75
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
76
+ p = argparse.ArgumentParser(
77
+ description="Dynamic INT8 (weights-only) ONNX quantiser for Granite Speech 4.1 graphs.",
78
+ )
79
+ p.add_argument(
80
+ "--input",
81
+ required=True,
82
+ type=Path,
83
+ help="Path to the FP32 .onnx graph (external sidecar must sit alongside it).",
84
+ )
85
+ p.add_argument(
86
+ "--output",
87
+ required=True,
88
+ type=Path,
89
+ help="Destination .onnx path. A single sidecar named <output>_data is written next to it.",
90
+ )
91
+ p.add_argument(
92
+ "--per-channel",
93
+ dest="per_channel",
94
+ action="store_true",
95
+ default=True,
96
+ help="Quantise weights per output channel (default: on).",
97
+ )
98
+ p.add_argument(
99
+ "--no-per-channel",
100
+ dest="per_channel",
101
+ action="store_false",
102
+ help="Disable per-channel quantisation.",
103
+ )
104
+ p.add_argument(
105
+ "--reduce-range",
106
+ action="store_true",
107
+ default=False,
108
+ help="Quantise to 7 bits instead of 8. Improves accuracy on non-VNNI hardware "
109
+ "but reduces the quantisation gain. Off by default.",
110
+ )
111
+ p.add_argument(
112
+ "--weight-type",
113
+ choices=sorted(WEIGHT_TYPE_MAP.keys()),
114
+ default="qint8",
115
+ help="Weight quantisation dtype (default: qint8).",
116
+ )
117
+ p.add_argument(
118
+ "--op-types",
119
+ default="MatMul",
120
+ help=(
121
+ "Comma-separated op types to quantise. Default: 'MatMul' (emits "
122
+ "MatMulInteger only, all ai.onnx). Adding 'Conv' enables ConvInteger "
123
+ "for the Conformer encoder's depthwise convolutions; this shrinks the "
124
+ "encoder INT8 sidecar by ~40 percent but on this model family feeds "
125
+ "enough weight-quant noise into the LLM head that it flips "
126
+ "capitalisation and drops sentence-final punctuation on short clips - "
127
+ "see task 17 in dev-plan.md. MatMul-only is the validated default. "
128
+ "Adding 'Attention' or 'EmbedLayerNormalization' would introduce "
129
+ "com.microsoft domain ops, which are forbidden by the parakeet-rs "
130
+ "contract."
131
+ ),
132
+ )
133
+ p.add_argument(
134
+ "--exclude-pattern",
135
+ default=None,
136
+ help="Regex applied to ONNX node names. Matching nodes are excluded from "
137
+ "quantisation. Useful for skipping e.g. lm_head if its quantisation "
138
+ "breaks parity.",
139
+ )
140
+ p.add_argument(
141
+ "--exclude-nodes",
142
+ default="",
143
+ help="Explicit comma-separated list of node names to exclude from quantisation.",
144
+ )
145
+ p.add_argument(
146
+ "--ir-version",
147
+ type=int,
148
+ default=10,
149
+ help="ONNX IR version to write (default: 10, matches the FP32 exports).",
150
+ )
151
+ return p.parse_args(argv)
152
+
153
+
154
+ def collect_excluded_nodes(
155
+ input_path: Path,
156
+ exclude_pattern: str | None,
157
+ exclude_nodes: list[str],
158
+ ) -> list[str]:
159
+ """Resolve --exclude-pattern against the FP32 graph's node names and merge
160
+ with the explicit --exclude-nodes list. Loaded without external data so we
161
+ only touch the small graph proto.
162
+ """
163
+ excluded = set(n for n in exclude_nodes if n)
164
+ if exclude_pattern:
165
+ rx = re.compile(exclude_pattern)
166
+ proto = onnx.load(str(input_path), load_external_data=False)
167
+ for node in proto.graph.node:
168
+ if node.name and rx.search(node.name):
169
+ excluded.add(node.name)
170
+ return sorted(excluded)
171
+
172
+
173
+ def assert_pure_ai_onnx(model_path: Path) -> list[str]:
174
+ """Reload the produced graph and verify only `ai.onnx` nodes are present.
175
+ Returns the sorted list of domains for reporting.
176
+ """
177
+ proto = onnx.load(str(model_path), load_external_data=False)
178
+ domains = sorted({(n.domain or "ai.onnx") for n in proto.graph.node})
179
+ forbidden = [d for d in domains if d not in ("ai.onnx", "")]
180
+ if forbidden:
181
+ raise RuntimeError(
182
+ f"Quantised graph contains forbidden op domains {forbidden}. "
183
+ "Re-run with a narrower --op-types list."
184
+ )
185
+ return domains
186
+
187
+
188
+ def consolidate_single_sidecar(
189
+ quantised_in: Path,
190
+ final_out: Path,
191
+ ir_version: int,
192
+ ) -> None:
193
+ """The dynamic quantiser may scatter weights across multiple external-data
194
+ files. Reload + resave through a tempdir to land on the single-sidecar
195
+ layout that matches the FP32 exports.
196
+ """
197
+ print(" consolidating to single .onnx_data sidecar")
198
+ proto = onnx.load(str(quantised_in), load_external_data=True)
199
+ if proto.ir_version < ir_version:
200
+ proto.ir_version = ir_version
201
+
202
+ for tensor in proto.graph.initializer:
203
+ tensor.ClearField("data_location")
204
+ tensor.ClearField("external_data")
205
+
206
+ sidecar_name = final_out.name + "_data"
207
+ if (final_out.parent / sidecar_name).exists():
208
+ (final_out.parent / sidecar_name).unlink()
209
+ if final_out.exists():
210
+ final_out.unlink()
211
+ final_out.parent.mkdir(parents=True, exist_ok=True)
212
+
213
+ onnx.save_model(
214
+ proto,
215
+ str(final_out),
216
+ save_as_external_data=True,
217
+ all_tensors_to_one_file=True,
218
+ location=sidecar_name,
219
+ size_threshold=1024,
220
+ convert_attribute=False,
221
+ )
222
+ onnx.checker.check_model(str(final_out), full_check=False)
223
+
224
+
225
+ def quantise_graph(args: argparse.Namespace) -> None:
226
+ input_path: Path = args.input.resolve()
227
+ output_path: Path = args.output.resolve()
228
+ if not input_path.exists():
229
+ raise SystemExit(f"input not found: {input_path}")
230
+ op_types = [s.strip() for s in args.op_types.split(",") if s.strip()]
231
+ explicit_excludes = [s.strip() for s in args.exclude_nodes.split(",") if s.strip()]
232
+
233
+ excluded = collect_excluded_nodes(input_path, args.exclude_pattern, explicit_excludes)
234
+ weight_type = WEIGHT_TYPE_MAP[args.weight_type]
235
+
236
+ print(f"input: {input_path}")
237
+ print(f"output: {output_path}")
238
+ print(f"op_types: {op_types}")
239
+ print(f"per_channel: {args.per_channel}")
240
+ print(f"reduce_range: {args.reduce_range}")
241
+ print(f"weight_type: {args.weight_type}")
242
+ if excluded:
243
+ print(f"excluded nodes ({len(excluded)}): {excluded}")
244
+ else:
245
+ print("excluded nodes: (none)")
246
+
247
+ fp32_size = input_path.stat().st_size
248
+ sidecar = input_path.with_name(input_path.name + "_data")
249
+ fp32_data_size = sidecar.stat().st_size if sidecar.exists() else 0
250
+ print(
251
+ f" fp32 graph={fp32_size / 1e6:.2f} MB "
252
+ f"sidecar={fp32_data_size / 1e9:.2f} GB"
253
+ )
254
+
255
+ with tempfile.TemporaryDirectory(prefix="quantise_int8_") as scratch_dir:
256
+ scratch_path = Path(scratch_dir) / output_path.name
257
+ t0 = time.time()
258
+ quantize_dynamic(
259
+ model_input=input_path,
260
+ model_output=scratch_path,
261
+ op_types_to_quantize=op_types,
262
+ per_channel=args.per_channel,
263
+ reduce_range=args.reduce_range,
264
+ weight_type=weight_type,
265
+ nodes_to_exclude=excluded or None,
266
+ use_external_data_format=True,
267
+ )
268
+ print(f" quantize_dynamic done in {time.time() - t0:.1f}s")
269
+
270
+ # Stage 2: consolidate any scattered external-data files into a single
271
+ # sidecar at the final destination.
272
+ consolidate_single_sidecar(scratch_path, output_path, args.ir_version)
273
+
274
+ # Verify pure ai.onnx after the move.
275
+ domains = assert_pure_ai_onnx(output_path)
276
+ int8_size = output_path.stat().st_size
277
+ int8_data = output_path.with_name(output_path.name + "_data")
278
+ int8_data_size = int8_data.stat().st_size if int8_data.exists() else 0
279
+ print(
280
+ f" saved {output_path} (+ {int8_data.name}) "
281
+ f"graph={int8_size / 1e6:.2f} MB sidecar={int8_data_size / 1e9:.2f} GB"
282
+ )
283
+ print(f" node-domains={domains}")
284
+ if fp32_data_size > 0:
285
+ ratio = int8_data_size / fp32_data_size
286
+ print(f" sidecar size ratio (int8 / fp32) = {ratio:.3f}")
287
+
288
+
289
+ def main(argv: list[str] | None = None) -> None:
290
+ args = parse_args(argv)
291
+ try:
292
+ quantise_graph(args)
293
+ except RuntimeError as exc:
294
+ print(f"FAIL: {exc}", file=sys.stderr)
295
+ raise SystemExit(2) from exc
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|end_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|end_of_text|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "100256": {
6
+ "content": "<|pad|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "100257": {
14
+ "content": "<|end_of_text|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "100258": {
22
+ "content": "<|fim_prefix|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "100259": {
30
+ "content": "<|fim_middle|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": false
36
+ },
37
+ "100260": {
38
+ "content": "<|fim_suffix|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": false
44
+ },
45
+ "100261": {
46
+ "content": "<|fim_pad|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": false
52
+ },
53
+ "100262": {
54
+ "content": "<|filename|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": false
60
+ },
61
+ "100263": {
62
+ "content": "<|reponame|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": false
68
+ },
69
+ "100264": {
70
+ "content": "<|start_of_role|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "100265": {
78
+ "content": "<|end_of_role|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "100266": {
86
+ "content": "<|unused_1|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "100267": {
94
+ "content": "<|start_of_plugin|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "100268": {
102
+ "content": "<|end_of_plugin|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "100269": {
110
+ "content": "<|unk|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "100270": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "100271": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "100272": {
134
+ "content": "<tool_response>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "100273": {
142
+ "content": "</tool_response>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "100274": {
150
+ "content": "<think>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "100275": {
158
+ "content": "</think>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "100276": {
166
+ "content": "<think_on>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": true
172
+ },
173
+ "100277": {
174
+ "content": "<think_off>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": true
180
+ },
181
+ "100278": {
182
+ "content": "<schema>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "100279": {
190
+ "content": "</schema>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "100280": {
198
+ "content": "<tools>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "100281": {
206
+ "content": "</tools>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "100282": {
214
+ "content": "<documents>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "100283": {
222
+ "content": "</documents>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "100284": {
230
+ "content": "<|unused_15|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "100285": {
238
+ "content": "<|unused_16|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "100286": {
246
+ "content": "<|unused_17|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "100287": {
254
+ "content": "<|unused_18|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ },
261
+ "100288": {
262
+ "content": "<|unused_19|>",
263
+ "lstrip": false,
264
+ "normalized": false,
265
+ "rstrip": false,
266
+ "single_word": false,
267
+ "special": true
268
+ },
269
+ "100289": {
270
+ "content": "<|unused_20|>",
271
+ "lstrip": false,
272
+ "normalized": false,
273
+ "rstrip": false,
274
+ "single_word": false,
275
+ "special": true
276
+ },
277
+ "100290": {
278
+ "content": "<|unused_21|>",
279
+ "lstrip": false,
280
+ "normalized": false,
281
+ "rstrip": false,
282
+ "single_word": false,
283
+ "special": true
284
+ },
285
+ "100291": {
286
+ "content": "<|unused_22|>",
287
+ "lstrip": false,
288
+ "normalized": false,
289
+ "rstrip": false,
290
+ "single_word": false,
291
+ "special": true
292
+ },
293
+ "100292": {
294
+ "content": "<|unused_23|>",
295
+ "lstrip": false,
296
+ "normalized": false,
297
+ "rstrip": false,
298
+ "single_word": false,
299
+ "special": true
300
+ },
301
+ "100293": {
302
+ "content": "<|unused_24|>",
303
+ "lstrip": false,
304
+ "normalized": false,
305
+ "rstrip": false,
306
+ "single_word": false,
307
+ "special": true
308
+ },
309
+ "100294": {
310
+ "content": "<|unused_25|>",
311
+ "lstrip": false,
312
+ "normalized": false,
313
+ "rstrip": false,
314
+ "single_word": false,
315
+ "special": true
316
+ },
317
+ "100295": {
318
+ "content": "<|unused_26|>",
319
+ "lstrip": false,
320
+ "normalized": false,
321
+ "rstrip": false,
322
+ "single_word": false,
323
+ "special": true
324
+ },
325
+ "100296": {
326
+ "content": "<|unused_27|>",
327
+ "lstrip": false,
328
+ "normalized": false,
329
+ "rstrip": false,
330
+ "single_word": false,
331
+ "special": true
332
+ },
333
+ "100297": {
334
+ "content": "<|unused_28|>",
335
+ "lstrip": false,
336
+ "normalized": false,
337
+ "rstrip": false,
338
+ "single_word": false,
339
+ "special": true
340
+ },
341
+ "100298": {
342
+ "content": "<|unused_29|>",
343
+ "lstrip": false,
344
+ "normalized": false,
345
+ "rstrip": false,
346
+ "single_word": false,
347
+ "special": true
348
+ },
349
+ "100299": {
350
+ "content": "<|unused_30|>",
351
+ "lstrip": false,
352
+ "normalized": false,
353
+ "rstrip": false,
354
+ "single_word": false,
355
+ "special": true
356
+ },
357
+ "100300": {
358
+ "content": "<|unused_31|>",
359
+ "lstrip": false,
360
+ "normalized": false,
361
+ "rstrip": false,
362
+ "single_word": false,
363
+ "special": true
364
+ },
365
+ "100301": {
366
+ "content": "<|unused_32|>",
367
+ "lstrip": false,
368
+ "normalized": false,
369
+ "rstrip": false,
370
+ "single_word": false,
371
+ "special": true
372
+ },
373
+ "100302": {
374
+ "content": "<|unused_33|>",
375
+ "lstrip": false,
376
+ "normalized": false,
377
+ "rstrip": false,
378
+ "single_word": false,
379
+ "special": true
380
+ },
381
+ "100303": {
382
+ "content": "<|unused_34|>",
383
+ "lstrip": false,
384
+ "normalized": false,
385
+ "rstrip": false,
386
+ "single_word": false,
387
+ "special": true
388
+ },
389
+ "100304": {
390
+ "content": "<|unused_35|>",
391
+ "lstrip": false,
392
+ "normalized": false,
393
+ "rstrip": false,
394
+ "single_word": false,
395
+ "special": true
396
+ },
397
+ "100305": {
398
+ "content": "<|unused_36|>",
399
+ "lstrip": false,
400
+ "normalized": false,
401
+ "rstrip": false,
402
+ "single_word": false,
403
+ "special": true
404
+ },
405
+ "100306": {
406
+ "content": "<|unused_37|>",
407
+ "lstrip": false,
408
+ "normalized": false,
409
+ "rstrip": false,
410
+ "single_word": false,
411
+ "special": true
412
+ },
413
+ "100307": {
414
+ "content": "<|unused_38|>",
415
+ "lstrip": false,
416
+ "normalized": false,
417
+ "rstrip": false,
418
+ "single_word": false,
419
+ "special": true
420
+ },
421
+ "100308": {
422
+ "content": "<|unused_39|>",
423
+ "lstrip": false,
424
+ "normalized": false,
425
+ "rstrip": false,
426
+ "single_word": false,
427
+ "special": true
428
+ },
429
+ "100309": {
430
+ "content": "<|unused_40|>",
431
+ "lstrip": false,
432
+ "normalized": false,
433
+ "rstrip": false,
434
+ "single_word": false,
435
+ "special": true
436
+ },
437
+ "100310": {
438
+ "content": "<|unused_41|>",
439
+ "lstrip": false,
440
+ "normalized": false,
441
+ "rstrip": false,
442
+ "single_word": false,
443
+ "special": true
444
+ },
445
+ "100311": {
446
+ "content": "<|unused_42|>",
447
+ "lstrip": false,
448
+ "normalized": false,
449
+ "rstrip": false,
450
+ "single_word": false,
451
+ "special": true
452
+ },
453
+ "100312": {
454
+ "content": "<|unused_43|>",
455
+ "lstrip": false,
456
+ "normalized": false,
457
+ "rstrip": false,
458
+ "single_word": false,
459
+ "special": true
460
+ },
461
+ "100313": {
462
+ "content": "<|unused_44|>",
463
+ "lstrip": false,
464
+ "normalized": false,
465
+ "rstrip": false,
466
+ "single_word": false,
467
+ "special": true
468
+ },
469
+ "100314": {
470
+ "content": "<|unused_45|>",
471
+ "lstrip": false,
472
+ "normalized": false,
473
+ "rstrip": false,
474
+ "single_word": false,
475
+ "special": true
476
+ },
477
+ "100315": {
478
+ "content": "<|unused_46|>",
479
+ "lstrip": false,
480
+ "normalized": false,
481
+ "rstrip": false,
482
+ "single_word": false,
483
+ "special": true
484
+ },
485
+ "100316": {
486
+ "content": "<|unused_47|>",
487
+ "lstrip": false,
488
+ "normalized": false,
489
+ "rstrip": false,
490
+ "single_word": false,
491
+ "special": true
492
+ },
493
+ "100317": {
494
+ "content": "<|unused_48|>",
495
+ "lstrip": false,
496
+ "normalized": false,
497
+ "rstrip": false,
498
+ "single_word": false,
499
+ "special": true
500
+ },
501
+ "100318": {
502
+ "content": "<|unused_49|>",
503
+ "lstrip": false,
504
+ "normalized": false,
505
+ "rstrip": false,
506
+ "single_word": false,
507
+ "special": true
508
+ },
509
+ "100319": {
510
+ "content": "<|unused_50|>",
511
+ "lstrip": false,
512
+ "normalized": false,
513
+ "rstrip": false,
514
+ "single_word": false,
515
+ "special": true
516
+ },
517
+ "100320": {
518
+ "content": "<|unused_51|>",
519
+ "lstrip": false,
520
+ "normalized": false,
521
+ "rstrip": false,
522
+ "single_word": false,
523
+ "special": true
524
+ },
525
+ "100321": {
526
+ "content": "<|unused_52|>",
527
+ "lstrip": false,
528
+ "normalized": false,
529
+ "rstrip": false,
530
+ "single_word": false,
531
+ "special": true
532
+ },
533
+ "100322": {
534
+ "content": "<|unused_53|>",
535
+ "lstrip": false,
536
+ "normalized": false,
537
+ "rstrip": false,
538
+ "single_word": false,
539
+ "special": true
540
+ },
541
+ "100323": {
542
+ "content": "<|unused_54|>",
543
+ "lstrip": false,
544
+ "normalized": false,
545
+ "rstrip": false,
546
+ "single_word": false,
547
+ "special": true
548
+ },
549
+ "100324": {
550
+ "content": "<|unused_55|>",
551
+ "lstrip": false,
552
+ "normalized": false,
553
+ "rstrip": false,
554
+ "single_word": false,
555
+ "special": true
556
+ },
557
+ "100325": {
558
+ "content": "<|unused_56|>",
559
+ "lstrip": false,
560
+ "normalized": false,
561
+ "rstrip": false,
562
+ "single_word": false,
563
+ "special": true
564
+ },
565
+ "100326": {
566
+ "content": "<|unused_57|>",
567
+ "lstrip": false,
568
+ "normalized": false,
569
+ "rstrip": false,
570
+ "single_word": false,
571
+ "special": true
572
+ },
573
+ "100327": {
574
+ "content": "<|unused_58|>",
575
+ "lstrip": false,
576
+ "normalized": false,
577
+ "rstrip": false,
578
+ "single_word": false,
579
+ "special": true
580
+ },
581
+ "100328": {
582
+ "content": "<|unused_59|>",
583
+ "lstrip": false,
584
+ "normalized": false,
585
+ "rstrip": false,
586
+ "single_word": false,
587
+ "special": true
588
+ },
589
+ "100329": {
590
+ "content": "<|unused_60|>",
591
+ "lstrip": false,
592
+ "normalized": false,
593
+ "rstrip": false,
594
+ "single_word": false,
595
+ "special": true
596
+ },
597
+ "100330": {
598
+ "content": "<|unused_61|>",
599
+ "lstrip": false,
600
+ "normalized": false,
601
+ "rstrip": false,
602
+ "single_word": false,
603
+ "special": true
604
+ },
605
+ "100331": {
606
+ "content": "<|unused_62|>",
607
+ "lstrip": false,
608
+ "normalized": false,
609
+ "rstrip": false,
610
+ "single_word": false,
611
+ "special": true
612
+ },
613
+ "100332": {
614
+ "content": "<|unused_63|>",
615
+ "lstrip": false,
616
+ "normalized": false,
617
+ "rstrip": false,
618
+ "single_word": false,
619
+ "special": true
620
+ },
621
+ "100333": {
622
+ "content": "<|unused_64|>",
623
+ "lstrip": false,
624
+ "normalized": false,
625
+ "rstrip": false,
626
+ "single_word": false,
627
+ "special": true
628
+ },
629
+ "100334": {
630
+ "content": "<|unused_65|>",
631
+ "lstrip": false,
632
+ "normalized": false,
633
+ "rstrip": false,
634
+ "single_word": false,
635
+ "special": true
636
+ },
637
+ "100335": {
638
+ "content": "<|unused_66|>",
639
+ "lstrip": false,
640
+ "normalized": false,
641
+ "rstrip": false,
642
+ "single_word": false,
643
+ "special": true
644
+ },
645
+ "100336": {
646
+ "content": "<|unused_67|>",
647
+ "lstrip": false,
648
+ "normalized": false,
649
+ "rstrip": false,
650
+ "single_word": false,
651
+ "special": true
652
+ },
653
+ "100337": {
654
+ "content": "<|unused_68|>",
655
+ "lstrip": false,
656
+ "normalized": false,
657
+ "rstrip": false,
658
+ "single_word": false,
659
+ "special": true
660
+ },
661
+ "100338": {
662
+ "content": "<|unused_69|>",
663
+ "lstrip": false,
664
+ "normalized": false,
665
+ "rstrip": false,
666
+ "single_word": false,
667
+ "special": true
668
+ },
669
+ "100339": {
670
+ "content": "<|unused_70|>",
671
+ "lstrip": false,
672
+ "normalized": false,
673
+ "rstrip": false,
674
+ "single_word": false,
675
+ "special": true
676
+ },
677
+ "100340": {
678
+ "content": "<|unused_71|>",
679
+ "lstrip": false,
680
+ "normalized": false,
681
+ "rstrip": false,
682
+ "single_word": false,
683
+ "special": true
684
+ },
685
+ "100341": {
686
+ "content": "<|unused_72|>",
687
+ "lstrip": false,
688
+ "normalized": false,
689
+ "rstrip": false,
690
+ "single_word": false,
691
+ "special": true
692
+ },
693
+ "100342": {
694
+ "content": "<|unused_73|>",
695
+ "lstrip": false,
696
+ "normalized": false,
697
+ "rstrip": false,
698
+ "single_word": false,
699
+ "special": true
700
+ },
701
+ "100343": {
702
+ "content": "<|unused_74|>",
703
+ "lstrip": false,
704
+ "normalized": false,
705
+ "rstrip": false,
706
+ "single_word": false,
707
+ "special": true
708
+ },
709
+ "100344": {
710
+ "content": "<|unused_75|>",
711
+ "lstrip": false,
712
+ "normalized": false,
713
+ "rstrip": false,
714
+ "single_word": false,
715
+ "special": true
716
+ },
717
+ "100345": {
718
+ "content": "<|unused_76|>",
719
+ "lstrip": false,
720
+ "normalized": false,
721
+ "rstrip": false,
722
+ "single_word": false,
723
+ "special": true
724
+ },
725
+ "100346": {
726
+ "content": "<|unused_77|>",
727
+ "lstrip": false,
728
+ "normalized": false,
729
+ "rstrip": false,
730
+ "single_word": false,
731
+ "special": true
732
+ },
733
+ "100347": {
734
+ "content": "<|unused_78|>",
735
+ "lstrip": false,
736
+ "normalized": false,
737
+ "rstrip": false,
738
+ "single_word": false,
739
+ "special": true
740
+ },
741
+ "100348": {
742
+ "content": "<|unused_79|>",
743
+ "lstrip": false,
744
+ "normalized": false,
745
+ "rstrip": false,
746
+ "single_word": false,
747
+ "special": true
748
+ },
749
+ "100349": {
750
+ "content": "<|unused_80|>",
751
+ "lstrip": false,
752
+ "normalized": false,
753
+ "rstrip": false,
754
+ "single_word": false,
755
+ "special": true
756
+ },
757
+ "100350": {
758
+ "content": "<|unused_81|>",
759
+ "lstrip": false,
760
+ "normalized": false,
761
+ "rstrip": false,
762
+ "single_word": false,
763
+ "special": true
764
+ },
765
+ "100351": {
766
+ "content": "<|unused_82|>",
767
+ "lstrip": false,
768
+ "normalized": false,
769
+ "rstrip": false,
770
+ "single_word": false,
771
+ "special": true
772
+ },
773
+ "100352": {
774
+ "content": "<|audio|>",
775
+ "lstrip": false,
776
+ "normalized": false,
777
+ "rstrip": false,
778
+ "single_word": false,
779
+ "special": true
780
+ }
781
+ },
782
+ "bos_token": "<|end_of_text|>",
783
+ "clean_up_tokenization_spaces": false,
784
+ "eos_token": "<|end_of_text|>",
785
+ "extra_special_tokens": {},
786
+ "model_max_length": 1000000000000000019884624838656,
787
+ "pad_token": "<|pad|>",
788
+ "padding_side": "left",
789
+ "processor_class": "GraniteSpeechProcessor",
790
+ "tokenizer_class": "GPT2Tokenizer",
791
+ "unk_token": "<|unk|>"
792
+ }