--- library_name: kernels license: bsd-3-clause tags: - attention - flash-attention - flash-attn-4 - sm120 - sm121 - blackwell - rtx5090 - rtx-pro-6000 - dgx-spark - cute-dsl --- # flash-attn-4-sm120 **Flash Attention 4 (CuTe DSL) for SM120 / SM121 consumer Blackwell GPUs.** This is a downstream distribution of [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention) that bundles **six open upstream PRs** targeting consumer Blackwell hardware (RTX 5090, RTX PRO 6000, DGX Spark GB10, SM121a). Once these PRs merge upstream, prefer the upstream `flash-attn` package; this bundle exists so SM120 users can use the improvements today. ## Why this exists `flash-attn-4`'s CuTe DSL kernels work great on Hopper (SM90) and datacenter Blackwell (SM100). But SM120 (consumer Blackwell) is genuinely different hardware: - No `tcgen05` / TMEM (so FA4's primary speed path doesn't apply) - No WGMMA (so the SM90 epilogue path doesn't apply) - 99 KB shared memory capacity (vs 163 KB on SM80) - Has TMA, but only single-CTA flavor - Same SM80-era `mma.sync.aligned.m16n8k16` for FP16/BF16 MMA The PRs bundled here adapt FA4's kernels to these constraints — runtime-correct dispatch, SMEM-budget-aware tiling, paged KV that fits in 99 KB, TMA-with-warp-spec for the loaded path, and a couple of crash fixes that block dispatch entirely. ## Bundled PRs | PR | Title | |---|---| | [#2336](https://github.com/Dao-AILab/flash-attention/pull/2336) | SM120 split-KV (FlashDecoding) with FP32 partial outputs | | [#2348](https://github.com/Dao-AILab/flash-attention/pull/2348) | SM120 kernel-level paged KV cache support | | [#2349](https://github.com/Dao-AILab/flash-attention/pull/2349) | SM120 TMA forward kernel with warp specialization | | [#2389](https://github.com/Dao-AILab/flash-attention/pull/2389) | SM80 / SM120 block-sparse forward attention support | | [#2439](https://github.com/Dao-AILab/flash-attention/pull/2439) | FA4 dropout (Philox, per-element, all arches) | | [#2484](https://github.com/Dao-AILab/flash-attention/pull/2484) | SM120 init-time runtime fix + GQA `pack_gqa` workaround | ## Setup ### Hardware - NVIDIA SM120 / SM121 / SM121a (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10) - Should also work on SM80 / SM90 / SM100 since the bundle inherits from upstream `flash-attn-4`, but those paths are not the primary target ### Software - CUDA Toolkit **12.8 or newer** (FA4 baseline requirement) - PyTorch with CUDA support - `nvidia-cutlass-dsl >= 4.4.1` (auto-installed by `kernels`) - `einops`, `apache-tvm-ffi` (auto-installed) ### Installation via the `kernels` library (recommended) ```bash pip install -U kernels ``` ```python from kernels import get_kernel flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120") ``` `kernels` will download this repository, resolve dependencies, and make the package importable without any manual build step. ### Direct use (alternative) If you prefer not to use the `kernels` library, you can clone the repo and import the package directly: ```bash git clone https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120 ``` ```python import sys sys.path.insert(0, "flash-attn-4-sm120/build/torch-cuda") import importlib flash_attn_4 = importlib.import_module("flash_attn_4_sm120") # or whatever you alias the dir to ``` The `kernels.get_kernel(...)` path is recommended since it handles caching and dependency resolution automatically. ## Usage ### Basic — non-causal MHA ```python import torch from kernels import get_kernel flash_attn_4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120") B, S, H, D = 1, 1024, 16, 128 q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) k = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) v = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=False) ``` ### Causal GQA — Qwen / LLaMA family models ```python B, S, Hq, Hkv, D = 1, 2048, 16, 8, 128 # Qwen3-style GQA: Hq=16, Hkv=8 q = torch.randn(B, S, Hq, D, device="cuda", dtype=torch.bfloat16) k = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16) v = torch.randn(B, S, Hkv, D, device="cuda", dtype=torch.bfloat16) out, _ = flash_attn_4.flash_attn_func(q, k, v, causal=True) ``` ### Variable-length (production batched serving) ```python # Pack a batch of sequences with different lengths into a single flat tensor seq_lens = [128, 256, 512] total = sum(seq_lens) cu_seqlens = torch.tensor([0] + list(__import__('itertools').accumulate(seq_lens)), dtype=torch.int32, device="cuda") q = torch.randn(total, Hq, D, device="cuda", dtype=torch.bfloat16) k = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16) v = torch.randn(total, Hkv, D, device="cuda", dtype=torch.bfloat16) out, _ = flash_attn_4.flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max(seq_lens), max_seqlen_k=max(seq_lens), causal=True, ) ``` ### Paged KV (vLLM / SGLang serving pattern) ```python out, _ = flash_attn_4.flash_attn_func( q, k_paged, v_paged, page_table=page_table, seqused_k=actual_seq_lens, max_seqlen_k=max_kv_len, causal=True, ) ``` ## API Two entry points exposed at the package root: - `flash_attn_func(q, k, v, ...)` — standard attention, fixed-length within a batch - `flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)` — variable-length Parameters supported beyond upstream main: | Parameter | What it enables | PR | |---|---|---| | `page_table=...` | Paged KV cache | #2348 | | `num_splits=...` | Split-KV / FlashDecoding | #2336 | | `block_sparse_tensors=...` | Block-sparse attention | #2389 | | `dropout_p=..., dropout_seed=...` | Per-element dropout | #2439 | | (automatic) | TMA forward dispatch when viable | #2349 | Tensor layout: `(batch, seqlen, num_heads, head_dim)`, last dim contiguous, 16-byte aligned. ## Validation End-to-end on SM121a (DGX Spark GB10), bf16 + fp16, causal + non-causal, dense + varlen: | Shape category | Configurations | |---|---| | MHA (`Hq = Hkv`) | `D ∈ {64, 128}`, `S ∈ {128, 256, 512, 1024}` | | GQA Qwen3-style | `Hq=16, Hkv=8, D=128` | | GQA LLaMA3-style | `Hq=32, Hkv=8, D=128` | | MQA | `Hq=4, Hkv=1, D=128` | | Batched | `B = 2` | - **Forward**: 64 / 64 configurations pass — max diff ≤ 0.0156 vs PyTorch f32 reference - **Backward**: 40 / 40 configurations pass (dq, dk, dv all within 0.05 vs PyTorch f32 reference) - **Standalone install**: validated via `kernels.get_kernel(...)` from a clean Python venv with only `kernels`, `torch`, `nvidia-cutlass-dsl`, `apache-tvm-ffi`, `einops`, `quack-kernels` installed — no `flash-attn` dependency required. ## Performance Patched HF FA4 vs vLLM's FA2 baseline on SM121a (DGX Spark), bf16, causal, Qwen3-style GQA `Hq=16, Hkv=8, D=128`, median of 30 iters after 5 warmups: | Shape (B, S, Hq, Hkv, D) | HF FA4 (ms) | vLLM FA2 (ms) | FA4 / FA2 | | --- | --- | --- | --- | | (1, 128, 16, 8, 128) | 0.036 | 0.021 | 1.71x | | (1, 512, 16, 8, 128) | 0.053 | 0.049 | 1.07x | | (1, 1024, 16, 8, 128) | 0.106 | 0.102 | 1.04x | | (1, 2048, 16, 8, 128) | 0.289 | 0.278 | 1.04x | | (1, 4096, 16, 8, 128) | 0.976 | 0.886 | 1.10x | | (2, 512, 16, 8, 128) | 0.075 | 0.069 | 1.09x | | (4, 256, 16, 8, 128) | 0.059 | 0.049 | 1.19x | | (8, 256, 16, 8, 128) | 0.109 | 0.104 | 1.05x | At very short sequences (S = 128) FA4's dispatch overhead dominates (~70% slower than FA2). At realistic Qwen 3 prefill lengths (S = 512 to 4096) FA4 is **within 4 to 10 percent of FA2**. This is consistent with the SM120 hardware: no `tcgen05` / TMEM means FA4's primary speed path doesn't apply, so it compiles down to roughly the same SM80 era `mma.sync` compute as FA2 with a small dispatch overhead. Use this kernel for the FA4 only features (paged KV, score_mod, block sparse, dropout); use FA2 if pure attention throughput is the only goal. ## Known limitations - **GQA dispatches through the non-packed path on SM120** (PR #2484 workaround). Functionally correct on every GQA / MQA shape we tested. Throughput is within roughly 10% of fmha_v2 on the GQA shapes measured. Tracked upstream. - **`head_dim > 128` is not supported on SM120** — the 99 KB SMEM budget cannot hold the Q tile. This affects models like Qwen3.5-9B (D=256) and Qwen3-Coder-Next (D=256). vLLM's existing `fa_utils.py` gate already routes `head_size > 128` to FA2 on Blackwell; this kernel maintains that boundary. - **Split-KV not supported on SM120 in this kernel variant.** PR #2336 implements it but the bundle's `interface.py` clamps `num_splits` to 1 on SM12x. Decode workloads use a single split, which is consistent with how vLLM and SGLang configure SM120 today. - **Dropout** runs but spills registers at `tile_m=128, tile_n=128` non-causal; the bundle's `interface.py` falls back to `tile_m=128, tile_n=64` (or `tile_m=64, tile_n=64` for `D > 64`) when `dropout_p > 0`, which fixes the spill at a small throughput cost. ## Hardware support outside SM120 The bundle inherits from upstream `flash-attn-4`'s SM80 / SM90 / SM100 dispatch paths. Those should work the same as upstream main; the bundled PRs target SM120 specifically. We do not test SM80 / SM90 / SM100 — please open an issue if you find regressions. ## License BSD-3-Clause, inherited from [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). ## Credits - **Upstream**: [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention) — Tri Dao, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and contributors - **SM120 PRs and bundle packaging**: Blake Ledden, [Second Nature Computing](https://joinsecondnature.com) - **Hub packaging template**: [kernels-community/flash-attn4](https://huggingface.co/kernels-community/flash-attn4) ## Issues For bundle-specific issues (the dispatch logic, validation gaps, packaging), open an issue on this HF repo. For kernel-level issues that exist upstream, file against [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention) directly. ## See also - [`CONFLICTS_LOG.md`](https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120/blob/main/CONFLICTS_LOG.md) — detailed log of every conflict encountered while stacking the six PRs, with resolution and per-PR backport guidance