diff --git a/scripts/convert_parakeet_to_gguf.py b/scripts/convert_parakeet_to_gguf.py index 9e2462d..a41f368 100644 --- a/scripts/convert_parakeet_to_gguf.py +++ b/scripts/convert_parakeet_to_gguf.py @@ -2,10 +2,12 @@ """Convert a NeMo Parakeet checkpoint to GGUF (f32 / f16 / q8_0). The GGUF is fully metadata-driven: all config lives in KV, and tensor names are -kept **verbatim** from the NeMo ``state_dict`` (no renaming) so the C++ port is a -1:1 mapping. The two featurizer buffers (``preprocessor.featurizer.fb`` and -``preprocessor.featurizer.window``) are lifted directly from the checkpoint so the -C++ side never re-derives the mel filterbank with librosa. +kept **verbatim** from the NeMo ``state_dict`` for upstream Parakeet checkpoints. +Omi Med STT adapter tensors are the one exception: they are written with compact +``omi_adapter`` names because the C GGUF reader rejects tensor names >=64 bytes. +The two featurizer buffers (``preprocessor.featurizer.fb`` and +``preprocessor.featurizer.window``) are lifted directly from the checkpoint so +the C++ side never re-derives the mel filterbank with librosa. Quantization (``--dtype f16|q8_0``) is applied **only** to the large linear weights that the C++ engine consumes directly via ``ggml_mul_mat`` (the encoder @@ -120,6 +122,47 @@ def should_quantize(name, shape, dtype): return None +_OMI_ADAPTER_RE = re.compile( + r"^encoder\.layers\.(\d+)\.adapter_layer\.medical_v1d_rank128" + r"\.module\.(0|1|3)\.(weight|bias)$" +) + + +def compact_omi_adapter_name(name): + """Return the GGUF tensor name for Omi's post-Conformer adapter. + + NeMo stores names like + ``encoder.layers.0.adapter_layer.medical_v1d_rank128.module.0.weight``. + Those exceed the C GGUF tensor-name limit used by parakeet.cpp, so Omi's + adapter extension writes them as compact names and the C++ runtime looks up + this compact schema. + """ + m = _OMI_ADAPTER_RE.match(name) + if not m: + return name + layer, module, suffix = m.groups() + if module == "0": + part = f"norm.{suffix}" + elif module == "1" and suffix == "weight": + part = "down.weight" + elif module == "3" and suffix == "weight": + part = "up.weight" + else: + return name + return f"encoder.layers.{layer}.omi_adapter.{part}" + + +def detect_omi_adapter(sd): + down_key = "encoder.layers.0.adapter_layer.medical_v1d_rank128.module.1.weight" + if not any(".adapter_layer.medical_v1d_rank128." in name for name in sd): + return False, 0 + rank = 0 + t = sd.get(down_key) + if t is not None and hasattr(t, "shape") and len(t.shape) >= 1: + rank = int(t.shape[0]) + return True, rank + + def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True, help="HF id or local .nemo") @@ -147,6 +190,7 @@ def main(): cfg = m.cfg enc = cfg.encoder feat = m.preprocessor.featurizer # effective runtime values live here + sd = m.state_dict() w = gguf.GGUFWriter(args.output, "parakeet") w.add_string("general.name", args.model) @@ -170,6 +214,19 @@ def main(): w.add_uint32("parakeet.encoder.pos_emb_max_len", int(_get(enc, "pos_emb_max_len", 5000))) + # Optional Omi Med STT post-Conformer adapter. This is absent from NVIDIA + # Parakeet checkpoints and present in Omi's H4/v1 checkpoint. We detect it + # from the state_dict because NeMo adapter metadata is not guaranteed to + # expose a simple encoder.medical_adapter_rank config value after restore. + adapter_rank = int(_get(enc, "medical_adapter_rank", 0) or 0) + has_omi_adapter, inferred_adapter_rank = detect_omi_adapter(sd) + if adapter_rank <= 0: + adapter_rank = inferred_adapter_rank + if has_omi_adapter and adapter_rank > 0: + w.add_bool("parakeet.omi_med_adapter.enabled", True) + w.add_uint32("parakeet.omi_med_adapter.rank", adapter_rank) + w.add_string("parakeet.omi_med_adapter.name", "omi_adapter") + # --- Cache-aware streaming / causal config (Phase 5) --------------------- # These KVs describe the chunked-limited attention + causal conv that the # streaming FastConformer (e.g. parakeet_realtime_eou_120m-v1) uses. They are @@ -270,10 +327,10 @@ def main(): ) w.add_array("parakeet.tdt.durations", [int(d) for d in durs]) - # tensors: verbatim names. Allowlisted linear weights are quantized per - # --dtype (ggml dequantizes them on the fly inside ggml_mul_mat); everything - # else stays f32. Include featurizer buffers explicitly. - sd = m.state_dict() + # tensors: verbatim names except Omi adapter compact aliases. Allowlisted + # linear weights are quantized per --dtype (ggml dequantizes them on the fly + # inside ggml_mul_mat); everything else stays f32. Include featurizer buffers + # explicitly. written = 0 quantized = 0 keep_buffers = {"preprocessor.featurizer.fb", "preprocessor.featurizer.window"} @@ -289,14 +346,15 @@ def main(): # ggml ne is the reverse of the numpy/torch shape; ne[0] is the leading # (contraction) axis q8_0 blocks along. ggml_ne = list(arr.shape[::-1]) + out_name = compact_omi_adapter_name(name) qtype = should_quantize(name, ggml_ne, args.dtype) if qtype is None: - w.add_tensor(name, arr) + w.add_tensor(out_name, arr) else: raw = gguf.quantize(arr, qtype) # gguf expects raw_shape to be the *byte* shape of the quantized # buffer; it derives the element shape from it via raw_dtype. - w.add_tensor(name, raw, raw_shape=raw.shape, raw_dtype=qtype) + w.add_tensor(out_name, raw, raw_shape=raw.shape, raw_dtype=qtype) quantized += 1 written += 1 diff --git a/src/conformer.cpp b/src/conformer.cpp index 8ef6645..8e80a15 100644 --- a/src/conformer.cpp +++ b/src/conformer.cpp @@ -276,6 +276,8 @@ ConformerLayer::ConformerLayer(const ModelLoader& ml, int layer_idx) conv_kernel_ = (int)ml.config().conv_kernel; conv_norm_type_ = ml.config().conv_norm_type; conv_causal_ = ml.config().conv_causal; + omi_med_adapter_ = ml.config().omi_med_adapter; + omi_med_adapter_name_ = ml.config().omi_med_adapter_name; assert((conv_norm_type_ == "batch_norm" || conv_norm_type_ == "layer_norm") && "ConformerLayer supports conv_norm_type in {batch_norm, layer_norm}"); assert(n_heads_ > 0 && d_model_ % n_heads_ == 0); @@ -322,6 +324,21 @@ ggml_tensor* ConformerLayer::build_graph_batched(ggml_context* ctx, h = linear(h, ff + ".linear2", /*bias*/true); // [D, T, B] return h; }; + auto omi_med_adapter = [&](ggml_tensor* in) { + if (!omi_med_adapter_) return in; + const std::string ap = pre + "omi_adapter."; + ggml_tensor* g = clone_weight(ctx, ml, ap + "norm.weight"); + ggml_tensor* b = clone_weight(ctx, ml, ap + "norm.bias"); + ggml_tensor* w_down = clone_weight(ctx, ml, ap + "down.weight"); + ggml_tensor* w_up = clone_weight(ctx, ml, ap + "up.weight"); + ggml_tensor* y = ggml_norm(ctx, in, ln_eps); + y = ggml_mul(ctx, y, g); + y = ggml_add(ctx, y, b); + y = ggml_mul_mat(ctx, w_down, y); + y = ggml_silu(ctx, y); + y = ggml_mul_mat(ctx, w_up, y); + return ggml_add(ctx, in, y); + }; // === Stage A: r = x + 0.5 * FFN1(norm_ff1(x)). === ggml_tensor* h1 = layer_norm(xt, "norm_feed_forward1"); @@ -349,6 +366,7 @@ ggml_tensor* ConformerLayer::build_graph_batched(ggml_context* ctx, h2 = ggml_scale(ctx, h2, 0.5f); r = ggml_add(ctx, r, h2); r = layer_norm(r, "norm_out"); + r = omi_med_adapter(r); return r; // [D, T, B] -> per item row-major [T, D] } @@ -394,6 +412,21 @@ ggml_tensor* ConformerLayer::build_graph(ggml_context* ctx, ggml_tensor* xt, h = linear(h, ff + ".linear2", /*bias*/true); // [D, T] return h; }; + auto omi_med_adapter = [&](ggml_tensor* in) { + if (!omi_med_adapter_) return in; + const std::string ap = pre + "omi_adapter."; + ggml_tensor* g = clone_weight(ctx, ml, ap + "norm.weight"); + ggml_tensor* b = clone_weight(ctx, ml, ap + "norm.bias"); + ggml_tensor* w_down = clone_weight(ctx, ml, ap + "down.weight"); + ggml_tensor* w_up = clone_weight(ctx, ml, ap + "up.weight"); + ggml_tensor* y = ggml_norm(ctx, in, ln_eps); + y = ggml_mul(ctx, y, g); + y = ggml_add(ctx, y, b); + y = ggml_mul_mat(ctx, w_down, y); + y = ggml_silu(ctx, y); + y = ggml_mul_mat(ctx, w_up, y); + return ggml_add(ctx, in, y); + }; // === Stage A: r = x + 0.5 * FFN1(norm_ff1(x)). === ggml_tensor* h1 = layer_norm(xt, "norm_feed_forward1"); @@ -420,6 +453,7 @@ ggml_tensor* ConformerLayer::build_graph(ggml_context* ctx, ggml_tensor* xt, h2 = ggml_scale(ctx, h2, 0.5f); r = ggml_add(ctx, r, h2); r = layer_norm(r, "norm_out"); + r = omi_med_adapter(r); return r; // [D, T] -> row-major [T, D] } diff --git a/src/conformer.hpp b/src/conformer.hpp index 23402fb..d6a07f6 100644 --- a/src/conformer.hpp +++ b/src/conformer.hpp @@ -99,6 +99,8 @@ private: int conv_kernel_; std::string conv_norm_type_; // "batch_norm" (offline) or "layer_norm" (streaming) bool conv_causal_ = false; // causal depthwise conv pad (left k-1, right 0) + bool omi_med_adapter_ = false; + std::string omi_med_adapter_name_; }; } // namespace pk diff --git a/src/model_loader.cpp b/src/model_loader.cpp index 218dc91..3ad5df5 100644 --- a/src/model_loader.cpp +++ b/src/model_loader.cpp @@ -112,6 +112,10 @@ bool ModelLoader::load(const std::string& path){ cfg_.subsampling_conv_channels = kv_u32(gguf_, "parakeet.encoder.subsampling_conv_channels"); cfg_.xscaling = kv_bool(gguf_, "parakeet.encoder.xscaling", true); cfg_.pos_emb_max_len = kv_u32(gguf_, "parakeet.encoder.pos_emb_max_len", 5000); + cfg_.omi_med_adapter = kv_bool(gguf_, "parakeet.omi_med_adapter.enabled", false); + cfg_.omi_med_adapter_rank = kv_u32(gguf_, "parakeet.omi_med_adapter.rank", 0); + cfg_.omi_med_adapter_name = kv_str( + gguf_, "parakeet.omi_med_adapter.name", "medical_v1d_rank128"); // cache-aware streaming / causal config (Phase 5). Absent for offline models // -> offline-safe defaults (regular style, no causal, streaming.present=false). cfg_.att_context_left = kv_i32(gguf_, "parakeet.encoder.att_context_left", -1); diff --git a/src/model_loader.hpp b/src/model_loader.hpp index 9947bd1..87be483 100644 --- a/src/model_loader.hpp +++ b/src/model_loader.hpp @@ -31,6 +31,12 @@ struct ParakeetConfig { std::string conv_norm_type; uint32_t subsampling_factor=0, subsampling_conv_channels=0, pos_emb_max_len=5000; bool xscaling=true; + // Optional Omi Med STT post-Conformer adapter. Absent/false for upstream + // Parakeet checkpoints; enabled only when the GGUF declares it and carries + // the adapter tensors. + bool omi_med_adapter=false; + uint32_t omi_med_adapter_rank=0; + std::string omi_med_adapter_name; // cache-aware streaming / causal config (Phase 5; offline-safe defaults) int32_t att_context_left=-1, att_context_right=-1; // [-1,-1] = full context std::string att_context_style="regular"; // or "chunked_limited"