Text Generation
Transformers
Safetensors
PyTorch
nvidia
two-tower
diffusion
mamba
fitsumreda Claude Opus 4.8 commited on
Commit
67bf233
·
1 Parent(s): c739325

Fix NaN corruption in long-context diffusion (fp32 denoiser SSM scan) + multi-request inference

Browse files

Denoiser Mamba chunk-scan ran in bf16. With a long context the seeded SSM
state grows large (e.g. ~5e3 at L00 for a 1042-token prompt) and the bf16
scan overflows to NaN. Because the Triton kernel's reductions are not
bit-deterministic this struck nondeterministically: a NaN on a block's
all-masked first step makes every confidence NaN, so `NaN > threshold` is
False, the fallback commits 1 token, and sorting NaN confidences force-commits
an arbitrary garbage token (e.g. "katalog"/"hips"), wrecking the answer.

Fix: run the SSM scan in fp32 (x/dt/B/C/D upcast; init_ssm too; cast back
before the gated norm). The scan spans one <=16-token block so cost is
negligible. Also covers the block-to-block context-extend path (same helper).
NOTE: this is broader than mcore (which keeps x/B/C/dt in bf16, fp32 only for
A/D/dt_bias/state); kept as a stability safety net pending the state-magnitude
parity investigation vs mcore.

inference.py: add --prompt-file (jsonl, mcore format) -> Request i/N loop.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

Files changed (2) hide show
  1. inference.py +95 -70
  2. modeling_nemotron_twotower.py +13 -4
inference.py CHANGED
@@ -28,6 +28,9 @@ from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
28
  parser = argparse.ArgumentParser()
29
  parser.add_argument("prompt_arg", nargs="?", default=None)
30
  parser.add_argument("--prompt", default=None)
 
 
 
31
  parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
32
  parser.add_argument("--max-new-tokens", type=int, default=128)
33
  parser.add_argument("--mode", choices=["ar", "mock_ar", "mask_diffusion"], default="mock_ar")
@@ -74,76 +77,97 @@ else:
74
  model.eval()
75
  model.trace_context_layers = args.trace_context_layers
76
  model.trace_denoiser_layers = args.trace_denoiser_layers
77
- inputs = tokenizer(prompt, return_tensors="pt").to(
78
- next(model.context_tower.parameters()).device
79
- )
80
-
81
- t0 = time.perf_counter()
82
- if args.mode == "ar":
83
- # Context-tower-only AR via our cached single-step path (the fair ST-AR
84
- # baseline). Avoids HF generate()'s cache path that crashes on this env.
85
- outputs = model.generate_ar(
86
- inputs["input_ids"], max_new_tokens=args.max_new_tokens,
87
- temperature=0.0, eos_token_id=tokenizer.eos_token_id,
88
- )
89
- elif args.mode == "mock_ar":
90
- outputs = model.generate_mock_ar(
91
- inputs["input_ids"], max_new_tokens=args.max_new_tokens,
92
- temperature=0.0, eos_token_id=tokenizer.eos_token_id,
93
- )
94
  else:
95
- def step_callback(step_idx, total_steps, tokens, t=None, logits=None, block_idx=0):
96
- if not args.print_diffusion_steps:
97
- return
98
- if logits is None:
99
- print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | init ---")
100
- print("xt:", tokenizer.decode(tokens[0], skip_special_tokens=False))
101
- return
102
- log_x = model._mdlm_forward(logits, tokens.to(logits.device), args.mask_token_id)
103
- probs = log_x.exp()[0]
104
- top2_probs, top2_ids = probs.topk(2, dim=-1)
105
- n_masked = int((tokens == args.mask_token_id).sum().item())
106
- print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | masked={n_masked}/{tokens.shape[1]} | t={t:.4f} ---")
107
- print("xt: " + repr(tokenizer.decode(tokens[0], skip_special_tokens=False)))
108
- print("top1: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 0]))
109
- print("prb1: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 0]))
110
- print("top2: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 1]))
111
- print("prb2: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 1]))
112
-
113
- generate_kwargs = dict(
114
- max_new_tokens=args.max_new_tokens,
115
- block_size=args.block_size,
116
- steps_per_block=args.steps_per_block,
117
- mask_token_id=args.mask_token_id,
118
- temperature=args.temperature,
119
- top_k=args.top_k,
120
- confidence_threshold=args.confidence_threshold,
121
- eos_token_id=tokenizer.eos_token_id,
122
- )
123
- if (
124
- args.print_diffusion_steps
125
- and "step_callback" in inspect.signature(model.generate_mask_diffusion).parameters
126
- ):
127
- generate_kwargs["step_callback"] = step_callback
128
- outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
129
-
130
- if torch.cuda.is_available():
131
- torch.cuda.synchronize()
132
- elapsed = max(time.perf_counter() - t0, 1e-9)
133
-
134
- prompt_len = inputs["input_ids"].shape[1]
135
- gen_ids = outputs[0][prompt_len:]
136
- n_new = int(gen_ids.shape[0])
137
- text = tokenizer.decode(gen_ids, skip_special_tokens=True)
138
- nfe = getattr(model, "_last_nfe", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  print("\n" + "=" * 70)
141
- print("--- Request 1/1 ---")
142
- print(f"Prompt: {prompt}")
143
- _nfe_str = f"{nfe} NFE, " if (args.mode == "mask_diffusion" and nfe is not None) else ""
144
- print(f"Generated ({_nfe_str}{n_new} tokens, {elapsed:.2f}s, {n_new / elapsed:.1f} tok/s):")
145
- print(text)
146
- print("=" * 70)
147
  if args.mode == "mask_diffusion":
148
  print("Two-Tower mask-diffusion generation complete")
149
  print("=" * 70)
@@ -156,8 +180,9 @@ if args.mode == "mask_diffusion":
156
  print(f" top_k: {args.top_k}")
157
  print(f" confidence_threshold: {args.confidence_threshold}")
158
  print(f" mask_token_id: {args.mask_token_id}")
159
- print(f" NFE: {nfe}")
160
- print(f" wall_clock: {elapsed:.2f}s")
161
- print(f" throughput: {n_new / elapsed:.1f} tokens/s")
162
  print(f" model: {args.model}")
163
  print("=" * 70)
 
 
 
 
28
  parser = argparse.ArgumentParser()
29
  parser.add_argument("prompt_arg", nargs="?", default=None)
30
  parser.add_argument("--prompt", default=None)
31
+ parser.add_argument("--prompt-file", dest="prompt_file", default=None,
32
+ help="jsonl of {\"text\": ...} per line (same format as mcore "
33
+ "--prompt-file); each line is run as its own Request i/N.")
34
  parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
35
  parser.add_argument("--max-new-tokens", type=int, default=128)
36
  parser.add_argument("--mode", choices=["ar", "mock_ar", "mask_diffusion"], default="mock_ar")
 
77
  model.eval()
78
  model.trace_context_layers = args.trace_context_layers
79
  model.trace_denoiser_layers = args.trace_denoiser_layers
80
+ # Build the request list. A --prompt-file (jsonl, one {"text": ...} per line,
81
+ # same format mcore consumes) runs as multiple Requests i/N; otherwise the
82
+ # single positional/--prompt is the lone request.
83
+ if args.prompt_file:
84
+ import json
85
+ prompts = []
86
+ with open(args.prompt_file) as f:
87
+ for line in f:
88
+ line = line.strip()
89
+ if line:
90
+ prompts.append(json.loads(line)["text"])
91
+ if not prompts:
92
+ raise ValueError(f"No prompts found in {args.prompt_file}")
 
 
 
 
93
  else:
94
+ prompts = [prompt]
95
+
96
+
97
+ def step_callback(step_idx, total_steps, tokens, t=None, logits=None, block_idx=0):
98
+ if not args.print_diffusion_steps:
99
+ return
100
+ if logits is None:
101
+ print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | init ---")
102
+ print("xt:", tokenizer.decode(tokens[0], skip_special_tokens=False))
103
+ return
104
+ log_x = model._mdlm_forward(logits, tokens.to(logits.device), args.mask_token_id)
105
+ probs = log_x.exp()[0]
106
+ top2_probs, top2_ids = probs.topk(2, dim=-1)
107
+ n_masked = int((tokens == args.mask_token_id).sum().item())
108
+ print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | masked={n_masked}/{tokens.shape[1]} | t={t:.4f} ---")
109
+ print("xt: " + repr(tokenizer.decode(tokens[0], skip_special_tokens=False)))
110
+ print("top1: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 0]))
111
+ print("prb1: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 0]))
112
+ print("top2: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 1]))
113
+ print("prb2: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 1]))
114
+
115
+
116
+ ctx_device = next(model.context_tower.parameters()).device
117
+ n_requests = len(prompts)
118
+ for ridx, prompt in enumerate(prompts):
119
+ inputs = tokenizer(prompt, return_tensors="pt").to(ctx_device)
120
+ if args.print_diffusion_steps and args.mode == "mask_diffusion":
121
+ print(f"\n--- Diffusion steps for request {ridx + 1} ---")
122
+
123
+ t0 = time.perf_counter()
124
+ if args.mode == "ar":
125
+ # Context-tower-only AR via our cached single-step path (the fair ST-AR
126
+ # baseline). Avoids HF generate()'s cache path that crashes on this env.
127
+ outputs = model.generate_ar(
128
+ inputs["input_ids"], max_new_tokens=args.max_new_tokens,
129
+ temperature=0.0, eos_token_id=tokenizer.eos_token_id,
130
+ )
131
+ elif args.mode == "mock_ar":
132
+ outputs = model.generate_mock_ar(
133
+ inputs["input_ids"], max_new_tokens=args.max_new_tokens,
134
+ temperature=0.0, eos_token_id=tokenizer.eos_token_id,
135
+ )
136
+ else:
137
+ generate_kwargs = dict(
138
+ max_new_tokens=args.max_new_tokens,
139
+ block_size=args.block_size,
140
+ steps_per_block=args.steps_per_block,
141
+ mask_token_id=args.mask_token_id,
142
+ temperature=args.temperature,
143
+ top_k=args.top_k,
144
+ confidence_threshold=args.confidence_threshold,
145
+ eos_token_id=tokenizer.eos_token_id,
146
+ )
147
+ if (
148
+ args.print_diffusion_steps
149
+ and "step_callback" in inspect.signature(model.generate_mask_diffusion).parameters
150
+ ):
151
+ generate_kwargs["step_callback"] = step_callback
152
+ outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
153
+
154
+ if torch.cuda.is_available():
155
+ torch.cuda.synchronize()
156
+ elapsed = max(time.perf_counter() - t0, 1e-9)
157
+
158
+ prompt_len = inputs["input_ids"].shape[1]
159
+ gen_ids = outputs[0][prompt_len:]
160
+ n_new = int(gen_ids.shape[0])
161
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True)
162
+ nfe = getattr(model, "_last_nfe", None)
163
+
164
+ print(f"\n--- Request {ridx + 1}/{n_requests} ---")
165
+ print(f"Prompt: {prompt}")
166
+ _nfe_str = f"{nfe} NFE, " if (args.mode == "mask_diffusion" and nfe is not None) else ""
167
+ print(f"Generated ({_nfe_str}{n_new} tokens, {elapsed:.2f}s, {n_new / elapsed:.1f} tok/s):")
168
+ print(text)
169
 
170
  print("\n" + "=" * 70)
 
 
 
 
 
 
171
  if args.mode == "mask_diffusion":
172
  print("Two-Tower mask-diffusion generation complete")
173
  print("=" * 70)
 
180
  print(f" top_k: {args.top_k}")
181
  print(f" confidence_threshold: {args.confidence_threshold}")
182
  print(f" mask_token_id: {args.mask_token_id}")
183
+ print(f" num_requests: {n_requests}")
 
 
184
  print(f" model: {args.model}")
185
  print("=" * 70)
186
+ else:
187
+ print("Two-tower generation complete")
188
+ print("=" * 70)
modeling_nemotron_twotower.py CHANGED
@@ -554,19 +554,28 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
554
  B_proj = rearrange(B_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
555
  C_proj = rearrange(C_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
556
 
 
 
 
 
 
 
 
 
557
  A = -torch.exp(mixer.A_log.float())
558
  scan = mamba_chunk_scan_combined(
559
- x, dt.contiguous(), A, B_proj, C_proj, mixer.chunk_size,
560
- D=mixer.D, z=None,
 
561
  dt_bias=mixer.dt_bias.float(), dt_softplus=True,
562
- initial_states=init_ssm,
563
  return_final_states=return_states,
564
  )
565
  if return_states:
566
  y, new_ssm = scan
567
  else:
568
  y = scan
569
- y = rearrange(y, "b s h p -> b s (h p)")
570
  y = mixer.norm(y, z) # Mamba2 z-gated RMSNorm
571
  out = mixer.out_proj(y)
572
  if not return_states:
 
554
  B_proj = rearrange(B_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
555
  C_proj = rearrange(C_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
556
 
557
+ # Run the SSM scan in fp32. With a long context the seeded SSM state gets
558
+ # large (O(1e3)+); the bf16 chunk-scan then overflows to NaN, and because
559
+ # the Triton kernel's reductions are not bit-deterministic this strikes
560
+ # nondeterministically (a NaN on a block's first/all-masked step force-
561
+ # commits a garbage token, e.g. "katalog"/"hips", and wrecks the answer).
562
+ # The scan spans only one block (<=16 tokens) so fp32 is essentially free,
563
+ # and it is strictly more accurate. Cast back before the gated norm.
564
+ _y_dtype = z.dtype
565
  A = -torch.exp(mixer.A_log.float())
566
  scan = mamba_chunk_scan_combined(
567
+ x.float(), dt.float().contiguous(), A, B_proj.float(), C_proj.float(),
568
+ mixer.chunk_size,
569
+ D=mixer.D.float(), z=None,
570
  dt_bias=mixer.dt_bias.float(), dt_softplus=True,
571
+ initial_states=(init_ssm.float() if init_ssm is not None else None),
572
  return_final_states=return_states,
573
  )
574
  if return_states:
575
  y, new_ssm = scan
576
  else:
577
  y = scan
578
+ y = rearrange(y, "b s h p -> b s (h p)").to(_y_dtype)
579
  y = mixer.norm(y, z) # Mamba2 z-gated RMSNorm
580
  out = mixer.out_proj(y)
581
  if not return_states: