LFM2.5-Embedding-350M-bf16 / lfm2_bidirectional.py
ronaldmannak's picture
Upload LFM2.5 MLX checkpoint (bidirectional encoder)
f165d3d verified
Raw
History Blame Contribute Delete
10 kB
"""MLX port of LiquidAI's LFM2.5 *bidirectional* (encoder) backbone + retrieval heads.
This is the encoder variant used by:
- LFM2.5-Embedding-350M (CLS pooling -> 1024-d sentence vector, cosine sim)
- LFM2.5-ColBERT-350M (Dense 1024->128 per-token vectors, MaxSim)
It is the LFM2.5-350M-Base hybrid backbone (short-conv + GQA attention layers,
SwiGLU MLP, RMSNorm) with three encoder patches relative to the causal LFM2:
1. attention is bidirectional (no causal mask; pad-only mask),
2. the short conv is non-causal / centered (symmetric padding = kernel//2),
3. no LM head; a pooling/projection head is used instead.
Ported from mlx-lm's `models/lfm2.py` (causal) — kept dependency-free so it can
be dropped into any MLX project.
"""
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
@dataclass
class ModelArgs:
vocab_size: int
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
norm_eps: float
conv_bias: bool
conv_L_cache: int
block_ff_dim: int
block_multiple_of: int
block_ffn_dim_multiplier: float
block_auto_adjust_ff_dim: bool
rope_theta: float
layer_types: List[str]
model_type: str = "lfm2"
@classmethod
def from_dict(cls, d: dict) -> "ModelArgs":
theta = d.get("rope_theta")
if theta is None:
theta = d.get("rope_parameters", {}).get("rope_theta", 1000000.0)
return cls(
vocab_size=d["vocab_size"],
hidden_size=d["hidden_size"],
num_hidden_layers=d["num_hidden_layers"],
num_attention_heads=d["num_attention_heads"],
num_key_value_heads=d.get("num_key_value_heads", d["num_attention_heads"]),
norm_eps=d.get("norm_eps", d.get("block_norm_eps", 1e-5)),
conv_bias=d.get("conv_bias", False),
conv_L_cache=d.get("conv_L_cache", 3),
block_ff_dim=d.get("block_ff_dim", d.get("intermediate_size")),
block_multiple_of=d.get("block_multiple_of", 256),
block_ffn_dim_multiplier=d.get("block_ffn_dim_multiplier", 1.0),
block_auto_adjust_ff_dim=d.get("block_auto_adjust_ff_dim", True),
rope_theta=theta,
layer_types=d["layer_types"],
model_type=d.get("model_type", "lfm2"),
)
@property
def attn_layer_idxs(self) -> List[int]:
return [i for i, t in enumerate(self.layer_types) if t == "full_attention"]
class Attention(nn.Module):
"""GQA attention with per-head q/k RMSNorm and RoPE. Non-causal."""
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
self.head_dim = dim // self.n_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps)
self.k_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps)
self.rope = nn.RoPE(self.head_dim, base=args.rope_theta, traditional=False)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
B, L, _ = x.shape
q = self.q_layernorm(self.q_proj(x).reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
k = self.k_layernorm(self.k_proj(x).reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
q = self.rope(q)
k = self.rope(k)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(out)
class ShortConv(nn.Module):
"""Non-causal gated short convolution (centered, symmetric padding)."""
def __init__(self, args: ModelArgs):
super().__init__()
self.L_cache = args.conv_L_cache
bias = args.conv_bias
self.conv = nn.Conv1d(
in_channels=args.hidden_size,
out_channels=args.hidden_size,
kernel_size=self.L_cache,
groups=args.hidden_size,
padding=self.L_cache // 2, # centered / non-causal
bias=bias,
)
self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=bias)
self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=bias)
def __call__(self, x: mx.array, keep: Optional[mx.array] = None) -> mx.array:
B, C, x = mx.split(self.in_proj(x), 3, axis=-1)
Bx = B * x
if keep is not None: # zero padded positions so they don't leak into the conv
Bx = Bx * keep[..., None]
conv_out = self.conv(Bx)
# odd kernel + symmetric padding keeps length == L, but guard anyway
if conv_out.shape[1] != Bx.shape[1]:
conv_out = conv_out[:, : Bx.shape[1], :]
return self.out_proj(C * conv_out)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
ff_dim = args.block_ff_dim
if args.block_auto_adjust_ff_dim:
ff_dim = int(2 * ff_dim / 3)
if args.block_ffn_dim_multiplier is not None:
ff_dim = int(args.block_ffn_dim_multiplier * ff_dim)
m = args.block_multiple_of
ff_dim = m * ((ff_dim + m - 1) // m)
dim = args.hidden_size
self.w1 = nn.Linear(dim, ff_dim, bias=False)
self.w3 = nn.Linear(dim, ff_dim, bias=False)
self.w2 = nn.Linear(ff_dim, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.is_attention = layer_idx in args.attn_layer_idxs
if self.is_attention:
self.self_attn = Attention(args)
else:
self.conv = ShortConv(args)
self.feed_forward = MLP(args)
self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
def __call__(self, x, attn_mask=None, keep=None):
if self.is_attention:
r = self.self_attn(self.operator_norm(x), mask=attn_mask)
else:
r = self.conv(self.operator_norm(x), keep=keep)
h = x + r
return h + self.feed_forward(self.ffn_norm(h))
class Lfm2Backbone(nn.Module):
"""Token ids -> last_hidden_state (post embedding_norm)."""
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [DecoderLayer(args, i) for i in range(args.num_hidden_layers)]
self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array:
h = self.embed_tokens(input_ids)
attn_mask = None
keep = None
if attention_mask is not None:
keep = attention_mask.astype(h.dtype) # (B, L) 1=real 0=pad
# additive bidirectional pad mask: (B, 1, 1, L)
neg = mx.array(-1e9, dtype=h.dtype)
attn_mask = mx.where(attention_mask[:, None, None, :] > 0, mx.array(0, h.dtype), neg)
for layer in self.layers:
h = layer(h, attn_mask=attn_mask, keep=keep)
return self.embedding_norm(h)
def _l2_normalize(x: mx.array, axis: int = -1, eps: float = 1e-12) -> mx.array:
return x / mx.maximum(mx.linalg.norm(x, axis=axis, keepdims=True), eps)
class EmbeddingModel(nn.Module):
"""LFM2.5-Embedding-350M: CLS-token pooling -> 1024-d sentence vector."""
pooling = "cls"
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model = Lfm2Backbone(args)
def __call__(self, input_ids, attention_mask=None) -> mx.array:
return self.model(input_ids, attention_mask)
def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array:
lhs = self.model(input_ids, attention_mask)
pooled = lhs[:, 0, :] # CLS == BOS at position 0 (add_bos_token=True)
return _l2_normalize(pooled) if normalize else pooled
class ColbertModel(nn.Module):
"""LFM2.5-ColBERT-350M: per-token Dense 1024->128 projection (MaxSim)."""
def __init__(self, args: ModelArgs, proj_dim: int = 128):
super().__init__()
self.args = args
self.model = Lfm2Backbone(args)
self.dense = nn.Linear(args.hidden_size, proj_dim, bias=False)
def __call__(self, input_ids, attention_mask=None) -> mx.array:
return self.dense(self.model(input_ids, attention_mask))
def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array:
tok = self.dense(self.model(input_ids, attention_mask)) # (B, L, 128)
if normalize:
tok = _l2_normalize(tok, axis=-1)
if attention_mask is not None:
tok = tok * attention_mask[..., None].astype(tok.dtype)
return tok
def sanitize(weights: dict) -> dict:
"""Transpose HF depthwise conv weights (O,1,K) -> MLX Conv1d (O,K,1)."""
out = {}
for k, v in weights.items():
if k.endswith("conv.conv.weight") and v.shape[-1] < v.shape[1]:
# already (O,K,1); leave as is
out[k] = v
elif k.endswith("conv.conv.weight"):
out[k] = v.transpose(0, 2, 1) # (O,1,K) -> (O,K,1)
else:
out[k] = v
return out