"""MLX port of LiquidAI's LFM2.5 *bidirectional* (encoder) backbone + retrieval heads. This is the encoder variant used by: - LFM2.5-Embedding-350M (CLS pooling -> 1024-d sentence vector, cosine sim) - LFM2.5-ColBERT-350M (Dense 1024->128 per-token vectors, MaxSim) It is the LFM2.5-350M-Base hybrid backbone (short-conv + GQA attention layers, SwiGLU MLP, RMSNorm) with three encoder patches relative to the causal LFM2: 1. attention is bidirectional (no causal mask; pad-only mask), 2. the short conv is non-causal / centered (symmetric padding = kernel//2), 3. no LM head; a pooling/projection head is used instead. Ported from mlx-lm's `models/lfm2.py` (causal) — kept dependency-free so it can be dropped into any MLX project. """ from dataclasses import dataclass from typing import List, Optional import mlx.core as mx import mlx.nn as nn @dataclass class ModelArgs: vocab_size: int hidden_size: int num_hidden_layers: int num_attention_heads: int num_key_value_heads: int norm_eps: float conv_bias: bool conv_L_cache: int block_ff_dim: int block_multiple_of: int block_ffn_dim_multiplier: float block_auto_adjust_ff_dim: bool rope_theta: float layer_types: List[str] model_type: str = "lfm2" @classmethod def from_dict(cls, d: dict) -> "ModelArgs": theta = d.get("rope_theta") if theta is None: theta = d.get("rope_parameters", {}).get("rope_theta", 1000000.0) return cls( vocab_size=d["vocab_size"], hidden_size=d["hidden_size"], num_hidden_layers=d["num_hidden_layers"], num_attention_heads=d["num_attention_heads"], num_key_value_heads=d.get("num_key_value_heads", d["num_attention_heads"]), norm_eps=d.get("norm_eps", d.get("block_norm_eps", 1e-5)), conv_bias=d.get("conv_bias", False), conv_L_cache=d.get("conv_L_cache", 3), block_ff_dim=d.get("block_ff_dim", d.get("intermediate_size")), block_multiple_of=d.get("block_multiple_of", 256), block_ffn_dim_multiplier=d.get("block_ffn_dim_multiplier", 1.0), block_auto_adjust_ff_dim=d.get("block_auto_adjust_ff_dim", True), rope_theta=theta, layer_types=d["layer_types"], model_type=d.get("model_type", "lfm2"), ) @property def attn_layer_idxs(self) -> List[int]: return [i for i, t in enumerate(self.layer_types) if t == "full_attention"] class Attention(nn.Module): """GQA attention with per-head q/k RMSNorm and RoPE. Non-causal.""" def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.n_heads = args.num_attention_heads self.n_kv_heads = args.num_key_value_heads self.head_dim = dim // self.n_heads self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False) self.q_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps) self.k_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps) self.rope = nn.RoPE(self.head_dim, base=args.rope_theta, traditional=False) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: B, L, _ = x.shape q = self.q_layernorm(self.q_proj(x).reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3) k = self.k_layernorm(self.k_proj(x).reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3) v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) q = self.rope(q) k = self.rope(k) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) out = out.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(out) class ShortConv(nn.Module): """Non-causal gated short convolution (centered, symmetric padding).""" def __init__(self, args: ModelArgs): super().__init__() self.L_cache = args.conv_L_cache bias = args.conv_bias self.conv = nn.Conv1d( in_channels=args.hidden_size, out_channels=args.hidden_size, kernel_size=self.L_cache, groups=args.hidden_size, padding=self.L_cache // 2, # centered / non-causal bias=bias, ) self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=bias) self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=bias) def __call__(self, x: mx.array, keep: Optional[mx.array] = None) -> mx.array: B, C, x = mx.split(self.in_proj(x), 3, axis=-1) Bx = B * x if keep is not None: # zero padded positions so they don't leak into the conv Bx = Bx * keep[..., None] conv_out = self.conv(Bx) # odd kernel + symmetric padding keeps length == L, but guard anyway if conv_out.shape[1] != Bx.shape[1]: conv_out = conv_out[:, : Bx.shape[1], :] return self.out_proj(C * conv_out) class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() ff_dim = args.block_ff_dim if args.block_auto_adjust_ff_dim: ff_dim = int(2 * ff_dim / 3) if args.block_ffn_dim_multiplier is not None: ff_dim = int(args.block_ffn_dim_multiplier * ff_dim) m = args.block_multiple_of ff_dim = m * ((ff_dim + m - 1) // m) dim = args.hidden_size self.w1 = nn.Linear(dim, ff_dim, bias=False) self.w3 = nn.Linear(dim, ff_dim, bias=False) self.w2 = nn.Linear(ff_dim, dim, bias=False) def __call__(self, x: mx.array) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) class DecoderLayer(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() self.is_attention = layer_idx in args.attn_layer_idxs if self.is_attention: self.self_attn = Attention(args) else: self.conv = ShortConv(args) self.feed_forward = MLP(args) self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) def __call__(self, x, attn_mask=None, keep=None): if self.is_attention: r = self.self_attn(self.operator_norm(x), mask=attn_mask) else: r = self.conv(self.operator_norm(x), keep=keep) h = x + r return h + self.feed_forward(self.ffn_norm(h)) class Lfm2Backbone(nn.Module): """Token ids -> last_hidden_state (post embedding_norm).""" def __init__(self, args: ModelArgs): super().__init__() self.args = args self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [DecoderLayer(args, i) for i in range(args.num_hidden_layers)] self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array: h = self.embed_tokens(input_ids) attn_mask = None keep = None if attention_mask is not None: keep = attention_mask.astype(h.dtype) # (B, L) 1=real 0=pad # additive bidirectional pad mask: (B, 1, 1, L) neg = mx.array(-1e9, dtype=h.dtype) attn_mask = mx.where(attention_mask[:, None, None, :] > 0, mx.array(0, h.dtype), neg) for layer in self.layers: h = layer(h, attn_mask=attn_mask, keep=keep) return self.embedding_norm(h) def _l2_normalize(x: mx.array, axis: int = -1, eps: float = 1e-12) -> mx.array: return x / mx.maximum(mx.linalg.norm(x, axis=axis, keepdims=True), eps) class EmbeddingModel(nn.Module): """LFM2.5-Embedding-350M: CLS-token pooling -> 1024-d sentence vector.""" pooling = "cls" def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model = Lfm2Backbone(args) def __call__(self, input_ids, attention_mask=None) -> mx.array: return self.model(input_ids, attention_mask) def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array: lhs = self.model(input_ids, attention_mask) pooled = lhs[:, 0, :] # CLS == BOS at position 0 (add_bos_token=True) return _l2_normalize(pooled) if normalize else pooled class ColbertModel(nn.Module): """LFM2.5-ColBERT-350M: per-token Dense 1024->128 projection (MaxSim).""" def __init__(self, args: ModelArgs, proj_dim: int = 128): super().__init__() self.args = args self.model = Lfm2Backbone(args) self.dense = nn.Linear(args.hidden_size, proj_dim, bias=False) def __call__(self, input_ids, attention_mask=None) -> mx.array: return self.dense(self.model(input_ids, attention_mask)) def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array: tok = self.dense(self.model(input_ids, attention_mask)) # (B, L, 128) if normalize: tok = _l2_normalize(tok, axis=-1) if attention_mask is not None: tok = tok * attention_mask[..., None].astype(tok.dtype) return tok def sanitize(weights: dict) -> dict: """Transpose HF depthwise conv weights (O,1,K) -> MLX Conv1d (O,K,1).""" out = {} for k, v in weights.items(): if k.endswith("conv.conv.weight") and v.shape[-1] < v.shape[1]: # already (O,K,1); leave as is out[k] = v elif k.endswith("conv.conv.weight"): out[k] = v.transpose(0, 2, 1) # (O,1,K) -> (O,K,1) else: out[k] = v return out