Instructions to use nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16
- SGLang
How to use nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16 with Docker Model Runner:
docker model run hf.co/nvidia/Nemotron-Labs-TwoTower-30B-A3B-Base-BF16
Two-tower mask diffusion: fix denoiser (adaLN norm order, bidirectional in-block attention, block-wise chunk-scan Mamba) + fp64 router; refresh README
Browse files- README.md +41 -21
- config.json +1 -1
- inference.py +76 -5
- modeling_nemotron_h.py +4 -2
- modeling_nemotron_twotower.py +242 -29
README.md
CHANGED
|
@@ -77,19 +77,31 @@ Both towers share the same architecture (52 layers, `MEMEM*EMEMEM*...` hybrid pa
|
|
| 77 |
|
| 78 |
### Two-Tower Generation Modes
|
| 79 |
|
| 80 |
-
| Mode | Description | Tokens/step |
|
| 81 |
-
|------|-------------|-------------|
|
| 82 |
-
| **
|
| 83 |
-
| **Mock-AR** | Two-tower autoregressive. Context tower builds cache, denoiser predicts next token. | 1 |
|
| 84 |
-
| **
|
| 85 |
|
| 86 |
### What is Two-Tower?
|
| 87 |
|
| 88 |
-
The two-tower architecture decouples
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
This model is ready for commercial use.
|
| 95 |
|
|
@@ -142,7 +154,7 @@ Software used for training: [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
|
|
| 142 |
|
| 143 |
### Use it with Transformers
|
| 144 |
|
| 145 |
-
The snippet below shows how to use this model with HuggingFace Transformers. **Two-tower inference requires 2 GPUs** (~59GB per GPU for bf16 weights).
|
| 146 |
|
| 147 |
```python
|
| 148 |
import torch
|
|
@@ -156,39 +168,47 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 156 |
trust_remote_code=True,
|
| 157 |
)
|
| 158 |
|
| 159 |
-
#
|
| 160 |
model.place_towers_on_devices("cuda:0", "cuda:1")
|
| 161 |
model.eval()
|
| 162 |
|
| 163 |
prompt = "France is a country "
|
| 164 |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
outputs = model.
|
| 168 |
inputs["input_ids"],
|
| 169 |
max_new_tokens=128,
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
eos_token_id=tokenizer.eos_token_id,
|
| 172 |
)
|
|
|
|
|
|
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 175 |
```
|
| 176 |
|
| 177 |
-
|
| 178 |
|
| 179 |
```python
|
| 180 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 181 |
-
model_name,
|
| 182 |
-
torch_dtype=torch.bfloat16,
|
| 183 |
-
trust_remote_code=True,
|
| 184 |
-
).cuda()
|
| 185 |
-
|
| 186 |
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
| 187 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 188 |
```
|
| 189 |
|
| 190 |
## Model Version(s)
|
| 191 |
|
|
|
|
| 192 |
- v1.0 — Two-tower AR (mock-AR) checkpoint
|
| 193 |
|
| 194 |
# Training, Testing, and Evaluation Datasets
|
|
|
|
| 77 |
|
| 78 |
### Two-Tower Generation Modes
|
| 79 |
|
| 80 |
+
| Mode | Description | Tokens/step | API |
|
| 81 |
+
|------|-------------|-------------|-----|
|
| 82 |
+
| **Mask Diffusion** | Block-wise iterative denoising with confidence-based unmasking (flagship two-tower mode). | up to `block_size` | `generate_mask_diffusion()` |
|
| 83 |
+
| **Mock-AR** | Two-tower autoregressive. Context tower builds cache, denoiser predicts next token. | 1 | `generate_mock_ar()` |
|
| 84 |
+
| **AR** | Standard autoregressive via `generate()`. Uses context tower only (single GPU). | 1 | `generate()` |
|
| 85 |
|
| 86 |
### What is Two-Tower?
|
| 87 |
|
| 88 |
+
The two-tower architecture decouples "understanding context" from "generating tokens" into separate networks:
|
| 89 |
|
| 90 |
+
- **Context Tower** runs causally over the prompt and all previously committed tokens, producing the layer-aligned KV cache (attention) and Mamba states that the denoiser conditions on.
|
| 91 |
+
- **Denoiser Tower** generates a *block* of tokens at once. Within a block it is **bidirectional** (every position attends to the whole noisy block + the full causal context); across blocks it is causal via the context cache.
|
| 92 |
+
|
| 93 |
+
This enables **block-wise parallel generation** — the denoiser fills `block_size` masked positions per block and commits the most confident ones each step, so a block resolves in a handful of denoising steps rather than `block_size` autoregressive steps.
|
| 94 |
+
|
| 95 |
+
### Mask Diffusion: how it works
|
| 96 |
+
|
| 97 |
+
Generation proceeds block by block. For each new block of `block_size` positions:
|
| 98 |
+
|
| 99 |
+
1. Initialize the block as all `[MASK]` tokens (`mask_token_id`).
|
| 100 |
+
2. For `steps_per_block` iterations:
|
| 101 |
+
- Compute the diffusion timestep `t` = current masked fraction of the block, and feed it to the **time-conditioned denoiser** (PixArt-α adaLN-single modulation on every denoiser layer).
|
| 102 |
+
- Run the denoiser over the whole block (bidirectional self-attention + cross-attention to the context cache; Mamba chunk-scan seeded from the context state).
|
| 103 |
+
- Constrain to `p(x₀ | xₜ)` (mask token forbidden; already-decoded positions fixed), then **commit** the highest-confidence positions (all above `confidence_threshold`, with a floor that guarantees completion in `steps_per_block`) and re-mask the rest.
|
| 104 |
+
3. Append the resolved block to the context, extend the context cache, and continue.
|
| 105 |
|
| 106 |
This model is ready for commercial use.
|
| 107 |
|
|
|
|
| 154 |
|
| 155 |
### Use it with Transformers
|
| 156 |
|
| 157 |
+
The snippet below shows how to use this model with HuggingFace Transformers. **Two-tower inference requires 2 GPUs** (~59GB per GPU for bf16 weights); the towers are placed on separate devices.
|
| 158 |
|
| 159 |
```python
|
| 160 |
import torch
|
|
|
|
| 168 |
trust_remote_code=True,
|
| 169 |
)
|
| 170 |
|
| 171 |
+
# Context tower -> GPU 0, denoiser tower -> GPU 1
|
| 172 |
model.place_towers_on_devices("cuda:0", "cuda:1")
|
| 173 |
model.eval()
|
| 174 |
|
| 175 |
prompt = "France is a country "
|
| 176 |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
|
| 177 |
|
| 178 |
+
# Flagship mode: block-wise mask diffusion
|
| 179 |
+
outputs = model.generate_mask_diffusion(
|
| 180 |
inputs["input_ids"],
|
| 181 |
max_new_tokens=128,
|
| 182 |
+
block_size=16, # tokens generated per block
|
| 183 |
+
steps_per_block=16, # denoising iterations per block
|
| 184 |
+
mask_token_id=3, # <mask>
|
| 185 |
+
temperature=0.1,
|
| 186 |
+
confidence_threshold=0.8, # commit positions above this confidence each step
|
| 187 |
eos_token_id=tokenizer.eos_token_id,
|
| 188 |
)
|
| 189 |
+
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
|
| 190 |
+
```
|
| 191 |
|
| 192 |
+
**Mock-AR** (two-tower, one token per step):
|
| 193 |
+
|
| 194 |
+
```python
|
| 195 |
+
outputs = model.generate_mock_ar(
|
| 196 |
+
inputs["input_ids"], max_new_tokens=128, temperature=0.0,
|
| 197 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 198 |
+
)
|
| 199 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 200 |
```
|
| 201 |
|
| 202 |
+
**AR-only** (single GPU, context tower only — load with `.cuda()` instead of `place_towers_on_devices`):
|
| 203 |
|
| 204 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
| 206 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 207 |
```
|
| 208 |
|
| 209 |
## Model Version(s)
|
| 210 |
|
| 211 |
+
- v1.1 — Block-wise **mask-diffusion** generation enabled (time-conditioned denoiser, bidirectional in-block attention, chunk-scan Mamba); AR and mock-AR also supported.
|
| 212 |
- v1.0 — Two-tower AR (mock-AR) checkpoint
|
| 213 |
|
| 214 |
# Training, Testing, and Evaluation Datasets
|
config.json
CHANGED
|
@@ -54,7 +54,7 @@
|
|
| 54 |
"time_step_floor": 0.0001,
|
| 55 |
"time_step_limit": [
|
| 56 |
0.0,
|
| 57 |
-
|
| 58 |
],
|
| 59 |
"time_step_max": 0.1,
|
| 60 |
"time_step_min": 0.001,
|
|
|
|
| 54 |
"time_step_floor": 0.0001,
|
| 55 |
"time_step_limit": [
|
| 56 |
0.0,
|
| 57 |
+
Infinity
|
| 58 |
],
|
| 59 |
"time_step_max": 0.1,
|
| 60 |
"time_step_min": 0.001,
|
inference.py
CHANGED
|
@@ -11,19 +11,46 @@ Usage:
|
|
| 11 |
|
| 12 |
# AR (context tower only, 1 GPU):
|
| 13 |
python inference.py --mode ar
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
import argparse
|
|
|
|
| 16 |
import torch
|
|
|
|
|
|
|
| 17 |
from pathlib import Path
|
| 18 |
from transformers import AutoTokenizer
|
| 19 |
from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
|
| 20 |
|
| 21 |
parser = argparse.ArgumentParser()
|
| 22 |
-
parser.add_argument("
|
|
|
|
| 23 |
parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
|
| 24 |
parser.add_argument("--max-new-tokens", type=int, default=128)
|
| 25 |
-
parser.add_argument("--mode", choices=["ar", "mock_ar"], default="mock_ar")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 29 |
model = NemotronHTwoTowerForCausalLM.from_pretrained(
|
|
@@ -31,23 +58,67 @@ model = NemotronHTwoTowerForCausalLM.from_pretrained(
|
|
| 31 |
)
|
| 32 |
|
| 33 |
num_gpus = torch.cuda.device_count()
|
| 34 |
-
if
|
|
|
|
|
|
|
| 35 |
model.place_towers_on_devices("cuda:0", "cuda:1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
else:
|
| 37 |
model.cuda()
|
| 38 |
|
| 39 |
model.eval()
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
next(model.context_tower.parameters()).device
|
| 42 |
)
|
| 43 |
|
| 44 |
if args.mode == "ar":
|
| 45 |
outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
|
| 46 |
-
|
| 47 |
outputs = model.generate_mock_ar(
|
| 48 |
inputs["input_ids"], max_new_tokens=args.max_new_tokens,
|
| 49 |
temperature=0.0, eos_token_id=tokenizer.eos_token_id,
|
| 50 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 53 |
print(text)
|
|
|
|
| 11 |
|
| 12 |
# AR (context tower only, 1 GPU):
|
| 13 |
python inference.py --mode ar
|
| 14 |
+
|
| 15 |
+
# Mask diffusion (two-tower, 2 GPUs):
|
| 16 |
+
python inference.py --mode mask_diffusion --model /path/to/diffusion_hf_out
|
| 17 |
"""
|
| 18 |
import argparse
|
| 19 |
+
import inspect
|
| 20 |
import torch
|
| 21 |
+
import random
|
| 22 |
+
import numpy as np
|
| 23 |
from pathlib import Path
|
| 24 |
from transformers import AutoTokenizer
|
| 25 |
from modeling_nemotron_twotower import NemotronHTwoTowerForCausalLM
|
| 26 |
|
| 27 |
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument("prompt_arg", nargs="?", default=None)
|
| 29 |
+
parser.add_argument("--prompt", default=None)
|
| 30 |
parser.add_argument("--model", default=str(Path(__file__).resolve().parent))
|
| 31 |
parser.add_argument("--max-new-tokens", type=int, default=128)
|
| 32 |
+
parser.add_argument("--mode", choices=["ar", "mock_ar", "mask_diffusion"], default="mock_ar")
|
| 33 |
+
parser.add_argument("--block-size", type=int, default=16)
|
| 34 |
+
parser.add_argument("--steps-per-block", type=int, default=16)
|
| 35 |
+
parser.add_argument("--mask-token-id", type=int, default=3)
|
| 36 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
| 37 |
+
parser.add_argument("--top-k", "--top_k", dest="top_k", type=int, default=None)
|
| 38 |
+
parser.add_argument("--confidence-threshold", type=float, default=0.9)
|
| 39 |
+
parser.add_argument("--deterministic", action="store_true")
|
| 40 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 41 |
+
parser.add_argument("--print-diffusion-steps", action="store_true")
|
| 42 |
+
parser.add_argument("--trace-context-layers", action="store_true")
|
| 43 |
+
parser.add_argument("--trace-denoiser-layers", action="store_true")
|
| 44 |
args = parser.parse_args()
|
| 45 |
+
prompt = args.prompt if args.prompt is not None else (args.prompt_arg or "France is a country ")
|
| 46 |
+
|
| 47 |
+
if args.deterministic:
|
| 48 |
+
random.seed(args.seed)
|
| 49 |
+
np.random.seed(args.seed)
|
| 50 |
+
torch.manual_seed(args.seed)
|
| 51 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 52 |
+
torch.backends.cudnn.deterministic = True
|
| 53 |
+
torch.backends.cudnn.benchmark = False
|
| 54 |
|
| 55 |
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 56 |
model = NemotronHTwoTowerForCausalLM.from_pretrained(
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
num_gpus = torch.cuda.device_count()
|
| 61 |
+
if num_gpus >= 2:
|
| 62 |
+
# Split towers across GPUs (both towers don't fit on one 80GB card).
|
| 63 |
+
# AR mode only uses the context tower (cuda:0), but placing both is fine.
|
| 64 |
model.place_towers_on_devices("cuda:0", "cuda:1")
|
| 65 |
+
elif args.mode == "ar":
|
| 66 |
+
# AR uses only the context tower + context head; keep the denoiser tower
|
| 67 |
+
# off the GPU so a single card suffices.
|
| 68 |
+
model.context_tower = model.context_tower.cuda()
|
| 69 |
+
model.context_lm_head = model.context_lm_head.cuda()
|
| 70 |
else:
|
| 71 |
model.cuda()
|
| 72 |
|
| 73 |
model.eval()
|
| 74 |
+
model.trace_context_layers = args.trace_context_layers
|
| 75 |
+
model.trace_denoiser_layers = args.trace_denoiser_layers
|
| 76 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(
|
| 77 |
next(model.context_tower.parameters()).device
|
| 78 |
)
|
| 79 |
|
| 80 |
if args.mode == "ar":
|
| 81 |
outputs = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=False)
|
| 82 |
+
elif args.mode == "mock_ar":
|
| 83 |
outputs = model.generate_mock_ar(
|
| 84 |
inputs["input_ids"], max_new_tokens=args.max_new_tokens,
|
| 85 |
temperature=0.0, eos_token_id=tokenizer.eos_token_id,
|
| 86 |
)
|
| 87 |
+
else:
|
| 88 |
+
def step_callback(step_idx, total_steps, tokens, t=None, logits=None, block_idx=0):
|
| 89 |
+
if not args.print_diffusion_steps:
|
| 90 |
+
return
|
| 91 |
+
if logits is None:
|
| 92 |
+
print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | init ---")
|
| 93 |
+
print("xt:", tokenizer.decode(tokens[0], skip_special_tokens=False))
|
| 94 |
+
return
|
| 95 |
+
log_x = model._mdlm_forward(logits, tokens.to(logits.device), args.mask_token_id)
|
| 96 |
+
probs = log_x.exp()[0]
|
| 97 |
+
top2_probs, top2_ids = probs.topk(2, dim=-1)
|
| 98 |
+
n_masked = int((tokens == args.mask_token_id).sum().item())
|
| 99 |
+
print(f"\n--- Block {block_idx} Step {step_idx}/{total_steps} | masked={n_masked}/{tokens.shape[1]} | t={t:.4f} ---")
|
| 100 |
+
print("xt: " + repr(tokenizer.decode(tokens[0], skip_special_tokens=False)))
|
| 101 |
+
print("top1: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 0]))
|
| 102 |
+
print("prb1: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 0]))
|
| 103 |
+
print("top2: " + "|".join(tokenizer.decode([tid.item()])[:9].rjust(9) for tid in top2_ids[:, 1]))
|
| 104 |
+
print("prb2: " + "|".join(f"{p.item():.3f}".rjust(9) for p in top2_probs[:, 1]))
|
| 105 |
+
|
| 106 |
+
generate_kwargs = dict(
|
| 107 |
+
max_new_tokens=args.max_new_tokens,
|
| 108 |
+
block_size=args.block_size,
|
| 109 |
+
steps_per_block=args.steps_per_block,
|
| 110 |
+
mask_token_id=args.mask_token_id,
|
| 111 |
+
temperature=args.temperature,
|
| 112 |
+
top_k=args.top_k,
|
| 113 |
+
confidence_threshold=args.confidence_threshold,
|
| 114 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 115 |
+
)
|
| 116 |
+
if (
|
| 117 |
+
args.print_diffusion_steps
|
| 118 |
+
and "step_callback" in inspect.signature(model.generate_mask_diffusion).parameters
|
| 119 |
+
):
|
| 120 |
+
generate_kwargs["step_callback"] = step_callback
|
| 121 |
+
outputs = model.generate_mask_diffusion(inputs["input_ids"], **generate_kwargs)
|
| 122 |
|
| 123 |
text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 124 |
print(text)
|
modeling_nemotron_h.py
CHANGED
|
@@ -910,7 +910,9 @@ class NemotronHTopkRouter(nn.Module):
|
|
| 910 |
|
| 911 |
def forward(self, hidden_states):
|
| 912 |
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
| 913 |
-
|
|
|
|
|
|
|
| 914 |
scores = router_logits.sigmoid()
|
| 915 |
topk_indices = self.get_topk_indices(scores)
|
| 916 |
topk_weights = scores.gather(1, topk_indices)
|
|
@@ -918,7 +920,7 @@ class NemotronHTopkRouter(nn.Module):
|
|
| 918 |
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
| 919 |
topk_weights /= denominator
|
| 920 |
topk_weights = topk_weights * self.routed_scaling_factor
|
| 921 |
-
return topk_indices, topk_weights
|
| 922 |
|
| 923 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 924 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 910 |
|
| 911 |
def forward(self, hidden_states):
|
| 912 |
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
| 913 |
+
# mcore runs the MoE router in fp64 (--moe-router-dtype fp64); match it so
|
| 914 |
+
# top-k expert selection is bit-identical at borderline scores.
|
| 915 |
+
router_logits = F.linear(hidden_states.type(torch.float64), self.weight.type(torch.float64))
|
| 916 |
scores = router_logits.sigmoid()
|
| 917 |
topk_indices = self.get_topk_indices(scores)
|
| 918 |
topk_weights = scores.gather(1, topk_indices)
|
|
|
|
| 920 |
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
| 921 |
topk_weights /= denominator
|
| 922 |
topk_weights = topk_weights * self.routed_scaling_factor
|
| 923 |
+
return topk_indices, topk_weights.type(torch.float32)
|
| 924 |
|
| 925 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 926 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
modeling_nemotron_twotower.py
CHANGED
|
@@ -30,6 +30,7 @@ try:
|
|
| 30 |
NemotronHForCausalLM,
|
| 31 |
NemotronHModel,
|
| 32 |
NemotronHPreTrainedModel,
|
|
|
|
| 33 |
)
|
| 34 |
from .configuration_nemotron_h import NemotronHConfig
|
| 35 |
except ImportError:
|
|
@@ -39,6 +40,7 @@ except ImportError:
|
|
| 39 |
NemotronHForCausalLM,
|
| 40 |
NemotronHModel,
|
| 41 |
NemotronHPreTrainedModel,
|
|
|
|
| 42 |
)
|
| 43 |
from configuration_nemotron_h import NemotronHConfig
|
| 44 |
|
|
@@ -180,8 +182,11 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 180 |
elif input_ids.shape[1] != cache_position.shape[0]:
|
| 181 |
input_ids = input_ids[:, cache_position]
|
| 182 |
else:
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
if attention_mask is not None and position_ids is None:
|
| 187 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
@@ -330,18 +335,20 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 330 |
return {"ctx_cache": cache_p2, "mamba_s2": mamba_s2, "ctx_len": S}
|
| 331 |
|
| 332 |
def _extend_context_cache(self, new_tokens, cache_state):
|
| 333 |
-
"""Extend context cache by new_tokens (B, L)
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
"""
|
| 338 |
ctx_cache = cache_state["ctx_cache"]
|
| 339 |
pattern = self.config.hybrid_override_pattern
|
| 340 |
ctx_len = cache_state["ctx_len"]
|
| 341 |
-
|
|
|
|
| 342 |
L = new_tokens.shape[1]
|
| 343 |
-
|
| 344 |
|
|
|
|
| 345 |
new_s2 = {}
|
| 346 |
for i in range(self.config.num_hidden_layers):
|
| 347 |
if pattern[i] == "M":
|
|
@@ -350,12 +357,37 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 350 |
cache_state["mamba_s2"] = new_s2
|
| 351 |
|
| 352 |
ctx_cache.has_previous_state = True
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
cache_state["ctx_len"] = ctx_len + L
|
| 361 |
return cache_state
|
|
@@ -415,13 +447,139 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 415 |
self.denoiser_tower, self.lm_head, den_input, den_cache, cp,
|
| 416 |
)
|
| 417 |
|
| 418 |
-
def
|
| 419 |
-
"""
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
Args:
|
| 427 |
block_ids: (B, L) tokens to denoise
|
|
@@ -431,28 +589,74 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 431 |
Returns: logits (B, L, V)
|
| 432 |
"""
|
| 433 |
ctx_len = cache_state["ctx_len"]
|
| 434 |
-
|
|
|
|
| 435 |
den_input = block_ids.to(den_device)
|
| 436 |
L = den_input.shape[1]
|
| 437 |
|
|
|
|
| 438 |
t_emb = None
|
| 439 |
if t is not None:
|
| 440 |
t_dev = t.to(device=den_device, dtype=self.dtype)
|
| 441 |
t_repr = self.t_embedder(t_dev)
|
| 442 |
t_emb = self.t_block(t_repr)
|
| 443 |
|
|
|
|
| 444 |
den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
|
| 445 |
|
| 446 |
-
|
| 447 |
-
for i in range(L):
|
| 448 |
-
cp = torch.tensor([ctx_len + i], device=den_device)
|
| 449 |
-
logits_i = self._forward_tower_with_cache(
|
| 450 |
-
self.denoiser_tower, self.lm_head, den_input[:, i:i+1],
|
| 451 |
-
den_cache, cp, t_emb=t_emb,
|
| 452 |
-
)
|
| 453 |
-
all_logits.append(logits_i)
|
| 454 |
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
# ------------------------------------------------------------------
|
| 458 |
# Mock-AR generation (unchanged)
|
|
@@ -516,6 +720,7 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 516 |
top_k=None,
|
| 517 |
confidence_threshold=0.9,
|
| 518 |
eos_token_id=None,
|
|
|
|
| 519 |
):
|
| 520 |
"""Block-wise mask diffusion with confidence_unmasking.
|
| 521 |
|
|
@@ -558,6 +763,9 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 558 |
# Initialize fully masked block
|
| 559 |
xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
|
| 560 |
device=device)
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
for step_idx in range(steps_per_block):
|
| 563 |
# t_model = current mask fraction
|
|
@@ -626,6 +834,11 @@ class NemotronHTwoTowerForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 626 |
remask_idx = masked_indices[sort_idx[:num_to_remask[b]]]
|
| 627 |
output[b, remask_idx] = mask_token_id
|
| 628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
xt = output
|
| 630 |
|
| 631 |
# Block complete — extend context
|
|
|
|
| 30 |
NemotronHForCausalLM,
|
| 31 |
NemotronHModel,
|
| 32 |
NemotronHPreTrainedModel,
|
| 33 |
+
repeat_kv,
|
| 34 |
)
|
| 35 |
from .configuration_nemotron_h import NemotronHConfig
|
| 36 |
except ImportError:
|
|
|
|
| 40 |
NemotronHForCausalLM,
|
| 41 |
NemotronHModel,
|
| 42 |
NemotronHPreTrainedModel,
|
| 43 |
+
repeat_kv,
|
| 44 |
)
|
| 45 |
from configuration_nemotron_h import NemotronHConfig
|
| 46 |
|
|
|
|
| 182 |
elif input_ids.shape[1] != cache_position.shape[0]:
|
| 183 |
input_ids = input_ids[:, cache_position]
|
| 184 |
else:
|
| 185 |
+
# FixedHybridCache (not the base class) so the Mamba mixer finds
|
| 186 |
+
# conv_kernel_size during the cached forward (needed for AR generate).
|
| 187 |
+
past_key_values = FixedHybridCache(
|
| 188 |
+
self.config, input_ids.shape[0], self.dtype,
|
| 189 |
+
device=next(self.context_tower.parameters()).device,
|
| 190 |
)
|
| 191 |
if attention_mask is not None and position_ids is None:
|
| 192 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
|
| 335 |
return {"ctx_cache": cache_p2, "mamba_s2": mamba_s2, "ctx_len": S}
|
| 336 |
|
| 337 |
def _extend_context_cache(self, new_tokens, cache_state):
|
| 338 |
+
"""Extend context cache by new_tokens (B, L), block-wise (matches mcore).
|
| 339 |
|
| 340 |
+
Mamba layers advance via the block chunk-scan from the current state;
|
| 341 |
+
attention layers append the block KV (causal within block); MoE is plain.
|
| 342 |
"""
|
| 343 |
ctx_cache = cache_state["ctx_cache"]
|
| 344 |
pattern = self.config.hybrid_override_pattern
|
| 345 |
ctx_len = cache_state["ctx_len"]
|
| 346 |
+
tower = self.context_tower
|
| 347 |
+
ctx_device = next(tower.parameters()).device
|
| 348 |
L = new_tokens.shape[1]
|
| 349 |
+
tokens = new_tokens.to(ctx_device)
|
| 350 |
|
| 351 |
+
# Snapshot pre-extension Mamba states as the new S-2 (used by mock-AR).
|
| 352 |
new_s2 = {}
|
| 353 |
for i in range(self.config.num_hidden_layers):
|
| 354 |
if pattern[i] == "M":
|
|
|
|
| 357 |
cache_state["mamba_s2"] = new_s2
|
| 358 |
|
| 359 |
ctx_cache.has_previous_state = True
|
| 360 |
+
cache_position = torch.arange(ctx_len, ctx_len + L, device=ctx_device)
|
| 361 |
+
hidden = tower.embeddings(tokens)
|
| 362 |
+
causal_mask = tower._update_causal_mask(None, hidden, cache_position)
|
| 363 |
+
|
| 364 |
+
for layer_idx, block in enumerate(tower.layers):
|
| 365 |
+
residual = hidden
|
| 366 |
+
h = block.norm(hidden.to(dtype=block.norm.weight.dtype))
|
| 367 |
+
if block.residual_in_fp32:
|
| 368 |
+
residual = residual.to(torch.float32)
|
| 369 |
+
|
| 370 |
+
if block.block_type == "mamba":
|
| 371 |
+
d_conv = block.mixer.conv_kernel_size
|
| 372 |
+
init_conv = ctx_cache.conv_states[layer_idx][..., -(d_conv - 1):]
|
| 373 |
+
init_ssm = ctx_cache.ssm_states[layer_idx].contiguous()
|
| 374 |
+
h, new_conv, new_ssm = self._denoiser_block_mamba(
|
| 375 |
+
block.mixer, h, init_conv, init_ssm, return_states=True,
|
| 376 |
+
)
|
| 377 |
+
ctx_cache.conv_states[layer_idx] = new_conv
|
| 378 |
+
ctx_cache.ssm_states[layer_idx] = new_ssm
|
| 379 |
+
elif block.block_type == "attention":
|
| 380 |
+
# Standard cached attention appends block KV (causal within block).
|
| 381 |
+
h, _, _ = block.mixer(
|
| 382 |
+
h, attention_mask=causal_mask,
|
| 383 |
+
past_key_value=ctx_cache, cache_position=cache_position,
|
| 384 |
+
)
|
| 385 |
+
elif block.block_type in ["mlp", "moe"]:
|
| 386 |
+
h = block.mixer(h)
|
| 387 |
+
else:
|
| 388 |
+
raise ValueError(f"Unknown block_type: {block.block_type}")
|
| 389 |
+
|
| 390 |
+
hidden = residual + h
|
| 391 |
|
| 392 |
cache_state["ctx_len"] = ctx_len + L
|
| 393 |
return cache_state
|
|
|
|
| 447 |
self.denoiser_tower, self.lm_head, den_input, den_cache, cp,
|
| 448 |
)
|
| 449 |
|
| 450 |
+
def _denoiser_block_attention(self, mixer, hidden, ctx_k, ctx_v):
|
| 451 |
+
"""Bidirectional denoiser self-attention over [context_KV | block_KV].
|
| 452 |
|
| 453 |
+
Mirrors the mcore `_forward_attn_with_past` (is_causal=False, no mask):
|
| 454 |
+
every block position attends to ALL context positions and ALL block
|
| 455 |
+
positions (the noisy block is processed bidirectionally within itself).
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
mixer: NemotronHAttention module (provides q/k/v/o projections)
|
| 459 |
+
hidden: (B, L, D) post-norm (and post-modulation) block hidden states
|
| 460 |
+
ctx_k, ctx_v: context KV, each (B, num_kv_heads, ctx_len, head_dim)
|
| 461 |
+
|
| 462 |
+
Returns: (B, L, D) attention output (before residual add)
|
| 463 |
+
"""
|
| 464 |
+
bsz, q_len, _ = hidden.shape
|
| 465 |
+
q = mixer.q_proj(hidden).view(bsz, q_len, mixer.num_heads, mixer.head_dim).transpose(1, 2)
|
| 466 |
+
k = mixer.k_proj(hidden).view(bsz, q_len, mixer.num_key_value_heads, mixer.head_dim).transpose(1, 2)
|
| 467 |
+
v = mixer.v_proj(hidden).view(bsz, q_len, mixer.num_key_value_heads, mixer.head_dim).transpose(1, 2)
|
| 468 |
+
|
| 469 |
+
# Concatenate context KV (past) with current block KV on the sequence dim.
|
| 470 |
+
k = torch.cat([ctx_k.to(k.dtype), k], dim=2)
|
| 471 |
+
v = torch.cat([ctx_v.to(v.dtype), v], dim=2)
|
| 472 |
+
|
| 473 |
+
# GQA: expand KV heads to match query heads.
|
| 474 |
+
k = repeat_kv(k, mixer.num_key_value_groups)
|
| 475 |
+
v = repeat_kv(v, mixer.num_key_value_groups)
|
| 476 |
+
|
| 477 |
+
# Full (non-causal) attention: block sees all context + whole block.
|
| 478 |
+
attn_output = F.scaled_dot_product_attention(
|
| 479 |
+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False,
|
| 480 |
+
)
|
| 481 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| 482 |
+
bsz, q_len, mixer.num_heads * mixer.head_dim
|
| 483 |
+
)
|
| 484 |
+
return mixer.o_proj(attn_output)
|
| 485 |
+
|
| 486 |
+
def _denoiser_block_mamba(self, mixer, hidden, init_conv, init_ssm, return_states=False):
|
| 487 |
+
"""Chunk-scan the whole block through the Mamba mixer, seeded from the
|
| 488 |
+
context state — mirrors mcore `forward_mamba_layer_with_states`
|
| 489 |
+
(non-bidirectional). Uses the same mamba_ssm/causal_conv1d kernels as
|
| 490 |
+
mcore, instead of HF's token-by-token single-step path (which is both a
|
| 491 |
+
numerical mismatch and crashes in this env's causal_conv1d_update).
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
mixer: NemotronHMamba2Mixer
|
| 495 |
+
hidden: (B, L, D) post-norm (and post-modulation) block hidden states
|
| 496 |
+
init_conv: (B, conv_dim, d_conv-1) context conv state, or None
|
| 497 |
+
init_ssm: (B, nheads, headdim, d_state) context SSM state, or None
|
| 498 |
+
return_states: also return the updated (conv_state[width d_conv], ssm_state)
|
| 499 |
+
so the caller can advance a KV/Mamba cache (used by context extend).
|
| 500 |
+
|
| 501 |
+
Returns: (B, L, D) mixer output (before adaLN gate / residual);
|
| 502 |
+
or (output, new_conv_state, new_ssm_state) if return_states.
|
| 503 |
+
"""
|
| 504 |
+
from einops import rearrange
|
| 505 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 506 |
+
from causal_conv1d import causal_conv1d_fn
|
| 507 |
+
|
| 508 |
+
d_inner = mixer.intermediate_size
|
| 509 |
+
ngroups = mixer.n_groups
|
| 510 |
+
d_state = mixer.ssm_state_size
|
| 511 |
+
headdim = mixer.head_dim
|
| 512 |
+
conv_dim = mixer.conv_dim
|
| 513 |
+
d_conv = mixer.conv_kernel_size
|
| 514 |
+
|
| 515 |
+
proj = mixer.in_proj(hidden) # (B, L, d_inner+conv_dim+nheads)
|
| 516 |
+
z, xBC, dt = torch.split(proj, [d_inner, conv_dim, mixer.num_heads], dim=-1)
|
| 517 |
+
|
| 518 |
+
# causal_conv1d_fn with initial_states requires channel-last layout:
|
| 519 |
+
# - input (B, conv_dim, L): use the transpose VIEW (stride(1)==1), no .contiguous()
|
| 520 |
+
# - initial_states (B, conv_dim, d_conv-1): force channel-last via the
|
| 521 |
+
# transpose->contiguous->transpose trick (mcore _run_denoiser_step).
|
| 522 |
+
if init_conv is not None:
|
| 523 |
+
init_conv = init_conv.transpose(-1, -2).contiguous().transpose(-1, -2)
|
| 524 |
+
xBC_conv = causal_conv1d_fn(
|
| 525 |
+
xBC.transpose(1, 2), # (B, conv_dim, L) channel-last view
|
| 526 |
+
mixer.conv1d.weight.squeeze(1),
|
| 527 |
+
mixer.conv1d.bias,
|
| 528 |
+
activation=mixer.activation,
|
| 529 |
+
initial_states=init_conv,
|
| 530 |
+
).transpose(1, 2) # (B, L, conv_dim)
|
| 531 |
+
|
| 532 |
+
x, B_proj, C_proj = torch.split(
|
| 533 |
+
xBC_conv, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1
|
| 534 |
+
)
|
| 535 |
+
x = rearrange(x, "b s (h p) -> b s h p", p=headdim).contiguous()
|
| 536 |
+
B_proj = rearrange(B_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
|
| 537 |
+
C_proj = rearrange(C_proj, "b s (g n) -> b s g n", n=d_state).contiguous()
|
| 538 |
+
|
| 539 |
+
A = -torch.exp(mixer.A_log.float())
|
| 540 |
+
scan = mamba_chunk_scan_combined(
|
| 541 |
+
x, dt.contiguous(), A, B_proj, C_proj, mixer.chunk_size,
|
| 542 |
+
D=mixer.D, z=None,
|
| 543 |
+
dt_bias=mixer.dt_bias.float(), dt_softplus=True,
|
| 544 |
+
initial_states=init_ssm,
|
| 545 |
+
return_final_states=return_states,
|
| 546 |
+
)
|
| 547 |
+
if return_states:
|
| 548 |
+
y, new_ssm = scan
|
| 549 |
+
else:
|
| 550 |
+
y = scan
|
| 551 |
+
y = rearrange(y, "b s h p -> b s (h p)")
|
| 552 |
+
y = mixer.norm(y, z) # Mamba2 z-gated RMSNorm
|
| 553 |
+
out = mixer.out_proj(y)
|
| 554 |
+
if not return_states:
|
| 555 |
+
return out
|
| 556 |
+
# New conv state: HF cache stores the last d_conv raw xBC inputs (width
|
| 557 |
+
# d_conv), most-recent at index -1. block_size >= d_conv here.
|
| 558 |
+
L = xBC.shape[1]
|
| 559 |
+
if L >= d_conv:
|
| 560 |
+
new_conv = xBC[:, -d_conv:, :].transpose(1, 2).contiguous()
|
| 561 |
+
else:
|
| 562 |
+
hist = init_conv if init_conv is not None else xBC.new_zeros(xBC.shape[0], conv_dim, d_conv - 1)
|
| 563 |
+
comb = torch.cat([hist.transpose(1, 2), xBC], dim=1)
|
| 564 |
+
new_conv = comb[:, -d_conv:, :].transpose(1, 2).contiguous()
|
| 565 |
+
return out, new_conv, new_ssm
|
| 566 |
+
|
| 567 |
+
def _run_denoiser_step_diffusion(self, block_ids, cache_state, t=None):
|
| 568 |
+
"""Diffusion denoiser forward over the FULL block (B, L) in one pass.
|
| 569 |
+
|
| 570 |
+
Parity with mcore `_run_denoiser_step`:
|
| 571 |
+
- Attention layers run BIDIRECTIONALLY within the block, attending to
|
| 572 |
+
the full context KV cache + the whole noisy block (is_causal=False).
|
| 573 |
+
A token-by-token causal pass would hide later block positions from
|
| 574 |
+
earlier ones.
|
| 575 |
+
- Mamba layers are causal/forward-only (bidirectional_mamba=False) and
|
| 576 |
+
are chunk-scanned over the whole block from the context state (S-1),
|
| 577 |
+
matching mcore's `forward_mamba_layer_with_states`.
|
| 578 |
+
- Time conditioning (adaLN-single) is applied per layer. The modulate/norm
|
| 579 |
+
ORDER depends on where mcore's norm lives: mamba & attention norms are
|
| 580 |
+
FUSED into in_proj/linear_qkv (applied AFTER modulate) -> modulate THEN
|
| 581 |
+
norm; MoE uses a separate pre_mlp_layernorm -> norm THEN modulate.
|
| 582 |
+
Gate is applied to the mixer output in all cases.
|
| 583 |
|
| 584 |
Args:
|
| 585 |
block_ids: (B, L) tokens to denoise
|
|
|
|
| 589 |
Returns: logits (B, L, V)
|
| 590 |
"""
|
| 591 |
ctx_len = cache_state["ctx_len"]
|
| 592 |
+
tower = self.denoiser_tower
|
| 593 |
+
den_device = next(tower.parameters()).device
|
| 594 |
den_input = block_ids.to(den_device)
|
| 595 |
L = den_input.shape[1]
|
| 596 |
|
| 597 |
+
# Time embedding -> per-layer modulation params (shift, scale, gate).
|
| 598 |
t_emb = None
|
| 599 |
if t is not None:
|
| 600 |
t_dev = t.to(device=den_device, dtype=self.dtype)
|
| 601 |
t_repr = self.t_embedder(t_dev)
|
| 602 |
t_emb = self.t_block(t_repr)
|
| 603 |
|
| 604 |
+
# Fresh denoiser cache seeded from context: Mamba S-1 state + full context KV.
|
| 605 |
den_cache = self._build_denoiser_cache_diffusion(cache_state, den_device)
|
| 606 |
|
| 607 |
+
hidden = tower.embeddings(den_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
+
for layer_idx, block in enumerate(tower.layers):
|
| 610 |
+
residual = hidden
|
| 611 |
+
if block.residual_in_fp32:
|
| 612 |
+
residual = residual.to(torch.float32)
|
| 613 |
+
|
| 614 |
+
mod = None
|
| 615 |
+
if t_emb is not None:
|
| 616 |
+
mod = _get_mod_params(t_emb, self.scale_shift_tables[layer_idx])
|
| 617 |
+
shift, scale, gate = mod
|
| 618 |
+
|
| 619 |
+
# adaLN modulate vs norm ORDER depends on where mcore's norm lives:
|
| 620 |
+
# - mamba/attention: norm is FUSED into in_proj/linear_qkv and is
|
| 621 |
+
# applied AFTER the explicit modulate -> modulate THEN norm.
|
| 622 |
+
# - moe/mlp: separate pre_mlp_layernorm applied BEFORE modulate
|
| 623 |
+
# -> norm THEN modulate.
|
| 624 |
+
if block.block_type in ("mamba", "attention"):
|
| 625 |
+
h = hidden
|
| 626 |
+
if mod is not None:
|
| 627 |
+
h = _modulate(h, shift, scale)
|
| 628 |
+
h = block.norm(h.to(dtype=block.norm.weight.dtype))
|
| 629 |
+
else: # mlp / moe
|
| 630 |
+
h = block.norm(hidden.to(dtype=block.norm.weight.dtype))
|
| 631 |
+
if mod is not None:
|
| 632 |
+
h = _modulate(h, shift, scale)
|
| 633 |
+
|
| 634 |
+
if block.block_type == "mamba":
|
| 635 |
+
# Chunk-scan the whole block in one kernel launch, seeded from the
|
| 636 |
+
# context Mamba state (matches mcore forward_mamba_layer_with_states).
|
| 637 |
+
# HF conv_states are width d_conv; causal_conv1d_fn's initial_states
|
| 638 |
+
# wants the d_conv-1 most-recent columns.
|
| 639 |
+
d_conv = block.mixer.conv_kernel_size
|
| 640 |
+
init_conv = den_cache.conv_states[layer_idx][..., -(d_conv - 1):]
|
| 641 |
+
init_ssm = den_cache.ssm_states[layer_idx].contiguous()
|
| 642 |
+
h = self._denoiser_block_mamba(block.mixer, h, init_conv, init_ssm)
|
| 643 |
+
elif block.block_type == "attention":
|
| 644 |
+
ctx_k = den_cache.key_cache[layer_idx]
|
| 645 |
+
ctx_v = den_cache.value_cache[layer_idx]
|
| 646 |
+
h = self._denoiser_block_attention(block.mixer, h, ctx_k, ctx_v)
|
| 647 |
+
elif block.block_type in ["mlp", "moe"]:
|
| 648 |
+
h = block.mixer(h)
|
| 649 |
+
else:
|
| 650 |
+
raise ValueError(f"Unknown block_type: {block.block_type}")
|
| 651 |
+
|
| 652 |
+
if mod is not None:
|
| 653 |
+
h = gate.unsqueeze(1) * h
|
| 654 |
+
|
| 655 |
+
hidden = residual + h
|
| 656 |
+
|
| 657 |
+
hidden = tower.norm_f(hidden)
|
| 658 |
+
logits = self.lm_head(hidden.to(self.lm_head.weight.dtype)).float()
|
| 659 |
+
return logits
|
| 660 |
|
| 661 |
# ------------------------------------------------------------------
|
| 662 |
# Mock-AR generation (unchanged)
|
|
|
|
| 720 |
top_k=None,
|
| 721 |
confidence_threshold=0.9,
|
| 722 |
eos_token_id=None,
|
| 723 |
+
step_callback=None,
|
| 724 |
):
|
| 725 |
"""Block-wise mask diffusion with confidence_unmasking.
|
| 726 |
|
|
|
|
| 763 |
# Initialize fully masked block
|
| 764 |
xt = torch.full((B, block_size), mask_token_id, dtype=torch.long,
|
| 765 |
device=device)
|
| 766 |
+
if step_callback is not None:
|
| 767 |
+
step_callback(0, steps_per_block, xt, t=1.0, logits=None,
|
| 768 |
+
block_idx=block_idx)
|
| 769 |
|
| 770 |
for step_idx in range(steps_per_block):
|
| 771 |
# t_model = current mask fraction
|
|
|
|
| 834 |
remask_idx = masked_indices[sort_idx[:num_to_remask[b]]]
|
| 835 |
output[b, remask_idx] = mask_token_id
|
| 836 |
|
| 837 |
+
if step_callback is not None:
|
| 838 |
+
step_callback(step_idx, steps_per_block, xt,
|
| 839 |
+
t=float(t_model.detach().cpu()), logits=logits,
|
| 840 |
+
block_idx=block_idx)
|
| 841 |
+
|
| 842 |
xt = output
|
| 843 |
|
| 844 |
# Block complete — extend context
|