File size: 10,046 Bytes
f165d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""MLX port of LiquidAI's LFM2.5 *bidirectional* (encoder) backbone + retrieval heads.

This is the encoder variant used by:
  - LFM2.5-Embedding-350M  (CLS pooling -> 1024-d sentence vector, cosine sim)
  - LFM2.5-ColBERT-350M    (Dense 1024->128 per-token vectors, MaxSim)

It is the LFM2.5-350M-Base hybrid backbone (short-conv + GQA attention layers,
SwiGLU MLP, RMSNorm) with three encoder patches relative to the causal LFM2:
  1. attention is bidirectional (no causal mask; pad-only mask),
  2. the short conv is non-causal / centered (symmetric padding = kernel//2),
  3. no LM head; a pooling/projection head is used instead.

Ported from mlx-lm's `models/lfm2.py` (causal) — kept dependency-free so it can
be dropped into any MLX project.
"""

from dataclasses import dataclass
from typing import List, Optional

import mlx.core as mx
import mlx.nn as nn


@dataclass
class ModelArgs:
    vocab_size: int
    hidden_size: int
    num_hidden_layers: int
    num_attention_heads: int
    num_key_value_heads: int
    norm_eps: float
    conv_bias: bool
    conv_L_cache: int
    block_ff_dim: int
    block_multiple_of: int
    block_ffn_dim_multiplier: float
    block_auto_adjust_ff_dim: bool
    rope_theta: float
    layer_types: List[str]
    model_type: str = "lfm2"

    @classmethod
    def from_dict(cls, d: dict) -> "ModelArgs":
        theta = d.get("rope_theta")
        if theta is None:
            theta = d.get("rope_parameters", {}).get("rope_theta", 1000000.0)
        return cls(
            vocab_size=d["vocab_size"],
            hidden_size=d["hidden_size"],
            num_hidden_layers=d["num_hidden_layers"],
            num_attention_heads=d["num_attention_heads"],
            num_key_value_heads=d.get("num_key_value_heads", d["num_attention_heads"]),
            norm_eps=d.get("norm_eps", d.get("block_norm_eps", 1e-5)),
            conv_bias=d.get("conv_bias", False),
            conv_L_cache=d.get("conv_L_cache", 3),
            block_ff_dim=d.get("block_ff_dim", d.get("intermediate_size")),
            block_multiple_of=d.get("block_multiple_of", 256),
            block_ffn_dim_multiplier=d.get("block_ffn_dim_multiplier", 1.0),
            block_auto_adjust_ff_dim=d.get("block_auto_adjust_ff_dim", True),
            rope_theta=theta,
            layer_types=d["layer_types"],
            model_type=d.get("model_type", "lfm2"),
        )

    @property
    def attn_layer_idxs(self) -> List[int]:
        return [i for i, t in enumerate(self.layer_types) if t == "full_attention"]


class Attention(nn.Module):
    """GQA attention with per-head q/k RMSNorm and RoPE. Non-causal."""

    def __init__(self, args: ModelArgs):
        super().__init__()
        dim = args.hidden_size
        self.n_heads = args.num_attention_heads
        self.n_kv_heads = args.num_key_value_heads
        self.head_dim = dim // self.n_heads
        self.scale = self.head_dim**-0.5

        self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)

        self.q_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps)
        self.k_layernorm = nn.RMSNorm(self.head_dim, eps=args.norm_eps)
        self.rope = nn.RoPE(self.head_dim, base=args.rope_theta, traditional=False)

    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
        B, L, _ = x.shape
        q = self.q_layernorm(self.q_proj(x).reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
        k = self.k_layernorm(self.k_proj(x).reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
        v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

        q = self.rope(q)
        k = self.rope(k)

        out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
        out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.out_proj(out)


class ShortConv(nn.Module):
    """Non-causal gated short convolution (centered, symmetric padding)."""

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.L_cache = args.conv_L_cache
        bias = args.conv_bias
        self.conv = nn.Conv1d(
            in_channels=args.hidden_size,
            out_channels=args.hidden_size,
            kernel_size=self.L_cache,
            groups=args.hidden_size,
            padding=self.L_cache // 2,  # centered / non-causal
            bias=bias,
        )
        self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=bias)
        self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=bias)

    def __call__(self, x: mx.array, keep: Optional[mx.array] = None) -> mx.array:
        B, C, x = mx.split(self.in_proj(x), 3, axis=-1)
        Bx = B * x
        if keep is not None:  # zero padded positions so they don't leak into the conv
            Bx = Bx * keep[..., None]
        conv_out = self.conv(Bx)
        # odd kernel + symmetric padding keeps length == L, but guard anyway
        if conv_out.shape[1] != Bx.shape[1]:
            conv_out = conv_out[:, : Bx.shape[1], :]
        return self.out_proj(C * conv_out)


class MLP(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        ff_dim = args.block_ff_dim
        if args.block_auto_adjust_ff_dim:
            ff_dim = int(2 * ff_dim / 3)
            if args.block_ffn_dim_multiplier is not None:
                ff_dim = int(args.block_ffn_dim_multiplier * ff_dim)
            m = args.block_multiple_of
            ff_dim = m * ((ff_dim + m - 1) // m)
        dim = args.hidden_size
        self.w1 = nn.Linear(dim, ff_dim, bias=False)
        self.w3 = nn.Linear(dim, ff_dim, bias=False)
        self.w2 = nn.Linear(ff_dim, dim, bias=False)

    def __call__(self, x: mx.array) -> mx.array:
        return self.w2(nn.silu(self.w1(x)) * self.w3(x))


class DecoderLayer(nn.Module):
    def __init__(self, args: ModelArgs, layer_idx: int):
        super().__init__()
        self.is_attention = layer_idx in args.attn_layer_idxs
        if self.is_attention:
            self.self_attn = Attention(args)
        else:
            self.conv = ShortConv(args)
        self.feed_forward = MLP(args)
        self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
        self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)

    def __call__(self, x, attn_mask=None, keep=None):
        if self.is_attention:
            r = self.self_attn(self.operator_norm(x), mask=attn_mask)
        else:
            r = self.conv(self.operator_norm(x), keep=keep)
        h = x + r
        return h + self.feed_forward(self.ffn_norm(h))


class Lfm2Backbone(nn.Module):
    """Token ids -> last_hidden_state (post embedding_norm)."""

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [DecoderLayer(args, i) for i in range(args.num_hidden_layers)]
        self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)

    def __call__(self, input_ids: mx.array, attention_mask: Optional[mx.array] = None) -> mx.array:
        h = self.embed_tokens(input_ids)

        attn_mask = None
        keep = None
        if attention_mask is not None:
            keep = attention_mask.astype(h.dtype)  # (B, L) 1=real 0=pad
            # additive bidirectional pad mask: (B, 1, 1, L)
            neg = mx.array(-1e9, dtype=h.dtype)
            attn_mask = mx.where(attention_mask[:, None, None, :] > 0, mx.array(0, h.dtype), neg)

        for layer in self.layers:
            h = layer(h, attn_mask=attn_mask, keep=keep)
        return self.embedding_norm(h)


def _l2_normalize(x: mx.array, axis: int = -1, eps: float = 1e-12) -> mx.array:
    return x / mx.maximum(mx.linalg.norm(x, axis=axis, keepdims=True), eps)


class EmbeddingModel(nn.Module):
    """LFM2.5-Embedding-350M: CLS-token pooling -> 1024-d sentence vector."""

    pooling = "cls"

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.model = Lfm2Backbone(args)

    def __call__(self, input_ids, attention_mask=None) -> mx.array:
        return self.model(input_ids, attention_mask)

    def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array:
        lhs = self.model(input_ids, attention_mask)
        pooled = lhs[:, 0, :]  # CLS == BOS at position 0 (add_bos_token=True)
        return _l2_normalize(pooled) if normalize else pooled


class ColbertModel(nn.Module):
    """LFM2.5-ColBERT-350M: per-token Dense 1024->128 projection (MaxSim)."""

    def __init__(self, args: ModelArgs, proj_dim: int = 128):
        super().__init__()
        self.args = args
        self.model = Lfm2Backbone(args)
        self.dense = nn.Linear(args.hidden_size, proj_dim, bias=False)

    def __call__(self, input_ids, attention_mask=None) -> mx.array:
        return self.dense(self.model(input_ids, attention_mask))

    def encode(self, input_ids, attention_mask=None, normalize: bool = True) -> mx.array:
        tok = self.dense(self.model(input_ids, attention_mask))  # (B, L, 128)
        if normalize:
            tok = _l2_normalize(tok, axis=-1)
        if attention_mask is not None:
            tok = tok * attention_mask[..., None].astype(tok.dtype)
        return tok


def sanitize(weights: dict) -> dict:
    """Transpose HF depthwise conv weights (O,1,K) -> MLX Conv1d (O,K,1)."""
    out = {}
    for k, v in weights.items():
        if k.endswith("conv.conv.weight") and v.shape[-1] < v.shape[1]:
            # already (O,K,1); leave as is
            out[k] = v
        elif k.endswith("conv.conv.weight"):
            out[k] = v.transpose(0, 2, 1)  # (O,1,K) -> (O,K,1)
        else:
            out[k] = v
    return out