Text Generation
Transformers
Safetensors
PyTorch
nvidia
two-tower
diffusion
mamba
fitsumreda commited on
Commit
a203471
·
verified ·
1 Parent(s): 8d7e74f

Two-tower mask diffusion: fix denoiser (adaLN norm order, bidirectional in-block attention, block-wise chunk-scan Mamba) + fp64 router; refresh README

Browse files
Files changed (5) hide show
  1. README.md +41 -21
  2. config.json +1 -1
  3. inference.py +76 -5
  4. modeling_nemotron_h.py +4 -2
  5. modeling_nemotron_twotower.py +242 -29
README.md CHANGED
@@ -77,19 +77,31 @@ Both towers share the same architecture (52 layers, `MEMEM*EMEMEM*...` hybrid pa
77
 
78
  ### Two-Tower Generation Modes
79
 
80
- | Mode | Description | Tokens/step |
81
- |------|-------------|-------------|
82
- | **AR** | Standard autoregressive via `generate()`. Uses context tower only. | 1 |
83
- | **Mock-AR** | Two-tower autoregressive. Context tower builds cache, denoiser predicts next token. | 1 |
84
- | **Mask Diffusion** | Block-wise iterative denoising with confidence-based unmasking. *(Coming soon)* | block_size |
85
 
86
  ### What is Two-Tower?
87
 
88
- The two-tower architecture decouples the "understanding context" and "generating tokens" responsibilities into separate networks. This enables:
89
 
90
- 1. **Block-wise parallel generation** the denoiser can generate multiple tokens simultaneously via iterative diffusion
91
- 2. **Architectural flexibility** context and denoiser can be optimized independently
92
- 3. **Speculative decoding** — the denoiser can be a smaller/faster model
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  This model is ready for commercial use.
95
 
@@ -142,7 +154,7 @@ Software used for training: [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
142
 
143
  ### Use it with Transformers
144
 
145
- The snippet below shows how to use this model with HuggingFace Transformers. **Two-tower inference requires 2 GPUs** (~59GB per GPU for bf16 weights).
146
 
147
  ```python
148
  import torch
@@ -156,39 +168,47 @@ model = AutoModelForCausalLM.from_pretrained(
156
  trust_remote_code=True,
157
  )
158
 
159
- # Place context tower on GPU 0, denoiser tower on GPU 1
160
  model.place_towers_on_devices("cuda:0", "cuda:1")
161
  model.eval()
162
 
163
  prompt = "France is a country "
164
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
165
 
166
- # Two-tower mock-AR generation
167
- outputs = model.generate_mock_ar(
168
  inputs["input_ids"],
169
  max_new_tokens=128,
170
- temperature=0.0,
 
 
 
 
171
  eos_token_id=tokenizer.eos_token_id,
172
  )
 
 
173
 
 
 
 
 
 
 
 
174
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
175
  ```
176
 
177
- For **AR-only mode** (single GPU, context tower only):
178
 
179
  ```python
180
- model = AutoModelForCausalLM.from_pretrained(
181
- model_name,
182
- torch_dtype=torch.bfloat16,
183
- trust_remote_code=True,
184
- ).cuda()
185
-
186
  outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
187
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
188
  ```
189
 
190
  ## Model Version(s)
191
 
 
192
  - v1.0 — Two-tower AR (mock-AR) checkpoint
193
 
194
  # Training, Testing, and Evaluation Datasets
 
77
 
78
  ### Two-Tower Generation Modes
79
 
80
+ | Mode | Description | Tokens/step | API |
81
+ |------|-------------|-------------|-----|
82
+ | **Mask Diffusion** | Block-wise iterative denoising with confidence-based unmasking (flagship two-tower mode). | up to `block_size` | `generate_mask_diffusion()` |
83
+ | **Mock-AR** | Two-tower autoregressive. Context tower builds cache, denoiser predicts next token. | 1 | `generate_mock_ar()` |
84
+ | **AR** | Standard autoregressive via `generate()`. Uses context tower only (single GPU). | 1 | `generate()` |
85
 
86
  ### What is Two-Tower?
87
 
88
+ The two-tower architecture decouples "understanding context" from "generating tokens" into separate networks:
89
 
90
+ - **Context Tower** runs causally over the prompt and all previously committed tokens, producing the layer-aligned KV cache (attention) and Mamba states that the denoiser conditions on.
91
+ - **Denoiser Tower** generates a *block* of tokens at once. Within a block it is **bidirectional** (every position attends to the whole noisy block + the full causal context); across blocks it is causal via the context cache.
92
+
93
+ This enables **block-wise parallel generation** — the denoiser fills `block_size` masked positions per block and commits the most confident ones each step, so a block resolves in a handful of denoising steps rather than `block_size` autoregressive steps.
94
+
95
+ ### Mask Diffusion: how it works
96
+
97
+ Generation proceeds block by block. For each new block of `block_size` positions:
98
+
99
+ 1. Initialize the block as all `[MASK]` tokens (`mask_token_id`).
100
+ 2. For `steps_per_block` iterations:
101
+ - Compute the diffusion timestep `t` = current masked fraction of the block, and feed it to the **time-conditioned denoiser** (PixArt-α adaLN-single modulation on every denoiser layer).
102
+ - Run the denoiser over the whole block (bidirectional self-attention + cross-attention to the context cache; Mamba chunk-scan seeded from the context state).
103
+ - Constrain to `p(x₀ | xₜ)` (mask token forbidden; already-decoded positions fixed), then **commit** the highest-confidence positions (all above `confidence_threshold`, with a floor that guarantees completion in `steps_per_block`) and re-mask the rest.
104
+ 3. Append the resolved block to the context, extend the context cache, and continue.
105
 
106
  This model is ready for commercial use.
107
 
 
154
 
155
  ### Use it with Transformers
156
 
157
+ The snippet below shows how to use this model with HuggingFace Transformers. **Two-tower inference requires 2 GPUs** (~59GB per GPU for bf16 weights); the towers are placed on separate devices.
158
 
159
  ```python
160
  import torch
 
168
  trust_remote_code=True,
169
  )
170
 
171
+ # Context tower -> GPU 0, denoiser tower -> GPU 1
172
  model.place_towers_on_devices("cuda:0", "cuda:1")
173
  model.eval()
174
 
175
  prompt = "France is a country "
176
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
177
 
178
+ # Flagship mode: block-wise mask diffusion
179
+ outputs = model.generate_mask_diffusion(
180
  inputs["input_ids"],
181
  max_new_tokens=128,
182
+ block_size=16, # tokens generated per block
183
+ steps_per_block=16, # denoising iterations per block
184
+ mask_token_id=3, # <mask>
185
+ temperature=0.1,
186
+ confidence_threshold=0.8, # commit positions above this confidence each step
187
  eos_token_id=tokenizer.eos_token_id,
188
  )
189
+ print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
190
+ ```
191
 
192
+ **Mock-AR** (two-tower, one token per step):
193
+
194
+ ```python
195
+ outputs = model.generate_mock_ar(
196
+ inputs["input_ids"], max_new_tokens=128, temperature=0.0,
197
+ eos_token_id=tokenizer.eos_token_id,
198
+ )
199
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
200
  ```
201
 
202
+ **AR-only** (single GPU, context tower only — load with `.cuda()` instead of `place_towers_on_devices`):
203
 
204
  ```python
 
 
 
 
 
 
205
  outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
206
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
207
  ```
208
 
209
  ## Model Version(s)
210
 
211
+ - v1.1 — Block-wise **mask-diffusion** generation enabled (time-conditioned denoiser, bidirectional in-block attention, chunk-scan Mamba); AR and mock-AR also supported.
212
  - v1.0 — Two-tower AR (mock-AR) checkpoint
213
 
214
  # Training, Testing, and Evaluation Datasets
config.json CHANGED
@@ -54,7 +54,7 @@
54
  "time_step_floor": 0.0001,
55
  "time_step_limit": [
56
  0.0,
57
- "Infinity"
58
  ],
59
  "time_step_max": 0.1,
60
  "time_step_min": 0.001,
 
54
  "time_step_floor": 0.0001,
55
  "time_step_limit": [
56
  0.0,
57
+ Infinity
58
  ],
59
  "time_step_max": 0.1,
60
  "time_step_min": 0.001,
inference.py CHANGED
@@ -11,19 +11,46 @@ Usage:
11
 
12
  # AR (context tower only, 1 GPU):
13
  python inference.py --mode ar
 
 
 
14
  """
15
  import argparse
 
16
  import torch
 
 
17
  from pathlib import Path
18
  from transformers import AutoTokenizer
19
  from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
20
 
21
  parser = argparse.ArgumentParser()
22
- parser.add_argument("--prompt", default="France is a country ")
 
23
  parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
24
  parser.add_argument("--max-new-tokens", type=int, default=128)
25
- parser.add_argument("--mode", choices=["ar", "mock_ar"], default="mock_ar")
 
 
 
 
 
 
 
 
 
 
 
26
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(args.model)
29
  model = NemotronHTwoTowerForCausalLM.from_pretrained(
@@ -31,23 +58,67 @@ model = NemotronHTwoTowerForCausalLM.from_pretrained(
31
  )
32
 
33
  num_gpus = torch.cuda.device_count()
34
- if args.mode == "mock_ar" and num_gpus >= 2:
 
 
35
  model.place_towers_on_devices("cuda:0", "cuda:1")
 
 
 
 
 
36
  else:
37
  model.cuda()
38
 
39
  model.eval()
40
- inputs = tokenizer(args.prompt, return_tensors="pt").to(
 
 
41
  next(model.context_tower.parameters()).device
42
  )
43
 
44
  if args.mode == "ar":
45
  outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
46
- else:
47
  outputs = model.generate_mock_ar(
48
  inputs["input_ids"], max_new_tokens=args.max_new_tokens,
49
  temperature=0.0, eos_token_id=tokenizer.eos_token_id,
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
53
  print(text)
 
11
 
12
  # AR (context tower only, 1 GPU):
13
  python inference.py --mode ar
14
+
15
+ # Mask diffusion (two-tower, 2 GPUs):
16
+ python inference.py --mode mask_diffusion --model /path/to/diffusion_hf_out
17
  """
18
  import argparse
19
+ import inspect
20
  import torch
21
+ import random
22
+ import numpy as np
23
  from pathlib import Path
24
  from transformers import AutoTokenizer
25
  from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
26
 
27
  parser = argparse.ArgumentParser()
28
+ parser.add_argument("prompt_arg", nargs="?", default=None)
29
+ parser.add_argument("--prompt", default=None)
30
  parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
31
  parser.add_argument("--max-new-tokens", type=int, default=128)
32
+ parser.add_argument("--mode", choices=["ar", "mock_ar", "mask_diffusion"], default="mock_ar")
33
+ parser.add_argument("--block-size", type=int, default=16)
34
+ parser.add_argument("--steps-per-block", type=int, default=16)
35
+ parser.add_argument("--mask-token-id", type=int, default=3)
36
+ parser.add_argument("--temperature", type=float, default=0.0)
37
+ parser.add_argument("--top-k", "--top_k", dest="top_k", type=int, default=None)
38
+ parser.add_argument("--confidence-threshold", type=float, default=0.9)
39
+ parser.add_argument("--deterministic", action="store_true")
40
+ parser.add_argument("--seed", type=int, default=42)
41
+ parser.add_argument("--print-diffusion-steps", action="store_true")
42
+ parser.add_argument("--trace-context-layers", action="store_true")
43
+ parser.add_argument("--trace-denoiser-layers", action="store_true")
44
  args = parser.parse_args()
45
+ prompt = args.prompt if args.prompt is not None else (args.prompt_arg or "France is a country ")
46
+
47
+ if args.deterministic:
48
+ random.seed(args.seed)
49
+ np.random.seed(args.seed)
50
+ torch.manual_seed(args.seed)
51
+ torch.cuda.manual_seed_all(args.seed)
52
+ torch.backends.cudnn.deterministic = True
53
+ torch.backends.cudnn.benchmark = False
54
 
55
  tokenizer = AutoTokenizer.from_pretrained(args.model)
56
  model = NemotronHTwoTowerForCausalLM.from_pretrained(
 
58
  )
59
 
60
  num_gpus = torch.cuda.device_count()
61
+ if num_gpus >= 2:
62
+ # Split towers across GPUs (both towers don't fit on one 80GB card).
63
+ # AR mode only uses the context tower (cuda:0), but placing both is fine.
64
  model.place_towers_on_devices("cuda:0", "cuda:1")
65
+ elif args.mode == "ar":
66
+ # AR uses only the context tower + context head; keep the denoiser tower
67
+ # off the GPU so a single card suffices.
68
+ model.context_tower = model.context_tower.cuda()
69
+ model.context_lm_head = model.context_lm_head.cuda()
70
  else:
71
  model.cuda()
72
 
73
  model.eval()
74
+ model.trace_context_layers = args.trace_context_layers
75
+ model.trace_denoiser_layers = args.trace_denoiser_layers
76
+ inputs = tokenizer(prompt, return_tensors="pt").to(
77
  next(model.context_tower.parameters()).device
78
  )
79
 
80
  if args.mode == "ar":
81
  outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
82
+ elif args.mode == "mock_ar":
83
  outputs = model.generate_mock_ar(
84
  inputs["input_ids"], max_new_tokens=args.max_new_tokens,
85
  temperature=0.0, eos_token_id=tokenizer.eos_token_id,
86
  )
87
+ else:
88
+ def step_callback(step_idx, total_steps, tokens, t=None, logits=None, block_idx=0):
89
+ if not args.print_diffusion_steps:
90
+ return
91
+ if logits is None:
92
+ print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | init ---")
93
+ print("xt:", tokenizer.decode(tokens[0], skip_special_tokens=False))
94
+ return
95
+ log_x = model._mdlm_forward(logits, tokens.to(logits.device), args.mask_token_id)
96
+ probs = log_x.exp()[0]
97
+ top2_probs, top2_ids = probs.topk(2, dim=-1)
98
+ n_masked = int((tokens == args.mask_token_id).sum().item())
99
+ print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | masked={n_masked}/{tokens.shape[1]} | t={t:.4f} ---")
100
+ print("xt: " + repr(tokenizer.decode(tokens[0], skip_special_tokens=False)))
101
+ print("top1: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 0]))
102
+ print("prb1: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 0]))
103
+ print("top2: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 1]))
104
+ print("prb2: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 1]))
105
+
106
+ generate_kwargs = dict(
107
+ max_new_tokens=args.max_new_tokens,
108
+ block_size=args.block_size,
109
+ steps_per_block=args.steps_per_block,
110
+ mask_token_id=args.mask_token_id,
111
+ temperature=args.temperature,
112
+ top_k=args.top_k,
113
+ confidence_threshold=args.confidence_threshold,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ )
116
+ if (
117
+ args.print_diffusion_steps
118
+ and "step_callback" in inspect.signature(model.generate_mask_diffusion).parameters
119
+ ):
120
+ generate_kwargs["step_callback"] = step_callback
121
+ outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
122
 
123
  text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
124
  print(text)
modeling_nemotron_h.py CHANGED
@@ -910,7 +910,9 @@ class NemotronHTopkRouter(nn.Module):
910
 
911
  def forward(self, hidden_states):
912
  hidden_states = hidden_states.view(-1, self.config.hidden_size)
913
- router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
 
 
914
  scores = router_logits.sigmoid()
915
  topk_indices = self.get_topk_indices(scores)
916
  topk_weights = scores.gather(1, topk_indices)
@@ -918,7 +920,7 @@ class NemotronHTopkRouter(nn.Module):
918
  denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
919
  topk_weights /= denominator
920
  topk_weights = topk_weights * self.routed_scaling_factor
921
- return topk_indices, topk_weights
922
 
923
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
924
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
910
 
911
  def forward(self, hidden_states):
912
  hidden_states = hidden_states.view(-1, self.config.hidden_size)
913
+ # mcore runs the MoE router in fp64 (--moe-router-dtype fp64); match it so
914
+ # top-k expert selection is bit-identical at borderline scores.
915
+ router_logits = F.linear(hidden_states.type(torch.float64), self.weight.type(torch.float64))
916
  scores = router_logits.sigmoid()
917
  topk_indices = self.get_topk_indices(scores)
918
  topk_weights = scores.gather(1, topk_indices)
 
920
  denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
921
  topk_weights /= denominator
922
  topk_weights = topk_weights * self.routed_scaling_factor
923
+ return topk_indices, topk_weights.type(torch.float32)
924
 
925
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
926
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
modeling_nemotron_twotower.py CHANGED
@@ -30,6 +30,7 @@ try:
30
  NemotronHForCausalLM,
31
  NemotronHModel,
32
  NemotronHPreTrainedModel,
 
33
  )
34
  from .configuration_nemotron_h import NemotronHConfig
35
  except ImportError:
@@ -39,6 +40,7 @@ except ImportError:
39
  NemotronHForCausalLM,
40
  NemotronHModel,
41
  NemotronHPreTrainedModel,
 
42
  )
43
  from configuration_nemotron_h import NemotronHConfig
44
 
@@ -180,8 +182,11 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
180
  elif input_ids.shape[1] != cache_position.shape[0]:
181
  input_ids = input_ids[:, cache_position]
182
  else:
183
- past_key_values = HybridMambaAttentionDynamicCache(
184
- self.config, input_ids.shape[0], self.dtype, device=self.device
 
 
 
185
  )
186
  if attention_mask is not None and position_ids is None:
187
  position_ids = attention_mask.long().cumsum(-1) - 1
@@ -330,18 +335,20 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
330
  return {"ctx_cache": cache_p2, "mamba_s2": mamba_s2, "ctx_len": S}
331
 
332
  def _extend_context_cache(self, new_tokens, cache_state):
333
- """Extend context cache by new_tokens (B, L). Old S-1 -> new S-2.
334
 
335
- Processes tokens one at a time so HF Mamba can use its single-step
336
- cached path (seq_len=1, cache_position[0] > 0).
337
  """
338
  ctx_cache = cache_state["ctx_cache"]
339
  pattern = self.config.hybrid_override_pattern
340
  ctx_len = cache_state["ctx_len"]
341
- ctx_device = next(self.context_tower.parameters()).device
 
342
  L = new_tokens.shape[1]
343
- tokens_on_device = new_tokens.to(ctx_device)
344
 
 
345
  new_s2 = {}
346
  for i in range(self.config.num_hidden_layers):
347
  if pattern[i] == "M":
@@ -350,12 +357,37 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
350
  cache_state["mamba_s2"] = new_s2
351
 
352
  ctx_cache.has_previous_state = True
353
- for j in range(L):
354
- cp = torch.tensor([ctx_len + j], device=ctx_device)
355
- self._forward_tower_with_cache(
356
- self.context_tower, self.context_lm_head,
357
- tokens_on_device[:, j:j+1], ctx_cache, cp,
358
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  cache_state["ctx_len"] = ctx_len + L
361
  return cache_state
@@ -415,13 +447,139 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
415
  self.denoiser_tower, self.lm_head, den_input, den_cache, cp,
416
  )
417
 
418
- def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None):
419
- """Diffusion denoiser: pos=ctx_len..ctx_len+L-1, full KV, Mamba S-1.
420
 
421
- Processes the block token-by-token so the HF Mamba mixer can use its
422
- single-step cached path (seq_len=1 with cache_position[0] > 0).
423
- This is mathematically equivalent to full-block processing since all
424
- layers are causal, and it properly propagates Mamba states from context.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  Args:
427
  block_ids: (B, L) tokens to denoise
@@ -431,28 +589,74 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
431
  Returns: logits (B, L, V)
432
  """
433
  ctx_len = cache_state["ctx_len"]
434
- den_device = next(self.denoiser_tower.parameters()).device
 
435
  den_input = block_ids.to(den_device)
436
  L = den_input.shape[1]
437
 
 
438
  t_emb = None
439
  if t is not None:
440
  t_dev = t.to(device=den_device, dtype=self.dtype)
441
  t_repr = self.t_embedder(t_dev)
442
  t_emb = self.t_block(t_repr)
443
 
 
444
  den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
445
 
446
- all_logits = []
447
- for i in range(L):
448
- cp = torch.tensor([ctx_len + i], device=den_device)
449
- logits_i = self._forward_tower_with_cache(
450
- self.denoiser_tower, self.lm_head, den_input[:, i:i+1],
451
- den_cache, cp, t_emb=t_emb,
452
- )
453
- all_logits.append(logits_i)
454
 
455
- return torch.cat(all_logits, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  # ------------------------------------------------------------------
458
  # Mock-AR generation (unchanged)
@@ -516,6 +720,7 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
516
  top_k=None,
517
  confidence_threshold=0.9,
518
  eos_token_id=None,
 
519
  ):
520
  """Block-wise mask diffusion with confidence_unmasking.
521
 
@@ -558,6 +763,9 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
558
  # Initialize fully masked block
559
  xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
560
  device=device)
 
 
 
561
 
562
  for step_idx in range(steps_per_block):
563
  # t_model = current mask fraction
@@ -626,6 +834,11 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
626
  remask_idx = masked_indices[sort_idx[:num_to_remask[b]]]
627
  output[b, remask_idx] = mask_token_id
628
 
 
 
 
 
 
629
  xt = output
630
 
631
  # Block complete — extend context
 
30
  NemotronHForCausalLM,
31
  NemotronHModel,
32
  NemotronHPreTrainedModel,
33
+ repeat_kv,
34
  )
35
  from .configuration_nemotron_h import NemotronHConfig
36
  except ImportError:
 
40
  NemotronHForCausalLM,
41
  NemotronHModel,
42
  NemotronHPreTrainedModel,
43
+ repeat_kv,
44
  )
45
  from configuration_nemotron_h import NemotronHConfig
46
 
 
182
  elif input_ids.shape[1] != cache_position.shape[0]:
183
  input_ids = input_ids[:, cache_position]
184
  else:
185
+ # FixedHybridCache (not the base class) so the Mamba mixer finds
186
+ # conv_kernel_size during the cached forward (needed for AR generate).
187
+ past_key_values = FixedHybridCache(
188
+ self.config, input_ids.shape[0], self.dtype,
189
+ device=next(self.context_tower.parameters()).device,
190
  )
191
  if attention_mask is not None and position_ids is None:
192
  position_ids = attention_mask.long().cumsum(-1) - 1
 
335
  return {"ctx_cache": cache_p2, "mamba_s2": mamba_s2, "ctx_len": S}
336
 
337
  def _extend_context_cache(self, new_tokens, cache_state):
338
+ """Extend context cache by new_tokens (B, L), block-wise (matches mcore).
339
 
340
+ Mamba layers advance via the block chunk-scan from the current state;
341
+ attention layers append the block KV (causal within block); MoE is plain.
342
  """
343
  ctx_cache = cache_state["ctx_cache"]
344
  pattern = self.config.hybrid_override_pattern
345
  ctx_len = cache_state["ctx_len"]
346
+ tower = self.context_tower
347
+ ctx_device = next(tower.parameters()).device
348
  L = new_tokens.shape[1]
349
+ tokens = new_tokens.to(ctx_device)
350
 
351
+ # Snapshot pre-extension Mamba states as the new S-2 (used by mock-AR).
352
  new_s2 = {}
353
  for i in range(self.config.num_hidden_layers):
354
  if pattern[i] == "M":
 
357
  cache_state["mamba_s2"] = new_s2
358
 
359
  ctx_cache.has_previous_state = True
360
+ cache_position = torch.arange(ctx_len, ctx_len + L, device=ctx_device)
361
+ hidden = tower.embeddings(tokens)
362
+ causal_mask = tower._update_causal_mask(None, hidden, cache_position)
363
+
364
+ for layer_idx, block in enumerate(tower.layers):
365
+ residual = hidden
366
+ h = block.norm(hidden.to(dtype=block.norm.weight.dtype))
367
+ if block.residual_in_fp32:
368
+ residual = residual.to(torch.float32)
369
+
370
+ if block.block_type == "mamba":
371
+ d_conv = block.mixer.conv_kernel_size
372
+ init_conv = ctx_cache.conv_states[layer_idx][..., -(d_conv - 1):]
373
+ init_ssm = ctx_cache.ssm_states[layer_idx].contiguous()
374
+ h, new_conv, new_ssm = self._denoiser_block_mamba(
375
+ block.mixer, h, init_conv, init_ssm, return_states=True,
376
+ )
377
+ ctx_cache.conv_states[layer_idx] = new_conv
378
+ ctx_cache.ssm_states[layer_idx] = new_ssm
379
+ elif block.block_type == "attention":
380
+ # Standard cached attention appends block KV (causal within block).
381
+ h, _, _ = block.mixer(
382
+ h, attention_mask=causal_mask,
383
+ past_key_value=ctx_cache, cache_position=cache_position,
384
+ )
385
+ elif block.block_type in ["mlp", "moe"]:
386
+ h = block.mixer(h)
387
+ else:
388
+ raise ValueError(f"Unknown block_type: {block.block_type}")
389
+
390
+ hidden = residual + h
391
 
392
  cache_state["ctx_len"] = ctx_len + L
393
  return cache_state
 
447
  self.denoiser_tower, self.lm_head, den_input, den_cache, cp,
448
  )
449
 
450
+ def _denoiser_block_attention(self, mixer, hidden, ctx_k, ctx_v):
451
+ """Bidirectional denoiser self-attention over [context_KV | block_KV].
452
 
453
+ Mirrors the mcore `_forward_attn_with_past` (is_causal=False, no mask):
454
+ every block position attends to ALL context positions and ALL block
455
+ positions (the noisy block is processed bidirectionally within itself).
456
+
457
+ Args:
458
+ mixer: NemotronHAttention module (provides q/k/v/o projections)
459
+ hidden: (B, L, D) post-norm (and post-modulation) block hidden states
460
+ ctx_k, ctx_v: context KV, each (B, num_kv_heads, ctx_len, head_dim)
461
+
462
+ Returns: (B, L, D) attention output (before residual add)
463
+ """
464
+ bsz, q_len, _ = hidden.shape
465
+ q = mixer.q_proj(hidden).view(bsz, q_len, mixer.num_heads, mixer.head_dim).transpose(1, 2)
466
+ k = mixer.k_proj(hidden).view(bsz, q_len, mixer.num_key_value_heads, mixer.head_dim).transpose(1, 2)
467
+ v = mixer.v_proj(hidden).view(bsz, q_len, mixer.num_key_value_heads, mixer.head_dim).transpose(1, 2)
468
+
469
+ # Concatenate context KV (past) with current block KV on the sequence dim.
470
+ k = torch.cat([ctx_k.to(k.dtype), k], dim=2)
471
+ v = torch.cat([ctx_v.to(v.dtype), v], dim=2)
472
+
473
+ # GQA: expand KV heads to match query heads.
474
+ k = repeat_kv(k, mixer.num_key_value_groups)
475
+ v = repeat_kv(v, mixer.num_key_value_groups)
476
+
477
+ # Full (non-causal) attention: block sees all context + whole block.
478
+ attn_output = F.scaled_dot_product_attention(
479
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False,
480
+ )
481
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
482
+ bsz, q_len, mixer.num_heads * mixer.head_dim
483
+ )
484
+ return mixer.o_proj(attn_output)
485
+
486
+ def _denoiser_block_mamba(self, mixer, hidden, init_conv, init_ssm, return_states=False):
487
+ """Chunk-scan the whole block through the Mamba mixer, seeded from the
488
+ context state — mirrors mcore `forward_mamba_layer_with_states`
489
+ (non-bidirectional). Uses the same mamba_ssm/causal_conv1d kernels as
490
+ mcore, instead of HF's token-by-token single-step path (which is both a
491
+ numerical mismatch and crashes in this env's causal_conv1d_update).
492
+
493
+ Args:
494
+ mixer: NemotronHMamba2Mixer
495
+ hidden: (B, L, D) post-norm (and post-modulation) block hidden states
496
+ init_conv: (B, conv_dim, d_conv-1) context conv state, or None
497
+ init_ssm: (B, nheads, headdim, d_state) context SSM state, or None
498
+ return_states: also return the updated (conv_state[width d_conv], ssm_state)
499
+ so the caller can advance a KV/Mamba cache (used by context extend).
500
+
501
+ Returns: (B, L, D) mixer output (before adaLN gate / residual);
502
+ or (output, new_conv_state, new_ssm_state) if return_states.
503
+ """
504
+ from einops import rearrange
505
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
506
+ from causal_conv1d import causal_conv1d_fn
507
+
508
+ d_inner = mixer.intermediate_size
509
+ ngroups = mixer.n_groups
510
+ d_state = mixer.ssm_state_size
511
+ headdim = mixer.head_dim
512
+ conv_dim = mixer.conv_dim
513
+ d_conv = mixer.conv_kernel_size
514
+
515
+ proj = mixer.in_proj(hidden) # (B, L, d_inner+conv_dim+nheads)
516
+ z, xBC, dt = torch.split(proj, [d_inner, conv_dim, mixer.num_heads], dim=-1)
517
+
518
+ # causal_conv1d_fn with initial_states requires channel-last layout:
519
+ # - input (B, conv_dim, L): use the transpose VIEW (stride(1)==1), no .contiguous()
520
+ # - initial_states (B, conv_dim, d_conv-1): force channel-last via the
521
+ # transpose->contiguous->transpose trick (mcore _run_denoiser_step).
522
+ if init_conv is not None:
523
+ init_conv = init_conv.transpose(-1, -2).contiguous().transpose(-1, -2)
524
+ xBC_conv = causal_conv1d_fn(
525
+ xBC.transpose(1, 2), # (B, conv_dim, L) channel-last view
526
+ mixer.conv1d.weight.squeeze(1),
527
+ mixer.conv1d.bias,
528
+ activation=mixer.activation,
529
+ initial_states=init_conv,
530
+ ).transpose(1, 2) # (B, L, conv_dim)
531
+
532
+ x, B_proj, C_proj = torch.split(
533
+ xBC_conv, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1
534
+ )
535
+ x = rearrange(x, "b s (h p) -> b s h p", p=headdim).contiguous()
536
+ B_proj = rearrange(B_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
537
+ C_proj = rearrange(C_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
538
+
539
+ A = -torch.exp(mixer.A_log.float())
540
+ scan = mamba_chunk_scan_combined(
541
+ x, dt.contiguous(), A, B_proj, C_proj, mixer.chunk_size,
542
+ D=mixer.D, z=None,
543
+ dt_bias=mixer.dt_bias.float(), dt_softplus=True,
544
+ initial_states=init_ssm,
545
+ return_final_states=return_states,
546
+ )
547
+ if return_states:
548
+ y, new_ssm = scan
549
+ else:
550
+ y = scan
551
+ y = rearrange(y, "b s h p -> b s (h p)")
552
+ y = mixer.norm(y, z) # Mamba2 z-gated RMSNorm
553
+ out = mixer.out_proj(y)
554
+ if not return_states:
555
+ return out
556
+ # New conv state: HF cache stores the last d_conv raw xBC inputs (width
557
+ # d_conv), most-recent at index -1. block_size >= d_conv here.
558
+ L = xBC.shape[1]
559
+ if L >= d_conv:
560
+ new_conv = xBC[:, -d_conv:, :].transpose(1, 2).contiguous()
561
+ else:
562
+ hist = init_conv if init_conv is not None else xBC.new_zeros(xBC.shape[0], conv_dim, d_conv - 1)
563
+ comb = torch.cat([hist.transpose(1, 2), xBC], dim=1)
564
+ new_conv = comb[:, -d_conv:, :].transpose(1, 2).contiguous()
565
+ return out, new_conv, new_ssm
566
+
567
+ def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None):
568
+ """Diffusion denoiser forward over the FULL block (B, L) in one pass.
569
+
570
+ Parity with mcore `_run_denoiser_step`:
571
+ - Attention layers run BIDIRECTIONALLY within the block, attending to
572
+ the full context KV cache + the whole noisy block (is_causal=False).
573
+ A token-by-token causal pass would hide later block positions from
574
+ earlier ones.
575
+ - Mamba layers are causal/forward-only (bidirectional_mamba=False) and
576
+ are chunk-scanned over the whole block from the context state (S-1),
577
+ matching mcore's `forward_mamba_layer_with_states`.
578
+ - Time conditioning (adaLN-single) is applied per layer. The modulate/norm
579
+ ORDER depends on where mcore's norm lives: mamba & attention norms are
580
+ FUSED into in_proj/linear_qkv (applied AFTER modulate) -> modulate THEN
581
+ norm; MoE uses a separate pre_mlp_layernorm -> norm THEN modulate.
582
+ Gate is applied to the mixer output in all cases.
583
 
584
  Args:
585
  block_ids: (B, L) tokens to denoise
 
589
  Returns: logits (B, L, V)
590
  """
591
  ctx_len = cache_state["ctx_len"]
592
+ tower = self.denoiser_tower
593
+ den_device = next(tower.parameters()).device
594
  den_input = block_ids.to(den_device)
595
  L = den_input.shape[1]
596
 
597
+ # Time embedding -> per-layer modulation params (shift, scale, gate).
598
  t_emb = None
599
  if t is not None:
600
  t_dev = t.to(device=den_device, dtype=self.dtype)
601
  t_repr = self.t_embedder(t_dev)
602
  t_emb = self.t_block(t_repr)
603
 
604
+ # Fresh denoiser cache seeded from context: Mamba S-1 state + full context KV.
605
  den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
606
 
607
+ hidden = tower.embeddings(den_input)
 
 
 
 
 
 
 
608
 
609
+ for layer_idx, block in enumerate(tower.layers):
610
+ residual = hidden
611
+ if block.residual_in_fp32:
612
+ residual = residual.to(torch.float32)
613
+
614
+ mod = None
615
+ if t_emb is not None:
616
+ mod = _get_mod_params(t_emb, self.scale_shift_tables[layer_idx])
617
+ shift, scale, gate = mod
618
+
619
+ # adaLN modulate vs norm ORDER depends on where mcore's norm lives:
620
+ # - mamba/attention: norm is FUSED into in_proj/linear_qkv and is
621
+ # applied AFTER the explicit modulate -> modulate THEN norm.
622
+ # - moe/mlp: separate pre_mlp_layernorm applied BEFORE modulate
623
+ # -> norm THEN modulate.
624
+ if block.block_type in ("mamba", "attention"):
625
+ h = hidden
626
+ if mod is not None:
627
+ h = _modulate(h, shift, scale)
628
+ h = block.norm(h.to(dtype=block.norm.weight.dtype))
629
+ else: # mlp / moe
630
+ h = block.norm(hidden.to(dtype=block.norm.weight.dtype))
631
+ if mod is not None:
632
+ h = _modulate(h, shift, scale)
633
+
634
+ if block.block_type == "mamba":
635
+ # Chunk-scan the whole block in one kernel launch, seeded from the
636
+ # context Mamba state (matches mcore forward_mamba_layer_with_states).
637
+ # HF conv_states are width d_conv; causal_conv1d_fn's initial_states
638
+ # wants the d_conv-1 most-recent columns.
639
+ d_conv = block.mixer.conv_kernel_size
640
+ init_conv = den_cache.conv_states[layer_idx][..., -(d_conv - 1):]
641
+ init_ssm = den_cache.ssm_states[layer_idx].contiguous()
642
+ h = self._denoiser_block_mamba(block.mixer, h, init_conv, init_ssm)
643
+ elif block.block_type == "attention":
644
+ ctx_k = den_cache.key_cache[layer_idx]
645
+ ctx_v = den_cache.value_cache[layer_idx]
646
+ h = self._denoiser_block_attention(block.mixer, h, ctx_k, ctx_v)
647
+ elif block.block_type in ["mlp", "moe"]:
648
+ h = block.mixer(h)
649
+ else:
650
+ raise ValueError(f"Unknown block_type: {block.block_type}")
651
+
652
+ if mod is not None:
653
+ h = gate.unsqueeze(1) * h
654
+
655
+ hidden = residual + h
656
+
657
+ hidden = tower.norm_f(hidden)
658
+ logits = self.lm_head(hidden.to(self.lm_head.weight.dtype)).float()
659
+ return logits
660
 
661
  # ------------------------------------------------------------------
662
  # Mock-AR generation (unchanged)
 
720
  top_k=None,
721
  confidence_threshold=0.9,
722
  eos_token_id=None,
723
+ step_callback=None,
724
  ):
725
  """Block-wise mask diffusion with confidence_unmasking.
726
 
 
763
  # Initialize fully masked block
764
  xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
765
  device=device)
766
+ if step_callback is not None:
767
+ step_callback(0, steps_per_block, xt, t=1.0, logits=None,
768
+ block_idx=block_idx)
769
 
770
  for step_idx in range(steps_per_block):
771
  # t_model = current mask fraction
 
834
  remask_idx = masked_indices[sort_idx[:num_to_remask[b]]]
835
  output[b, remask_idx] = mask_token_id
836
 
837
+ if step_callback is not None:
838
+ step_callback(step_idx, steps_per_block, xt,
839
+ t=float(t_model.detach().cpu()), logits=logits,
840
+ block_idx=block_idx)
841
+
842
  xt = output
843
 
844
  # Block complete — extend context