Sentence Similarity
MLX
Safetensors
lfm2
lfm2.5
ColBERT
late-interaction
feature-extraction
retrieval
custom_code
Instructions to use ronaldmannak/LFM2.5-ColBERT-350M-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use ronaldmannak/LFM2.5-ColBERT-350M-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir LFM2.5-ColBERT-350M-bf16 ronaldmannak/LFM2.5-ColBERT-350M-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """MLX port of LiquidAI's LFM2.5 *bidirectional* (encoder) backbone + retrieval heads. | |
| This is the encoder variant used by: | |
| - LFM2.5-Embedding-350M (CLS pooling -> 1024-d sentence vector, cosine sim) | |
| - LFM2.5-ColBERT-350M (Dense 1024->128 per-token vectors, MaxSim) | |
| It is the LFM2.5-350M-Base hybrid backbone (short-conv + GQA attention layers, | |
| SwiGLU MLP, RMSNorm) with three encoder patches relative to the causal LFM2: | |
| 1. attention is bidirectional (no causal mask; pad-only mask), | |
| 2. the short conv is non-causal / centered (symmetric padding = kernel//2), | |
| 3. no LM head; a pooling/projection head is used instead. | |
| Ported from mlx-lm's `models/lfm2.py` (causal) — kept dependency-free so it can | |
| be dropped into any MLX project. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| class ModelArgs: | |
| vocab_size: int | |
| hidden_size: int | |
| num_hidden_layers: int | |
| num_attention_heads: int | |
| num_key_value_heads: int | |
| norm_eps: float | |
| conv_bias: bool | |
| conv_L_cache: int | |
| block_ff_dim: int | |
| block_multiple_of: int | |
| block_ffn_dim_multiplier: float | |
| block_auto_adjust_ff_dim: bool | |
| rope_theta: float | |
| layer_types: List[str] | |
| model_type: str = "lfm2" | |
| def from_dict(cls, d: dict) -> "ModelArgs": | |
| theta = d.get("rope_theta") | |
| if theta is None: | |
| theta = d.get("rope_parameters", {}).get("rope_theta", 1000000.0) | |
| return cls( | |
| vocab_size=d["vocab_size"], | |
| hidden_size=d["hidden_size"], | |
| num_hidden_layers=d["num_hidden_layers"], | |
| num_attention_heads=d["num_attention_heads"], | |
| num_key_value_heads=d.get("num_key_value_heads", d["num_attention_heads"]), | |
| norm_eps=d.get("norm_eps", d.get("block_norm_eps", 1e-5)), | |
| conv_bias=d.get("conv_bias", False), | |
| conv_L_cache=d.get("conv_L_cache", 3), | |
| block_ff_dim=d.get("block_ff_dim", d.get("intermediate_size")), | |
| block_multiple_of=d.get("block_multiple_of", 256), | |
| block_ffn_dim_multiplier=d.get("block_ffn_dim_multiplier", 1.0), | |
| block_auto_adjust_ff_dim=d.get("block_auto_adjust_ff_dim", True), | |
| rope_theta=theta, | |
| layer_types=d["layer_types"], | |
| model_type=d.get("model_type", "lfm2"), | |
| ) | |
| def attn_layer_idxs(self) -> List[int]: | |
| return [i for i, t in enumerate(self.layer_types) if t == "full_attention"] | |
| class Attention(nn.Module): | |
| """GQA attention with per-head q/k RMSNorm and RoPE. Non-causal.""" | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| dim = args.hidden_size | |
| self.n_heads = args.num_attention_heads | |
| self.n_kv_heads = args.num_key_value_heads | |
| self.head_dim = dim // self.n_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) | |
| self.out_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False) | |
| self.q_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps) | |
| self.k_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps) | |
| self.rope = nn.RoPE(self.head_dim, base=args.rope_theta, traditional=False) | |
| def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: | |
| B, L, _ = x.shape | |
| q = self.q_layernorm(self.q_proj(x).reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3) | |
| k = self.k_layernorm(self.k_proj(x).reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3) | |
| v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
| q = self.rope(q) | |
| k = self.rope(k) | |
| out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) | |
| out = out.transpose(0, 2, 1, 3).reshape(B, L, -1) | |
| return self.out_proj(out) | |
| class ShortConv(nn.Module): | |
| """Non-causal gated short convolution (centered, symmetric padding).""" | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| self.L_cache = args.conv_L_cache | |
| bias = args.conv_bias | |
| self.conv = nn.Conv1d( | |
| in_channels=args.hidden_size, | |
| out_channels=args.hidden_size, | |
| kernel_size=self.L_cache, | |
| groups=args.hidden_size, | |
| padding=self.L_cache // 2, # centered / non-causal | |
| bias=bias, | |
| ) | |
| self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=bias) | |
| self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=bias) | |
| def __call__(self, x: mx.array, keep: Optional[mx.array] = None) -> mx.array: | |
| B, C, x = mx.split(self.in_proj(x), 3, axis=-1) | |
| Bx = B * x | |
| if keep is not None: # zero padded positions so they don't leak into the conv | |
| Bx = Bx * keep[..., None] | |
| conv_out = self.conv(Bx) | |
| # odd kernel + symmetric padding keeps length == L, but guard anyway | |
| if conv_out.shape[1] != Bx.shape[1]: | |
| conv_out = conv_out[:, : Bx.shape[1], :] | |
| return self.out_proj(C * conv_out) | |
| class MLP(nn.Module): | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| ff_dim = args.block_ff_dim | |
| if args.block_auto_adjust_ff_dim: | |
| ff_dim = int(2 * ff_dim / 3) | |
| if args.block_ffn_dim_multiplier is not None: | |
| ff_dim = int(args.block_ffn_dim_multiplier * ff_dim) | |
| m = args.block_multiple_of | |
| ff_dim = m * ((ff_dim + m - 1) // m) | |
| dim = args.hidden_size | |
| self.w1 = nn.Linear(dim, ff_dim, bias=False) | |
| self.w3 = nn.Linear(dim, ff_dim, bias=False) | |
| self.w2 = nn.Linear(ff_dim, dim, bias=False) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return self.w2(nn.silu(self.w1(x)) * self.w3(x)) | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, args: ModelArgs, layer_idx: int): | |
| super().__init__() | |
| self.is_attention = layer_idx in args.attn_layer_idxs | |
| if self.is_attention: | |
| self.self_attn = Attention(args) | |
| else: | |
| self.conv = ShortConv(args) | |
| self.feed_forward = MLP(args) | |
| self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) | |
| self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) | |
| def __call__(self, x, attn_mask=None, keep=None): | |
| if self.is_attention: | |
| r = self.self_attn(self.operator_norm(x), mask=attn_mask) | |
| else: | |
| r = self.conv(self.operator_norm(x), keep=keep) | |
| h = x + r | |
| return h + self.feed_forward(self.ffn_norm(h)) | |
| class Lfm2Backbone(nn.Module): | |
| """Token ids -> last_hidden_state (post embedding_norm).""" | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| self.args = args | |
| self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | |
| self.layers = [DecoderLayer(args, i) for i in range(args.num_hidden_layers)] | |
| self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps) | |
| def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array: | |
| h = self.embed_tokens(input_ids) | |
| attn_mask = None | |
| keep = None | |
| if attention_mask is not None: | |
| keep = attention_mask.astype(h.dtype) # (B, L) 1=real 0=pad | |
| # additive bidirectional pad mask: (B, 1, 1, L) | |
| neg = mx.array(-1e9, dtype=h.dtype) | |
| attn_mask = mx.where(attention_mask[:, None, None, :] > 0, mx.array(0, h.dtype), neg) | |
| for layer in self.layers: | |
| h = layer(h, attn_mask=attn_mask, keep=keep) | |
| return self.embedding_norm(h) | |
| def _l2_normalize(x: mx.array, axis: int = -1, eps: float = 1e-12) -> mx.array: | |
| return x / mx.maximum(mx.linalg.norm(x, axis=axis, keepdims=True), eps) | |
| class EmbeddingModel(nn.Module): | |
| """LFM2.5-Embedding-350M: CLS-token pooling -> 1024-d sentence vector.""" | |
| pooling = "cls" | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| self.args = args | |
| self.model = Lfm2Backbone(args) | |
| def __call__(self, input_ids, attention_mask=None) -> mx.array: | |
| return self.model(input_ids, attention_mask) | |
| def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array: | |
| lhs = self.model(input_ids, attention_mask) | |
| pooled = lhs[:, 0, :] # CLS == BOS at position 0 (add_bos_token=True) | |
| return _l2_normalize(pooled) if normalize else pooled | |
| class ColbertModel(nn.Module): | |
| """LFM2.5-ColBERT-350M: per-token Dense 1024->128 projection (MaxSim).""" | |
| def __init__(self, args: ModelArgs, proj_dim: int = 128): | |
| super().__init__() | |
| self.args = args | |
| self.model = Lfm2Backbone(args) | |
| self.dense = nn.Linear(args.hidden_size, proj_dim, bias=False) | |
| def __call__(self, input_ids, attention_mask=None) -> mx.array: | |
| return self.dense(self.model(input_ids, attention_mask)) | |
| def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array: | |
| tok = self.dense(self.model(input_ids, attention_mask)) # (B, L, 128) | |
| if normalize: | |
| tok = _l2_normalize(tok, axis=-1) | |
| if attention_mask is not None: | |
| tok = tok * attention_mask[..., None].astype(tok.dtype) | |
| return tok | |
| def sanitize(weights: dict) -> dict: | |
| """Transpose HF depthwise conv weights (O,1,K) -> MLX Conv1d (O,K,1).""" | |
| out = {} | |
| for k, v in weights.items(): | |
| if k.endswith("conv.conv.weight") and v.shape[-1] < v.shape[1]: | |
| # already (O,K,1); leave as is | |
| out[k] = v | |
| elif k.endswith("conv.conv.weight"): | |
| out[k] = v.transpose(0, 2, 1) # (O,1,K) -> (O,K,1) | |
| else: | |
| out[k] = v | |
| return out | |