Text Generation
Transformers
Safetensors
PyTorch
nvidia
two-tower
diffusion
mamba
fitsumreda commited on
Commit
947a10f
·
verified ·
1 Parent(s): fbdc5a6

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NemotronHTwoTowerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_nemotron_h.NemotronHConfig",
9
+ "AutoModelForCausalLM": "modeling_nemotron_twotower.NemotronHTwoTowerForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "chunk_size": 128,
13
+ "conv_kernel": 4,
14
+ "dtype": "bfloat16",
15
+ "eos_token_id": 2,
16
+ "expand": 2,
17
+ "head_dim": 128,
18
+ "hidden_dropout": 0.0,
19
+ "hidden_size": 2688,
20
+ "hybrid_override_pattern": "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME",
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 1856,
23
+ "layer_norm_epsilon": 1e-05,
24
+ "mamba_head_dim": 64,
25
+ "mamba_hidden_act": "silu",
26
+ "mamba_num_heads": 64,
27
+ "mamba_proj_bias": false,
28
+ "max_position_embeddings": 262144,
29
+ "mlp_bias": false,
30
+ "mlp_hidden_act": "relu2",
31
+ "model_type": "nemotron_h",
32
+ "moe_intermediate_size": 1856,
33
+ "moe_shared_expert_intermediate_size": 3712,
34
+ "n_group": 1,
35
+ "n_groups": 8,
36
+ "n_routed_experts": 128,
37
+ "n_shared_experts": 1,
38
+ "norm_eps": 1e-05,
39
+ "norm_topk_prob": true,
40
+ "num_attention_heads": 32,
41
+ "num_experts_per_tok": 6,
42
+ "num_hidden_layers": 52,
43
+ "num_key_value_heads": 2,
44
+ "num_logits_to_keep": 1,
45
+ "pad_token_id": 0,
46
+ "partial_rotary_factor": 1.0,
47
+ "rescale_prenorm_residual": true,
48
+ "residual_in_fp32": false,
49
+ "rope_theta": 10000,
50
+ "routed_scaling_factor": 2.5,
51
+ "sliding_window": null,
52
+ "ssm_state_size": 128,
53
+ "tie_word_embeddings": false,
54
+ "time_step_floor": 0.0001,
55
+ "time_step_limit": [
56
+ 0.0,
57
+ "Infinity"
58
+ ],
59
+ "time_step_max": 0.1,
60
+ "time_step_min": 0.001,
61
+ "topk_group": 1,
62
+ "transformers_version": "4.57.1",
63
+ "use_bias": false,
64
+ "use_cache": true,
65
+ "use_conv_bias": true,
66
+ "use_mamba_kernels": true,
67
+ "vocab_size": 131072
68
+ }
configuration_nemotron_h.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """NemotronH model configuration"""
17
+
18
+ import re
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class NemotronHConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a
30
+ NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the NemotronH-v0.1 model.
32
+
33
+ [todo](todo)
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 131072):
41
+ Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`NemotronHModel`]
43
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
44
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
45
+ model has a output word embedding layer.
46
+ hidden_size (`int`, *optional*, defaults to 4096):
47
+ Dimension of the hidden representations.
48
+ intermediate_size (`int`, *optional*, defaults to 21504):
49
+ Dimension of the MLP representations.
50
+ num_hidden_layers (`int`, *optional*, defaults to 52):
51
+ Number of hidden layers in the Transformer encoder.
52
+ hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
53
+ The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP
54
+ num_attention_heads (`int`, *optional*, defaults to 32):
55
+ Number of attention heads for each attention layer in the Transformer encoder.
56
+ head_dim (`int`, *optional*, defaults to 128):
57
+ Dimension of each attention head.
58
+ num_key_value_heads (`int`, *optional*, defaults to 8):
59
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
60
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
61
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
62
+ mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
63
+ The non-linear activation function in the MLP layers.
64
+ attention_bias (`bool`, *optional*, defaults to `False`):
65
+ Whether to use bias in attention layers.
66
+ mlp_bias (`bool`, *optional*, defaults to `False`):
67
+ Whether to use bias in MLP layers.
68
+ use_bias (`bool`, *optional*, defaults to `False`):
69
+ Whether to use bias in the model.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
73
+ The epsilon used by the layer normalization layers.
74
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
75
+ Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
80
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
81
+ integer value, only last `num_logits_to_keep` logits will be calculated.
82
+ pad_token_id (`int`, *optional*, defaults to 0):
83
+ The id of the padding token.
84
+ bos_token_id (`int`, *optional*, defaults to 1):
85
+ The id of the "beginning-of-sequence" token.
86
+ eos_token_id (`int`, *optional*, defaults to 2):
87
+ The id of the "end-of-sequence" token.
88
+ sliding_window (`int`, *optional*, defaults to None):
89
+ Sliding window attention window size.
90
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
91
+ The maximum sequence length that this model might ever be used with.
92
+ attention_dropout (`float`, *optional*, defaults to 0.0):
93
+ The dropout ratio for the attention probabilities.
94
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
95
+ The dropout ratio for the hidden states.
96
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
97
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
98
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device.
99
+ ssm_state_size (`int`, *optional*, defaults to 128):
100
+ The dimension of the mamba state space latents.
101
+ mamba_num_heads (`int`, *optional*, defaults to 128):
102
+ Number of heads in Mamba layers.
103
+ mamba_n_groups (`int`, *optional*, defaults to 8):
104
+ Number of groups in Mamba layers.
105
+ mamba_head_dim (`int`, *optional*, defaults to 64):
106
+ Dimension of each Mamba head.
107
+ mamba_d_conv (`int`, *optional*, defaults to 4):
108
+ The size of the mamba convolution kernel.
109
+ mamba_expand (`int`, *optional*, defaults to 2):
110
+ Expanding factor used to determine the mamba intermediate size.
111
+ mamba_hidden_act (`str`, *optional*, defaults to "silu"):
112
+ The non-linear activation function in the Mamba layers.
113
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
114
+ Minimum value for the time step in Mamba.
115
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
116
+ Maximum value for the time step in Mamba.
117
+ mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
118
+ Limits for the time step in Mamba.
119
+ mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
120
+ Floor value for time step initialization in Mamba.
121
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
122
+ Whether to use bias in the convolution layer of the mamba mixer block.
123
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
124
+ Whether to use bias in the input and output projections of the mamba mixer block.
125
+ mamba_chunk_size (`int`, *optional*, defaults to 256):
126
+ Size of chunks for Mamba processing.
127
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
128
+ Whether to rescale the pre-normalization residual connections.
129
+ """
130
+
131
+ model_type = "nemotron_h"
132
+ keys_to_ignore_at_inference = ["past_key_values"]
133
+
134
+ def __init__(
135
+ self,
136
+ vocab_size=131072,
137
+ tie_word_embeddings=False,
138
+ hidden_size=4096,
139
+ intermediate_size=21504,
140
+ num_hidden_layers=52,
141
+ hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
142
+ num_attention_heads=32,
143
+ head_dim=128,
144
+ num_key_value_heads=8, # nemo: num_query_groups
145
+ mlp_hidden_act="relu2",
146
+ attention_bias=False,
147
+ mlp_bias=False,
148
+ use_bias=False,
149
+ initializer_range=0.02, # nemo: init_method_std
150
+ layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
151
+ residual_in_fp32=False, # Megatron Core default value
152
+ use_cache=True,
153
+ num_logits_to_keep=1,
154
+ pad_token_id=0,
155
+ bos_token_id=1,
156
+ eos_token_id=2,
157
+ sliding_window=None,
158
+ max_position_embeddings=4096,
159
+ attention_dropout=0.0,
160
+ hidden_dropout=0.0, # * ADDED
161
+ use_mamba_kernels=True,
162
+ ssm_state_size=128, # mamba_state_size
163
+ mamba_num_heads=128,
164
+ mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
165
+ mamba_head_dim=64,
166
+ mamba_d_conv=4,
167
+ mamba_expand=2,
168
+ mamba_hidden_act="silu",
169
+ mamba_dt_min=0.001,
170
+ mamba_dt_max=0.1,
171
+ mamba_dt_limit=(0.0, float("inf")),
172
+ mamba_dt_init_floor=1e-4,
173
+ mamba_conv_bias=True,
174
+ mamba_proj_bias=False,
175
+ mamba_chunk_size=128,
176
+ rescale_prenorm_residual=True,
177
+ n_routed_experts=8,
178
+ n_shared_experts=1,
179
+ moe_intermediate_size=7688,
180
+ moe_shared_expert_intermediate_size=7688,
181
+ num_experts_per_tok=2,
182
+ routed_scaling_factor=1.0,
183
+ n_group=1,
184
+ topk_group=1,
185
+ norm_topk_prob=True,
186
+ **kwargs,
187
+ ):
188
+ self.vocab_size = vocab_size
189
+ self.tie_word_embeddings = tie_word_embeddings
190
+ self.hidden_size = hidden_size
191
+ self.intermediate_size = intermediate_size
192
+ self.num_hidden_layers = num_hidden_layers
193
+ self.hybrid_override_pattern = hybrid_override_pattern
194
+ self.num_attention_heads = num_attention_heads
195
+ self.head_dim = head_dim
196
+ self.sliding_window = sliding_window
197
+ self.max_position_embeddings = max_position_embeddings
198
+ self.attention_dropout = attention_dropout
199
+ self.hidden_dropout = hidden_dropout
200
+
201
+ # Validate hybrid_override_pattern
202
+ # M: Mamba2, *: Attention, -: MLP
203
+ assert len(self.hybrid_override_pattern) == self.num_hidden_layers, "hybrid_override_pattern must have the same length as num_hidden_layers"
204
+ assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), "hybrid_override_pattern must only contain characters 'M', '*', or '-'"
205
+
206
+ # for backward compatibility
207
+ if num_key_value_heads is None:
208
+ num_key_value_heads = num_attention_heads
209
+
210
+ self.num_key_value_heads = num_key_value_heads
211
+ self.mlp_hidden_act = mlp_hidden_act
212
+ self.attention_bias = attention_bias
213
+ self.mlp_bias = mlp_bias
214
+ self.use_bias = use_bias
215
+ self.initializer_range = initializer_range
216
+ self.layer_norm_epsilon = layer_norm_epsilon
217
+ self.residual_in_fp32 = residual_in_fp32
218
+
219
+ self.use_cache = use_cache
220
+ self.num_logits_to_keep = num_logits_to_keep
221
+
222
+ self.use_mamba_kernels = use_mamba_kernels
223
+ self.n_groups = mamba_n_groups
224
+ self.mamba_head_dim = mamba_head_dim
225
+ self.ssm_state_size = ssm_state_size
226
+ self.mamba_num_heads = mamba_num_heads
227
+ self.conv_kernel = mamba_d_conv
228
+ self.expand = mamba_expand
229
+ self.mamba_hidden_act = mamba_hidden_act
230
+ self.time_step_min = mamba_dt_min
231
+ self.time_step_max = mamba_dt_max
232
+ self.time_step_limit = mamba_dt_limit
233
+ self.time_step_floor = mamba_dt_init_floor
234
+ self.use_conv_bias = mamba_conv_bias
235
+ self.mamba_proj_bias = mamba_proj_bias
236
+ self.chunk_size = mamba_chunk_size
237
+ self.rescale_prenorm_residual = rescale_prenorm_residual
238
+ self.n_routed_experts = n_routed_experts
239
+ self.n_shared_experts = n_shared_experts
240
+ self.moe_intermediate_size = moe_intermediate_size
241
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
242
+ self.num_experts_per_tok = num_experts_per_tok
243
+ self.routed_scaling_factor = routed_scaling_factor
244
+ self.n_group = n_group
245
+ self.topk_group = topk_group
246
+ self.norm_topk_prob = norm_topk_prob
247
+
248
+ super().__init__(
249
+ pad_token_id=pad_token_id,
250
+ bos_token_id=bos_token_id,
251
+ eos_token_id=eos_token_id,
252
+ tie_word_embeddings=tie_word_embeddings,
253
+ **kwargs,
254
+ )
255
+
256
+ @property
257
+ def layers_block_type(self):
258
+ return [
259
+ "mamba" if self.hybrid_override_pattern[i] == "M" else
260
+ "attention" if self.hybrid_override_pattern[i] == "*" else
261
+ "mlp" if self.hybrid_override_pattern[i] == "-" else "moe"
262
+ for i in range(self.num_hidden_layers)]
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 2,
6
+ 11
7
+ ],
8
+ "pad_token_id": 0,
9
+ "transformers_version": "4.57.1"
10
+ }
inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Two-tower NemotronH inference example.
4
+
5
+ Requires 2 GPUs (118GB total) for full two-tower inference.
6
+ Single GPU works for AR-only mode (context tower only, ~59GB).
7
+
8
+ Usage:
9
+ # Mock-AR (two-tower, 2 GPUs):
10
+ CUDA_VISIBLE_DEVICES=0,1 python inference.py --mode mock_ar
11
+
12
+ # AR (context tower only, 1 GPU):
13
+ python inference.py --mode ar
14
+ """
15
+ import argparse
16
+ import torch
17
+ from pathlib import Path
18
+ from transformers import AutoTokenizer
19
+ from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--prompt", default="France is a country ")
23
+ parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
24
+ parser.add_argument("--max-new-tokens", type=int, default=128)
25
+ parser.add_argument("--mode", choices=["ar", "mock_ar"], default="mock_ar")
26
+ args = parser.parse_args()
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
29
+ model = NemotronHTwoTowerForCausalLM.from_pretrained(
30
+ args.model, torch_dtype=torch.bfloat16, trust_remote_code=True,
31
+ )
32
+
33
+ num_gpus = torch.cuda.device_count()
34
+ if args.mode == "mock_ar" and num_gpus >= 2:
35
+ model.place_towers_on_devices("cuda:0", "cuda:1")
36
+ else:
37
+ model.cuda()
38
+
39
+ model.eval()
40
+ inputs = tokenizer(args.prompt, return_tensors="pt").to(
41
+ next(model.context_tower.parameters()).device
42
+ )
43
+
44
+ if args.mode == "ar":
45
+ outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
46
+ else:
47
+ outputs = model.generate_mock_ar(
48
+ inputs["input_ids"], max_new_tokens=args.max_new_tokens,
49
+ temperature=0.0, eos_token_id=tokenizer.eos_token_id,
50
+ )
51
+
52
+ text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
53
+ print(text)
model-00001-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eff060c55c2f20867acb5d412e63feec6175340227b9b38270ba816334333594
3
+ size 5365738984
model-00002-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a582b4e20845bfbeeb8b80c91624e7f259f65341d97c65ee13622c54764b4071
3
+ size 5350671792
model-00003-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbc86f6550410286924d76898c2c13b395832d2bc3b22eaea1aa4bd3e2e5d382
3
+ size 5363394184
model-00004-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:631cd0694e725599fb45edf52a330a18e8ffc643643cf6bdac4711372c48b515
3
+ size 5363727128
model-00005-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f07b85ed9b03e29beda3afe00a1c479149f4b21f02ecbc7c11456d0641047a88
3
+ size 5360635856
model-00006-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc40b2c8a075dc35e94ecf3cfd47eb94f74ef133bd748cae4b35363c1892aede
3
+ size 5360635856
model-00007-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc3ffecd9f6740c5e0bc746234681cba2db785b4f662fd5614058614b0bf5d6b
3
+ size 5363727128
model-00008-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08aea4ff8ce89227dac514fdd4ae56ad9e3d172d8b1a5fd562f45bfae21da9be
3
+ size 5360635856
model-00009-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75b7c2051ffd617e36bb6de6d76c7fa0604c1bcead866285e5d717f3ce8a5b25
3
+ size 5363727128
model-00010-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07a70b80f104b580b352bc5cf78d3e91cb7fc3eb007d00575ce101db3c4bce01
3
+ size 5360635856
model-00011-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0968dae02686d5cbbbe9974cf7a6716a8d467458b844032e9b43aa063c288da7
3
+ size 5363727128
model-00012-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:540f96c3e2c9efb09663b82e43849676f9152437baa382086367be8a0f3123e1
3
+ size 5365759168
model-00013-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:292e20177704694fd461f3cc598d197589ab59c7ca2bd3de64044257c256a762
3
+ size 5367213064
model-00014-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01671ee86d60241f4ce522201c91ff41c45378912594fed787656e5993c423f3
3
+ size 5350657752
model-00015-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b80332393f7eb7c3f9b2f1a090c6c14d60489fb7f35dd6752a2a8f383e3f9012
3
+ size 5363409592
model-00016-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:662d21b42e92e1441a9860db8bf4b5ea8e4e08278c2fc1dafbbe390dcef23863
3
+ size 5360636392
model-00017-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12242225b337d2022fde58ecde7357bc6ba027290d059d492916da284359fd45
3
+ size 5363727680
model-00018-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ac8a06a4e94e644084e925244d78cc92b99c348e089f5a66009443ed04463e5
3
+ size 5360636400
model-00019-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79456c05a8afdf71041f6013cac5b32ed168ab4c3398fa4c49306582d02f13d4
3
+ size 5360636400
model-00020-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91b0ed65cc415d6c04eeef687a2905c6310cf792290e58ab1427f5043853f6bf
3
+ size 5363727680
model-00021-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fde78a6d35ff5439273c6c20714bc9770759dd6d43827f1fcfe50bc09248278
3
+ size 5360636400
model-00022-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8aec2046033d13504c32c8e06d0f51441b10e07f38bf05a782ae296e00920042
3
+ size 5363727680
model-00023-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ff4d9a4669cf0ae8efaace49a6adba5eea09fd0121d06b8bd806e6d959baf64
3
+ size 5363727680
model-00024-of-00024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaff6c43ac9215f8771656a92b4e5bf4469415eb1666520bd2740b037b17c196
3
+ size 2991704808
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_nemotron_h.py ADDED
@@ -0,0 +1,1739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team.
3
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch NemotronH model."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ import torch.nn.functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import (
32
+ AttentionMaskConverter,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import (
36
+ ModelOutput,
37
+ add_code_sample_docstrings,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ )
42
+ from transformers.utils.import_utils import (
43
+ is_causal_conv1d_available,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ is_mamba_2_ssm_available,
47
+ )
48
+ try:
49
+ from .configuration_nemotron_h import NemotronHConfig
50
+ except ImportError:
51
+ from configuration_nemotron_h import NemotronHConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ # Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH
58
+ # For Mamba2 components Mamba2->NemotronHMamba2
59
+ if is_mamba_2_ssm_available():
60
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
61
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
62
+ else:
63
+ mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None
64
+
65
+ try:
66
+ #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
67
+ from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
68
+ except ImportError:
69
+ raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
70
+
71
+ if is_causal_conv1d_available():
72
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
73
+ else:
74
+ causal_conv1d_update, causal_conv1d_fn = None, None
75
+
76
+ if is_flash_attn_2_available():
77
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
78
+
79
+ is_fast_path_available = all(
80
+ (
81
+ selective_state_update,
82
+ mamba_chunk_scan_combined,
83
+ mamba_split_conv1d_scan_combined,
84
+ causal_conv1d_fn,
85
+ causal_conv1d_update,
86
+ )
87
+ )
88
+
89
+
90
+ _CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K"
91
+ _CONFIG_FOR_DOC = "NemotronHConfig"
92
+
93
+
94
+ # Helper methods for segment sum computation
95
+
96
+
97
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
98
+ """
99
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
100
+
101
+ Assumes that we only have tensors of either size 4 or 3
102
+ """
103
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
104
+
105
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
106
+
107
+
108
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
109
+ """
110
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
111
+ simultaneously splitting it into chunk sequences.
112
+
113
+ Assumes that we only have tensors of either size 4 or 3
114
+ """
115
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
116
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
117
+
118
+ if len(input_tensor.shape) == 3:
119
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
120
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
121
+ else:
122
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
123
+ return input_tensor.reshape(
124
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
125
+ )
126
+
127
+
128
+ def segment_sum(input_tensor):
129
+ """
130
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
131
+ """
132
+ chunk_size = input_tensor.size(-1)
133
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
134
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
135
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
136
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
137
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
138
+ input_tensor = input_tensor.masked_fill(~mask, 0)
139
+ # 3. compute actual cumsum
140
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
141
+
142
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
143
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
144
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
145
+ return tensor_segsum
146
+
147
+
148
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
149
+ """
150
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
151
+ """
152
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
153
+ dtype = hidden_states.dtype
154
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
155
+
156
+ return hidden_states
157
+
158
+ # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
159
+ class HybridMambaAttentionDynamicCache(DynamicCache):
160
+ """
161
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
162
+ (which has a constant shape regardless of seq_len).
163
+
164
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
165
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
166
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
167
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
168
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
169
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
170
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
171
+ """
172
+
173
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
174
+ super().__init__()
175
+ self.dtype = dtype
176
+ self.hybrid_override_pattern = config.hybrid_override_pattern
177
+ self.has_previous_state = False # only used by mamba
178
+ intermediate_size = config.mamba_num_heads * config.mamba_head_dim
179
+ ssm_state_size = config.ssm_state_size
180
+ conv_kernel_size = config.conv_kernel
181
+ self.conv_states = []
182
+ self.ssm_states = []
183
+ self.transformer_layers = []
184
+ for i in range(config.num_hidden_layers):
185
+ if self.hybrid_override_pattern[i] == "M":
186
+ # Mamba layer
187
+ self.conv_states += [
188
+ torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
189
+ ]
190
+ self.ssm_states += [
191
+ torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
192
+ ]
193
+ else:
194
+ # Attention or MLP layer
195
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
196
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
197
+ self.transformer_layers.append(i)
198
+
199
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
200
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
201
+
202
+ def update(
203
+ self,
204
+ key_states: torch.Tensor,
205
+ value_states: torch.Tensor,
206
+ layer_idx: int,
207
+ cache_kwargs: Optional[Dict[str, Any]] = None,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ # Update the cache
210
+ if self.key_cache[layer_idx].shape[-1] == 0:
211
+ self.key_cache[layer_idx] = key_states
212
+ self.value_cache[layer_idx] = value_states
213
+ else:
214
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
215
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
216
+
217
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
218
+
219
+ def reorder_cache(self, beam_idx: torch.LongTensor):
220
+ """Reorders the cache for beam search, given the selected beam indices."""
221
+ for layer_idx in range(len(self.key_cache)):
222
+ device = self.key_cache[layer_idx].device
223
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
224
+ device = self.value_cache[layer_idx].device
225
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
226
+
227
+ device = self.conv_states[layer_idx].device
228
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
229
+ device = self.ssm_states[layer_idx].device
230
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
231
+
232
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
233
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
234
+ # take any layer that contains cache and not empty tensor
235
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
236
+ if len(self.key_cache) <= layer_idx:
237
+ return 0
238
+ return self.key_cache[layer_idx].shape[-2]
239
+
240
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
241
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
242
+
243
+ @classmethod
244
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
245
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
246
+
247
+ # Copied from modeling_mamba2.py
248
+ def update_conv_state(
249
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
250
+ ) -> torch.Tensor:
251
+ if cache_init:
252
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
253
+ else:
254
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
255
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
256
+ return self.conv_states[layer_idx]
257
+
258
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
259
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
260
+ return self.ssm_states[layer_idx]
261
+
262
+ def reset(self):
263
+ self.conv_states.zero_()
264
+ self.ssm_states.zero_()
265
+
266
+ class MambaRMSNormGated(torch.nn.Module):
267
+ def __init__(self, hidden_size, group_size, eps=1e-5):
268
+ super().__init__()
269
+ self.weight = nn.Parameter(torch.ones(hidden_size))
270
+ self.variance_epsilon = eps
271
+ self.group_size = group_size
272
+
273
+ # jan28b version
274
+ def forward(self, hidden_states, gate=None):
275
+ return rmsnorm_fn(x=hidden_states,
276
+ weight=self.weight,
277
+ bias=None, # No bias
278
+ z=gate,
279
+ eps=self.variance_epsilon,
280
+ group_size=self.group_size,
281
+ norm_before_gate=False
282
+ )
283
+
284
+ class NemotronHMamba2Mixer(nn.Module):
285
+ """
286
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
287
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
288
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
289
+ and is why Mamba is called **selective** state spaces)
290
+ """
291
+
292
+ def __init__(self, config: NemotronHConfig, layer_idx: int):
293
+ super().__init__()
294
+ self.num_heads = config.mamba_num_heads
295
+ self.hidden_size = config.hidden_size
296
+ self.ssm_state_size = config.ssm_state_size
297
+ self.conv_kernel_size = config.conv_kernel
298
+ self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim
299
+ self.layer_idx = layer_idx
300
+ self.use_conv_bias = config.use_conv_bias
301
+ self.activation = config.mamba_hidden_act
302
+ self.act = ACT2FN[config.mamba_hidden_act]
303
+
304
+ self.layer_norm_epsilon = config.layer_norm_epsilon
305
+
306
+ self.n_groups = config.n_groups
307
+ self.head_dim = config.mamba_head_dim
308
+ self.chunk_size = config.chunk_size
309
+
310
+ self.time_step_limit = config.time_step_limit
311
+ self.time_step_min = config.time_step_min
312
+ self.time_step_max = config.time_step_max
313
+
314
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
315
+ self.conv1d = nn.Conv1d(
316
+ in_channels=self.conv_dim,
317
+ out_channels=self.conv_dim,
318
+ bias=config.use_conv_bias,
319
+ kernel_size=config.conv_kernel,
320
+ groups=self.conv_dim,
321
+ padding=config.conv_kernel - 1,
322
+ )
323
+
324
+ # projection of the input hidden states
325
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
326
+ self.in_proj = nn.Linear(
327
+ self.hidden_size,
328
+ projection_size,
329
+ bias=config.use_bias,
330
+ )
331
+ # selective projection used to make dt, B and C input dependant
332
+
333
+ # time step projection (discretization)
334
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
335
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
336
+
337
+ # S4D real initialization. These are not discretized!
338
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
339
+ A = torch.arange(1, self.num_heads + 1)
340
+ self.A_log = nn.Parameter(torch.log(A))
341
+ self.A_log._no_weight_decay = True
342
+ self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
343
+ self.D = nn.Parameter(torch.ones(self.num_heads))
344
+ self.D._no_weight_decay = True
345
+
346
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
347
+ self.use_bias = config.use_bias
348
+
349
+ if not is_fast_path_available:
350
+ logger.warning_once(
351
+ "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
352
+ " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
353
+ " https://github.com/Dao-AILab/causal-conv1d"
354
+ )
355
+
356
+ def cuda_kernels_forward(
357
+ self,
358
+ hidden_states: torch.Tensor,
359
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
360
+ cache_position: Optional[torch.LongTensor] = None,
361
+ attention_mask: Optional[torch.Tensor] = None,
362
+ ):
363
+ # 1. Gated MLP's linear projection
364
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
365
+ projected_states = self.in_proj(hidden_states)
366
+
367
+ # Set up dimensions for reshapes later
368
+ batch_size, seq_len, _ = hidden_states.shape
369
+ groups_time_state_size = self.n_groups * self.ssm_state_size
370
+ d_mlp = (
371
+ projected_states.shape[-1]
372
+ - 2 * self.intermediate_size
373
+ - 2 * self.n_groups * self.ssm_state_size
374
+ - self.num_heads
375
+ ) // 2
376
+
377
+ # Single step calculations via cache
378
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
379
+ _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
380
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
381
+ )
382
+
383
+ # 2. Convolution sequence transformation
384
+ hidden_states_B_C = causal_conv1d_update(
385
+ hidden_states_B_C,
386
+ cache_params.conv_states[self.layer_idx],
387
+ self.conv1d.weight.squeeze(1),
388
+ self.conv1d.bias,
389
+ self.activation,
390
+ )
391
+
392
+ hidden_states, B, C = torch.split(
393
+ hidden_states_B_C,
394
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
395
+ dim=-1,
396
+ )
397
+
398
+ # 3. SSM transformation
399
+ A = -torch.exp(self.A_log.float()) # (nheads,)
400
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
401
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
402
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
403
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
404
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
405
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
406
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
407
+ hidden_states = selective_state_update(
408
+ cache_params.ssm_states[self.layer_idx],
409
+ hidden_states_reshaped,
410
+ dt,
411
+ A,
412
+ B,
413
+ C,
414
+ D,
415
+ z=None,
416
+ dt_bias=dt_bias,
417
+ dt_softplus=True,
418
+ )
419
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
420
+ hidden_states = self.norm(hidden_states, gate)
421
+
422
+ # 4. Final linear projection
423
+ out = self.out_proj(hidden_states)[:, None, ...]
424
+
425
+ # Fused calculations or step by step if no initialized cache is found
426
+ else:
427
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
428
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
429
+
430
+ # 2-4. Fused kernel for conv1d, SSM, and the final projection
431
+ if self.training and cache_params is None:
432
+ out = mamba_split_conv1d_scan_combined(
433
+ projected_states,
434
+ self.conv1d.weight.squeeze(1),
435
+ self.conv1d.bias,
436
+ self.dt_bias,
437
+ A,
438
+ D=self.D,
439
+ chunk_size=self.chunk_size,
440
+ seq_idx=None, # was seq_idx
441
+ activation=self.activation,
442
+ rmsnorm_weight=self.norm.weight,
443
+ rmsnorm_eps=self.norm.variance_epsilon,
444
+ outproj_weight=self.out_proj.weight,
445
+ outproj_bias=self.out_proj.bias,
446
+ headdim=self.head_dim,
447
+ ngroups=self.n_groups,
448
+ norm_before_gate=False,
449
+ return_final_states=False,
450
+ **dt_limit_kwargs,
451
+ )
452
+
453
+ else:
454
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
455
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
456
+ )
457
+
458
+ # 2. Convolution sequence transformation
459
+ # Init cache
460
+ if cache_params is not None:
461
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
462
+ conv_states = nn.functional.pad(
463
+ hidden_states_B_C_transposed,
464
+ (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
465
+ )
466
+ cache_params.update_conv_state(
467
+ layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
468
+ )
469
+
470
+ if self.activation not in ["silu", "swish"]:
471
+ hidden_states_B_C = self.act(
472
+ self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
473
+ )
474
+ else:
475
+ hidden_states_B_C = causal_conv1d_fn(
476
+ x=hidden_states_B_C.transpose(1, 2),
477
+ weight=self.conv1d.weight.squeeze(1),
478
+ bias=self.conv1d.bias,
479
+ activation=self.activation,
480
+ ).transpose(1, 2)
481
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
482
+ hidden_states, B, C = torch.split(
483
+ hidden_states_B_C,
484
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
485
+ dim=-1,
486
+ )
487
+
488
+ # 3. SSM transformation
489
+ scan_output, ssm_state = mamba_chunk_scan_combined(
490
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
491
+ dt,
492
+ A,
493
+ B.view(batch_size, seq_len, self.n_groups, -1),
494
+ C.view(batch_size, seq_len, self.n_groups, -1),
495
+ chunk_size=self.chunk_size,
496
+ D=self.D,
497
+ z=None,
498
+ seq_idx=None,
499
+ return_final_states=True,
500
+ dt_bias=self.dt_bias,
501
+ dt_softplus=True,
502
+ **dt_limit_kwargs,
503
+ )
504
+
505
+ # Init cache
506
+ if ssm_state is not None and cache_params is not None:
507
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
508
+
509
+ scan_output = scan_output.view(batch_size, seq_len, -1)
510
+
511
+ # Multiply "gate" branch and apply extra normalization layer
512
+ scan_output = self.norm(scan_output, gate)
513
+
514
+ # 4. Final linear projection
515
+ out = self.out_proj(scan_output)
516
+ return out
517
+
518
+ # fmt: off
519
+ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
520
+ batch_size, seq_len, _ = input_states.shape
521
+ dtype = input_states.dtype
522
+
523
+ # 1. Gated MLP's linear projection
524
+ input_states = apply_mask_to_padding_states(input_states, attention_mask)
525
+ projected_states = self.in_proj(input_states)
526
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
527
+ _, _, gate, hidden_states_B_C, dt = projected_states.split(
528
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
529
+ )
530
+
531
+ # 2. Convolution sequence transformation
532
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
533
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
534
+
535
+ # We need to guarantee that anything regarding the cache is on the same device
536
+ conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
537
+
538
+ hidden_states_B_C = torch.sum(
539
+ conv_states * self.conv1d.weight.squeeze(1), dim=-1
540
+ )
541
+ if self.use_conv_bias:
542
+ hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
543
+ hidden_states_B_C = self.act(hidden_states_B_C)
544
+ else:
545
+ # Init cache
546
+ if cache_params is not None:
547
+ hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
548
+ conv_states = nn.functional.pad(
549
+ hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
550
+ )
551
+ cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
552
+
553
+ hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
554
+
555
+ hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
556
+ hidden_states, B, C = torch.split(
557
+ hidden_states_B_C,
558
+ [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
559
+ dim=-1
560
+ )
561
+
562
+ # 3. SSM transformation
563
+ A = -torch.exp(self.A_log.float()) # [num_heads]
564
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
565
+ # We need to guarantee that anything regarding the cache is on the same device
566
+ cache_device = cache_params.ssm_states.device
567
+
568
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
569
+ # for batched generation
570
+ dt = dt[:, 0, :][:, None, ...]
571
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
572
+ # [num_heads] -> [num_heads, head_dim]
573
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
574
+
575
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
576
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
577
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
578
+ # [bsz, num_heads, head_dim, state_size]
579
+ dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
580
+
581
+ # Discretize B
582
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
583
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
584
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
585
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
586
+ B = B.reshape(batch_size, -1, B.shape[-1])
587
+ # [bsz, num_heads, head_dim, state_size]
588
+ dB = dt[..., None] * B[..., None, :]
589
+
590
+ # Discretize x into dB
591
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
592
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
593
+ dBx = (dB * hidden_states[..., None]).to(device=cache_device)
594
+
595
+ # State calculation
596
+ cache_params.update_ssm_state(
597
+ layer_idx=self.layer_idx,
598
+ new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
599
+ )
600
+
601
+ # Subsequent output
602
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
603
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
604
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
605
+ C = C.reshape(batch_size, -1, C.shape[-1])
606
+ # [bsz, num_heads, head_dim]
607
+
608
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
609
+ # Reshape ssm_states to merge the first two dimensions
610
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
611
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
612
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
613
+ y = y.view(batch_size, self.num_heads, self.head_dim)
614
+
615
+ # D skip connection
616
+ # [num_heads] -> [num_heads, head_dim]
617
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
618
+ y = (y + hidden_states * D).to(y.dtype)
619
+
620
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
621
+ y = y.reshape(batch_size, -1)[:, None, ...]
622
+ else:
623
+ # begin ssd naive implementation without einsums
624
+ dt = nn.functional.softplus(dt + self.dt_bias)
625
+ dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
626
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
627
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
628
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
629
+ B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
630
+ C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
631
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
632
+
633
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
634
+
635
+ # Discretize x and A
636
+ hidden_states = hidden_states * dt[..., None]
637
+ A = A.to(hidden_states.dtype) * dt
638
+
639
+ # Rearrange into blocks/chunks
640
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
641
+
642
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
643
+ A = A.permute(0, 3, 1, 2)
644
+ A_cumsum = torch.cumsum(A, dim=-1)
645
+
646
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
647
+ # This is the analog of a causal mask
648
+ L = torch.exp(segment_sum(A))
649
+
650
+ # Contraction of C and B to get G (attention-weights like)
651
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
652
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
653
+
654
+ # Compute M, equivalent to applying attention mask to weights
655
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
656
+ M = M_intermediate.sum(dim=-1)
657
+
658
+ # Compute Y_diag (apply to values)
659
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
660
+
661
+ # 2. Compute the state for each intra-chunk
662
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
663
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
664
+ B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
665
+ states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
666
+
667
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
668
+ # (middle term of factorization of off-diag blocks; A terms)
669
+ if cache_params is not None and cache_position is not None and cache_position[0] > 0:
670
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
671
+ else:
672
+ previous_states = torch.zeros_like(states[:, :1])
673
+ states = torch.cat([previous_states, states], dim=1)
674
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
675
+ decay_chunk = decay_chunk.transpose(1, 3)
676
+ new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
677
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
678
+
679
+ # 4. Compute state -> output conversion per chunk
680
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
681
+ state_decay_out = torch.exp(A_cumsum)
682
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
683
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
684
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
685
+
686
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
687
+ y = Y_diag + Y_off
688
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
689
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
690
+
691
+ y = y + D_residual
692
+ # Cutting off padded chunks
693
+ if pad_size > 0:
694
+ y = y[:, :seq_len, :, :]
695
+ y = y.reshape(batch_size, seq_len, -1)
696
+
697
+ # Init cache
698
+ if ssm_state is not None and cache_params is not None:
699
+ cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
700
+
701
+ scan_output = self.norm(y, gate)
702
+
703
+ # end ssd naive
704
+
705
+ # 4. Final linear projection
706
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
707
+ return contextualized_states
708
+ # fmt: on
709
+
710
+ def forward(
711
+ self,
712
+ hidden_states,
713
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
714
+ cache_position: Optional[torch.LongTensor] = None,
715
+ attention_mask: Optional[torch.Tensor] = None,
716
+ ):
717
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
718
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
719
+ dtype = hidden_states.dtype
720
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
721
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
722
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
723
+
724
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
725
+
726
+
727
+ class NemotronHRMSNorm(nn.Module):
728
+ def __init__(self, hidden_size, eps=1e-6):
729
+ """
730
+ NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
731
+ """
732
+ super().__init__()
733
+ self.weight = nn.Parameter(torch.ones(hidden_size))
734
+ self.variance_epsilon = eps
735
+
736
+ def forward(self, hidden_states):
737
+ input_dtype = hidden_states.dtype
738
+ hidden_states = hidden_states.to(torch.float32)
739
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
740
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
741
+ # Weights are in float32
742
+ return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
743
+
744
+ class NemotronHBlock(nn.Module):
745
+ def __init__(self, config, layer_idx):
746
+ super().__init__()
747
+ self.config = config
748
+ self.layer_idx = layer_idx
749
+ self.residual_in_fp32 = config.residual_in_fp32
750
+ self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
751
+
752
+ # M: Mamba2, *: Attention, -: MLP
753
+ self.block_type = config.layers_block_type[layer_idx]
754
+ if self.block_type == "mamba":
755
+ self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx)
756
+ elif self.block_type == "attention":
757
+ self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
758
+ elif self.block_type == "mlp":
759
+ self.mixer = NemotronHMLP(config, layer_idx=layer_idx)
760
+ elif self.block_type == "moe":
761
+ self.mixer = NemotronHMOE(config, layer_idx=layer_idx)
762
+ else:
763
+ raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
764
+
765
+ def forward(
766
+ self,
767
+ hidden_states,
768
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
769
+ cache_position: Optional[torch.LongTensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ ):
772
+ with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
773
+ # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
774
+ residual = hidden_states
775
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
776
+ if self.residual_in_fp32:
777
+ residual = residual.to(torch.float32)
778
+
779
+ if self.block_type == "mamba":
780
+ hidden_states = self.mixer(
781
+ hidden_states, cache_params=cache_params, cache_position=cache_position
782
+ )
783
+ elif self.block_type == "attention":
784
+ hidden_states = self.mixer(
785
+ hidden_states, cache_position=cache_position
786
+ )
787
+ hidden_states = hidden_states[0]
788
+ elif self.block_type in ["mlp", "moe"]:
789
+ hidden_states = self.mixer(
790
+ hidden_states
791
+ )
792
+ else:
793
+ raise ValueError(f"Invalid block_type: {self.block_type}")
794
+
795
+ hidden_states = residual + hidden_states
796
+ return hidden_states
797
+
798
+
799
+ # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
800
+ class NemotronHMLP(nn.Module):
801
+ def __init__(self, config, intermediate_size=None, layer_idx: Optional[int] = None):
802
+ super().__init__()
803
+ self.config = config
804
+ self.layer_idx = layer_idx
805
+ if layer_idx is None:
806
+ logger.warning_once(
807
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
808
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
809
+ "when creating this class."
810
+ )
811
+ self.hidden_size = config.hidden_size
812
+ self.intermediate_size = intermediate_size or config.intermediate_size
813
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
814
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
815
+ self.act_fn = ACT2FN[config.mlp_hidden_act]
816
+
817
+ def forward(self, x):
818
+ return self.down_proj(self.act_fn(self.up_proj(x)))
819
+
820
+
821
+ class NemotronHMOE(nn.Module):
822
+ def __init__(self, config, layer_idx: Optional[int] = None):
823
+ super().__init__()
824
+ self.config = config
825
+ self.experts = nn.ModuleList(
826
+ [
827
+ NemotronHMLP(config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx)
828
+ for _ in range(config.n_routed_experts)
829
+ ]
830
+ )
831
+ self.gate = NemotronHTopkRouter(config)
832
+ self.shared_experts = NemotronHMLP(
833
+ config=config, intermediate_size=config.moe_shared_expert_intermediate_size, layer_idx=layer_idx
834
+ )
835
+
836
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
837
+ r"""
838
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
839
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
840
+ """
841
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
842
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
843
+ expert_mask = expert_mask.permute(2, 0, 1)
844
+
845
+ for expert_idx in range(len(self.experts)):
846
+ expert = self.experts[expert_idx]
847
+ mask = expert_mask[expert_idx]
848
+ token_indices, weight_indices = torch.where(mask)
849
+
850
+ if token_indices.numel() > 0:
851
+ expert_weights = topk_weights[token_indices, weight_indices]
852
+ expert_input = hidden_states[token_indices]
853
+ expert_output = expert(expert_input)
854
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
855
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
856
+ else:
857
+ # Local empty expert: no-op compute that still marks params as used.
858
+ expert_dtype = expert.down_proj.weight.dtype
859
+ dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
860
+ final_hidden_states = final_hidden_states + dummy_out
861
+
862
+ # in original deepseek, the output of the experts are gathered once we leave this module
863
+ # thus the moe module is itelsf an IsolatedParallel module
864
+ # and all expert are "local" meaning we shard but we don't gather
865
+ return final_hidden_states.type(hidden_states.dtype)
866
+
867
+ def forward(self, hidden_states):
868
+ residuals = hidden_states
869
+ orig_shape = hidden_states.shape
870
+ topk_indices, topk_weights = self.gate(hidden_states)
871
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
872
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
873
+ hidden_states = hidden_states + self.shared_experts(residuals)
874
+ return hidden_states
875
+
876
+
877
+ class NemotronHTopkRouter(nn.Module):
878
+ def __init__(self, config):
879
+ super().__init__()
880
+ self.config = config
881
+ self.top_k = config.num_experts_per_tok
882
+ self.n_routed_experts = config.n_routed_experts
883
+ self.routed_scaling_factor = config.routed_scaling_factor
884
+ self.n_group = config.n_group
885
+ self.topk_group = config.topk_group
886
+ self.norm_topk_prob = config.norm_topk_prob
887
+
888
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32))
889
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32))
890
+
891
+ @torch.no_grad()
892
+ def get_topk_indices(self, scores):
893
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
894
+ group_scores = (
895
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
896
+ .topk(2, dim=-1)[0]
897
+ .sum(dim=-1)
898
+ )
899
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
900
+ group_mask = torch.zeros_like(group_scores)
901
+ group_mask.scatter_(1, group_idx, 1)
902
+ score_mask = (
903
+ group_mask.unsqueeze(-1)
904
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
905
+ .reshape(-1, self.n_routed_experts)
906
+ )
907
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
908
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
909
+ return topk_indices
910
+
911
+ def forward(self, hidden_states):
912
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
913
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
914
+ scores = router_logits.sigmoid()
915
+ topk_indices = self.get_topk_indices(scores)
916
+ topk_weights = scores.gather(1, topk_indices)
917
+ if self.norm_topk_prob:
918
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
919
+ topk_weights /= denominator
920
+ topk_weights = topk_weights * self.routed_scaling_factor
921
+ return topk_indices, topk_weights
922
+
923
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
924
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
925
+ """
926
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
927
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
928
+ """
929
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
930
+ if n_rep == 1:
931
+ return hidden_states
932
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
933
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
934
+
935
+
936
+ class NemotronHAttention(nn.Module):
937
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
938
+
939
+ def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None):
940
+ super().__init__()
941
+ self.config = config
942
+ self.layer_idx = layer_idx
943
+ if layer_idx is None:
944
+ logger.warning_once(
945
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
946
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
947
+ "when creating this class."
948
+ )
949
+
950
+ self.attention_dropout = config.attention_dropout
951
+ self.hidden_size = config.hidden_size
952
+ self.num_heads = config.num_attention_heads
953
+ if hasattr(config, "head_dim") and config.head_dim is not None:
954
+ self.head_dim = config.head_dim
955
+ else:
956
+ self.head_dim = config.hidden_size // self.num_attention_heads
957
+ self.num_key_value_heads = config.num_key_value_heads
958
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
959
+ self.max_position_embeddings = config.max_position_embeddings
960
+ self.is_causal = True
961
+
962
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
963
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
964
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
965
+ self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
966
+
967
+ def forward(
968
+ self,
969
+ hidden_states: torch.Tensor,
970
+ # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO
971
+ attention_mask: Optional[torch.Tensor] = None,
972
+ position_ids: Optional[torch.LongTensor] = None,
973
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
974
+ output_attentions: bool = False,
975
+ use_cache: bool = False,
976
+ cache_position: Optional[torch.LongTensor] = None,
977
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
978
+ bsz, q_len, _ = hidden_states.size()
979
+
980
+ query_states = self.q_proj(hidden_states)
981
+ key_states = self.k_proj(hidden_states)
982
+ value_states = self.v_proj(hidden_states)
983
+
984
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
985
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
986
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
987
+
988
+ if past_key_value is not None:
989
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
990
+
991
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
992
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
993
+
994
+ causal_mask = attention_mask
995
+ if attention_mask is not None: # no matter the length, we just slice it
996
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
997
+
998
+ if query_states.device.type == "cuda" and attention_mask is not None:
999
+ query_states = query_states.contiguous()
1000
+ key_states = key_states.contiguous()
1001
+ value_states = value_states.contiguous()
1002
+
1003
+ is_causal = True if causal_mask is None and q_len > 1 else False
1004
+
1005
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1006
+ query_states,
1007
+ key_states,
1008
+ value_states,
1009
+ attn_mask=causal_mask,
1010
+ dropout_p=self.attention_dropout if self.training else 0.0,
1011
+ is_causal=is_causal,
1012
+ )
1013
+ attn_output = attn_output.transpose(1, 2).contiguous()
1014
+ #attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1015
+ attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
1016
+
1017
+ attn_output = self.o_proj(attn_output)
1018
+
1019
+ return attn_output, None, past_key_value
1020
+
1021
+
1022
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
1023
+ #class JambaFlashAttention2(JambaAttention):
1024
+ class NemotronHFlashAttention2(NemotronHAttention):
1025
+ """
1026
+ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
1027
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
1028
+ flash attention and deal with padding tokens in case the input contains any of them.
1029
+ """
1030
+ def __init__(self, *args, **kwargs):
1031
+ super().__init__(*args, **kwargs)
1032
+
1033
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1034
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1035
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1036
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1037
+
1038
+ def forward(
1039
+ self,
1040
+ hidden_states: torch.Tensor,
1041
+ attention_mask: Optional[torch.Tensor] = None,
1042
+ position_ids: Optional[torch.LongTensor] = None,
1043
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1044
+ output_attentions: bool = False,
1045
+ use_cache: bool = False,
1046
+ cache_position: Optional[torch.LongTensor] = None,
1047
+ **kwargs,
1048
+ ):
1049
+ bsz, q_len, _ = hidden_states.size()
1050
+
1051
+ query_states = self.q_proj(hidden_states)
1052
+ key_states = self.k_proj(hidden_states)
1053
+ value_states = self.v_proj(hidden_states)
1054
+
1055
+ # Flash attention requires the input to have the shape
1056
+ # batch_size x seq_length x head_dim x hidden_dim
1057
+ # therefore we just need to keep the original shape
1058
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
1059
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1060
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1061
+
1062
+ if past_key_value is not None:
1063
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
1064
+
1065
+ # repeat k/v heads if n_kv_heads < n_heads
1066
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1067
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1068
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
1069
+
1070
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1071
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
1072
+ # cast them back in float16 just to be sure everything works as expected.
1073
+ input_dtype = query_states.dtype
1074
+ if input_dtype == torch.float32:
1075
+ if torch.is_autocast_enabled():
1076
+ target_dtype = torch.get_autocast_gpu_dtype()
1077
+ # Handle the case where the model is quantized
1078
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1079
+ target_dtype = self.config._pre_quantization_dtype
1080
+ else:
1081
+ target_dtype = self.q_proj.weight.dtype
1082
+
1083
+ logger.warning_once(
1084
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1085
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1086
+ f" {target_dtype}."
1087
+ )
1088
+
1089
+ query_states = query_states.to(target_dtype)
1090
+ key_states = key_states.to(target_dtype)
1091
+ value_states = value_states.to(target_dtype)
1092
+
1093
+ # Reashape to the expected shape for Flash Attention
1094
+ key_states = key_states.transpose(1, 2)
1095
+ value_states = value_states.transpose(1, 2)
1096
+
1097
+ attn_output = _flash_attention_forward(
1098
+ query_states,
1099
+ key_states,
1100
+ value_states,
1101
+ attention_mask,
1102
+ q_len,
1103
+ dropout=dropout_rate,
1104
+ sliding_window=getattr(self.config, "sliding_window", None),
1105
+ is_causal=self.is_causal,
1106
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
1107
+ )
1108
+
1109
+ #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1110
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
1111
+ attn_output = self.o_proj(attn_output)
1112
+
1113
+ if not output_attentions:
1114
+ attn_weights = None
1115
+
1116
+ return attn_output, attn_weights, past_key_value
1117
+
1118
+
1119
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
1120
+ #class JambaSdpaAttention(JambaAttention):
1121
+ class NemotronHSdpaAttention(NemotronHAttention):
1122
+ """
1123
+ Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1124
+ `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1125
+ SDPA API.
1126
+ """
1127
+
1128
+ # Adapted from NemotronHAttention.forward
1129
+ def forward(
1130
+ self,
1131
+ hidden_states: torch.Tensor,
1132
+ attention_mask: Optional[torch.Tensor] = None,
1133
+ position_ids: Optional[torch.LongTensor] = None,
1134
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1135
+ output_attentions: bool = False,
1136
+ use_cache: bool = False,
1137
+ cache_position: Optional[torch.LongTensor] = None,
1138
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1139
+ if output_attentions:
1140
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1141
+ logger.warning_once(
1142
+ "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1143
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1144
+ )
1145
+ return super().forward(
1146
+ hidden_states=hidden_states,
1147
+ attention_mask=attention_mask,
1148
+ position_ids=position_ids,
1149
+ past_key_value=past_key_value,
1150
+ output_attentions=output_attentions,
1151
+ use_cache=use_cache,
1152
+ )
1153
+
1154
+ bsz, q_len, _ = hidden_states.size()
1155
+
1156
+ query_states = self.q_proj(hidden_states)
1157
+ key_states = self.k_proj(hidden_states)
1158
+ value_states = self.v_proj(hidden_states)
1159
+
1160
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1161
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1162
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1163
+
1164
+ if past_key_value is not None:
1165
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
1166
+
1167
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1168
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1169
+
1170
+ causal_mask = attention_mask
1171
+ if attention_mask is not None:
1172
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
1173
+
1174
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1175
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1176
+ if query_states.device.type == "cuda" and attention_mask is not None:
1177
+ query_states = query_states.contiguous()
1178
+ key_states = key_states.contiguous()
1179
+ value_states = value_states.contiguous()
1180
+
1181
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1182
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1183
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1184
+ is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
1185
+
1186
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1187
+ query_states,
1188
+ key_states,
1189
+ value_states,
1190
+ attn_mask=causal_mask,
1191
+ dropout_p=self.attention_dropout if self.training else 0.0,
1192
+ is_causal=is_causal,
1193
+ )
1194
+
1195
+ attn_output = attn_output.transpose(1, 2).contiguous()
1196
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1197
+
1198
+ attn_output = self.o_proj(attn_output)
1199
+
1200
+ return attn_output, None, past_key_value
1201
+
1202
+
1203
+ NEMOTRONH_ATTENTION_CLASSES = {
1204
+ "eager": NemotronHAttention,
1205
+ "flash_attention_2": NemotronHFlashAttention2,
1206
+ "sdpa": NemotronHSdpaAttention,
1207
+ }
1208
+
1209
+ # Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel
1210
+ class NemotronHPreTrainedModel(PreTrainedModel):
1211
+ """
1212
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1213
+ models.
1214
+ """
1215
+
1216
+ config_class = NemotronHConfig
1217
+ base_model_prefix = "backbone"
1218
+ _no_split_modules = ["NemotronHBlock"]
1219
+ supports_gradient_checkpointing = True
1220
+ _is_stateful = True
1221
+
1222
+ def _init_weights(self, module):
1223
+ """Initialize the weights."""
1224
+ if isinstance(module, NemotronHMamba2Mixer):
1225
+ module.A_log._no_weight_decay = True
1226
+ module.D._no_weight_decay = True
1227
+
1228
+ dt = torch.exp(
1229
+ torch.rand(self.config.mamba_num_heads)
1230
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
1231
+ + math.log(self.config.time_step_min)
1232
+ ).clamp(min=self.config.time_step_floor)
1233
+
1234
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1235
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
1236
+ with torch.no_grad():
1237
+ module.dt_bias.copy_(inv_dt)
1238
+ module.dt_bias._no_reinit = True
1239
+
1240
+ if isinstance(module, nn.Linear):
1241
+ if module.bias is not None:
1242
+ if not getattr(module.bias, "_no_reinit", False):
1243
+ nn.init.zeros_(module.bias)
1244
+ elif isinstance(module, nn.Embedding):
1245
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
1246
+
1247
+ # TODO: Check
1248
+ if self.config.rescale_prenorm_residual:
1249
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
1250
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
1251
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
1252
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
1253
+ #
1254
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
1255
+ for name, p in module.named_parameters():
1256
+ if name in ["out_proj.weight"]:
1257
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
1258
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1259
+ # We need to reinit p since this code could be called multiple times
1260
+ # Having just p *= scale would repeatedly scale it down
1261
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
1262
+ with torch.no_grad():
1263
+ p /= math.sqrt(self.config.num_hidden_layers)
1264
+
1265
+
1266
+ @dataclass
1267
+ # Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH
1268
+ class NemotronHOutput(ModelOutput):
1269
+ """
1270
+ Class for the NemotronH model outputs.
1271
+
1272
+ Args:
1273
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1274
+ Sequence of hidden-states at the output of the last layer of the model.
1275
+ cache_params (`HybridMambaAttentionDynamicCache`):
1276
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1277
+ avoid providing the old `input_ids`.
1278
+
1279
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
1280
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1281
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1282
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1283
+
1284
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1285
+ """
1286
+
1287
+ last_hidden_state: Optional[torch.FloatTensor] = None
1288
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1289
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1290
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1291
+
1292
+
1293
+ @dataclass
1294
+ # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1295
+ class NemotronHCausalLMOutput(ModelOutput):
1296
+ """
1297
+ Base class for causal language model (or autoregressive) outputs.
1298
+
1299
+ Args:
1300
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1301
+ Language modeling loss (for next-token prediction).
1302
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1303
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1304
+ cache_params (`HybridMambaAttentionDynamicCache`):
1305
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1306
+ avoid providing the old `input_ids`.
1307
+
1308
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
1309
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1310
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1311
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1312
+
1313
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1314
+ """
1315
+
1316
+ loss: Optional[torch.FloatTensor] = None
1317
+ logits: Optional[torch.FloatTensor] = None
1318
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1319
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1320
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1321
+
1322
+
1323
+ NEMOTRONH_START_DOCSTRING = r"""
1324
+
1325
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1326
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1327
+ etc.)
1328
+
1329
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1330
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1331
+ and behavior.
1332
+
1333
+ Parameters:
1334
+ config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model.
1335
+ Initializing with a config file does not load the weights associated with the model, only the
1336
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1337
+ """
1338
+
1339
+ NEMOTRONH_INPUTS_DOCSTRING = r"""
1340
+ Args:
1341
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
1342
+ Indices of input sequence tokens in the vocabulary.
1343
+
1344
+ If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
1345
+ `input_ids`.
1346
+
1347
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1348
+ [`PreTrainedTokenizer.__call__`] for details.
1349
+
1350
+ [What are input IDs?](../glossary#input-ids)
1351
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1352
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1353
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1354
+ model's internal embedding lookup matrix.
1355
+ position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1356
+ Indices of positions of each input sequence tokens in the position embeddings.
1357
+ cache_params (`HybridMambaAttentionDynamicCache`, *optional*):
1358
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
1359
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
1360
+ use_cache (`bool`, *optional*):
1361
+ If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
1362
+ output_attentions (`bool`, *optional*):
1363
+ Whether or not to return the attentions tensors of all attention layers.
1364
+ output_hidden_states (`bool`, *optional*):
1365
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1366
+ more detail.
1367
+ return_dict (`bool`, *optional*):
1368
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1369
+ cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1370
+ The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
1371
+ If `cache_params` is passed, `cache_position` should also be passed.
1372
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1373
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1374
+
1375
+ - 1 for tokens that are **not masked**,
1376
+ - 0 for tokens that are **masked**.
1377
+
1378
+ [What are attention masks?](../glossary#attention-mask)
1379
+ """
1380
+
1381
+
1382
+ @add_start_docstrings(
1383
+ "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.",
1384
+ NEMOTRONH_START_DOCSTRING,
1385
+ )
1386
+ class NemotronHModel(NemotronHPreTrainedModel):
1387
+ def __init__(self, config):
1388
+ super().__init__(config)
1389
+
1390
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
1391
+ self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
1392
+
1393
+ self.gradient_checkpointing = False
1394
+ self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
1395
+ # Initialize weights and apply final processing
1396
+ self._register_load_state_dict_pre_hook(self.load_hook)
1397
+ self.post_init()
1398
+
1399
+ def load_hook(self, state_dict, prefix, *args):
1400
+ for k in state_dict:
1401
+ if "embedding." in k:
1402
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
1403
+ break
1404
+
1405
+ def get_input_embeddings(self):
1406
+ return self.embeddings
1407
+
1408
+ def set_input_embeddings(self, new_embeddings):
1409
+ self.embeddings = new_embeddings
1410
+
1411
+ @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1412
+ @add_code_sample_docstrings(
1413
+ checkpoint=_CHECKPOINT_FOR_DOC,
1414
+ output_type=NemotronHOutput,
1415
+ config_class=_CONFIG_FOR_DOC,
1416
+ )
1417
+ def forward(
1418
+ self,
1419
+ input_ids: Optional[torch.LongTensor] = None,
1420
+ inputs_embeds: Optional[torch.LongTensor] = None,
1421
+ position_ids: Optional[torch.LongTensor] = None,
1422
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1423
+ use_cache: Optional[bool] = None,
1424
+ output_attentions: Optional[bool] = None,
1425
+ output_hidden_states: Optional[bool] = None,
1426
+ return_dict: Optional[bool] = None,
1427
+ cache_position: Optional[torch.LongTensor] = None,
1428
+ attention_mask: Optional[torch.Tensor] = None,
1429
+ **kwargs,
1430
+ ) -> Union[Tuple, NemotronHOutput]:
1431
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1432
+ output_hidden_states = (
1433
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1434
+ )
1435
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
1436
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
1437
+
1438
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1439
+
1440
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
1441
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1442
+
1443
+ if inputs_embeds is None:
1444
+ inputs_embeds = self.embeddings(input_ids)
1445
+
1446
+ if self.gradient_checkpointing and self.training and use_cache:
1447
+ logger.warning_once(
1448
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1449
+ )
1450
+ use_cache = False
1451
+
1452
+ # From zamba_modeling.py
1453
+ if use_cache and cache_params is None:
1454
+ logger.warning_once(
1455
+ "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was "
1456
+ "provided, so no cache will be returned."
1457
+ )
1458
+
1459
+ hidden_states = inputs_embeds
1460
+
1461
+ if cache_position is None:
1462
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
1463
+ if position_ids is None:
1464
+ position_ids = cache_position.unsqueeze(0)
1465
+
1466
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1467
+ mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
1468
+
1469
+ all_hidden_states = () if output_hidden_states else None
1470
+ all_self_attns = () if output_attentions else None
1471
+ # Until HERE
1472
+
1473
+ for layer_idx, mixer_block in enumerate(self.layers):
1474
+ # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
1475
+ if mixer_block.block_type == "mamba":
1476
+ layer_mask = mamba_mask
1477
+ elif mixer_block.block_type == "attention":
1478
+ layer_mask = causal_mask
1479
+ elif mixer_block.block_type in ["mlp", "moe"]:
1480
+ layer_mask = None
1481
+ else:
1482
+ raise ValueError(f"Invalid block_type: {self.block_type}")
1483
+
1484
+ if output_hidden_states:
1485
+ all_hidden_states += (hidden_states,)
1486
+
1487
+ if self.gradient_checkpointing and self.training:
1488
+ hidden_states = self._gradient_checkpointing_func(
1489
+ mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask
1490
+ )
1491
+ else:
1492
+ hidden_states = mixer_block(
1493
+ hidden_states,
1494
+ cache_params=cache_params,
1495
+ cache_position=cache_position,
1496
+ attention_mask=layer_mask,
1497
+ )
1498
+
1499
+ # TODO: Store attentions
1500
+ # if output_attentions:
1501
+ # if layer_outputs[1] is not None:
1502
+ # # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1503
+ # all_self_attns += (layer_outputs[1],)
1504
+
1505
+ # TODO (Check): should it happen before the forward pass?
1506
+ # if output_hidden_states:
1507
+ # all_hidden_states = all_hidden_states + (hidden_states,)
1508
+
1509
+ hidden_states = self.norm_f(hidden_states)
1510
+
1511
+ if output_hidden_states:
1512
+ all_hidden_states = all_hidden_states + (hidden_states,)
1513
+
1514
+ if not return_dict:
1515
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
1516
+
1517
+ return NemotronHOutput(
1518
+ last_hidden_state=hidden_states,
1519
+ cache_params=cache_params if use_cache else None,
1520
+ hidden_states=all_hidden_states,
1521
+ attentions=all_self_attns,
1522
+ )
1523
+
1524
+ # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
1525
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1526
+ if self.config._attn_implementation == "flash_attention_2":
1527
+ if attention_mask is not None and 0.0 in attention_mask:
1528
+ return attention_mask
1529
+ return None
1530
+
1531
+ dtype, device = input_tensor.dtype, input_tensor.device
1532
+ min_dtype = torch.finfo(dtype).min
1533
+ sequence_length = input_tensor.shape[1]
1534
+ target_length = cache_position[-1] + 1
1535
+
1536
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1537
+ if sequence_length != 1:
1538
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1539
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1540
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1541
+ if attention_mask is not None:
1542
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1543
+ if attention_mask.dim() == 2:
1544
+ mask_length = attention_mask.shape[-1]
1545
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1546
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1547
+
1548
+ if (
1549
+ self.config._attn_implementation == "sdpa"
1550
+ and attention_mask is not None
1551
+ and attention_mask.device.type == "cuda"
1552
+ ):
1553
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1554
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1555
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1556
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1557
+
1558
+ return causal_mask
1559
+
1560
+ def _update_mamba_mask(self, attention_mask, cache_position):
1561
+ """
1562
+ No need for zeroing states when
1563
+ 1. Cached forward
1564
+ 2. Attending to all inputs
1565
+ """
1566
+ mamba_mask = attention_mask
1567
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
1568
+ mamba_mask = None
1569
+ return mamba_mask
1570
+
1571
+
1572
+ @add_start_docstrings(
1573
+ """
1574
+ The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input
1575
+ embeddings).
1576
+ """,
1577
+ NEMOTRONH_START_DOCSTRING,
1578
+ )
1579
+ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1580
+ _tied_weights_keys = ["lm_head.weight"]
1581
+
1582
+ def __init__(self, config):
1583
+ super().__init__(config)
1584
+ self.backbone = NemotronHModel(config)
1585
+ self.vocab_size = config.vocab_size
1586
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1587
+
1588
+ # Initialize weights and apply final processing
1589
+ self.post_init()
1590
+
1591
+ def get_input_embeddings(self):
1592
+ return self.backbone.get_input_embeddings()
1593
+
1594
+ def set_input_embeddings(self, new_embeddings):
1595
+ return self.backbone.set_input_embeddings(new_embeddings)
1596
+
1597
+ def get_output_embeddings(self):
1598
+ return self.lm_head
1599
+
1600
+ def set_output_embeddings(self, new_embeddings):
1601
+ self.lm_head = new_embeddings
1602
+
1603
+ def get_decoder(self):
1604
+ return self.model
1605
+
1606
+ def set_decoder(self, decoder):
1607
+ self.model = decoder
1608
+
1609
+ def prepare_inputs_for_generation(
1610
+ self,
1611
+ input_ids,
1612
+ past_key_values=None,
1613
+ attention_mask=None,
1614
+ inputs_embeds=None,
1615
+ cache_position=None,
1616
+ position_ids=None,
1617
+ use_cache=True,
1618
+ **kwargs,
1619
+ ):
1620
+ # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
1621
+ # Overwitten -- uses `cache_params` as opposed to `past_key_values`
1622
+ empty_past_kv = past_key_values is None
1623
+
1624
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1625
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1626
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1627
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
1628
+ # (we can't check exception 3 while compiling)
1629
+ if not empty_past_kv:
1630
+ if (
1631
+ inputs_embeds is not None # Exception 1
1632
+ or cache_position[-1] >= input_ids.shape[1] # Exception 3
1633
+ ):
1634
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1635
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1636
+ input_ids = input_ids[:, cache_position]
1637
+ else:
1638
+ past_key_values = HybridMambaAttentionDynamicCache(
1639
+ self.config, input_ids.shape[0], self.dtype, device=self.device
1640
+ )
1641
+
1642
+ if attention_mask is not None and position_ids is None:
1643
+ # create position_ids on the fly for batch generation
1644
+ position_ids = attention_mask.long().cumsum(-1) - 1
1645
+ position_ids.masked_fill_(attention_mask == 0, 1)
1646
+ if not empty_past_kv:
1647
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1648
+
1649
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1650
+ if inputs_embeds is not None and empty_past_kv:
1651
+ model_inputs = {"inputs_embeds": inputs_embeds}
1652
+ else:
1653
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
1654
+
1655
+ model_inputs.update(
1656
+ {
1657
+ "position_ids": position_ids,
1658
+ "past_key_values": past_key_values,
1659
+ "use_cache": use_cache,
1660
+ "attention_mask": attention_mask,
1661
+ "logits_to_keep": self.config.num_logits_to_keep,
1662
+ "cache_position": cache_position,
1663
+ }
1664
+ )
1665
+ return model_inputs
1666
+
1667
+ @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1668
+ @add_code_sample_docstrings(
1669
+ checkpoint=_CHECKPOINT_FOR_DOC,
1670
+ output_type=NemotronHCausalLMOutput,
1671
+ config_class=_CONFIG_FOR_DOC,
1672
+ )
1673
+ def forward(
1674
+ self,
1675
+ input_ids: Optional[torch.LongTensor] = None,
1676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1677
+ position_ids: Optional[torch.LongTensor] = None,
1678
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1679
+ labels: Optional[torch.LongTensor] = None,
1680
+ output_attentions: Optional[bool] = None,
1681
+ output_hidden_states: Optional[bool] = None,
1682
+ return_dict: Optional[bool] = None,
1683
+ use_cache: Optional[bool] = None,
1684
+ cache_position: Optional[torch.Tensor] = None,
1685
+ attention_mask: Optional[torch.Tensor] = None,
1686
+ **kwargs, # for now we need this for generation
1687
+ ) -> Union[Tuple, NemotronHCausalLMOutput]:
1688
+ r"""
1689
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1690
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1691
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1692
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1693
+ """
1694
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1695
+
1696
+ output_hidden_states = (
1697
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1698
+ )
1699
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1700
+
1701
+ nemotron_h_outputs = self.backbone(
1702
+ input_ids,
1703
+ cache_params=cache_params,
1704
+ inputs_embeds=inputs_embeds,
1705
+ output_attentions=output_attentions,
1706
+ output_hidden_states=output_hidden_states,
1707
+ return_dict=return_dict,
1708
+ use_cache=use_cache,
1709
+ cache_position=cache_position,
1710
+ attention_mask=attention_mask,
1711
+ )
1712
+ hidden_states = nemotron_h_outputs[0]
1713
+
1714
+ # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2
1715
+ #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
1716
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
1717
+
1718
+ loss = None
1719
+ if labels is not None:
1720
+ # move labels to correct device to enable model parallelism
1721
+ labels = labels.to(logits.device)
1722
+ # Shift so that tokens < n predict n
1723
+ shift_logits = logits[..., :-1, :].contiguous()
1724
+ shift_labels = labels[..., 1:].contiguous()
1725
+ # Flatten the tokens
1726
+ loss_fct = CrossEntropyLoss()
1727
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1728
+
1729
+ if not return_dict:
1730
+ output = (logits,) + nemotron_h_outputs[1:]
1731
+ return ((loss,) + output) if loss is not None else output
1732
+
1733
+ return NemotronHCausalLMOutput(
1734
+ loss=loss,
1735
+ logits=logits,
1736
+ cache_params=nemotron_h_outputs.cache_params,
1737
+ hidden_states=nemotron_h_outputs.hidden_states,
1738
+ attentions=nemotron_h_outputs.attentions,
1739
+ )
modeling_nemotron_twotower.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Two-tower NemotronH for HuggingFace — real separate context + denoiser weights.
5
+ #
6
+ # Checkpoint key layout (from converted safetensors):
7
+ # context_tower.* — context backbone (NemotronHModel)
8
+ # context_lm_head.weight — context output head
9
+ # denoiser_tower.* — denoiser backbone (NemotronHModel)
10
+ # lm_head.weight — denoiser output head
11
+ # t_embedder.* — timestep embedder (optional, for mask_diffusion)
12
+ # t_block.* — timestep MLP (optional)
13
+ # scale_shift_tables.* — per-layer modulation bias (optional)
14
+ #
15
+ # Modes:
16
+ # AR: forward() + generate() — context_tower only
17
+ # Mock-AR: generate_mock_ar() — two-tower, S-2/KV[:-1] semantics
18
+ # Mask-Diffusion: generate_mask_diffusion() — block-wise iterative denoising
19
+
20
+ import math
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
25
+
26
+ try:
27
+ from .modeling_nemotron_h import (
28
+ HybridMambaAttentionDynamicCache,
29
+ NemotronHCausalLMOutput,
30
+ NemotronHForCausalLM,
31
+ NemotronHModel,
32
+ NemotronHPreTrainedModel,
33
+ )
34
+ from .configuration_nemotron_h import NemotronHConfig
35
+ except ImportError:
36
+ from modeling_nemotron_h import (
37
+ HybridMambaAttentionDynamicCache,
38
+ NemotronHCausalLMOutput,
39
+ NemotronHForCausalLM,
40
+ NemotronHModel,
41
+ NemotronHPreTrainedModel,
42
+ )
43
+ from configuration_nemotron_h import NemotronHConfig
44
+
45
+ from transformers.generation import GenerationMixin
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Time conditioning (PixArt-alpha adaLN-single style)
50
+ # ---------------------------------------------------------------------------
51
+
52
+ class TimestepEmbedder(nn.Module):
53
+ """Sinusoidal + MLP embedder for scalar timesteps in [0,1]."""
54
+
55
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256,
56
+ max_period: int = 1000):
57
+ super().__init__()
58
+ self.frequency_embedding_size = frequency_embedding_size
59
+ self.max_period = max_period
60
+ self.mlp = nn.Sequential(
61
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
62
+ nn.SiLU(),
63
+ nn.Linear(hidden_size, hidden_size, bias=True),
64
+ )
65
+
66
+ @staticmethod
67
+ def timestep_embedding(t, dim, max_period=10000):
68
+ half = dim // 2
69
+ freqs = torch.exp(
70
+ -math.log(max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half
71
+ )
72
+ args = t[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
76
+ return embedding.to(t.dtype)
77
+
78
+ def forward(self, t):
79
+ t_scaled = t * self.max_period
80
+ t_freq = self.timestep_embedding(t_scaled, self.frequency_embedding_size)
81
+ return self.mlp(t_freq)
82
+
83
+
84
+ def _modulate(x, shift, scale):
85
+ """Adaptive LN: x * (1 + scale) + shift. Broadcasts for (B,L,D) input."""
86
+ return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
87
+
88
+
89
+ def _get_mod_params(t_emb, table):
90
+ """(B, 3*D) + (3, D) -> (shift, scale, gate) each (B, D)."""
91
+ B, D = t_emb.shape[0], table.shape[1]
92
+ combined = table[None] + t_emb.reshape(B, 3, D)
93
+ shift, scale, gate = combined.chunk(3, dim=1)
94
+ return shift.squeeze(1), scale.squeeze(1), gate.squeeze(1)
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Bug-fixed cache
99
+ # ---------------------------------------------------------------------------
100
+
101
+ class FixedHybridCache(HybridMambaAttentionDynamicCache):
102
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
103
+ super().__init__(config, batch_size, dtype, device)
104
+ self.conv_kernel_size = config.conv_kernel
105
+
106
+ def update_conv_state(self, layer_idx, new_conv_state, cache_init=False):
107
+ if cache_init:
108
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[layer_idx].device)
109
+ else:
110
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
111
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(
112
+ self.conv_states[layer_idx].device
113
+ )
114
+ return self.conv_states[layer_idx]
115
+
116
+ def update_ssm_state(self, layer_idx, new_ssm_state):
117
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
118
+ return self.ssm_states[layer_idx]
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Two-Tower CausalLM
123
+ # ---------------------------------------------------------------------------
124
+
125
+ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
126
+ """Two-tower NemotronH with real separate context and denoiser weights.
127
+
128
+ Modes:
129
+ AR: forward() + generate() — context_tower only
130
+ Mock-AR: generate_mock_ar() — S-2/KV[:-1] semantics
131
+ Mask-Diffusion: generate_mask_diffusion() — block-wise confidence_unmasking
132
+ """
133
+
134
+ _tied_weights_keys = []
135
+
136
+ def __init__(self, config: NemotronHConfig):
137
+ super().__init__(config)
138
+ self.context_tower = NemotronHModel(config)
139
+ self.context_lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
140
+ self.denoiser_tower = NemotronHModel(config)
141
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
142
+ self.vocab_size = config.vocab_size
143
+
144
+ # Time conditioning (created unconditionally; weights loaded if present)
145
+ H = config.hidden_size
146
+ N = config.num_hidden_layers
147
+ self.t_embedder = TimestepEmbedder(H)
148
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(H, 3 * H, bias=True))
149
+ self.scale_shift_tables = nn.ParameterList([
150
+ nn.Parameter(torch.randn(3, H) / (H ** 0.5)) for _ in range(N)
151
+ ])
152
+
153
+ self.post_init()
154
+
155
+ # ------------------------------------------------------------------
156
+ # HF interface
157
+ # ------------------------------------------------------------------
158
+
159
+ def get_input_embeddings(self):
160
+ return self.context_tower.get_input_embeddings()
161
+
162
+ def set_input_embeddings(self, new_embeddings):
163
+ return self.context_tower.set_input_embeddings(new_embeddings)
164
+
165
+ def get_output_embeddings(self):
166
+ return self.context_lm_head
167
+
168
+ def set_output_embeddings(self, new_embeddings):
169
+ self.context_lm_head = new_embeddings
170
+
171
+ def prepare_inputs_for_generation(
172
+ self, input_ids, past_key_values=None, attention_mask=None,
173
+ inputs_embeds=None, cache_position=None, position_ids=None,
174
+ use_cache=True, **kwargs,
175
+ ):
176
+ empty_past_kv = past_key_values is None
177
+ if not empty_past_kv:
178
+ if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]:
179
+ input_ids = input_ids[:, -cache_position.shape[0]:]
180
+ elif input_ids.shape[1] != cache_position.shape[0]:
181
+ input_ids = input_ids[:, cache_position]
182
+ else:
183
+ past_key_values = HybridMambaAttentionDynamicCache(
184
+ self.config, input_ids.shape[0], self.dtype, device=self.device
185
+ )
186
+ if attention_mask is not None and position_ids is None:
187
+ position_ids = attention_mask.long().cumsum(-1) - 1
188
+ position_ids.masked_fill_(attention_mask == 0, 1)
189
+ if not empty_past_kv:
190
+ position_ids = position_ids[:, -input_ids.shape[1]:]
191
+ if inputs_embeds is not None and empty_past_kv:
192
+ model_inputs = {"inputs_embeds": inputs_embeds}
193
+ else:
194
+ model_inputs = {"input_ids": input_ids.contiguous()}
195
+ model_inputs.update({
196
+ "position_ids": position_ids, "past_key_values": past_key_values,
197
+ "use_cache": use_cache, "attention_mask": attention_mask,
198
+ "logits_to_keep": self.config.num_logits_to_keep,
199
+ "cache_position": cache_position,
200
+ })
201
+ return model_inputs
202
+
203
+ # ------------------------------------------------------------------
204
+ # Forward (context tower only, for HF generate)
205
+ # ------------------------------------------------------------------
206
+
207
+ def forward(
208
+ self, input_ids=None, inputs_embeds=None, position_ids=None,
209
+ cache_params=None, labels=None, output_attentions=None,
210
+ output_hidden_states=None, return_dict=None, use_cache=None,
211
+ cache_position=None, attention_mask=None, **kwargs,
212
+ ) -> Union[Tuple, NemotronHCausalLMOutput]:
213
+ past_key_values = kwargs.pop("past_key_values", None)
214
+ if past_key_values is not None and cache_params is None:
215
+ cache_params = past_key_values
216
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
217
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
218
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
219
+
220
+ outputs = self.context_tower(
221
+ input_ids, cache_params=cache_params, inputs_embeds=inputs_embeds,
222
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
223
+ return_dict=return_dict, use_cache=use_cache,
224
+ cache_position=cache_position, attention_mask=attention_mask,
225
+ )
226
+ hidden_states = outputs[0]
227
+ logits = self.context_lm_head(hidden_states.to(self.context_lm_head.weight.dtype)).float()
228
+
229
+ loss = None
230
+ if labels is not None:
231
+ labels = labels.to(logits.device)
232
+ shift_logits = logits[..., :-1, :].contiguous()
233
+ shift_labels = labels[..., 1:].contiguous()
234
+ loss = nn.CrossEntropyLoss()(
235
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
236
+ )
237
+ if not return_dict:
238
+ output = (logits,) + outputs[1:]
239
+ return ((loss,) + output) if loss is not None else output
240
+ return NemotronHCausalLMOutput(
241
+ loss=loss, logits=logits, cache_params=outputs.cache_params,
242
+ hidden_states=outputs.hidden_states, attentions=outputs.attentions,
243
+ )
244
+
245
+ # ------------------------------------------------------------------
246
+ # Layer-by-layer forward with cache + optional time conditioning
247
+ # ------------------------------------------------------------------
248
+
249
+ def _forward_tower_with_cache(self, tower, lm_head, input_ids, cache,
250
+ cache_position, t_emb=None):
251
+ """Forward through tower with KV cache. If t_emb is provided, applies
252
+ PixArt-style adaLN modulation (shift/scale after norm, gate on output)."""
253
+ hidden = tower.embeddings(input_ids)
254
+ causal_mask = tower._update_causal_mask(None, hidden, cache_position)
255
+
256
+ for layer_idx, block in enumerate(tower.layers):
257
+ residual = hidden
258
+ hidden = block.norm(hidden.to(dtype=block.norm.weight.dtype))
259
+ if block.residual_in_fp32:
260
+ residual = residual.to(torch.float32)
261
+
262
+ mod = None
263
+ if t_emb is not None:
264
+ mod = _get_mod_params(t_emb, self.scale_shift_tables[layer_idx])
265
+ shift, scale, gate = mod
266
+ hidden = _modulate(hidden, shift, scale)
267
+
268
+ if block.block_type == "mamba":
269
+ hidden = block.mixer(
270
+ hidden, cache_params=cache, cache_position=cache_position,
271
+ )
272
+ elif block.block_type == "attention":
273
+ hidden, _, _ = block.mixer(
274
+ hidden, attention_mask=causal_mask,
275
+ past_key_value=cache, cache_position=cache_position,
276
+ )
277
+ elif block.block_type in ["mlp", "moe"]:
278
+ hidden = block.mixer(hidden)
279
+ else:
280
+ raise ValueError(f"Unknown block_type: {block.block_type}")
281
+
282
+ if mod is not None:
283
+ hidden = gate.unsqueeze(1) * hidden
284
+
285
+ hidden = residual + hidden
286
+
287
+ hidden = tower.norm_f(hidden)
288
+ logits = lm_head(hidden.to(lm_head.weight.dtype)).float()
289
+ return logits
290
+
291
+ # ------------------------------------------------------------------
292
+ # Cache management
293
+ # ------------------------------------------------------------------
294
+
295
+ def _make_cache(self, config, batch_size, dtype, device):
296
+ return FixedHybridCache(config, batch_size, dtype, device)
297
+
298
+ def _build_context_cache(self, prompt_ids):
299
+ """Two-pass context prefill: S-2 and S-1 Mamba states + full KV."""
300
+ B, S = prompt_ids.shape
301
+ device = prompt_ids.device
302
+ tower = self.context_tower
303
+ pattern = self.config.hybrid_override_pattern
304
+
305
+ cache_p1 = self._make_cache(self.config, B, self.dtype, device)
306
+ cp_p1 = torch.arange(S - 1, device=device)
307
+ self._forward_tower_with_cache(tower, self.context_lm_head,
308
+ prompt_ids[:, :-1], cache_p1, cp_p1)
309
+
310
+ mamba_s2 = {}
311
+ for i in range(self.config.num_hidden_layers):
312
+ if pattern[i] == "M":
313
+ mamba_s2[i] = (cache_p1.conv_states[i].clone(),
314
+ cache_p1.ssm_states[i].clone())
315
+
316
+ cache_p2 = self._make_cache(self.config, B, self.dtype, device)
317
+ for i in range(self.config.num_hidden_layers):
318
+ if pattern[i] == "M":
319
+ cache_p2.conv_states[i] = cache_p1.conv_states[i].clone()
320
+ cache_p2.ssm_states[i] = cache_p1.ssm_states[i].clone()
321
+ elif pattern[i] == "*":
322
+ cache_p2.key_cache[i] = cache_p1.key_cache[i].clone()
323
+ cache_p2.value_cache[i] = cache_p1.value_cache[i].clone()
324
+
325
+ cache_p2.has_previous_state = True
326
+ cp_p2 = torch.arange(S - 1, S, device=device)
327
+ self._forward_tower_with_cache(tower, self.context_lm_head,
328
+ prompt_ids[:, -1:], cache_p2, cp_p2)
329
+
330
+ return {"ctx_cache": cache_p2, "mamba_s2": mamba_s2, "ctx_len": S}
331
+
332
+ def _extend_context_cache(self, new_tokens, cache_state):
333
+ """Extend context cache by new_tokens (B, L). Old S-1 -> new S-2.
334
+
335
+ Processes tokens one at a time so HF Mamba can use its single-step
336
+ cached path (seq_len=1, cache_position[0] > 0).
337
+ """
338
+ ctx_cache = cache_state["ctx_cache"]
339
+ pattern = self.config.hybrid_override_pattern
340
+ ctx_len = cache_state["ctx_len"]
341
+ ctx_device = next(self.context_tower.parameters()).device
342
+ L = new_tokens.shape[1]
343
+ tokens_on_device = new_tokens.to(ctx_device)
344
+
345
+ new_s2 = {}
346
+ for i in range(self.config.num_hidden_layers):
347
+ if pattern[i] == "M":
348
+ new_s2[i] = (ctx_cache.conv_states[i].clone(),
349
+ ctx_cache.ssm_states[i].clone())
350
+ cache_state["mamba_s2"] = new_s2
351
+
352
+ ctx_cache.has_previous_state = True
353
+ for j in range(L):
354
+ cp = torch.tensor([ctx_len + j], device=ctx_device)
355
+ self._forward_tower_with_cache(
356
+ self.context_tower, self.context_lm_head,
357
+ tokens_on_device[:, j:j+1], ctx_cache, cp,
358
+ )
359
+
360
+ cache_state["ctx_len"] = ctx_len + L
361
+ return cache_state
362
+
363
+ def _build_denoiser_cache_mock_ar(self, cache_state, device):
364
+ """Mock-AR denoiser cache: Mamba S-2, Attention KV[:-1]."""
365
+ ctx_cache = cache_state["ctx_cache"]
366
+ mamba_s2 = cache_state["mamba_s2"]
367
+ pattern = self.config.hybrid_override_pattern
368
+ B = ctx_cache.conv_states[0].shape[0] if pattern[0] == "M" else ctx_cache.key_cache[0].shape[0]
369
+
370
+ den = self._make_cache(self.config, B, self.dtype, device)
371
+ for i in range(self.config.num_hidden_layers):
372
+ if pattern[i] == "M":
373
+ conv_s2, ssm_s2 = mamba_s2[i]
374
+ den.conv_states[i] = conv_s2.to(device).clone()
375
+ den.ssm_states[i] = ssm_s2.to(device).clone()
376
+ elif pattern[i] == "*":
377
+ k, v = ctx_cache.key_cache[i], ctx_cache.value_cache[i]
378
+ if k.dim() == 4 and k.shape[2] > 0:
379
+ den.key_cache[i] = k[:, :, :-1, :].to(device).clone()
380
+ den.value_cache[i] = v[:, :, :-1, :].to(device).clone()
381
+ den.has_previous_state = True
382
+ return den
383
+
384
+ def _build_denoiser_cache_diffusion(self, cache_state, device):
385
+ """Diffusion denoiser cache: Mamba S-1 (latest), full Attention KV."""
386
+ ctx_cache = cache_state["ctx_cache"]
387
+ pattern = self.config.hybrid_override_pattern
388
+ B = ctx_cache.conv_states[0].shape[0] if pattern[0] == "M" else ctx_cache.key_cache[0].shape[0]
389
+
390
+ den = self._make_cache(self.config, B, self.dtype, device)
391
+ for i in range(self.config.num_hidden_layers):
392
+ if pattern[i] == "M":
393
+ den.conv_states[i] = ctx_cache.conv_states[i].to(device).clone()
394
+ den.ssm_states[i] = ctx_cache.ssm_states[i].to(device).clone()
395
+ elif pattern[i] == "*":
396
+ k, v = ctx_cache.key_cache[i], ctx_cache.value_cache[i]
397
+ if k.dim() == 4 and k.shape[2] > 0:
398
+ den.key_cache[i] = k.to(device).clone()
399
+ den.value_cache[i] = v.to(device).clone()
400
+ den.has_previous_state = True
401
+ return den
402
+
403
+ # ------------------------------------------------------------------
404
+ # Denoiser step (shared by mock-AR and diffusion)
405
+ # ------------------------------------------------------------------
406
+
407
+ def _run_denoiser_step_mock_ar(self, input_ids, cache_state):
408
+ """Mock-AR denoiser: pos=ctx_len-1, KV[:-1], Mamba S-2."""
409
+ ctx_len = cache_state["ctx_len"]
410
+ den_device = next(self.denoiser_tower.parameters()).device
411
+ den_input = input_ids.to(den_device)
412
+ den_cache = self._build_denoiser_cache_mock_ar(cache_state, den_device)
413
+ cp = torch.tensor([ctx_len - 1], device=den_device)
414
+ return self._forward_tower_with_cache(
415
+ self.denoiser_tower, self.lm_head, den_input, den_cache, cp,
416
+ )
417
+
418
+ def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None):
419
+ """Diffusion denoiser: pos=ctx_len..ctx_len+L-1, full KV, Mamba S-1.
420
+
421
+ Processes the block token-by-token so the HF Mamba mixer can use its
422
+ single-step cached path (seq_len=1 with cache_position[0] > 0).
423
+ This is mathematically equivalent to full-block processing since all
424
+ layers are causal, and it properly propagates Mamba states from context.
425
+
426
+ Args:
427
+ block_ids: (B, L) tokens to denoise
428
+ cache_state: context cache state
429
+ t: (B,) timestep in [0,1], or None
430
+
431
+ Returns: logits (B, L, V)
432
+ """
433
+ ctx_len = cache_state["ctx_len"]
434
+ den_device = next(self.denoiser_tower.parameters()).device
435
+ den_input = block_ids.to(den_device)
436
+ L = den_input.shape[1]
437
+
438
+ t_emb = None
439
+ if t is not None:
440
+ t_dev = t.to(device=den_device, dtype=self.dtype)
441
+ t_repr = self.t_embedder(t_dev)
442
+ t_emb = self.t_block(t_repr)
443
+
444
+ den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
445
+
446
+ all_logits = []
447
+ for i in range(L):
448
+ cp = torch.tensor([ctx_len + i], device=den_device)
449
+ logits_i = self._forward_tower_with_cache(
450
+ self.denoiser_tower, self.lm_head, den_input[:, i:i+1],
451
+ den_cache, cp, t_emb=t_emb,
452
+ )
453
+ all_logits.append(logits_i)
454
+
455
+ return torch.cat(all_logits, dim=1)
456
+
457
+ # ------------------------------------------------------------------
458
+ # Mock-AR generation (unchanged)
459
+ # ------------------------------------------------------------------
460
+
461
+ @torch.no_grad()
462
+ def generate_mock_ar(self, input_ids, max_new_tokens=128, temperature=0.0,
463
+ top_k=None, top_p=None, eos_token_id=None):
464
+ """Two-tower mock-AR: S-2/KV[:-1] cache, 1 token/step."""
465
+ B = input_ids.shape[0]
466
+ generated: List[torch.Tensor] = []
467
+ cache_state = self._build_context_cache(input_ids)
468
+
469
+ for step in range(max_new_tokens):
470
+ last_token = input_ids[:, -1:] if step == 0 else generated[-1]
471
+ logits = self._run_denoiser_step_mock_ar(last_token, cache_state)
472
+ logits = logits[:, -1, :].float()
473
+ tok = self._sample_token(logits, temperature, top_k, top_p)
474
+ generated.append(tok)
475
+ if eos_token_id is not None and (tok == eos_token_id).any():
476
+ break
477
+ cache_state = self._extend_context_cache(tok, cache_state)
478
+
479
+ return torch.cat([input_ids] + [g.to(input_ids.device) for g in generated], dim=1)
480
+
481
+ # ------------------------------------------------------------------
482
+ # Mask-Diffusion generation
483
+ # ------------------------------------------------------------------
484
+
485
+ @staticmethod
486
+ def _mdlm_forward(logits, xt, mask_token_id):
487
+ """Constrain logits -> p(x0|xt): mask token gets -inf, decoded tokens
488
+ get delta on their current value."""
489
+ logits = logits.clone()
490
+ logits[..., mask_token_id] = -1e12
491
+ log_probs = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
492
+ # Fix unmasked positions: they must predict themselves with prob 1
493
+ unmasked = (xt != mask_token_id)
494
+ if unmasked.any():
495
+ log_probs[unmasked] = -1e12
496
+ log_probs[unmasked, :].scatter_(-1, xt[unmasked].unsqueeze(-1), 0.0)
497
+ return log_probs
498
+
499
+ @staticmethod
500
+ def _gumbel_sample(log_probs):
501
+ """Gumbel-max sampling from log probabilities."""
502
+ gumbel_noise = -torch.log(-torch.log(
503
+ torch.rand_like(log_probs).clamp(min=1e-10)
504
+ ))
505
+ return (log_probs + gumbel_noise).argmax(dim=-1)
506
+
507
+ @torch.no_grad()
508
+ def generate_mask_diffusion(
509
+ self,
510
+ input_ids,
511
+ max_new_tokens=128,
512
+ block_size=16,
513
+ steps_per_block=16,
514
+ mask_token_id=3,
515
+ temperature=0.0,
516
+ top_k=None,
517
+ confidence_threshold=0.9,
518
+ eos_token_id=None,
519
+ ):
520
+ """Block-wise mask diffusion with confidence_unmasking.
521
+
522
+ Algorithm:
523
+ 1. Build context cache from prompt
524
+ 2. For each block:
525
+ a. Init block_ids = all mask tokens
526
+ b. For each denoising step:
527
+ - Compute t_model = fraction of masked positions
528
+ - Denoiser forward -> logits -> p(x0|xt) via _mdlm_forward
529
+ - Predict tokens (greedy or gumbel)
530
+ - Confidence = p(predicted|xt) from unscaled probs
531
+ - Commit high-confidence predictions, remask low-confidence
532
+ c. Extend context cache with final block
533
+ 3. Return full sequence
534
+
535
+ Args:
536
+ input_ids: (B, S) prompt
537
+ max_new_tokens: total tokens to generate (must be divisible by block_size)
538
+ block_size: tokens per diffusion block
539
+ steps_per_block: denoising iterations per block
540
+ mask_token_id: ID of the [MASK] token
541
+ temperature: 0 = greedy argmax, >0 = gumbel sampling
542
+ top_k: unused currently (kept for API compat)
543
+ confidence_threshold: commit tokens above this confidence
544
+ eos_token_id: stop on EOS
545
+
546
+ Returns: (B, S + generated) full token sequence
547
+ """
548
+ B = input_ids.shape[0]
549
+ device = input_ids.device
550
+ assert max_new_tokens % block_size == 0, \
551
+ f"max_new_tokens ({max_new_tokens}) must be divisible by block_size ({block_size})"
552
+ num_blocks = max_new_tokens // block_size
553
+
554
+ cache_state = self._build_context_cache(input_ids)
555
+ context_ids = input_ids.clone()
556
+
557
+ for block_idx in range(num_blocks):
558
+ # Initialize fully masked block
559
+ xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
560
+ device=device)
561
+
562
+ for step_idx in range(steps_per_block):
563
+ # t_model = current mask fraction
564
+ is_masked = (xt == mask_token_id)
565
+ n_masked = is_masked.float().sum(-1).mean().item()
566
+ if n_masked == 0:
567
+ break
568
+ t_model = is_masked.float().mean()
569
+ t_vec = t_model.expand(B).to(device)
570
+
571
+ # Denoiser forward (logits come back on denoiser device, move to xt's device)
572
+ logits = self._run_denoiser_step_diffusion(xt, cache_state, t=t_vec)
573
+ logits = logits.to(device)
574
+
575
+ # p(x0|xt) with constraints
576
+ log_x_theta = self._mdlm_forward(logits, xt, mask_token_id)
577
+ x_theta = log_x_theta.exp()
578
+
579
+ # Predict: greedy or gumbel
580
+ if temperature <= 0:
581
+ predicted = log_x_theta.argmax(dim=-1)
582
+ else:
583
+ scaled_logits = logits.clone()
584
+ scaled_logits[..., mask_token_id] = -1e12
585
+ scaled_log = scaled_logits / temperature - torch.logsumexp(
586
+ scaled_logits / temperature, dim=-1, keepdim=True)
587
+ unmasked = (xt != mask_token_id)
588
+ if unmasked.any():
589
+ scaled_log[unmasked] = -1e12
590
+ scaled_log[unmasked, :].scatter_(-1, xt[unmasked].unsqueeze(-1), 0.0)
591
+ predicted = self._gumbel_sample(scaled_log)
592
+
593
+ # Confidence from unscaled x_theta
594
+ confidence = x_theta.gather(-1, predicted.unsqueeze(-1)).squeeze(-1)
595
+ confidence[~is_masked] = float('inf')
596
+
597
+ # Determine how many to commit
598
+ is_last_step = (step_idx == steps_per_block - 1)
599
+ n_masked_int = is_masked.sum(-1) # (B,)
600
+
601
+ if is_last_step:
602
+ tokens_to_commit = n_masked_int
603
+ else:
604
+ # Per-batch commitment logic (simplified for B=1 common case)
605
+ remaining_steps = max(1, steps_per_block - step_idx)
606
+ num_above = ((confidence > confidence_threshold) & is_masked).sum(-1)
607
+ tokens_to_commit = torch.where(
608
+ num_above > 0, num_above,
609
+ torch.ones_like(num_above),
610
+ )
611
+ min_commit = (n_masked_int.float() / remaining_steps).ceil().long()
612
+ tokens_to_commit = torch.clamp(
613
+ torch.max(tokens_to_commit, min_commit),
614
+ max=n_masked_int,
615
+ )
616
+
617
+ # Apply predictions then remask low-confidence
618
+ output = torch.where(is_masked, predicted, xt)
619
+ num_to_remask = n_masked_int - tokens_to_commit # (B,)
620
+
621
+ for b in range(B):
622
+ if num_to_remask[b] > 0:
623
+ masked_indices = is_masked[b].nonzero(as_tuple=True)[0]
624
+ masked_conf = confidence[b, masked_indices]
625
+ _, sort_idx = masked_conf.sort()
626
+ remask_idx = masked_indices[sort_idx[:num_to_remask[b]]]
627
+ output[b, remask_idx] = mask_token_id
628
+
629
+ xt = output
630
+
631
+ # Block complete — extend context
632
+ context_ids = torch.cat([context_ids, xt], dim=1)
633
+ cache_state = self._extend_context_cache(xt, cache_state)
634
+
635
+ if eos_token_id is not None and (xt == eos_token_id).any():
636
+ break
637
+
638
+ return context_ids
639
+
640
+ # ------------------------------------------------------------------
641
+ # Sampling helper
642
+ # ------------------------------------------------------------------
643
+
644
+ @staticmethod
645
+ def _sample_token(logits, temperature, top_k, top_p):
646
+ if temperature is None or temperature <= 0:
647
+ return logits.argmax(dim=-1, keepdim=True)
648
+ probs = F.softmax(logits / temperature, dim=-1)
649
+ if top_k is not None and top_k > 0:
650
+ kth = torch.topk(probs, min(top_k, probs.size(-1)), dim=-1).values[..., -1:]
651
+ probs = torch.where(probs >= kth, probs, torch.zeros_like(probs))
652
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-12)
653
+ if top_p is not None and 0.0 < top_p < 1.0:
654
+ sorted_p, idx = torch.sort(probs, descending=True, dim=-1)
655
+ cum = sorted_p.cumsum(dim=-1)
656
+ remove = torch.cat(
657
+ [torch.zeros_like(cum[..., :1]), (cum > top_p)[..., :-1]], dim=-1,
658
+ )
659
+ sorted_p = sorted_p.masked_fill(remove.bool(), 0.0)
660
+ probs = torch.zeros_like(probs).scatter_(-1, idx, sorted_p)
661
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-12)
662
+ return torch.multinomial(probs, num_samples=1)
663
+
664
+ # ------------------------------------------------------------------
665
+ # Multi-GPU placement
666
+ # ------------------------------------------------------------------
667
+
668
+ def place_towers_on_devices(self, ctx_device="cuda:0", den_device="cuda:1"):
669
+ """Manual tower placement. Time conditioning goes with denoiser."""
670
+ self.context_tower = self.context_tower.to(ctx_device)
671
+ self.context_lm_head = self.context_lm_head.to(ctx_device)
672
+ self.denoiser_tower = self.denoiser_tower.to(den_device)
673
+ self.lm_head = self.lm_head.to(den_device)
674
+ self.t_embedder = self.t_embedder.to(den_device)
675
+ self.t_block = self.t_block.to(den_device)
676
+ self.scale_shift_tables = nn.ParameterList([
677
+ nn.Parameter(p.to(den_device)) for p in self.scale_shift_tables
678
+ ])
679
+ return self
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:623c34567aebb18582765289fbe23d901c62704d6518d71866e0e58db892b5b7
3
+ size 17077484
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff