Text Generation
Transformers
Safetensors
PyTorch
nvidia
two-tower
diffusion
mamba
fitsumreda commited on
Commit
b348e21
·
1 Parent(s): da2556e

faster inferneece

Browse files
inference.py CHANGED
@@ -17,6 +17,7 @@ Usage:
17
  """
18
  import argparse
19
  import inspect
 
20
  import torch
21
  import random
22
  import numpy as np
@@ -77,6 +78,7 @@ inputs = tokenizer(prompt, return_tensors="pt").to(
77
  next(model.context_tower.parameters()).device
78
  )
79
 
 
80
  if args.mode == "ar":
81
  outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
82
  elif args.mode == "mock_ar":
@@ -120,5 +122,37 @@ else:
120
  generate_kwargs["step_callback"] = step_callback
121
  outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
122
 
123
- text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  print(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
  import argparse
19
  import inspect
20
+ import time
21
  import torch
22
  import random
23
  import numpy as np
 
78
  next(model.context_tower.parameters()).device
79
  )
80
 
81
+ t0 = time.perf_counter()
82
  if args.mode == "ar":
83
  outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
84
  elif args.mode == "mock_ar":
 
122
  generate_kwargs["step_callback"] = step_callback
123
  outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
124
 
125
+ if torch.cuda.is_available():
126
+ torch.cuda.synchronize()
127
+ elapsed = max(time.perf_counter() - t0, 1e-9)
128
+
129
+ prompt_len = inputs["input_ids"].shape[1]
130
+ gen_ids = outputs[0][prompt_len:]
131
+ n_new = int(gen_ids.shape[0])
132
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True)
133
+ nfe = getattr(model, "_last_nfe", None)
134
+
135
+ print("\n" + "=" * 70)
136
+ print("--- Request 1/1 ---")
137
+ print(f"Prompt: {prompt}")
138
+ _nfe_str = f"{nfe} NFE, " if (args.mode == "mask_diffusion" and nfe is not None) else ""
139
+ print(f"Generated ({_nfe_str}{n_new} tokens, {elapsed:.2f}s, {n_new / elapsed:.1f} tok/s):")
140
  print(text)
141
+ print("=" * 70)
142
+ if args.mode == "mask_diffusion":
143
+ print("Two-Tower mask-diffusion generation complete")
144
+ print("=" * 70)
145
+ print(f" mode: {args.mode}")
146
+ print(f" block_size: {args.block_size}")
147
+ print(f" steps_per_block: {args.steps_per_block}")
148
+ print(f" max_new_tokens: {args.max_new_tokens}")
149
+ print(f" num_blocks: {args.max_new_tokens // args.block_size}")
150
+ print(f" temperature: {args.temperature}")
151
+ print(f" top_k: {args.top_k}")
152
+ print(f" confidence_threshold: {args.confidence_threshold}")
153
+ print(f" mask_token_id: {args.mask_token_id}")
154
+ print(f" NFE: {nfe}")
155
+ print(f" wall_clock: {elapsed:.2f}s")
156
+ print(f" throughput: {n_new / elapsed:.1f} tokens/s")
157
+ print(f" model: {args.model}")
158
+ print("=" * 70)
modeling_nemotron_h.py CHANGED
@@ -853,8 +853,12 @@ class NemotronHMOE(nn.Module):
853
  expert_output = expert(expert_input)
854
  weighted_output = expert_output * expert_weights.unsqueeze(-1)
855
  final_hidden_states.index_add_(0, token_indices, weighted_output)
856
- else:
857
- # Local empty expert: no-op compute that still marks params as used.
 
 
 
 
858
  expert_dtype = expert.down_proj.weight.dtype
859
  dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
860
  final_hidden_states = final_hidden_states + dummy_out
 
853
  expert_output = expert(expert_input)
854
  weighted_output = expert_output * expert_weights.unsqueeze(-1)
855
  final_hidden_states.index_add_(0, token_indices, weighted_output)
856
+ elif self.training:
857
+ # Training only: no-op forward on a zero token so DDP/grad hooks
858
+ # mark every expert's params as "used". It adds exactly 0 (no
859
+ # biases: relu2(0)=0, down_proj(0)=0), so it's numerically inert.
860
+ # Skipped at inference, where it would otherwise cost ~100+
861
+ # pointless expert MLP calls per MoE layer per step.
862
  expert_dtype = expert.down_proj.weight.dtype
863
  dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
864
  final_hidden_states = final_hidden_states + dummy_out
modeling_nemotron_twotower.py CHANGED
@@ -564,7 +564,7 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
564
  new_conv = comb[:, -d_conv:, :].transpose(1, 2).contiguous()
565
  return out, new_conv, new_ssm
566
 
567
- def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None):
568
  """Diffusion denoiser forward over the FULL block (B, L) in one pass.
569
 
570
  Parity with mcore `_run_denoiser_step`:
@@ -601,8 +601,13 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
601
  t_repr = self.t_embedder(t_dev)
602
  t_emb = self.t_block(t_repr)
603
 
604
- # Fresh denoiser cache seeded from context: Mamba S-1 state + full context KV.
605
- den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
 
 
 
 
 
606
 
607
  hidden = tower.embeddings(den_input)
608
 
@@ -758,8 +763,14 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
758
 
759
  cache_state = self._build_context_cache(input_ids)
760
  context_ids = input_ids.clone()
 
761
 
 
762
  for block_idx in range(num_blocks):
 
 
 
 
763
  # Initialize fully masked block
764
  xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
765
  device=device)
@@ -777,7 +788,8 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
777
  t_vec = t_model.expand(B).to(device)
778
 
779
  # Denoiser forward (logits come back on denoiser device, move to xt's device)
780
- logits = self._run_denoiser_step_diffusion(xt, cache_state, t=t_vec)
 
781
  logits = logits.to(device)
782
 
783
  # p(x0|xt) with constraints
@@ -848,6 +860,8 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
848
  if eos_token_id is not None and (xt == eos_token_id).any():
849
  break
850
 
 
 
851
  return context_ids
852
 
853
  # ------------------------------------------------------------------
 
564
  new_conv = comb[:, -d_conv:, :].transpose(1, 2).contiguous()
565
  return out, new_conv, new_ssm
566
 
567
+ def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None, den_cache=None):
568
  """Diffusion denoiser forward over the FULL block (B, L) in one pass.
569
 
570
  Parity with mcore `_run_denoiser_step`:
 
601
  t_repr = self.t_embedder(t_dev)
602
  t_emb = self.t_block(t_repr)
603
 
604
+ # Denoiser cache (context Mamba S-1 state + full context KV). It is
605
+ # READ-ONLY here and identical for every step within a block, so the
606
+ # caller should build it once per block and pass it in (avoids cloning +
607
+ # cuda:0->cuda:1 copying the whole context cache on every NFE). Fall back
608
+ # to building it if not provided.
609
+ if den_cache is None:
610
+ den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
611
 
612
  hidden = tower.embeddings(den_input)
613
 
 
763
 
764
  cache_state = self._build_context_cache(input_ids)
765
  context_ids = input_ids.clone()
766
+ nfe = 0 # number of denoiser forward passes (network function evaluations)
767
 
768
+ den_device = next(self.denoiser_tower.parameters()).device
769
  for block_idx in range(num_blocks):
770
+ # Build the denoiser cache ONCE per block (context is fixed within a
771
+ # block); reused by every denoising step to avoid per-NFE clone+copy.
772
+ den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
773
+
774
  # Initialize fully masked block
775
  xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
776
  device=device)
 
788
  t_vec = t_model.expand(B).to(device)
789
 
790
  # Denoiser forward (logits come back on denoiser device, move to xt's device)
791
+ logits = self._run_denoiser_step_diffusion(xt, cache_state, t=t_vec, den_cache=den_cache)
792
+ nfe += 1
793
  logits = logits.to(device)
794
 
795
  # p(x0|xt) with constraints
 
860
  if eos_token_id is not None and (xt == eos_token_id).any():
861
  break
862
 
863
+ # Expose NFE (denoiser forward passes) for reporting, e.g. inference.py.
864
+ self._last_nfe = nfe
865
  return context_ids
866
 
867
  # ------------------------------------------------------------------