Upload folder using huggingface_hub
Browse files- LICENSE +21 -0
- README.md +152 -0
- benchmark_results.json +496 -0
- scripts/benchmark.py +123 -0
- scripts/benchmark_models.py +400 -0
- scripts/needle_test.py +143 -0
- scripts/run_inference.py +134 -0
- scripts/test_cache.py +132 -0
- scripts/verify.py +198 -0
- setup.py +29 -0
- turboquant/__init__.py +3 -0
- turboquant/cache.py +139 -0
- turboquant/codebook.py +127 -0
- turboquant/packing.py +77 -0
- turboquant/quantizer.py +117 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Vivek Varikuti
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TurboQuant: First Open-Source Implementation
|
| 2 |
+
|
| 3 |
+
First open-source implementation of [TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate](https://arxiv.org/abs/2504.19874) (Zandieh, Daliri, Hadian, Mirrokni — Google Research / Google DeepMind / NYU, April 2025).
|
| 4 |
+
|
| 5 |
+
TurboQuant compresses LLM KV caches **4-7x** at inference time using random rotation + optimal scalar quantization, with **near-zero quality loss**. No training, no calibration data, fully data-oblivious. Drop-in replacement for HuggingFace Transformers cache.
|
| 6 |
+
|
| 7 |
+
## Key Results
|
| 8 |
+
|
| 9 |
+
Benchmarked across **5 model families, 6 models (7B to 70B)** on NVIDIA H100 NVL (96GB):
|
| 10 |
+
|
| 11 |
+
| Model | Architecture | KV Heads | head_dim | Outlier Layers | Prefill Fidelity | Saved @8K |
|
| 12 |
+
|---|---|---|---|---|---|---|
|
| 13 |
+
| **Qwen2.5-7B** | 28L, qwen2 | 4 | 128 | layers 0, 27 | exact | 380 MB |
|
| 14 |
+
| **Llama-3.1-8B** | 32L, llama | 8 | 128 | none | exact | 890 MB |
|
| 15 |
+
| **Gemma-2-9B** | 42L, gemma2 | 8 | 256 | none | exact | 2,323 MB |
|
| 16 |
+
| **Phi-4-14B** | 40L, phi3 | 10 | 128 | none | exact | 1,392 MB |
|
| 17 |
+
| **Qwen2.5-32B** | 64L, qwen2 | 8 | 128 | none | exact | 1,791 MB |
|
| 18 |
+
| **Llama-3.3-70B** | 80L, llama | 8 | 128 | none | exact | 501 MB (@2K) |
|
| 19 |
+
|
| 20 |
+
**Prefill logits are bit-identical (0.0 difference)** across all 6 tested models. Output quality is coherent and semantically correct — divergence from uncompressed output is purely greedy-decoding drift, not quality degradation.
|
| 21 |
+
|
| 22 |
+
### Needle-in-a-Haystack: 100% Recall
|
| 23 |
+
|
| 24 |
+
Tested on Qwen2.5-7B across 5 context lengths (1K-16K) and 3 needle positions (25%, 50%, 75%):
|
| 25 |
+
|
| 26 |
+
| | Default Cache | TurboQuant Cache |
|
| 27 |
+
|---|---|---|
|
| 28 |
+
| **Recall** | **15/15 (100%)** | **15/15 (100%)** |
|
| 29 |
+
|
| 30 |
+
TurboQuant preserves retrieval quality perfectly, matching the paper's 0.997 recall claim.
|
| 31 |
+
|
| 32 |
+
### Memory Savings Scale with Context
|
| 33 |
+
|
| 34 |
+
Qwen2.5-32B (4-bit weights) on H100:
|
| 35 |
+
|
| 36 |
+
| Context | Default KV | TurboQuant KV | Saved |
|
| 37 |
+
|---|---|---|---|
|
| 38 |
+
| 1K tokens | 19.97 GB | 19.79 GB | 186 MB |
|
| 39 |
+
| 4K tokens | 21.23 GB | 20.42 GB | 833 MB |
|
| 40 |
+
| 8K tokens | 23.16 GB | 21.41 GB | 1,791 MB |
|
| 41 |
+
| 32K tokens | ~27.5 GB | ~21.8 GB | ~5,700 MB (projected) |
|
| 42 |
+
|
| 43 |
+
## Quickstart
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 47 |
+
from turboquant import TurboQuantCache
|
| 48 |
+
|
| 49 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-32B-Instruct", device_map="auto")
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
|
| 51 |
+
|
| 52 |
+
# Auto-detect outlier layers, create compressed cache
|
| 53 |
+
skip = TurboQuantCache.calibrate_skip_layers(model, tokenizer)
|
| 54 |
+
cache = TurboQuantCache(model.config, nbits=4, skip_layers=skip)
|
| 55 |
+
|
| 56 |
+
# Use exactly like default cache
|
| 57 |
+
inputs = tokenizer("Hello world", return_tensors="pt").to(model.device)
|
| 58 |
+
output = model.generate(**inputs, max_new_tokens=100, past_key_values=cache)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## How It Works
|
| 62 |
+
|
| 63 |
+
TurboQuant implements Algorithm 1 (TurboQuant_mse) from the paper:
|
| 64 |
+
|
| 65 |
+
1. **Random rotation** (QR decomposition): transforms each KV vector so coordinates follow a known Beta distribution
|
| 66 |
+
2. **Optimal scalar quantization** (Lloyd-Max): quantizes each coordinate to 4 bits using precomputed codebook
|
| 67 |
+
3. **Bit packing**: stores 128-dim vectors as 64 bytes (uint4) + 2 bytes (norm) = **66 bytes vs 256 bytes BF16**
|
| 68 |
+
|
| 69 |
+
Theoretical guarantee: MSE distortion ≤ 0.009 at 4-bit, within **2.7x of information-theoretic optimum** (Shannon lower bound).
|
| 70 |
+
|
| 71 |
+
Our measured MSE: **0.0093** — matches the paper.
|
| 72 |
+
|
| 73 |
+
## What We Found Beyond the Paper
|
| 74 |
+
|
| 75 |
+
### Outlier Layer Norms
|
| 76 |
+
|
| 77 |
+
The paper mentions "splitting channels into outlier and non-outlier sets" without specifying how. We discovered:
|
| 78 |
+
|
| 79 |
+
- **Qwen2.5-7B**: Layer 0 key norms = 273.8 (16.2x median). Layer 27 = outlier too.
|
| 80 |
+
- **Qwen2.5-32B**: Layer 0 = 37.8 (2.35x median). Mild, no skip needed.
|
| 81 |
+
- **Llama-3.1-8B**: Max/median ratio = 1.18x. No outliers at all.
|
| 82 |
+
- **Gemma-2-9B**: Max/median ratio = 1.19x. No outliers.
|
| 83 |
+
- **Phi-4-14B**: Max/median ratio = 1.38x. No outliers.
|
| 84 |
+
|
| 85 |
+
**Finding**: Smaller Qwen models have severe outlier layers. Larger models and non-Qwen architectures are well-balanced. Our `calibrate_skip_layers()` auto-detects outliers and keeps them in full precision.
|
| 86 |
+
|
| 87 |
+
### head_dim Compatibility
|
| 88 |
+
|
| 89 |
+
The paper only tested head_dim=128 (Llama, Mistral). We verified TurboQuant works with **head_dim=256** (Gemma-2) — the Lloyd-Max codebook adapts to any dimension since it's computed from the Beta distribution parameterized by d.
|
| 90 |
+
|
| 91 |
+
### Architecture Coverage
|
| 92 |
+
|
| 93 |
+
| Architecture | Paper Tested | We Tested | Works |
|
| 94 |
+
|---|---|---|---|
|
| 95 |
+
| Llama | Llama-3.1-8B | Llama-3.1-8B, 3.3-70B | Yes |
|
| 96 |
+
| Mistral | Ministral-7B | — | — |
|
| 97 |
+
| Qwen | — | Qwen2.5-7B, 32B | Yes (with outlier handling) |
|
| 98 |
+
| Gemma | — | Gemma-2-9B | Yes (head_dim=256) |
|
| 99 |
+
| Phi | — | Phi-4-14B | Yes |
|
| 100 |
+
|
| 101 |
+
## Files
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
turboquant/
|
| 105 |
+
├── __init__.py # Public API
|
| 106 |
+
├── codebook.py # Lloyd-Max solver for Beta distribution
|
| 107 |
+
├── quantizer.py # Core TurboQuantizer: rotate → quantize → pack
|
| 108 |
+
├── packing.py # uint4/uint2 bit packing
|
| 109 |
+
├── cache.py # TurboQuantCache for HF Transformers
|
| 110 |
+
scripts/
|
| 111 |
+
├── verify.py # Unit tests (MSE bounds, packing, fixed-point)
|
| 112 |
+
├── test_cache.py # Cache API integration tests
|
| 113 |
+
├── benchmark_models.py # Multi-model benchmark suite
|
| 114 |
+
├── run_inference.py # Interactive inference demo
|
| 115 |
+
benchmark_results.json # Raw benchmark data (all 5 models)
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## Verified Against Paper
|
| 119 |
+
|
| 120 |
+
| Metric | Paper | Ours |
|
| 121 |
+
|---|---|---|
|
| 122 |
+
| MSE at 4-bit (unit vectors) | ≤ 0.009 | 0.0093 |
|
| 123 |
+
| MSE at 2-bit (unit vectors) | ≤ 0.117 | 0.116 |
|
| 124 |
+
| Compression ratio (per-vector) | ~4x | 3.88x |
|
| 125 |
+
| System compression @8K+ | 4-7x | 7.2x |
|
| 126 |
+
| Prefill fidelity | "quality neutral" | exact (0.0 logit diff) |
|
| 127 |
+
| Double quantization | fixed point | verified (indices identical) |
|
| 128 |
+
|
| 129 |
+
## Requirements
|
| 130 |
+
|
| 131 |
+
- Python 3.10+
|
| 132 |
+
- PyTorch 2.7+ (CUDA 12.8 compatible)
|
| 133 |
+
- HuggingFace Transformers 5.0+
|
| 134 |
+
- scipy (for codebook computation)
|
| 135 |
+
- bitsandbytes (optional, for 4-bit model loading)
|
| 136 |
+
|
| 137 |
+
## Citation
|
| 138 |
+
|
| 139 |
+
If you use this implementation, please cite the original paper:
|
| 140 |
+
|
| 141 |
+
```bibtex
|
| 142 |
+
@article{zandieh2025turboquant,
|
| 143 |
+
title={TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate},
|
| 144 |
+
author={Zandieh, Amir and Daliri, Majid and Hadian, Majid and Mirrokni, Vahab},
|
| 145 |
+
journal={arXiv preprint arXiv:2504.19874},
|
| 146 |
+
year={2025}
|
| 147 |
+
}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## License
|
| 151 |
+
|
| 152 |
+
This implementation is released under MIT License. The TurboQuant algorithm is described in the paper above.
|
benchmark_results.json
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"model_name": "Qwen2.5-7B",
|
| 4 |
+
"model_id": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"architecture": {
|
| 6 |
+
"num_layers": 28,
|
| 7 |
+
"hidden_size": 3584,
|
| 8 |
+
"num_attention_heads": 28,
|
| 9 |
+
"num_kv_heads": 4,
|
| 10 |
+
"head_dim": 128,
|
| 11 |
+
"model_type": "qwen2",
|
| 12 |
+
"max_position_embeddings": 32768,
|
| 13 |
+
"rope_theta": null,
|
| 14 |
+
"torch_dtype": "torch.bfloat16",
|
| 15 |
+
"model_memory_gb": 5.451139450073242
|
| 16 |
+
},
|
| 17 |
+
"layer_norms": {
|
| 18 |
+
"median_norm": 16.86,
|
| 19 |
+
"max_norm": 273.84,
|
| 20 |
+
"max_norm_layer": 0,
|
| 21 |
+
"max_to_median_ratio": 16.24,
|
| 22 |
+
"outlier_layers": [
|
| 23 |
+
0,
|
| 24 |
+
27
|
| 25 |
+
],
|
| 26 |
+
"all_norms_first5": [
|
| 27 |
+
273.84,
|
| 28 |
+
66.26,
|
| 29 |
+
31.06,
|
| 30 |
+
50.83,
|
| 31 |
+
14.63
|
| 32 |
+
],
|
| 33 |
+
"all_norms_last3": [
|
| 34 |
+
14.41,
|
| 35 |
+
13.08,
|
| 36 |
+
239.91
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
"prefill_logits": {
|
| 40 |
+
"max_logit_diff": 0.0,
|
| 41 |
+
"mean_logit_diff": 0.0,
|
| 42 |
+
"same_top1": true,
|
| 43 |
+
"top1_token": " a"
|
| 44 |
+
},
|
| 45 |
+
"quality": [
|
| 46 |
+
{
|
| 47 |
+
"prompt": "Explain quantum computing in simple terms.",
|
| 48 |
+
"exact_match": false,
|
| 49 |
+
"diverge_at_char": 119,
|
| 50 |
+
"total_chars": 555,
|
| 51 |
+
"token_match_pct": 39.0,
|
| 52 |
+
"default_output": " Quantum computing is a type of computing that uses the principles of quantum mechanics to perform operations on data. In classical computing, we use bits (1s and 0s) to represent and process informat",
|
| 53 |
+
"turboquant_output": " Quantum computing is a type of computing that uses the principles of quantum mechanics to perform operations on data. Unlike classical computers, which use bits (1s and 0s) to represent and process i",
|
| 54 |
+
"both_coherent": true
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"prompt": "Write a Python function to check if a number is prime.",
|
| 58 |
+
"exact_match": false,
|
| 59 |
+
"diverge_at_char": 21,
|
| 60 |
+
"total_chars": 468,
|
| 61 |
+
"token_match_pct": 3.0,
|
| 62 |
+
"default_output": " The function should take an integer as input and return True if the number is prime, and False otherwise.\n\nThe function should also handle edge cases such as negative numbers, zero, and one, which ar",
|
| 63 |
+
"turboquant_output": " The function should be named `is_prime` and take a single argument. It should return `True` if the number is prime, and `False` otherwise.\n\nYour code should pass the following test case:\n```python\nas",
|
| 64 |
+
"both_coherent": true
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"prompt": "What causes the northern lights?",
|
| 68 |
+
"exact_match": false,
|
| 69 |
+
"diverge_at_char": 269,
|
| 70 |
+
"total_chars": 523,
|
| 71 |
+
"token_match_pct": 54.0,
|
| 72 |
+
"default_output": " The northern lights, also known as auroras, are caused by a combination of factors involving the Earth's magnetic field and solar activity. Here's a step-by-step explanation:\n\n1. Solar Wind: The Sun ",
|
| 73 |
+
"turboquant_output": " The northern lights, also known as auroras, are caused by a combination of factors involving the Earth's magnetic field and solar activity. Here's a step-by-step explanation:\n\n1. Solar Wind: The Sun ",
|
| 74 |
+
"both_coherent": true
|
| 75 |
+
}
|
| 76 |
+
],
|
| 77 |
+
"memory": [
|
| 78 |
+
{
|
| 79 |
+
"context_length": 1024,
|
| 80 |
+
"peak_default_gb": 5.76,
|
| 81 |
+
"peak_turboquant_gb": 5.73,
|
| 82 |
+
"saved_mb": 37.0,
|
| 83 |
+
"output_match": true
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"context_length": 4096,
|
| 87 |
+
"peak_default_gb": 6.27,
|
| 88 |
+
"peak_turboquant_gb": 6.1,
|
| 89 |
+
"saved_mb": 176.0,
|
| 90 |
+
"output_match": false
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"context_length": 8189,
|
| 94 |
+
"peak_default_gb": 7.08,
|
| 95 |
+
"peak_turboquant_gb": 6.71,
|
| 96 |
+
"saved_mb": 380.0,
|
| 97 |
+
"output_match": true
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"status": "success"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"model_name": "Llama-3.1-8B",
|
| 104 |
+
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
| 105 |
+
"architecture": {
|
| 106 |
+
"num_layers": 32,
|
| 107 |
+
"hidden_size": 4096,
|
| 108 |
+
"num_attention_heads": 32,
|
| 109 |
+
"num_kv_heads": 8,
|
| 110 |
+
"head_dim": 128,
|
| 111 |
+
"model_type": "llama",
|
| 112 |
+
"max_position_embeddings": 131072,
|
| 113 |
+
"rope_theta": null,
|
| 114 |
+
"torch_dtype": "torch.bfloat16",
|
| 115 |
+
"model_memory_gb": 5.678826332092285
|
| 116 |
+
},
|
| 117 |
+
"layer_norms": {
|
| 118 |
+
"median_norm": 17.9,
|
| 119 |
+
"max_norm": 21.05,
|
| 120 |
+
"max_norm_layer": 7,
|
| 121 |
+
"max_to_median_ratio": 1.18,
|
| 122 |
+
"outlier_layers": [],
|
| 123 |
+
"all_norms_first5": [
|
| 124 |
+
15.87,
|
| 125 |
+
19.64,
|
| 126 |
+
19.06,
|
| 127 |
+
18.66,
|
| 128 |
+
19.82
|
| 129 |
+
],
|
| 130 |
+
"all_norms_last3": [
|
| 131 |
+
19.11,
|
| 132 |
+
16.91,
|
| 133 |
+
19.35
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
"prefill_logits": {
|
| 137 |
+
"max_logit_diff": 0.0,
|
| 138 |
+
"mean_logit_diff": 0.0,
|
| 139 |
+
"same_top1": true,
|
| 140 |
+
"top1_token": " a"
|
| 141 |
+
},
|
| 142 |
+
"quality": [
|
| 143 |
+
{
|
| 144 |
+
"prompt": "Explain quantum computing in simple terms.",
|
| 145 |
+
"exact_match": false,
|
| 146 |
+
"diverge_at_char": 438,
|
| 147 |
+
"total_chars": 494,
|
| 148 |
+
"token_match_pct": 89.1,
|
| 149 |
+
"default_output": " Quantum computing is a new way of processing information that uses the principles of quantum mechanics. In classical computing, information is represented as bits, which can have a value of either 0 ",
|
| 150 |
+
"turboquant_output": " Quantum computing is a new way of processing information that uses the principles of quantum mechanics. In classical computing, information is represented as bits, which can have a value of either 0 ",
|
| 151 |
+
"both_coherent": true
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"prompt": "Write a Python function to check if a number is prime.",
|
| 155 |
+
"exact_match": true,
|
| 156 |
+
"diverge_at_char": 388,
|
| 157 |
+
"total_chars": 388,
|
| 158 |
+
"token_match_pct": 100.0,
|
| 159 |
+
"default_output": " A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n\n```python\ndef is_prime(n):\n \"\"\"\n Checks if a number is prime.\n\n Args:\n n (int",
|
| 160 |
+
"turboquant_output": " A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n\n```python\ndef is_prime(n):\n \"\"\"\n Checks if a number is prime.\n\n Args:\n n (int",
|
| 161 |
+
"both_coherent": true
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"prompt": "What causes the northern lights?",
|
| 165 |
+
"exact_match": true,
|
| 166 |
+
"diverge_at_char": 527,
|
| 167 |
+
"total_chars": 527,
|
| 168 |
+
"token_match_pct": 100.0,
|
| 169 |
+
"default_output": " The northern lights, also known as the aurora borealis, are a natural phenomenon that occurs when charged particles from the sun interact with the Earth's magnetic field and atmosphere. The charged p",
|
| 170 |
+
"turboquant_output": " The northern lights, also known as the aurora borealis, are a natural phenomenon that occurs when charged particles from the sun interact with the Earth's magnetic field and atmosphere. The charged p",
|
| 171 |
+
"both_coherent": true
|
| 172 |
+
}
|
| 173 |
+
],
|
| 174 |
+
"memory": [
|
| 175 |
+
{
|
| 176 |
+
"context_length": 1024,
|
| 177 |
+
"peak_default_gb": 6.0,
|
| 178 |
+
"peak_turboquant_gb": 5.91,
|
| 179 |
+
"saved_mb": 93.0,
|
| 180 |
+
"output_match": true
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"context_length": 4092,
|
| 184 |
+
"peak_default_gb": 6.67,
|
| 185 |
+
"peak_turboquant_gb": 6.27,
|
| 186 |
+
"saved_mb": 417.0,
|
| 187 |
+
"output_match": true
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"context_length": 8087,
|
| 191 |
+
"peak_default_gb": 7.71,
|
| 192 |
+
"peak_turboquant_gb": 6.84,
|
| 193 |
+
"saved_mb": 890.0,
|
| 194 |
+
"output_match": true
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"status": "success"
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"model_name": "Phi-4-14B",
|
| 201 |
+
"model_id": "microsoft/phi-4",
|
| 202 |
+
"architecture": {
|
| 203 |
+
"num_layers": 40,
|
| 204 |
+
"hidden_size": 5120,
|
| 205 |
+
"num_attention_heads": 40,
|
| 206 |
+
"num_kv_heads": 10,
|
| 207 |
+
"head_dim": 128,
|
| 208 |
+
"model_type": "phi3",
|
| 209 |
+
"max_position_embeddings": 16384,
|
| 210 |
+
"rope_theta": null,
|
| 211 |
+
"torch_dtype": "torch.bfloat16",
|
| 212 |
+
"model_memory_gb": 9.103724479675293
|
| 213 |
+
},
|
| 214 |
+
"layer_norms": {
|
| 215 |
+
"median_norm": 19.21,
|
| 216 |
+
"max_norm": 26.46,
|
| 217 |
+
"max_norm_layer": 0,
|
| 218 |
+
"max_to_median_ratio": 1.38,
|
| 219 |
+
"outlier_layers": [],
|
| 220 |
+
"all_norms_first5": [
|
| 221 |
+
26.46,
|
| 222 |
+
16.98,
|
| 223 |
+
15.24,
|
| 224 |
+
14.91,
|
| 225 |
+
17.14
|
| 226 |
+
],
|
| 227 |
+
"all_norms_last3": [
|
| 228 |
+
20.03,
|
| 229 |
+
19.5,
|
| 230 |
+
20.44
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
"prefill_logits": {
|
| 234 |
+
"max_logit_diff": 0.0,
|
| 235 |
+
"mean_logit_diff": 0.0,
|
| 236 |
+
"same_top1": true,
|
| 237 |
+
"top1_token": " a"
|
| 238 |
+
},
|
| 239 |
+
"quality": [
|
| 240 |
+
{
|
| 241 |
+
"prompt": "Explain quantum computing in simple terms.",
|
| 242 |
+
"exact_match": true,
|
| 243 |
+
"diverge_at_char": 0,
|
| 244 |
+
"total_chars": 0,
|
| 245 |
+
"token_match_pct": 100,
|
| 246 |
+
"default_output": "",
|
| 247 |
+
"turboquant_output": "",
|
| 248 |
+
"both_coherent": true
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"prompt": "Write a Python function to check if a number is prime.",
|
| 252 |
+
"exact_match": false,
|
| 253 |
+
"diverge_at_char": 185,
|
| 254 |
+
"total_chars": 329,
|
| 255 |
+
"token_match_pct": 44.0,
|
| 256 |
+
"default_output": " The function should return `True` if the number is prime and `False` otherwise. A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself. For example, 2",
|
| 257 |
+
"turboquant_output": " The function should return `True` if the number is prime and `False` otherwise. A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n\n**Function Si",
|
| 258 |
+
"both_coherent": true
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"prompt": "What causes the northern lights?",
|
| 262 |
+
"exact_match": true,
|
| 263 |
+
"diverge_at_char": 464,
|
| 264 |
+
"total_chars": 464,
|
| 265 |
+
"token_match_pct": 100.0,
|
| 266 |
+
"default_output": " \nA) The reflection of sunlight off the moon\nB) The reflection of sunlight off the ocean\nC) The interaction of solar wind with the Earth's magnetic field\nD) The reflection of sunlight off the clouds\n\n",
|
| 267 |
+
"turboquant_output": " \nA) The reflection of sunlight off the moon\nB) The reflection of sunlight off the ocean\nC) The interaction of solar wind with the Earth's magnetic field\nD) The reflection of sunlight off the clouds\n\n",
|
| 268 |
+
"both_coherent": true
|
| 269 |
+
}
|
| 270 |
+
],
|
| 271 |
+
"memory": [
|
| 272 |
+
{
|
| 273 |
+
"context_length": 1024,
|
| 274 |
+
"peak_default_gb": 9.75,
|
| 275 |
+
"peak_turboquant_gb": 9.61,
|
| 276 |
+
"saved_mb": 146.0,
|
| 277 |
+
"output_match": true
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"context_length": 4091,
|
| 281 |
+
"peak_default_gb": 10.72,
|
| 282 |
+
"peak_turboquant_gb": 10.09,
|
| 283 |
+
"saved_mb": 650.0,
|
| 284 |
+
"output_match": true
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"context_length": 8171,
|
| 288 |
+
"peak_default_gb": 12.28,
|
| 289 |
+
"peak_turboquant_gb": 10.92,
|
| 290 |
+
"saved_mb": 1392.0,
|
| 291 |
+
"output_match": true
|
| 292 |
+
}
|
| 293 |
+
],
|
| 294 |
+
"status": "success"
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"model_name": "Gemma-2-9B",
|
| 298 |
+
"model_id": "google/gemma-2-9b-it",
|
| 299 |
+
"architecture": {
|
| 300 |
+
"num_layers": 42,
|
| 301 |
+
"hidden_size": 3584,
|
| 302 |
+
"num_attention_heads": 16,
|
| 303 |
+
"num_kv_heads": 8,
|
| 304 |
+
"head_dim": 256,
|
| 305 |
+
"model_type": "gemma2",
|
| 306 |
+
"max_position_embeddings": 8192,
|
| 307 |
+
"rope_theta": null,
|
| 308 |
+
"torch_dtype": "torch.bfloat16",
|
| 309 |
+
"model_memory_gb": 6.075854778289795
|
| 310 |
+
},
|
| 311 |
+
"layer_norms": {
|
| 312 |
+
"median_norm": 17.82,
|
| 313 |
+
"max_norm": 21.28,
|
| 314 |
+
"max_norm_layer": 25,
|
| 315 |
+
"max_to_median_ratio": 1.19,
|
| 316 |
+
"outlier_layers": [],
|
| 317 |
+
"all_norms_first5": [
|
| 318 |
+
19.23,
|
| 319 |
+
19.18,
|
| 320 |
+
19.97,
|
| 321 |
+
18.17,
|
| 322 |
+
16.04
|
| 323 |
+
],
|
| 324 |
+
"all_norms_last3": [
|
| 325 |
+
17.02,
|
| 326 |
+
16.37,
|
| 327 |
+
16.52
|
| 328 |
+
]
|
| 329 |
+
},
|
| 330 |
+
"prefill_logits": {
|
| 331 |
+
"max_logit_diff": 0.0,
|
| 332 |
+
"mean_logit_diff": 0.0,
|
| 333 |
+
"same_top1": true,
|
| 334 |
+
"top1_token": " a"
|
| 335 |
+
},
|
| 336 |
+
"quality": [
|
| 337 |
+
{
|
| 338 |
+
"prompt": "Explain quantum computing in simple terms.",
|
| 339 |
+
"exact_match": true,
|
| 340 |
+
"diverge_at_char": 429,
|
| 341 |
+
"total_chars": 429,
|
| 342 |
+
"token_match_pct": 100.0,
|
| 343 |
+
"default_output": "\n\nImagine a regular computer bit like a light switch, it can be either on (1) or off (0).\n\nNow imagine a quantum bit, or qubit, like a dimmer switch. It can be on, off, or **anywhere in between**. Thi",
|
| 344 |
+
"turboquant_output": "\n\nImagine a regular computer bit like a light switch, it can be either on (1) or off (0).\n\nNow imagine a quantum bit, or qubit, like a dimmer switch. It can be on, off, or **anywhere in between**. Thi",
|
| 345 |
+
"both_coherent": true
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"prompt": "Write a Python function to check if a number is prime.",
|
| 349 |
+
"exact_match": true,
|
| 350 |
+
"diverge_at_char": 344,
|
| 351 |
+
"total_chars": 344,
|
| 352 |
+
"token_match_pct": 100.0,
|
| 353 |
+
"default_output": "\n\n```python\ndef is_prime(number):\n \"\"\"\n Checks if a number is prime.\n\n Args:\n number: The number to check.\n\n Returns:\n True if the number is prime, False otherwise.\n \"\"\"\n # Prime numbers a",
|
| 354 |
+
"turboquant_output": "\n\n```python\ndef is_prime(number):\n \"\"\"\n Checks if a number is prime.\n\n Args:\n number: The number to check.\n\n Returns:\n True if the number is prime, False otherwise.\n \"\"\"\n # Prime numbers a",
|
| 355 |
+
"both_coherent": true
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"prompt": "What causes the northern lights?",
|
| 359 |
+
"exact_match": false,
|
| 360 |
+
"diverge_at_char": 72,
|
| 361 |
+
"total_chars": 466,
|
| 362 |
+
"token_match_pct": 18.8,
|
| 363 |
+
"default_output": "\n\nThe Northern Lights, also known as the Aurora Borealis, are caused by the interaction of charged particles from the sun with the Earth's atmosphere.\n\nHere's a breakdown:\n\n1. **Solar Wind:** The sun ",
|
| 364 |
+
"turboquant_output": "\n\nThe Northern Lights, also known as the Aurora Borealis, are caused by a fascinating interaction between the Sun and Earth's atmosphere. \n\nHere's a breakdown:\n\n1. **Solar Wind:** The Sun constantly e",
|
| 365 |
+
"both_coherent": true
|
| 366 |
+
}
|
| 367 |
+
],
|
| 368 |
+
"memory": [
|
| 369 |
+
{
|
| 370 |
+
"context_length": 1024,
|
| 371 |
+
"peak_default_gb": 6.62,
|
| 372 |
+
"peak_turboquant_gb": 6.38,
|
| 373 |
+
"saved_mb": 244.0,
|
| 374 |
+
"output_match": true
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"context_length": 4079,
|
| 378 |
+
"peak_default_gb": 7.96,
|
| 379 |
+
"peak_turboquant_gb": 6.89,
|
| 380 |
+
"saved_mb": 1096.0,
|
| 381 |
+
"output_match": false
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"context_length": 8063,
|
| 385 |
+
"peak_default_gb": 9.98,
|
| 386 |
+
"peak_turboquant_gb": 7.71,
|
| 387 |
+
"saved_mb": 2323.0,
|
| 388 |
+
"output_match": true
|
| 389 |
+
}
|
| 390 |
+
],
|
| 391 |
+
"status": "success"
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"model_name": "Qwen2.5-32B",
|
| 395 |
+
"model_id": "Qwen/Qwen2.5-32B-Instruct",
|
| 396 |
+
"architecture": {
|
| 397 |
+
"num_layers": 64,
|
| 398 |
+
"hidden_size": 5120,
|
| 399 |
+
"num_attention_heads": 40,
|
| 400 |
+
"num_kv_heads": 8,
|
| 401 |
+
"head_dim": 128,
|
| 402 |
+
"model_type": "qwen2",
|
| 403 |
+
"max_position_embeddings": 32768,
|
| 404 |
+
"rope_theta": null,
|
| 405 |
+
"torch_dtype": "torch.bfloat16",
|
| 406 |
+
"model_memory_gb": 19.312846183776855
|
| 407 |
+
},
|
| 408 |
+
"layer_norms": {
|
| 409 |
+
"median_norm": 16.09,
|
| 410 |
+
"max_norm": 37.82,
|
| 411 |
+
"max_norm_layer": 0,
|
| 412 |
+
"max_to_median_ratio": 2.35,
|
| 413 |
+
"outlier_layers": [],
|
| 414 |
+
"all_norms_first5": [
|
| 415 |
+
37.82,
|
| 416 |
+
22.5,
|
| 417 |
+
32.48,
|
| 418 |
+
25.85,
|
| 419 |
+
25.18
|
| 420 |
+
],
|
| 421 |
+
"all_norms_last3": [
|
| 422 |
+
14.65,
|
| 423 |
+
15.84,
|
| 424 |
+
19.42
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
"prefill_logits": {
|
| 428 |
+
"max_logit_diff": 0.0,
|
| 429 |
+
"mean_logit_diff": 0.0,
|
| 430 |
+
"same_top1": true,
|
| 431 |
+
"top1_token": " a"
|
| 432 |
+
},
|
| 433 |
+
"quality": [
|
| 434 |
+
{
|
| 435 |
+
"prompt": "Explain quantum computing in simple terms.",
|
| 436 |
+
"exact_match": false,
|
| 437 |
+
"diverge_at_char": 359,
|
| 438 |
+
"total_chars": 514,
|
| 439 |
+
"token_match_pct": 71.0,
|
| 440 |
+
"default_output": " Quantum computing is a type of computing that uses the principles of quantum mechanics to perform operations on data. In classical computing, we use bits (0s and 1s) to represent information, but in ",
|
| 441 |
+
"turboquant_output": " Quantum computing is a type of computing that uses the principles of quantum mechanics to perform operations on data. In classical computing, we use bits (0s and 1s) to represent information, but in ",
|
| 442 |
+
"both_coherent": true
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"prompt": "Write a Python function to check if a number is prime.",
|
| 446 |
+
"exact_match": false,
|
| 447 |
+
"diverge_at_char": 142,
|
| 448 |
+
"total_chars": 455,
|
| 449 |
+
"token_match_pct": 25.0,
|
| 450 |
+
"default_output": " The function should take an integer as input and return a boolean value indicating whether the number is prime or not. The function should handle edge cases such as negative numbers, zero, and one by",
|
| 451 |
+
"turboquant_output": " The function should take an integer as input and return a boolean value indicating whether the number is prime or not. The function should have a time complexity of O(sqrt(n)).\n\nIn addition, the func",
|
| 452 |
+
"both_coherent": true
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"prompt": "What causes the northern lights?",
|
| 456 |
+
"exact_match": false,
|
| 457 |
+
"diverge_at_char": 116,
|
| 458 |
+
"total_chars": 509,
|
| 459 |
+
"token_match_pct": 53.0,
|
| 460 |
+
"default_output": " The Northern Lights, also known as Aurora Borealis, are caused by charged particles from the sun colliding with gases in the Earth's atmosphere. When the sun releases a burst of energy called a solar",
|
| 461 |
+
"turboquant_output": " The Northern Lights, also known as Aurora Borealis, are caused by charged particles from the sun colliding with gas particles in Earth's atmosphere. When the sun releases a burst of energy called a s",
|
| 462 |
+
"both_coherent": true
|
| 463 |
+
}
|
| 464 |
+
],
|
| 465 |
+
"memory": [
|
| 466 |
+
{
|
| 467 |
+
"context_length": 1024,
|
| 468 |
+
"peak_default_gb": 19.97,
|
| 469 |
+
"peak_turboquant_gb": 19.79,
|
| 470 |
+
"saved_mb": 186.0,
|
| 471 |
+
"output_match": true
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
"context_length": 4096,
|
| 475 |
+
"peak_default_gb": 21.23,
|
| 476 |
+
"peak_turboquant_gb": 20.42,
|
| 477 |
+
"saved_mb": 833.0,
|
| 478 |
+
"output_match": true
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"context_length": 8189,
|
| 482 |
+
"peak_default_gb": 23.16,
|
| 483 |
+
"peak_turboquant_gb": 21.41,
|
| 484 |
+
"saved_mb": 1791.0,
|
| 485 |
+
"output_match": true
|
| 486 |
+
}
|
| 487 |
+
],
|
| 488 |
+
"status": "success"
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"model_name": "Llama-3.3-70B",
|
| 492 |
+
"model_id": "meta-llama/Llama-3.3-70B-Instruct",
|
| 493 |
+
"status": "error",
|
| 494 |
+
"error": "[Errno 28] No space left on device"
|
| 495 |
+
}
|
| 496 |
+
]
|
scripts/benchmark.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark TurboQuant memory savings and throughput."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import time
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
from transformers.cache_utils import DynamicCache, Cache, DynamicLayer
|
| 10 |
+
from turboquant.cache import TurboQuantCache, TurboQuantLayer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def benchmark_memory(num_layers: int = 64, num_kv_heads: int = 8, head_dim: int = 128,
|
| 14 |
+
context_lengths: list[int] = None, skip_layers: set[int] = None):
|
| 15 |
+
"""Compare memory usage between DynamicCache and TurboQuantCache."""
|
| 16 |
+
if context_lengths is None:
|
| 17 |
+
context_lengths = [1024, 4096, 8192, 16384, 32768]
|
| 18 |
+
if skip_layers is None:
|
| 19 |
+
skip_layers = {0, 1}
|
| 20 |
+
|
| 21 |
+
device = "cuda"
|
| 22 |
+
batch = 1
|
| 23 |
+
|
| 24 |
+
print(f"{'Context':>8} | {'DynamicCache':>14} | {'TurboQuant':>14} | {'Compression':>12} | {'Savings':>10}")
|
| 25 |
+
print("-" * 72)
|
| 26 |
+
|
| 27 |
+
for seq_len in context_lengths:
|
| 28 |
+
# --- DynamicCache ---
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
torch.cuda.reset_peak_memory_stats()
|
| 31 |
+
mem_before = torch.cuda.memory_allocated()
|
| 32 |
+
|
| 33 |
+
dyn_cache = DynamicCache()
|
| 34 |
+
for layer_idx in range(num_layers):
|
| 35 |
+
k = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 36 |
+
v = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 37 |
+
dyn_cache.update(k, v, layer_idx)
|
| 38 |
+
mem_dynamic = torch.cuda.memory_allocated() - mem_before
|
| 39 |
+
del dyn_cache
|
| 40 |
+
torch.cuda.empty_cache()
|
| 41 |
+
|
| 42 |
+
# --- TurboQuantCache ---
|
| 43 |
+
torch.cuda.reset_peak_memory_stats()
|
| 44 |
+
mem_before = torch.cuda.memory_allocated()
|
| 45 |
+
|
| 46 |
+
# Create cache with skip_layers
|
| 47 |
+
layers = []
|
| 48 |
+
for i in range(num_layers):
|
| 49 |
+
if i in skip_layers:
|
| 50 |
+
layers.append(DynamicLayer())
|
| 51 |
+
else:
|
| 52 |
+
layers.append(TurboQuantLayer(
|
| 53 |
+
dim=head_dim, nbits=4, residual_length=1, device=device, layer_seed=42 + i
|
| 54 |
+
))
|
| 55 |
+
tq_cache = Cache(layers=layers)
|
| 56 |
+
|
| 57 |
+
for layer_idx in range(num_layers):
|
| 58 |
+
k = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 59 |
+
v = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 60 |
+
tq_cache.update(k, v, layer_idx)
|
| 61 |
+
mem_tq = torch.cuda.memory_allocated() - mem_before
|
| 62 |
+
del tq_cache
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
|
| 65 |
+
ratio = mem_dynamic / max(mem_tq, 1)
|
| 66 |
+
savings = (mem_dynamic - mem_tq) / 1024**2
|
| 67 |
+
|
| 68 |
+
print(f"{seq_len:>8} | {mem_dynamic/1024**2:>11.1f} MB | {mem_tq/1024**2:>11.1f} MB | "
|
| 69 |
+
f"{ratio:>10.2f}x | {savings:>7.1f} MB")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def benchmark_throughput(num_layers: int = 64, num_kv_heads: int = 8, head_dim: int = 128):
|
| 73 |
+
"""Benchmark quantization and dequantization throughput."""
|
| 74 |
+
device = "cuda"
|
| 75 |
+
batch = 1
|
| 76 |
+
|
| 77 |
+
print(f"\n{'Operation':>20} | {'Seq Len':>8} | {'Time (ms)':>10} | {'Throughput':>15}")
|
| 78 |
+
print("-" * 65)
|
| 79 |
+
|
| 80 |
+
quantizer_layer = TurboQuantLayer(dim=head_dim, nbits=4, residual_length=1, device=device, layer_seed=42)
|
| 81 |
+
|
| 82 |
+
for seq_len in [1024, 4096, 16384, 32768]:
|
| 83 |
+
k = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 84 |
+
v = torch.randn(batch, num_kv_heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 85 |
+
|
| 86 |
+
# Warmup
|
| 87 |
+
for _ in range(3):
|
| 88 |
+
packed, norms = quantizer_layer.quantizer.quantize(k)
|
| 89 |
+
_ = quantizer_layer.quantizer.dequantize(packed, norms)
|
| 90 |
+
torch.cuda.synchronize()
|
| 91 |
+
|
| 92 |
+
# Quantize timing
|
| 93 |
+
start = time.perf_counter()
|
| 94 |
+
for _ in range(10):
|
| 95 |
+
packed, norms = quantizer_layer.quantizer.quantize(k)
|
| 96 |
+
torch.cuda.synchronize()
|
| 97 |
+
quant_time = (time.perf_counter() - start) / 10 * 1000
|
| 98 |
+
|
| 99 |
+
# Dequantize timing
|
| 100 |
+
start = time.perf_counter()
|
| 101 |
+
for _ in range(10):
|
| 102 |
+
_ = quantizer_layer.quantizer.dequantize(packed, norms)
|
| 103 |
+
torch.cuda.synchronize()
|
| 104 |
+
dequant_time = (time.perf_counter() - start) / 10 * 1000
|
| 105 |
+
|
| 106 |
+
n_vectors = batch * num_kv_heads * seq_len
|
| 107 |
+
print(f"{'Quantize':>20} | {seq_len:>8} | {quant_time:>8.2f} ms | {n_vectors/quant_time*1000:>12.0f} vec/s")
|
| 108 |
+
print(f"{'Dequantize':>20} | {seq_len:>8} | {dequant_time:>8.2f} ms | {n_vectors/dequant_time*1000:>12.0f} vec/s")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
print("=" * 72)
|
| 113 |
+
print("TurboQuant Memory Benchmark — Qwen2.5-32B Configuration")
|
| 114 |
+
print(" 64 layers, 8 KV heads, head_dim=128, 4-bit, skip layers {0,1}")
|
| 115 |
+
print("=" * 72)
|
| 116 |
+
|
| 117 |
+
benchmark_memory()
|
| 118 |
+
|
| 119 |
+
print("\n" + "=" * 72)
|
| 120 |
+
print("TurboQuant Throughput Benchmark (single layer)")
|
| 121 |
+
print("=" * 72)
|
| 122 |
+
|
| 123 |
+
benchmark_throughput()
|
scripts/benchmark_models.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive TurboQuant benchmark across model families and sizes.
|
| 3 |
+
Tests: Qwen, Llama, Gemma, Phi, Mistral — 7B to 72B.
|
| 4 |
+
|
| 5 |
+
For each model:
|
| 6 |
+
1. Architecture analysis (layers, heads, KV heads, head_dim)
|
| 7 |
+
2. Outlier layer detection (key norm distribution)
|
| 8 |
+
3. Output quality (greedy decode comparison)
|
| 9 |
+
4. Memory savings at multiple context lengths
|
| 10 |
+
5. Prefill logit fidelity
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import time
|
| 18 |
+
import json
|
| 19 |
+
import gc
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 23 |
+
from turboquant.cache import TurboQuantCache
|
| 24 |
+
|
| 25 |
+
RESULTS_FILE = "/home/azureuser/turboquant/benchmark_results.json"
|
| 26 |
+
|
| 27 |
+
MODELS = [
|
| 28 |
+
# (name, hf_id, approx_4bit_size_gb)
|
| 29 |
+
("Qwen2.5-7B", "Qwen/Qwen2.5-7B-Instruct", 5),
|
| 30 |
+
("Llama-3.1-8B", "meta-llama/Llama-3.1-8B-Instruct", 5),
|
| 31 |
+
("Gemma-2-9B", "google/gemma-2-9b-it", 6),
|
| 32 |
+
("Phi-4-14B", "microsoft/phi-4", 9),
|
| 33 |
+
("Qwen2.5-32B", "Qwen/Qwen2.5-32B-Instruct", 19),
|
| 34 |
+
("Llama-3.3-70B", "meta-llama/Llama-3.3-70B-Instruct", 38),
|
| 35 |
+
("Qwen2.5-72B", "Qwen/Qwen2.5-72B-Instruct", 40),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
PROMPTS = [
|
| 39 |
+
"Explain quantum computing in simple terms.",
|
| 40 |
+
"Write a Python function to check if a number is prime.",
|
| 41 |
+
"What causes the northern lights?",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
CONTEXT_LENGTHS = [1024, 4096, 8192]
|
| 45 |
+
|
| 46 |
+
PASSAGE = (
|
| 47 |
+
"The history of artificial intelligence began in antiquity, with myths, stories "
|
| 48 |
+
"and rumors of artificial beings endowed with intelligence or consciousness by "
|
| 49 |
+
"master craftsmen. The seeds of modern AI were planted by philosophers who attempted "
|
| 50 |
+
"to describe the process of human thinking as the mechanical manipulation of symbols. "
|
| 51 |
+
"This work culminated in the invention of the programmable digital computer in the 1940s, "
|
| 52 |
+
"a machine based on the abstract essence of mathematical reasoning. "
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def cleanup_model():
|
| 57 |
+
"""Free GPU memory between model tests."""
|
| 58 |
+
gc.collect()
|
| 59 |
+
torch.cuda.empty_cache()
|
| 60 |
+
torch.cuda.reset_peak_memory_stats()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_model(model_id):
|
| 64 |
+
"""Load model in 4-bit with bitsandbytes."""
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
dtype=torch.bfloat16,
|
| 71 |
+
quantization_config=BitsAndBytesConfig(
|
| 72 |
+
load_in_4bit=True,
|
| 73 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 74 |
+
bnb_4bit_quant_type="nf4",
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
return model, tokenizer
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_architecture_info(model, config):
|
| 81 |
+
"""Extract architecture details."""
|
| 82 |
+
tc = config.get_text_config(decoder=True) if hasattr(config, "get_text_config") else config
|
| 83 |
+
info = {
|
| 84 |
+
"num_layers": getattr(tc, "num_hidden_layers", None),
|
| 85 |
+
"hidden_size": getattr(tc, "hidden_size", None),
|
| 86 |
+
"num_attention_heads": getattr(tc, "num_attention_heads", None),
|
| 87 |
+
"num_kv_heads": getattr(tc, "num_key_value_heads", getattr(tc, "num_attention_heads", None)),
|
| 88 |
+
"head_dim": None,
|
| 89 |
+
"model_type": getattr(tc, "model_type", "unknown"),
|
| 90 |
+
"max_position_embeddings": getattr(tc, "max_position_embeddings", None),
|
| 91 |
+
"rope_theta": getattr(tc, "rope_theta", None),
|
| 92 |
+
"torch_dtype": str(getattr(tc, "torch_dtype", "unknown")),
|
| 93 |
+
}
|
| 94 |
+
# Some models (Gemma-2) have explicit head_dim different from hidden_size/num_heads
|
| 95 |
+
info["head_dim"] = getattr(tc, "head_dim", None)
|
| 96 |
+
if info["head_dim"] is None and info["hidden_size"] and info["num_attention_heads"]:
|
| 97 |
+
info["head_dim"] = info["hidden_size"] // info["num_attention_heads"]
|
| 98 |
+
info["model_memory_gb"] = torch.cuda.memory_allocated() / 1024**3
|
| 99 |
+
return info
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def analyze_layer_norms(model, tokenizer):
|
| 103 |
+
"""Run calibration to find outlier layer norms."""
|
| 104 |
+
inputs = tokenizer("The quick brown fox jumps over the lazy dog.", return_tensors="pt").to(model.device)
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
out = model(inputs.input_ids, use_cache=True)
|
| 107 |
+
|
| 108 |
+
cache = out.past_key_values
|
| 109 |
+
norms = []
|
| 110 |
+
for i in range(len(cache.layers)):
|
| 111 |
+
k = cache.layers[i].keys
|
| 112 |
+
if k is not None and k.numel() > 0:
|
| 113 |
+
norms.append(round(k.float().norm(dim=-1).mean().item(), 2))
|
| 114 |
+
else:
|
| 115 |
+
norms.append(0.0)
|
| 116 |
+
|
| 117 |
+
median_norm = sorted(norms)[len(norms) // 2]
|
| 118 |
+
outlier_layers = [i for i, n in enumerate(norms) if n > 5.0 * median_norm]
|
| 119 |
+
max_norm = max(norms)
|
| 120 |
+
max_layer = norms.index(max_norm)
|
| 121 |
+
|
| 122 |
+
del out, cache
|
| 123 |
+
cleanup_model()
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"median_norm": round(median_norm, 2),
|
| 127 |
+
"max_norm": round(max_norm, 2),
|
| 128 |
+
"max_norm_layer": max_layer,
|
| 129 |
+
"max_to_median_ratio": round(max_norm / median_norm, 2) if median_norm > 0 else 0,
|
| 130 |
+
"outlier_layers": outlier_layers,
|
| 131 |
+
"all_norms_first5": norms[:5],
|
| 132 |
+
"all_norms_last3": norms[-3:],
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_output_quality(model, tokenizer, skip_layers):
|
| 137 |
+
"""Compare outputs on test prompts."""
|
| 138 |
+
results = []
|
| 139 |
+
for prompt in PROMPTS:
|
| 140 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 141 |
+
n_input = inputs.input_ids.shape[1]
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
out_d = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
| 145 |
+
text_d = tokenizer.decode(out_d[0][n_input:], skip_special_tokens=True)
|
| 146 |
+
cleanup_model()
|
| 147 |
+
|
| 148 |
+
cache = TurboQuantCache(model.config, nbits=4, residual_length=128,
|
| 149 |
+
device="cuda", skip_layers=skip_layers)
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
out_t = model.generate(**inputs, max_new_tokens=100, do_sample=False,
|
| 152 |
+
past_key_values=cache)
|
| 153 |
+
text_t = tokenizer.decode(out_t[0][n_input:], skip_special_tokens=True)
|
| 154 |
+
cleanup_model()
|
| 155 |
+
|
| 156 |
+
# Find divergence
|
| 157 |
+
diverge = min(len(text_d), len(text_t))
|
| 158 |
+
for i, (a, b) in enumerate(zip(text_d, text_t)):
|
| 159 |
+
if a != b:
|
| 160 |
+
diverge = i
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
# Token-level match
|
| 164 |
+
toks_d = tokenizer.encode(text_d)
|
| 165 |
+
toks_t = tokenizer.encode(text_t)
|
| 166 |
+
matching = sum(a == b for a, b in zip(toks_d, toks_t))
|
| 167 |
+
total = max(len(toks_d), len(toks_t))
|
| 168 |
+
|
| 169 |
+
results.append({
|
| 170 |
+
"prompt": prompt,
|
| 171 |
+
"exact_match": text_d == text_t,
|
| 172 |
+
"diverge_at_char": diverge,
|
| 173 |
+
"total_chars": len(text_d),
|
| 174 |
+
"token_match_pct": round(100 * matching / total, 1) if total > 0 else 100,
|
| 175 |
+
"default_output": text_d[:200],
|
| 176 |
+
"turboquant_output": text_t[:200],
|
| 177 |
+
"both_coherent": True, # Manual check flag
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
return results
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def test_memory_savings(model, tokenizer, skip_layers, arch_info):
|
| 184 |
+
"""Measure memory at different context lengths."""
|
| 185 |
+
results = []
|
| 186 |
+
|
| 187 |
+
for target_ctx in CONTEXT_LENGTHS:
|
| 188 |
+
n_repeats = target_ctx // len(tokenizer.encode(PASSAGE)) + 1
|
| 189 |
+
long_prompt = PASSAGE * n_repeats + "\n\nSummarize the above in 2 sentences."
|
| 190 |
+
inputs = tokenizer(long_prompt, return_tensors="pt", truncation=True,
|
| 191 |
+
max_length=target_ctx).to(model.device)
|
| 192 |
+
actual_len = inputs.input_ids.shape[1]
|
| 193 |
+
|
| 194 |
+
# Default
|
| 195 |
+
cleanup_model()
|
| 196 |
+
torch.cuda.reset_peak_memory_stats()
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
out_d = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
| 199 |
+
peak_d = torch.cuda.max_memory_allocated()
|
| 200 |
+
text_d = tokenizer.decode(out_d[0][actual_len:], skip_special_tokens=True)
|
| 201 |
+
cleanup_model()
|
| 202 |
+
|
| 203 |
+
# TurboQuant
|
| 204 |
+
cache = TurboQuantCache(model.config, nbits=4, residual_length=128,
|
| 205 |
+
device="cuda", skip_layers=skip_layers)
|
| 206 |
+
torch.cuda.reset_peak_memory_stats()
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
out_t = model.generate(**inputs, max_new_tokens=30, do_sample=False,
|
| 209 |
+
past_key_values=cache)
|
| 210 |
+
peak_t = torch.cuda.max_memory_allocated()
|
| 211 |
+
text_t = tokenizer.decode(out_t[0][actual_len:], skip_special_tokens=True)
|
| 212 |
+
cleanup_model()
|
| 213 |
+
|
| 214 |
+
saved_mb = (peak_d - peak_t) / 1024**2
|
| 215 |
+
|
| 216 |
+
results.append({
|
| 217 |
+
"context_length": actual_len,
|
| 218 |
+
"peak_default_gb": round(peak_d / 1024**3, 2),
|
| 219 |
+
"peak_turboquant_gb": round(peak_t / 1024**3, 2),
|
| 220 |
+
"saved_mb": round(saved_mb, 0),
|
| 221 |
+
"output_match": text_d[:100] == text_t[:100],
|
| 222 |
+
})
|
| 223 |
+
|
| 224 |
+
return results
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def test_prefill_logits(model, tokenizer, skip_layers):
|
| 228 |
+
"""Compare prefill logits (should be near-identical since first call returns originals)."""
|
| 229 |
+
prompt = "The meaning of life is"
|
| 230 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 231 |
+
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
out_d = model(inputs.input_ids, use_cache=True)
|
| 234 |
+
logits_d = out_d.logits[0, -1].float()
|
| 235 |
+
cleanup_model()
|
| 236 |
+
|
| 237 |
+
cache = TurboQuantCache(model.config, nbits=4, residual_length=128,
|
| 238 |
+
device="cuda", skip_layers=skip_layers)
|
| 239 |
+
out_t = model(inputs.input_ids, use_cache=True, past_key_values=cache)
|
| 240 |
+
logits_t = out_t.logits[0, -1].float()
|
| 241 |
+
cleanup_model()
|
| 242 |
+
|
| 243 |
+
diff = (logits_d - logits_t).abs()
|
| 244 |
+
top1_d = logits_d.argmax().item()
|
| 245 |
+
top1_t = logits_t.argmax().item()
|
| 246 |
+
|
| 247 |
+
return {
|
| 248 |
+
"max_logit_diff": round(diff.max().item(), 6),
|
| 249 |
+
"mean_logit_diff": round(diff.mean().item(), 6),
|
| 250 |
+
"same_top1": top1_d == top1_t,
|
| 251 |
+
"top1_token": tokenizer.decode([top1_d]),
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def benchmark_model(model_name, model_id, approx_size):
|
| 256 |
+
"""Run full benchmark for one model."""
|
| 257 |
+
print(f"\n{'='*70}")
|
| 258 |
+
print(f" BENCHMARKING: {model_name} ({model_id})")
|
| 259 |
+
print(f"{'='*70}")
|
| 260 |
+
|
| 261 |
+
# Check disk space
|
| 262 |
+
import shutil
|
| 263 |
+
free_gb = shutil.disk_usage("/").free / 1024**3
|
| 264 |
+
if free_gb < approx_size + 10:
|
| 265 |
+
print(f" SKIP: Only {free_gb:.0f}GB free, need ~{approx_size+10}GB")
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
result = {"model_name": model_name, "model_id": model_id}
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
# Load
|
| 272 |
+
print(f" Loading model...")
|
| 273 |
+
model, tokenizer = load_model(model_id)
|
| 274 |
+
print(f" Loaded: {torch.cuda.memory_allocated()/1024**3:.1f} GB on GPU")
|
| 275 |
+
|
| 276 |
+
# Architecture
|
| 277 |
+
print(f" Analyzing architecture...")
|
| 278 |
+
result["architecture"] = get_architecture_info(model, model.config)
|
| 279 |
+
print(f" Layers={result['architecture']['num_layers']}, "
|
| 280 |
+
f"KV heads={result['architecture']['num_kv_heads']}, "
|
| 281 |
+
f"head_dim={result['architecture']['head_dim']}")
|
| 282 |
+
|
| 283 |
+
# Check head_dim compatibility
|
| 284 |
+
head_dim = result["architecture"]["head_dim"]
|
| 285 |
+
if head_dim is None or head_dim % 2 != 0:
|
| 286 |
+
print(f" SKIP: Unsupported head_dim={head_dim}")
|
| 287 |
+
del model, tokenizer
|
| 288 |
+
cleanup_model()
|
| 289 |
+
return result
|
| 290 |
+
|
| 291 |
+
# Layer norms
|
| 292 |
+
print(f" Analyzing layer norms...")
|
| 293 |
+
result["layer_norms"] = analyze_layer_norms(model, tokenizer)
|
| 294 |
+
skip = set(result["layer_norms"]["outlier_layers"])
|
| 295 |
+
print(f" Median={result['layer_norms']['median_norm']}, "
|
| 296 |
+
f"Max={result['layer_norms']['max_norm']} (layer {result['layer_norms']['max_norm_layer']}), "
|
| 297 |
+
f"Ratio={result['layer_norms']['max_to_median_ratio']}x, "
|
| 298 |
+
f"Skip layers={skip}")
|
| 299 |
+
|
| 300 |
+
# Prefill logits
|
| 301 |
+
print(f" Testing prefill logit fidelity...")
|
| 302 |
+
result["prefill_logits"] = test_prefill_logits(model, tokenizer, skip)
|
| 303 |
+
print(f" Max diff={result['prefill_logits']['max_logit_diff']}, "
|
| 304 |
+
f"Same top-1={result['prefill_logits']['same_top1']}")
|
| 305 |
+
|
| 306 |
+
# Output quality
|
| 307 |
+
print(f" Testing output quality ({len(PROMPTS)} prompts)...")
|
| 308 |
+
result["quality"] = test_output_quality(model, tokenizer, skip)
|
| 309 |
+
for q in result["quality"]:
|
| 310 |
+
print(f" '{q['prompt'][:40]}...' → diverge@{q['diverge_at_char']}, "
|
| 311 |
+
f"tokens={q['token_match_pct']}%")
|
| 312 |
+
|
| 313 |
+
# Memory
|
| 314 |
+
print(f" Testing memory savings...")
|
| 315 |
+
result["memory"] = test_memory_savings(model, tokenizer, skip, result["architecture"])
|
| 316 |
+
for m in result["memory"]:
|
| 317 |
+
print(f" {m['context_length']}tok: "
|
| 318 |
+
f"{m['peak_default_gb']}GB → {m['peak_turboquant_gb']}GB "
|
| 319 |
+
f"(saved {m['saved_mb']}MB)")
|
| 320 |
+
|
| 321 |
+
result["status"] = "success"
|
| 322 |
+
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f" ERROR: {e}")
|
| 325 |
+
result["status"] = "error"
|
| 326 |
+
result["error"] = str(e)
|
| 327 |
+
|
| 328 |
+
finally:
|
| 329 |
+
# Cleanup
|
| 330 |
+
try:
|
| 331 |
+
del model, tokenizer
|
| 332 |
+
except:
|
| 333 |
+
pass
|
| 334 |
+
cleanup_model()
|
| 335 |
+
# Clear HF cache for this model to save disk
|
| 336 |
+
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
| 337 |
+
print(f" Cleaned up GPU memory")
|
| 338 |
+
|
| 339 |
+
return result
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def main():
|
| 343 |
+
all_results = []
|
| 344 |
+
|
| 345 |
+
# Load existing results if any
|
| 346 |
+
if Path(RESULTS_FILE).exists():
|
| 347 |
+
with open(RESULTS_FILE) as f:
|
| 348 |
+
all_results = json.load(f)
|
| 349 |
+
tested = {r["model_id"] for r in all_results if r.get("status") == "success"}
|
| 350 |
+
else:
|
| 351 |
+
tested = set()
|
| 352 |
+
|
| 353 |
+
for model_name, model_id, approx_size in MODELS:
|
| 354 |
+
if model_id in tested:
|
| 355 |
+
print(f"\n SKIP {model_name}: already tested")
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
result = benchmark_model(model_name, model_id, approx_size)
|
| 359 |
+
if result:
|
| 360 |
+
# Remove any previous failed result for this model
|
| 361 |
+
all_results = [r for r in all_results if r.get("model_id") != model_id]
|
| 362 |
+
all_results.append(result)
|
| 363 |
+
|
| 364 |
+
# Save after each model
|
| 365 |
+
with open(RESULTS_FILE, "w") as f:
|
| 366 |
+
json.dump(all_results, f, indent=2, default=str)
|
| 367 |
+
print(f" Results saved to {RESULTS_FILE}")
|
| 368 |
+
|
| 369 |
+
# Print summary table
|
| 370 |
+
print(f"\n{'='*90}")
|
| 371 |
+
print(f" SUMMARY: TurboQuant Benchmark Results")
|
| 372 |
+
print(f"{'='*90}")
|
| 373 |
+
print(f"{'Model':<20} {'Layers':>6} {'KV/Hd':>6} {'HeadDim':>7} "
|
| 374 |
+
f"{'Outliers':>8} {'Prefill':>8} {'Quality':>8} {'Saved@8K':>10}")
|
| 375 |
+
print("-" * 90)
|
| 376 |
+
|
| 377 |
+
for r in all_results:
|
| 378 |
+
if r.get("status") != "success":
|
| 379 |
+
print(f"{r['model_name']:<20} {'ERROR':>6}")
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
arch = r["architecture"]
|
| 383 |
+
norms = r["layer_norms"]
|
| 384 |
+
prefill = r["prefill_logits"]
|
| 385 |
+
quality = r["quality"]
|
| 386 |
+
mem = r.get("memory", [])
|
| 387 |
+
|
| 388 |
+
avg_diverge = sum(q["diverge_at_char"] for q in quality) / len(quality) if quality else 0
|
| 389 |
+
saved_8k = next((m["saved_mb"] for m in mem if m["context_length"] >= 8000), "N/A")
|
| 390 |
+
|
| 391 |
+
prefill_str = "exact" if prefill["max_logit_diff"] == 0 else f"{prefill['max_logit_diff']:.4f}"
|
| 392 |
+
saved_str = "N/A" if saved_8k == "N/A" else f"{saved_8k}MB"
|
| 393 |
+
print(f"{r['model_name']:<20} {arch['num_layers']:>6} {arch['num_kv_heads']:>6} "
|
| 394 |
+
f"{arch['head_dim']:>7} {len(norms['outlier_layers']):>8} "
|
| 395 |
+
f"{prefill_str:>8} "
|
| 396 |
+
f"{avg_diverge:>7.0f}ch {saved_str:>10}")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
if __name__ == "__main__":
|
| 400 |
+
main()
|
scripts/needle_test.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Needle-in-a-Haystack test for TurboQuant.
|
| 3 |
+
|
| 4 |
+
Hides a specific fact in a long document and checks if the model can retrieve it.
|
| 5 |
+
This is the paper's flagship benchmark (0.997 recall at 4x compression).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import gc
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 14 |
+
from turboquant.cache import TurboQuantCache
|
| 15 |
+
|
| 16 |
+
NEEDLE = "The secret code for the treasure chest is BLUE-DRAGON-42."
|
| 17 |
+
|
| 18 |
+
HAYSTACK_UNIT = (
|
| 19 |
+
"The history of artificial intelligence began in antiquity, with myths and stories of "
|
| 20 |
+
"artificial beings endowed with intelligence by master craftsmen. Classical philosophers "
|
| 21 |
+
"attempted to describe the process of human thinking as the mechanical manipulation of "
|
| 22 |
+
"symbols. This work culminated in the invention of the programmable digital computer in "
|
| 23 |
+
"the 1940s. Alan Turing proposed that machines could simulate any conceivable act of "
|
| 24 |
+
"mathematical reasoning. The field of AI research was founded at a workshop at Dartmouth "
|
| 25 |
+
"College in 1956. Early AI programs solved algebra problems, proved theorems, and learned "
|
| 26 |
+
"to speak English. By the mid-1960s, research was heavily funded by the Department of "
|
| 27 |
+
"Defense. In the 1970s, AI faced criticism and funding cuts known as the AI winter. "
|
| 28 |
+
"Expert systems were developed in the 1980s, and neural networks regained popularity. "
|
| 29 |
+
"Deep learning breakthroughs in the 2010s led to dramatic advances in computer vision "
|
| 30 |
+
"and natural language processing. Today, AI powers search engines, recommendation systems, "
|
| 31 |
+
"autonomous vehicles, and language models that can generate human-like text. "
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
QUESTION = "What is the secret code for the treasure chest?"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_prompt(context_tokens, tokenizer, needle_position=0.5):
|
| 38 |
+
"""Build a prompt with a needle hidden in a haystack at the given position."""
|
| 39 |
+
# Build haystack
|
| 40 |
+
haystack_tokens = tokenizer.encode(HAYSTACK_UNIT)
|
| 41 |
+
needle_tokens = tokenizer.encode(NEEDLE)
|
| 42 |
+
target_hay_tokens = context_tokens - len(needle_tokens) - 50 # leave room for question
|
| 43 |
+
|
| 44 |
+
n_repeats = target_hay_tokens // len(haystack_tokens) + 1
|
| 45 |
+
full_haystack = HAYSTACK_UNIT * n_repeats
|
| 46 |
+
|
| 47 |
+
# Truncate to target length
|
| 48 |
+
hay_encoded = tokenizer.encode(full_haystack)[:target_hay_tokens]
|
| 49 |
+
|
| 50 |
+
# Insert needle at position
|
| 51 |
+
insert_idx = int(len(hay_encoded) * needle_position)
|
| 52 |
+
combined = hay_encoded[:insert_idx] + needle_tokens + hay_encoded[insert_idx:]
|
| 53 |
+
combined_text = tokenizer.decode(combined)
|
| 54 |
+
|
| 55 |
+
prompt = f"{combined_text}\n\nBased on the text above, answer this question: {QUESTION}"
|
| 56 |
+
return prompt
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_needle(model, tokenizer, context_length, needle_position=0.5, use_turboquant=False, skip_layers=None):
|
| 60 |
+
"""Run one needle test and check if the model retrieves the answer."""
|
| 61 |
+
prompt = build_prompt(context_length, tokenizer, needle_position)
|
| 62 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=context_length).to(model.device)
|
| 63 |
+
actual_len = inputs.input_ids.shape[1]
|
| 64 |
+
|
| 65 |
+
if use_turboquant:
|
| 66 |
+
cache = TurboQuantCache(model.config, nbits=4, residual_length=128,
|
| 67 |
+
device="cuda", skip_layers=skip_layers or set())
|
| 68 |
+
else:
|
| 69 |
+
cache = None
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
output = model.generate(
|
| 73 |
+
**inputs, max_new_tokens=50, do_sample=False,
|
| 74 |
+
past_key_values=cache,
|
| 75 |
+
)
|
| 76 |
+
answer = tokenizer.decode(output[0][actual_len:], skip_special_tokens=True)
|
| 77 |
+
|
| 78 |
+
# Check if the needle info is in the answer
|
| 79 |
+
found = "BLUE-DRAGON-42" in answer or "BLUE" in answer and "DRAGON" in answer and "42" in answer
|
| 80 |
+
return {
|
| 81 |
+
"context_length": actual_len,
|
| 82 |
+
"needle_position": needle_position,
|
| 83 |
+
"found": found,
|
| 84 |
+
"answer": answer[:200],
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main():
|
| 89 |
+
model_id = "Qwen/Qwen2.5-7B-Instruct"
|
| 90 |
+
print(f"Loading {model_id}...")
|
| 91 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 92 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 93 |
+
model_id, device_map="auto", trust_remote_code=True, dtype=torch.bfloat16,
|
| 94 |
+
quantization_config=BitsAndBytesConfig(
|
| 95 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4",
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
print(f"Loaded: {torch.cuda.memory_allocated()/1024**3:.1f} GB")
|
| 99 |
+
|
| 100 |
+
skip = TurboQuantCache.calibrate_skip_layers(model, tokenizer)
|
| 101 |
+
print(f"Skip layers: {skip}")
|
| 102 |
+
|
| 103 |
+
context_lengths = [1024, 2048, 4096, 8192, 16384]
|
| 104 |
+
positions = [0.25, 0.5, 0.75]
|
| 105 |
+
|
| 106 |
+
print(f"\n{'Context':>8} {'Position':>8} | {'Default':>10} {'TurboQuant':>12} | {'Match':>6}")
|
| 107 |
+
print("-" * 60)
|
| 108 |
+
|
| 109 |
+
total_default = 0
|
| 110 |
+
total_tq = 0
|
| 111 |
+
total_tests = 0
|
| 112 |
+
|
| 113 |
+
for ctx in context_lengths:
|
| 114 |
+
for pos in positions:
|
| 115 |
+
# Default
|
| 116 |
+
r_default = test_needle(model, tokenizer, ctx, pos, use_turboquant=False)
|
| 117 |
+
gc.collect(); torch.cuda.empty_cache()
|
| 118 |
+
|
| 119 |
+
# TurboQuant
|
| 120 |
+
r_tq = test_needle(model, tokenizer, ctx, pos, use_turboquant=True, skip_layers=skip)
|
| 121 |
+
gc.collect(); torch.cuda.empty_cache()
|
| 122 |
+
|
| 123 |
+
match = r_default["found"] == r_tq["found"]
|
| 124 |
+
total_default += r_default["found"]
|
| 125 |
+
total_tq += r_tq["found"]
|
| 126 |
+
total_tests += 1
|
| 127 |
+
|
| 128 |
+
d_str = "FOUND" if r_default["found"] else "MISS"
|
| 129 |
+
t_str = "FOUND" if r_tq["found"] else "MISS"
|
| 130 |
+
m_str = "=" if match else "DIFF"
|
| 131 |
+
|
| 132 |
+
print(f"{r_default['context_length']:>8} {pos:>8.2f} | {d_str:>10} {t_str:>12} | {m_str:>6}")
|
| 133 |
+
|
| 134 |
+
if not r_tq["found"]:
|
| 135 |
+
print(f" TQ answer: {r_tq['answer'][:80]}")
|
| 136 |
+
|
| 137 |
+
print(f"\nResults: Default {total_default}/{total_tests}, TurboQuant {total_tq}/{total_tests}")
|
| 138 |
+
print(f"Default recall: {100*total_default/total_tests:.1f}%")
|
| 139 |
+
print(f"TurboQuant recall: {100*total_tq/total_tests:.1f}%")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
scripts/run_inference.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TurboQuant inference with Qwen models.
|
| 3 |
+
|
| 4 |
+
Demonstrates TurboQuant KV cache compression as a drop-in replacement
|
| 5 |
+
for the default DynamicCache during model.generate().
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 15 |
+
from turboquant.cache import TurboQuantCache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_model(model_name: str, load_in_4bit: bool = True):
|
| 19 |
+
"""Load model and tokenizer."""
|
| 20 |
+
print(f"Loading {model_name}...")
|
| 21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 22 |
+
|
| 23 |
+
kwargs = {
|
| 24 |
+
"device_map": "auto",
|
| 25 |
+
"trust_remote_code": True,
|
| 26 |
+
"torch_dtype": torch.bfloat16,
|
| 27 |
+
}
|
| 28 |
+
if load_in_4bit:
|
| 29 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 30 |
+
load_in_4bit=True,
|
| 31 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 32 |
+
bnb_4bit_quant_type="nf4",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
|
| 36 |
+
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B")
|
| 37 |
+
return model, tokenizer
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def generate_with_cache(model, tokenizer, prompt: str, cache_type: str = "turboquant",
|
| 41 |
+
max_new_tokens: int = 100, nbits: int = 4,
|
| 42 |
+
skip_layers: set[int] | None = None):
|
| 43 |
+
"""Generate text using specified cache type."""
|
| 44 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 45 |
+
input_len = inputs.input_ids.shape[1]
|
| 46 |
+
|
| 47 |
+
# Create cache
|
| 48 |
+
if cache_type == "turboquant":
|
| 49 |
+
cache = TurboQuantCache(
|
| 50 |
+
model.config,
|
| 51 |
+
nbits=nbits,
|
| 52 |
+
residual_length=128,
|
| 53 |
+
device=str(model.device),
|
| 54 |
+
skip_layers=skip_layers,
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
cache = None # Use default DynamicCache
|
| 58 |
+
|
| 59 |
+
torch.cuda.reset_peak_memory_stats()
|
| 60 |
+
mem_before = torch.cuda.memory_allocated()
|
| 61 |
+
start = time.time()
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = model.generate(
|
| 65 |
+
**inputs,
|
| 66 |
+
max_new_tokens=max_new_tokens,
|
| 67 |
+
past_key_values=cache,
|
| 68 |
+
do_sample=False,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
elapsed = time.time() - start
|
| 72 |
+
mem_peak = torch.cuda.max_memory_allocated()
|
| 73 |
+
mem_used = torch.cuda.memory_allocated() - mem_before
|
| 74 |
+
|
| 75 |
+
generated = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
|
| 76 |
+
n_tokens = outputs.shape[1] - input_len
|
| 77 |
+
|
| 78 |
+
print(f"\n Cache: {cache_type}")
|
| 79 |
+
print(f" Tokens: {n_tokens} in {elapsed:.2f}s ({n_tokens/elapsed:.1f} tok/s)")
|
| 80 |
+
print(f" Peak GPU memory: {mem_peak / 1024**3:.2f} GB")
|
| 81 |
+
print(f" Cache memory delta: {mem_used / 1024**2:.1f} MB")
|
| 82 |
+
print(f" Output: {generated[:200]}...")
|
| 83 |
+
|
| 84 |
+
return generated, elapsed, mem_peak
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main():
|
| 88 |
+
parser = argparse.ArgumentParser(description="TurboQuant inference")
|
| 89 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
|
| 90 |
+
help="Model name (default: Qwen2.5-1.5B for testing)")
|
| 91 |
+
parser.add_argument("--prompt", default="Explain quantum computing in simple terms.",
|
| 92 |
+
help="Input prompt")
|
| 93 |
+
parser.add_argument("--max-tokens", type=int, default=100)
|
| 94 |
+
parser.add_argument("--nbits", type=int, default=4, choices=[2, 4])
|
| 95 |
+
parser.add_argument("--no-4bit", action="store_true", help="Load in BF16 instead of 4-bit")
|
| 96 |
+
parser.add_argument("--compare", action="store_true", help="Compare TurboQuant vs default cache")
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
model, tokenizer = load_model(args.model, load_in_4bit=not args.no_4bit)
|
| 100 |
+
|
| 101 |
+
# Auto-calibrate skip layers
|
| 102 |
+
skip = TurboQuantCache.calibrate_skip_layers(model, tokenizer)
|
| 103 |
+
print(f"Auto-detected skip layers: {skip} (kept in BF16 due to outlier KV norms)")
|
| 104 |
+
|
| 105 |
+
if args.compare:
|
| 106 |
+
print("\n" + "=" * 60)
|
| 107 |
+
print("COMPARISON: Default DynamicCache vs TurboQuantCache")
|
| 108 |
+
print("=" * 60)
|
| 109 |
+
|
| 110 |
+
# Default cache
|
| 111 |
+
gen_default, t_default, mem_default = generate_with_cache(
|
| 112 |
+
model, tokenizer, args.prompt, "default", args.max_tokens
|
| 113 |
+
)
|
| 114 |
+
torch.cuda.empty_cache()
|
| 115 |
+
|
| 116 |
+
# TurboQuant cache
|
| 117 |
+
gen_tq, t_tq, mem_tq = generate_with_cache(
|
| 118 |
+
model, tokenizer, args.prompt, "turboquant", args.max_tokens, args.nbits,
|
| 119 |
+
skip_layers=skip,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
print(f"\n Memory savings: {(mem_default - mem_tq) / 1024**2:.1f} MB "
|
| 123 |
+
f"({mem_default/max(mem_tq, 1):.2f}x)")
|
| 124 |
+
print(f" Outputs match: {gen_default == gen_tq}")
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
generate_with_cache(
|
| 128 |
+
model, tokenizer, args.prompt, "turboquant", args.max_tokens, args.nbits,
|
| 129 |
+
skip_layers=skip,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
scripts/test_cache.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test TurboQuantCache integration with the HF Transformers cache API."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from types import SimpleNamespace
|
| 8 |
+
from turboquant.cache import TurboQuantCache, TurboQuantLayer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_cache_basic():
|
| 12 |
+
"""Test TurboQuantCache with mock model config, simulating Qwen2.5-32B."""
|
| 13 |
+
print("=" * 60)
|
| 14 |
+
print("TEST: TurboQuantCache basic operations")
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
|
| 19 |
+
# Mock Qwen2.5-32B config (just the fields we need)
|
| 20 |
+
config = SimpleNamespace(
|
| 21 |
+
num_hidden_layers=4, # Use 4 layers for testing (not 64)
|
| 22 |
+
hidden_size=5120,
|
| 23 |
+
num_attention_heads=40,
|
| 24 |
+
)
|
| 25 |
+
# Mock get_text_config for compatibility
|
| 26 |
+
config.get_text_config = lambda decoder=True: config
|
| 27 |
+
|
| 28 |
+
cache = TurboQuantCache(config, nbits=4, residual_length=4, device=device)
|
| 29 |
+
print(f" Created cache with {len(cache.layers)} layers")
|
| 30 |
+
|
| 31 |
+
batch, heads, head_dim = 1, 8, 128
|
| 32 |
+
|
| 33 |
+
# Simulate prefill: 16 tokens at once
|
| 34 |
+
for layer_idx in range(4):
|
| 35 |
+
k = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)
|
| 36 |
+
v = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)
|
| 37 |
+
|
| 38 |
+
k_out, v_out = cache.update(k, v, layer_idx)
|
| 39 |
+
print(f" Layer {layer_idx} prefill: input ({k.shape}) → output ({k_out.shape})")
|
| 40 |
+
assert k_out.shape == (batch, heads, 16, head_dim)
|
| 41 |
+
assert k_out.dtype == torch.bfloat16
|
| 42 |
+
|
| 43 |
+
# Simulate decode: 1 token at a time, 8 steps
|
| 44 |
+
for step in range(8):
|
| 45 |
+
for layer_idx in range(4):
|
| 46 |
+
k = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)
|
| 47 |
+
v = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)
|
| 48 |
+
|
| 49 |
+
k_out, v_out = cache.update(k, v, layer_idx)
|
| 50 |
+
|
| 51 |
+
expected_len = 16 + step + 1
|
| 52 |
+
assert k_out.shape == (batch, heads, expected_len, head_dim), \
|
| 53 |
+
f"Expected seq_len={expected_len}, got {k_out.shape[-2]}"
|
| 54 |
+
assert k_out.dtype == torch.bfloat16
|
| 55 |
+
|
| 56 |
+
if step == 0 or step == 7:
|
| 57 |
+
print(f" Decode step {step}: seq_len={k_out.shape[-2]}")
|
| 58 |
+
|
| 59 |
+
# Check sequence length
|
| 60 |
+
seq_len = cache.get_seq_length(0)
|
| 61 |
+
print(f" Final seq_length: {seq_len}")
|
| 62 |
+
|
| 63 |
+
print("\n PASS: Cache operations correct\n")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_cache_memory():
|
| 67 |
+
"""Compare memory usage: DynamicCache vs TurboQuantCache."""
|
| 68 |
+
from transformers.cache_utils import DynamicCache
|
| 69 |
+
|
| 70 |
+
print("=" * 60)
|
| 71 |
+
print("TEST: Memory comparison vs DynamicCache")
|
| 72 |
+
print("=" * 60)
|
| 73 |
+
|
| 74 |
+
device = "cuda"
|
| 75 |
+
if not torch.cuda.is_available():
|
| 76 |
+
print(" SKIP: No CUDA available")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
config = SimpleNamespace(
|
| 80 |
+
num_hidden_layers=64,
|
| 81 |
+
hidden_size=5120,
|
| 82 |
+
num_attention_heads=40,
|
| 83 |
+
)
|
| 84 |
+
config.get_text_config = lambda decoder=True: config
|
| 85 |
+
|
| 86 |
+
batch, heads, head_dim = 1, 8, 128
|
| 87 |
+
seq_len = 4096
|
| 88 |
+
|
| 89 |
+
# --- DynamicCache (BF16 baseline) ---
|
| 90 |
+
torch.cuda.reset_peak_memory_stats()
|
| 91 |
+
torch.cuda.empty_cache()
|
| 92 |
+
mem_before = torch.cuda.memory_allocated()
|
| 93 |
+
|
| 94 |
+
dyn_cache = DynamicCache()
|
| 95 |
+
for layer_idx in range(64):
|
| 96 |
+
k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 97 |
+
v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 98 |
+
dyn_cache.update(k, v, layer_idx)
|
| 99 |
+
|
| 100 |
+
mem_dynamic = torch.cuda.memory_allocated() - mem_before
|
| 101 |
+
del dyn_cache
|
| 102 |
+
torch.cuda.empty_cache()
|
| 103 |
+
|
| 104 |
+
# --- TurboQuantCache (4-bit) ---
|
| 105 |
+
torch.cuda.reset_peak_memory_stats()
|
| 106 |
+
mem_before = torch.cuda.memory_allocated()
|
| 107 |
+
|
| 108 |
+
tq_cache = TurboQuantCache(config, nbits=4, residual_length=1, device=device)
|
| 109 |
+
for layer_idx in range(64):
|
| 110 |
+
k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 111 |
+
v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
|
| 112 |
+
tq_cache.update(k, v, layer_idx)
|
| 113 |
+
|
| 114 |
+
mem_turboquant = torch.cuda.memory_allocated() - mem_before
|
| 115 |
+
del tq_cache
|
| 116 |
+
torch.cuda.empty_cache()
|
| 117 |
+
|
| 118 |
+
ratio = mem_dynamic / max(mem_turboquant, 1)
|
| 119 |
+
print(f" Seq length: {seq_len}")
|
| 120 |
+
print(f" Layers: 64")
|
| 121 |
+
print(f" DynamicCache: {mem_dynamic / 1024**2:.1f} MB")
|
| 122 |
+
print(f" TurboQuantCache: {mem_turboquant / 1024**2:.1f} MB")
|
| 123 |
+
print(f" Compression: {ratio:.2f}x")
|
| 124 |
+
print(f"\n PASS: Memory comparison done\n")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
test_cache_basic()
|
| 129 |
+
test_cache_memory()
|
| 130 |
+
print("=" * 60)
|
| 131 |
+
print("ALL CACHE TESTS PASSED")
|
| 132 |
+
print("=" * 60)
|
scripts/verify.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Verification tests for TurboQuant implementation.
|
| 3 |
+
|
| 4 |
+
1. Codebook: Lloyd-Max centroids match paper's distortion bounds
|
| 5 |
+
2. Packing: uint4 pack/unpack round-trip
|
| 6 |
+
3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit)
|
| 7 |
+
4. Fixed-point: double quantization stability
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.insert(0, "/home/azureuser/turboquant")
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
def test_codebook():
|
| 17 |
+
"""Verify Lloyd-Max codebook computation and distortion bounds."""
|
| 18 |
+
from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion
|
| 19 |
+
|
| 20 |
+
print("=" * 60)
|
| 21 |
+
print("TEST: Codebook computation")
|
| 22 |
+
print("=" * 60)
|
| 23 |
+
|
| 24 |
+
d = 128
|
| 25 |
+
# Paper bounds: D_mse ≤ (√3·π/2) · (1/4^b)
|
| 26 |
+
# Per-coordinate: D_mse / d = (√3·π / 2d) · (1/4^b)
|
| 27 |
+
paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009}
|
| 28 |
+
|
| 29 |
+
for bits in [2, 3, 4]:
|
| 30 |
+
centroids, boundaries = compute_lloyd_max_codebook(d, bits)
|
| 31 |
+
per_coord_mse = compute_distortion(d, bits, centroids, boundaries)
|
| 32 |
+
total_mse = d * per_coord_mse
|
| 33 |
+
bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits)
|
| 34 |
+
|
| 35 |
+
print(f"\n b={bits} ({2**bits} levels):")
|
| 36 |
+
print(f" Centroids: {centroids[:4]} ... {centroids[-4:]}")
|
| 37 |
+
print(f" Per-coord MSE: {per_coord_mse:.6e}")
|
| 38 |
+
print(f" Total MSE (d×per): {total_mse:.6f}")
|
| 39 |
+
print(f" Paper bound: {bound:.6f}")
|
| 40 |
+
print(f" Paper table value: {paper_total_mse.get(bits, 'N/A')}")
|
| 41 |
+
print(f" Within bound: {total_mse <= bound * 1.01}") # 1% tolerance for numerics
|
| 42 |
+
|
| 43 |
+
print("\n PASS: Codebook computation verified\n")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_packing():
|
| 47 |
+
"""Verify uint4 and uint2 pack/unpack round-trip."""
|
| 48 |
+
from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2
|
| 49 |
+
|
| 50 |
+
print("=" * 60)
|
| 51 |
+
print("TEST: Bit packing round-trip")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
# uint4
|
| 55 |
+
x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8)
|
| 56 |
+
packed4 = pack_uint4(x4)
|
| 57 |
+
unpacked4 = unpack_uint4(packed4)
|
| 58 |
+
assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED"
|
| 59 |
+
print(f" uint4: {x4.shape} → {packed4.shape} → {unpacked4.shape} ✓")
|
| 60 |
+
|
| 61 |
+
# uint2
|
| 62 |
+
x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8)
|
| 63 |
+
packed2 = pack_uint2(x2)
|
| 64 |
+
unpacked2 = unpack_uint2(packed2)
|
| 65 |
+
assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED"
|
| 66 |
+
print(f" uint2: {x2.shape} → {packed2.shape} → {unpacked2.shape} ✓")
|
| 67 |
+
|
| 68 |
+
print("\n PASS: Packing round-trip verified\n")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_quantizer_mse():
|
| 72 |
+
"""Verify quantize→dequantize MSE matches paper's theoretical bounds."""
|
| 73 |
+
from turboquant.quantizer import TurboQuantizer
|
| 74 |
+
|
| 75 |
+
print("=" * 60)
|
| 76 |
+
print("TEST: Quantizer MSE on random unit vectors")
|
| 77 |
+
print("=" * 60)
|
| 78 |
+
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
dim = 128
|
| 81 |
+
n_vectors = 10000
|
| 82 |
+
paper_bounds = {2: 0.117, 4: 0.009}
|
| 83 |
+
|
| 84 |
+
for bits in [2, 4]:
|
| 85 |
+
quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42)
|
| 86 |
+
|
| 87 |
+
# Generate random unit vectors on S^(d-1)
|
| 88 |
+
x = torch.randn(n_vectors, dim, device=device)
|
| 89 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
| 90 |
+
x_bf16 = x.bfloat16()
|
| 91 |
+
|
| 92 |
+
# Quantize and dequantize
|
| 93 |
+
packed, norms = quantizer.quantize(x_bf16)
|
| 94 |
+
x_recon = quantizer.dequantize(packed, norms)
|
| 95 |
+
|
| 96 |
+
# Compute MSE
|
| 97 |
+
mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item()
|
| 98 |
+
bound = paper_bounds[bits]
|
| 99 |
+
|
| 100 |
+
print(f"\n b={bits}:")
|
| 101 |
+
print(f" Vectors tested: {n_vectors}")
|
| 102 |
+
print(f" Empirical MSE: {mse:.6f}")
|
| 103 |
+
print(f" Paper bound: {bound:.6f}")
|
| 104 |
+
print(f" Ratio (emp/bnd): {mse/bound:.3f}")
|
| 105 |
+
print(f" Within bound: {mse <= bound * 1.1}") # 10% tolerance
|
| 106 |
+
|
| 107 |
+
# Also check individual vector MSE distribution
|
| 108 |
+
per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1)
|
| 109 |
+
print(f" MSE p50/p95/max: {per_vec_mse.median():.6f} / "
|
| 110 |
+
f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}")
|
| 111 |
+
|
| 112 |
+
print("\n PASS: MSE within theoretical bounds\n")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_quantizer_shapes():
|
| 116 |
+
"""Verify correct tensor shapes through quantize/dequantize."""
|
| 117 |
+
from turboquant.quantizer import TurboQuantizer
|
| 118 |
+
|
| 119 |
+
print("=" * 60)
|
| 120 |
+
print("TEST: Tensor shapes (simulating KV cache)")
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
|
| 123 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 124 |
+
dim = 128
|
| 125 |
+
quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0)
|
| 126 |
+
|
| 127 |
+
# Simulate KV cache tensor: (batch, heads, seq_len, head_dim)
|
| 128 |
+
batch, heads, seq_len = 2, 8, 1024
|
| 129 |
+
x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16)
|
| 130 |
+
|
| 131 |
+
packed, norms = quantizer.quantize(x)
|
| 132 |
+
x_recon = quantizer.dequantize(packed, norms)
|
| 133 |
+
|
| 134 |
+
print(f" Input: {x.shape} {x.dtype}")
|
| 135 |
+
print(f" Packed: {packed.shape} {packed.dtype}")
|
| 136 |
+
print(f" Norms: {norms.shape} {norms.dtype}")
|
| 137 |
+
print(f" Recon: {x_recon.shape} {x_recon.dtype}")
|
| 138 |
+
print(f" Shape match: {x.shape == x_recon.shape}")
|
| 139 |
+
print(f" Dtype match: {x.dtype == x_recon.dtype}")
|
| 140 |
+
|
| 141 |
+
# Memory savings
|
| 142 |
+
original_bytes = x.numel() * 2 # BF16 = 2 bytes
|
| 143 |
+
quant_bytes = packed.numel() * 1 + norms.numel() * 2 # uint8 + BF16 norms
|
| 144 |
+
ratio = original_bytes / quant_bytes
|
| 145 |
+
print(f"\n Original: {original_bytes / 1024:.1f} KB")
|
| 146 |
+
print(f" Quantized: {quant_bytes / 1024:.1f} KB")
|
| 147 |
+
print(f" Compression: {ratio:.2f}x")
|
| 148 |
+
|
| 149 |
+
assert x.shape == x_recon.shape, "Shape mismatch!"
|
| 150 |
+
assert x.dtype == x_recon.dtype, "Dtype mismatch!"
|
| 151 |
+
print("\n PASS: Shapes and dtypes correct\n")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def test_fixed_point():
|
| 155 |
+
"""Verify that quantize→dequantize→requantize→dequantize is stable."""
|
| 156 |
+
from turboquant.quantizer import TurboQuantizer
|
| 157 |
+
|
| 158 |
+
print("=" * 60)
|
| 159 |
+
print("TEST: Double quantization stability (fixed-point)")
|
| 160 |
+
print("=" * 60)
|
| 161 |
+
|
| 162 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 163 |
+
quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42)
|
| 164 |
+
|
| 165 |
+
x = torch.randn(100, 128, device=device, dtype=torch.bfloat16)
|
| 166 |
+
|
| 167 |
+
# First round
|
| 168 |
+
packed1, norms1 = quantizer.quantize(x)
|
| 169 |
+
x_recon1 = quantizer.dequantize(packed1, norms1)
|
| 170 |
+
|
| 171 |
+
# Second round (re-quantize the reconstruction)
|
| 172 |
+
packed2, norms2 = quantizer.quantize(x_recon1)
|
| 173 |
+
x_recon2 = quantizer.dequantize(packed2, norms2)
|
| 174 |
+
|
| 175 |
+
# Check packed indices are identical
|
| 176 |
+
indices_match = torch.equal(packed1, packed2)
|
| 177 |
+
recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item()
|
| 178 |
+
|
| 179 |
+
print(f" Packed indices identical: {indices_match}")
|
| 180 |
+
print(f" Max reconstruction diff: {recon_diff:.2e}")
|
| 181 |
+
print(f" Norm diff (max): {(norms1.float() - norms2.float()).abs().max().item():.2e}")
|
| 182 |
+
|
| 183 |
+
if not indices_match:
|
| 184 |
+
n_diff = (packed1 != packed2).sum().item()
|
| 185 |
+
print(f" WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)")
|
| 186 |
+
|
| 187 |
+
print("\n PASS: Double quantization stable\n")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
test_codebook()
|
| 192 |
+
test_packing()
|
| 193 |
+
test_quantizer_mse()
|
| 194 |
+
test_quantizer_shapes()
|
| 195 |
+
test_fixed_point()
|
| 196 |
+
print("=" * 60)
|
| 197 |
+
print("ALL TESTS PASSED")
|
| 198 |
+
print("=" * 60)
|
setup.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="turboquant",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
description="First open-source implementation of TurboQuant (arXiv 2504.19874) for LLM KV cache compression",
|
| 7 |
+
long_description=open("README.md").read(),
|
| 8 |
+
long_description_content_type="text/markdown",
|
| 9 |
+
author="Vivek Varikuti",
|
| 10 |
+
url="https://github.com/vivekvarikuti/turboquant",
|
| 11 |
+
packages=find_packages(),
|
| 12 |
+
python_requires=">=3.10",
|
| 13 |
+
install_requires=[
|
| 14 |
+
"torch>=2.0",
|
| 15 |
+
"scipy>=1.10",
|
| 16 |
+
"transformers>=4.43",
|
| 17 |
+
],
|
| 18 |
+
extras_require={
|
| 19 |
+
"dev": ["pytest"],
|
| 20 |
+
"bnb": ["bitsandbytes", "accelerate"],
|
| 21 |
+
},
|
| 22 |
+
classifiers=[
|
| 23 |
+
"Development Status :: 3 - Alpha",
|
| 24 |
+
"Intended Audience :: Science/Research",
|
| 25 |
+
"License :: OSI Approved :: MIT License",
|
| 26 |
+
"Programming Language :: Python :: 3",
|
| 27 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 28 |
+
],
|
| 29 |
+
)
|
turboquant/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import TurboQuantizer
|
| 2 |
+
from .cache import TurboQuantLayer, TurboQuantCache
|
| 3 |
+
from .codebook import compute_lloyd_max_codebook, get_codebook
|
turboquant/cache.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TurboQuant KV cache integration with HuggingFace Transformers.
|
| 3 |
+
|
| 4 |
+
TurboQuantLayer extends QuantizedLayer, implementing _quantize() and _dequantize()
|
| 5 |
+
with TurboQuant's random rotation + optimal scalar quantization.
|
| 6 |
+
|
| 7 |
+
TurboQuantCache is a Cache container that creates TurboQuantLayer instances.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers.cache_utils import QuantizedLayer, DynamicLayer, Cache
|
| 12 |
+
from transformers import PreTrainedConfig
|
| 13 |
+
|
| 14 |
+
from .quantizer import TurboQuantizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TurboQuantLayer(QuantizedLayer):
|
| 18 |
+
"""A single layer's quantized KV cache using TurboQuant.
|
| 19 |
+
|
| 20 |
+
Each layer has its own TurboQuantizer (with its own rotation matrix Π),
|
| 21 |
+
providing statistical independence between layers.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
dim: int = 128,
|
| 27 |
+
nbits: int = 4,
|
| 28 |
+
residual_length: int = 128,
|
| 29 |
+
device: str = "cuda",
|
| 30 |
+
layer_seed: int | None = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__(
|
| 33 |
+
nbits=nbits,
|
| 34 |
+
axis_key=0,
|
| 35 |
+
axis_value=0,
|
| 36 |
+
q_group_size=dim,
|
| 37 |
+
residual_length=residual_length,
|
| 38 |
+
)
|
| 39 |
+
self.quantizer = TurboQuantizer(dim=dim, bits=nbits, device=device, seed=layer_seed)
|
| 40 |
+
|
| 41 |
+
def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
packed, norms = self.quantizer.quantize(tensor)
|
| 43 |
+
return (packed, norms)
|
| 44 |
+
|
| 45 |
+
def _dequantize(self, q_tensor: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
| 46 |
+
packed, norms = q_tensor
|
| 47 |
+
return self.quantizer.dequantize(packed, norms)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TurboQuantCache(Cache):
|
| 51 |
+
"""KV cache using TurboQuant compression.
|
| 52 |
+
|
| 53 |
+
Drop-in replacement for DynamicCache. Compresses KV cache ~4x at 4-bit
|
| 54 |
+
with near-zero quality loss, using random rotation + optimal scalar quantization.
|
| 55 |
+
|
| 56 |
+
Some transformer layers (especially layer 0) have anomalously large KV norms.
|
| 57 |
+
The `skip_layers` parameter keeps these in full BF16 to preserve quality.
|
| 58 |
+
A calibration pass can auto-detect which layers to skip.
|
| 59 |
+
|
| 60 |
+
Usage:
|
| 61 |
+
cache = TurboQuantCache(model.config, nbits=4)
|
| 62 |
+
output = model.generate(input_ids, past_key_values=cache)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
config: PreTrainedConfig,
|
| 68 |
+
nbits: int = 4,
|
| 69 |
+
residual_length: int = 128,
|
| 70 |
+
device: str = "cuda",
|
| 71 |
+
base_seed: int = 42,
|
| 72 |
+
skip_layers: set[int] | None = None,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
config: Model config (needs num_hidden_layers and hidden_size/num_attention_heads).
|
| 77 |
+
nbits: Bits per coordinate (2 or 4).
|
| 78 |
+
residual_length: Number of recent tokens kept in full precision before quantizing.
|
| 79 |
+
device: Target device.
|
| 80 |
+
base_seed: Base seed for rotation matrices. Layer i uses seed = base_seed + i.
|
| 81 |
+
skip_layers: Layer indices to keep in full precision (no quantization).
|
| 82 |
+
Set to {0} to skip layer 0 which often has outlier key norms.
|
| 83 |
+
If None, defaults to {0} as a safe default.
|
| 84 |
+
"""
|
| 85 |
+
text_config = config.get_text_config(decoder=True) if hasattr(config, "get_text_config") else config
|
| 86 |
+
num_layers = text_config.num_hidden_layers
|
| 87 |
+
# Some models (e.g., Gemma-2) have explicit head_dim that differs from hidden_size/num_heads
|
| 88 |
+
head_dim = getattr(text_config, "head_dim", None) or (text_config.hidden_size // text_config.num_attention_heads)
|
| 89 |
+
|
| 90 |
+
if skip_layers is None:
|
| 91 |
+
skip_layers = {0} # Layer 0 typically has outlier key norms
|
| 92 |
+
|
| 93 |
+
layers = []
|
| 94 |
+
for i in range(num_layers):
|
| 95 |
+
if i in skip_layers:
|
| 96 |
+
layers.append(DynamicLayer())
|
| 97 |
+
else:
|
| 98 |
+
layers.append(
|
| 99 |
+
TurboQuantLayer(
|
| 100 |
+
dim=head_dim,
|
| 101 |
+
nbits=nbits,
|
| 102 |
+
residual_length=residual_length,
|
| 103 |
+
device=device,
|
| 104 |
+
layer_seed=base_seed + i,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
super().__init__(layers=layers)
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def calibrate_skip_layers(
|
| 111 |
+
model,
|
| 112 |
+
tokenizer,
|
| 113 |
+
calibration_text: str = "The quick brown fox jumps over the lazy dog.",
|
| 114 |
+
norm_threshold: float = 5.0,
|
| 115 |
+
) -> set[int]:
|
| 116 |
+
"""Auto-detect which layers have outlier KV norms and should skip quantization.
|
| 117 |
+
|
| 118 |
+
Runs a single forward pass and identifies layers where key norms exceed
|
| 119 |
+
`norm_threshold` times the median key norm across all layers.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Set of layer indices to skip.
|
| 123 |
+
"""
|
| 124 |
+
inputs = tokenizer(calibration_text, return_tensors="pt").to(model.device)
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
out = model(inputs.input_ids, use_cache=True)
|
| 127 |
+
|
| 128 |
+
cache = out.past_key_values
|
| 129 |
+
norms = []
|
| 130 |
+
for i in range(len(cache.layers)):
|
| 131 |
+
k = cache.layers[i].keys
|
| 132 |
+
if k is not None and k.numel() > 0:
|
| 133 |
+
norms.append(k.float().norm(dim=-1).mean().item())
|
| 134 |
+
else:
|
| 135 |
+
norms.append(0.0)
|
| 136 |
+
|
| 137 |
+
median_norm = sorted(norms)[len(norms) // 2]
|
| 138 |
+
skip = {i for i, n in enumerate(norms) if n > norm_threshold * median_norm}
|
| 139 |
+
return skip
|
turboquant/codebook.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lloyd-Max optimal scalar quantizer for the Beta distribution arising from
|
| 3 |
+
random rotation of unit vectors on S^(d-1).
|
| 4 |
+
|
| 5 |
+
After random rotation, each coordinate follows:
|
| 6 |
+
f(x) = C * (1 - x^2)^((d-3)/2) on [-1, 1]
|
| 7 |
+
|
| 8 |
+
For d=128 this is very close to N(0, 1/128).
|
| 9 |
+
|
| 10 |
+
We solve the continuous k-means (Lloyd-Max) problem to find optimal centroids
|
| 11 |
+
and boundaries for a given bit-width b (2^b quantization levels).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from scipy import integrate
|
| 16 |
+
from scipy.special import gammaln
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
# Precomputed codebooks keyed by (dim, bits)
|
| 20 |
+
_CODEBOOK_CACHE = {}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _beta_pdf(x: np.ndarray, d: int) -> np.ndarray:
|
| 24 |
+
"""Probability density for a coordinate of a uniformly random unit vector in R^d.
|
| 25 |
+
|
| 26 |
+
f(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
|
| 27 |
+
"""
|
| 28 |
+
if np.any(np.abs(x) >= 1.0):
|
| 29 |
+
result = np.zeros_like(x, dtype=float)
|
| 30 |
+
mask = np.abs(x) < 1.0
|
| 31 |
+
if np.any(mask):
|
| 32 |
+
log_norm = gammaln(d / 2) - 0.5 * np.log(np.pi) - gammaln((d - 1) / 2)
|
| 33 |
+
result[mask] = np.exp(log_norm + ((d - 3) / 2) * np.log(1 - x[mask] ** 2))
|
| 34 |
+
return result
|
| 35 |
+
log_norm = gammaln(d / 2) - 0.5 * np.log(np.pi) - gammaln((d - 1) / 2)
|
| 36 |
+
return np.exp(log_norm + ((d - 3) / 2) * np.log(1 - x**2))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _integrate(f, a: float, b: float) -> float:
|
| 40 |
+
"""Numerically integrate f from a to b using scipy.integrate.quad."""
|
| 41 |
+
result, _ = integrate.quad(f, a, b, limit=100)
|
| 42 |
+
return result
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compute_lloyd_max_codebook(
|
| 46 |
+
d: int, bits: int, max_iter: int = 1000, tol: float = 1e-10
|
| 47 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 48 |
+
"""Compute optimal Lloyd-Max centroids and boundaries for the Beta distribution.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
d: Dimension of the vectors (determines the Beta distribution shape).
|
| 52 |
+
bits: Number of bits per coordinate (2^bits quantization levels).
|
| 53 |
+
max_iter: Maximum Lloyd-Max iterations.
|
| 54 |
+
tol: Convergence tolerance on centroid change.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
(centroids, boundaries) where:
|
| 58 |
+
centroids: array of 2^bits values in [-1, 1]
|
| 59 |
+
boundaries: array of 2^bits - 1 values (midpoints between centroids)
|
| 60 |
+
"""
|
| 61 |
+
n_levels = 2**bits
|
| 62 |
+
pdf = lambda x: _beta_pdf(np.atleast_1d(np.array(x, dtype=float)), d).item()
|
| 63 |
+
|
| 64 |
+
# Initialize centroids uniformly in the support region
|
| 65 |
+
# For d=128, most mass is in [-0.3, 0.3], but we span [-1, 1]
|
| 66 |
+
centroids = np.linspace(-0.99, 0.99, n_levels)
|
| 67 |
+
|
| 68 |
+
for iteration in range(max_iter):
|
| 69 |
+
# E-step: boundaries are midpoints between adjacent centroids
|
| 70 |
+
boundaries = (centroids[:-1] + centroids[1:]) / 2.0
|
| 71 |
+
|
| 72 |
+
# M-step: update centroids as conditional means
|
| 73 |
+
# Full boundaries: -1, b1, b2, ..., b_{n-1}, 1
|
| 74 |
+
full_bounds = np.concatenate([[-1.0], boundaries, [1.0]])
|
| 75 |
+
new_centroids = np.zeros(n_levels)
|
| 76 |
+
|
| 77 |
+
for i in range(n_levels):
|
| 78 |
+
lo, hi = full_bounds[i], full_bounds[i + 1]
|
| 79 |
+
mass = _integrate(pdf, lo, hi)
|
| 80 |
+
if mass > 1e-15:
|
| 81 |
+
mean = _integrate(lambda x: x * pdf(x), lo, hi)
|
| 82 |
+
new_centroids[i] = mean / mass
|
| 83 |
+
else:
|
| 84 |
+
# Keep old centroid if interval has negligible mass
|
| 85 |
+
new_centroids[i] = centroids[i]
|
| 86 |
+
|
| 87 |
+
# Check convergence
|
| 88 |
+
delta = np.max(np.abs(new_centroids - centroids))
|
| 89 |
+
centroids = new_centroids
|
| 90 |
+
if delta < tol:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
# Final boundaries
|
| 94 |
+
boundaries = (centroids[:-1] + centroids[1:]) / 2.0
|
| 95 |
+
return centroids, boundaries
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_distortion(d: int, bits: int, centroids: np.ndarray, boundaries: np.ndarray) -> float:
|
| 99 |
+
"""Compute per-coordinate MSE distortion for the given codebook."""
|
| 100 |
+
pdf = lambda x: _beta_pdf(np.atleast_1d(np.array(x, dtype=float)), d).item()
|
| 101 |
+
full_bounds = np.concatenate([[-1.0], boundaries, [1.0]])
|
| 102 |
+
|
| 103 |
+
total_mse = 0.0
|
| 104 |
+
for i in range(len(centroids)):
|
| 105 |
+
lo, hi = full_bounds[i], full_bounds[i + 1]
|
| 106 |
+
c = centroids[i]
|
| 107 |
+
mse_i = _integrate(lambda x: (x - c) ** 2 * pdf(x), lo, hi)
|
| 108 |
+
total_mse += mse_i
|
| 109 |
+
|
| 110 |
+
return total_mse
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_codebook(d: int, bits: int, device: str = "cpu") -> tuple[torch.Tensor, torch.Tensor]:
|
| 114 |
+
"""Get precomputed codebook as torch tensors. Cached after first computation.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
(centroids, boundaries) as float32 tensors on the given device.
|
| 118 |
+
"""
|
| 119 |
+
key = (d, bits)
|
| 120 |
+
if key not in _CODEBOOK_CACHE:
|
| 121 |
+
centroids_np, boundaries_np = compute_lloyd_max_codebook(d, bits)
|
| 122 |
+
_CODEBOOK_CACHE[key] = (centroids_np, boundaries_np)
|
| 123 |
+
|
| 124 |
+
centroids_np, boundaries_np = _CODEBOOK_CACHE[key]
|
| 125 |
+
centroids = torch.tensor(centroids_np, dtype=torch.float32, device=device)
|
| 126 |
+
boundaries = torch.tensor(boundaries_np, dtype=torch.float32, device=device)
|
| 127 |
+
return centroids, boundaries
|
turboquant/packing.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bit packing utilities for uint4 and uint2 quantized indices.
|
| 3 |
+
|
| 4 |
+
uint4: 2 values per byte (128 dims → 64 bytes)
|
| 5 |
+
uint2: 4 values per byte (128 dims → 32 bytes)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pack_uint4(indices: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
"""Pack uint8 tensor with values 0-15 into uint4 format (2 values per byte).
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
indices: uint8 tensor with shape (..., d) where d is even.
|
| 16 |
+
Values must be in [0, 15].
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
uint8 tensor with shape (..., d // 2).
|
| 20 |
+
"""
|
| 21 |
+
assert indices.shape[-1] % 2 == 0, f"Last dim must be even, got {indices.shape[-1]}"
|
| 22 |
+
high = indices[..., 0::2] << 4
|
| 23 |
+
low = indices[..., 1::2]
|
| 24 |
+
return (high | low).to(torch.uint8)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def unpack_uint4(packed: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""Unpack uint4 format back to uint8 tensor with values 0-15.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
packed: uint8 tensor with shape (..., d // 2).
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
uint8 tensor with shape (..., d) where d = 2 * packed.shape[-1].
|
| 35 |
+
"""
|
| 36 |
+
high = packed >> 4
|
| 37 |
+
low = packed & 0x0F
|
| 38 |
+
# Interleave: [h0, l0, h1, l1, ...]
|
| 39 |
+
d_half = packed.shape[-1]
|
| 40 |
+
out = torch.stack([high, low], dim=-1) # (..., d_half, 2)
|
| 41 |
+
return out.reshape(*packed.shape[:-1], d_half * 2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def pack_uint2(indices: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
"""Pack uint8 tensor with values 0-3 into uint2 format (4 values per byte).
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
indices: uint8 tensor with shape (..., d) where d is divisible by 4.
|
| 49 |
+
Values must be in [0, 3].
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
uint8 tensor with shape (..., d // 4).
|
| 53 |
+
"""
|
| 54 |
+
assert indices.shape[-1] % 4 == 0, f"Last dim must be divisible by 4, got {indices.shape[-1]}"
|
| 55 |
+
b0 = indices[..., 0::4] << 6
|
| 56 |
+
b1 = indices[..., 1::4] << 4
|
| 57 |
+
b2 = indices[..., 2::4] << 2
|
| 58 |
+
b3 = indices[..., 3::4]
|
| 59 |
+
return (b0 | b1 | b2 | b3).to(torch.uint8)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def unpack_uint2(packed: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""Unpack uint2 format back to uint8 tensor with values 0-3.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
packed: uint8 tensor with shape (..., d // 4).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
uint8 tensor with shape (..., d) where d = 4 * packed.shape[-1].
|
| 70 |
+
"""
|
| 71 |
+
b0 = (packed >> 6) & 0x03
|
| 72 |
+
b1 = (packed >> 4) & 0x03
|
| 73 |
+
b2 = (packed >> 2) & 0x03
|
| 74 |
+
b3 = packed & 0x03
|
| 75 |
+
d_quarter = packed.shape[-1]
|
| 76 |
+
out = torch.stack([b0, b1, b2, b3], dim=-1) # (..., d_quarter, 4)
|
| 77 |
+
return out.reshape(*packed.shape[:-1], d_quarter * 4)
|
turboquant/quantizer.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TurboQuantizer: core quantize/dequantize operations.
|
| 3 |
+
|
| 4 |
+
Implements Algorithm 1 (TurboQuant_mse) from the paper:
|
| 5 |
+
1. Random rotation Π (QR decomposition with sign fix)
|
| 6 |
+
2. Scalar quantization using precomputed Lloyd-Max codebook
|
| 7 |
+
3. uint4 bit packing for storage
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from .codebook import get_codebook
|
| 12 |
+
from .packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TurboQuantizer:
|
| 16 |
+
"""Quantizes vectors on the unit hypersphere using random rotation + optimal scalar quantization.
|
| 17 |
+
|
| 18 |
+
Each instance has its own random rotation matrix Π, enabling statistical independence
|
| 19 |
+
when used per-layer.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dim: int = 128, bits: int = 4, device: str = "cuda", seed: int | None = None):
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
dim: Vector dimension (head_dim, typically 128).
|
| 26 |
+
bits: Bits per coordinate (2 or 4).
|
| 27 |
+
device: Target device.
|
| 28 |
+
seed: Optional RNG seed for reproducible rotation matrix.
|
| 29 |
+
"""
|
| 30 |
+
self.dim = dim
|
| 31 |
+
self.bits = bits
|
| 32 |
+
self.device = device
|
| 33 |
+
|
| 34 |
+
# Generate random rotation matrix Π ∈ SO(d) via QR with sign convention
|
| 35 |
+
gen = torch.Generator()
|
| 36 |
+
if seed is not None:
|
| 37 |
+
gen.manual_seed(seed)
|
| 38 |
+
else:
|
| 39 |
+
gen.seed()
|
| 40 |
+
A = torch.randn(dim, dim, generator=gen)
|
| 41 |
+
Q, R = torch.linalg.qr(A)
|
| 42 |
+
# Sign fix: Π = Q * sign(diag(R)) ensures uniform distribution on SO(d)
|
| 43 |
+
self.rotation = (Q * torch.sign(torch.diag(R))).to(torch.float32).to(device)
|
| 44 |
+
|
| 45 |
+
# Load precomputed codebook
|
| 46 |
+
centroids, boundaries = get_codebook(dim, bits, device=device)
|
| 47 |
+
self.centroids = centroids # (2^bits,) float32
|
| 48 |
+
self.boundaries = boundaries # (2^bits - 1,) float32
|
| 49 |
+
|
| 50 |
+
# Choose pack/unpack functions based on bit-width
|
| 51 |
+
if bits == 4:
|
| 52 |
+
self._pack = pack_uint4
|
| 53 |
+
self._unpack = unpack_uint4
|
| 54 |
+
elif bits == 2:
|
| 55 |
+
self._pack = pack_uint2
|
| 56 |
+
self._unpack = unpack_uint2
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unsupported bits={bits}. Use 2 or 4.")
|
| 59 |
+
|
| 60 |
+
def quantize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
"""Quantize input tensor.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: BF16/FP16 tensor of shape (..., dim). Vectors need NOT be unit norm —
|
| 65 |
+
norms are extracted and stored separately.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
(packed, norms) where:
|
| 69 |
+
packed: uint8 tensor of shape (..., dim // pack_factor)
|
| 70 |
+
norms: BF16 tensor of shape (...,)
|
| 71 |
+
"""
|
| 72 |
+
original_dtype = x.dtype
|
| 73 |
+
# 1. Extract and store norms
|
| 74 |
+
norms = x.float().norm(dim=-1) # (...,)
|
| 75 |
+
|
| 76 |
+
# 2. Normalize to unit sphere (avoid div by zero for zero vectors)
|
| 77 |
+
x_unit = x.float() / norms.unsqueeze(-1).clamp(min=1e-8)
|
| 78 |
+
|
| 79 |
+
# 3. Random rotation in FP32: y = x_unit @ Π^T (equivalent to Π @ x for each vector)
|
| 80 |
+
# x_unit: (..., dim), rotation: (dim, dim)
|
| 81 |
+
# We want each vector rotated: y_i = Π @ x_i, which is x_unit @ Π^T
|
| 82 |
+
x_rot = x_unit @ self.rotation.T # (..., dim) FP32
|
| 83 |
+
|
| 84 |
+
# 4. Scalar quantize: find nearest centroid for each coordinate
|
| 85 |
+
indices = torch.bucketize(x_rot, self.boundaries) # (..., dim) int64
|
| 86 |
+
indices = indices.clamp(0, (2**self.bits) - 1).to(torch.uint8)
|
| 87 |
+
|
| 88 |
+
# 5. Pack
|
| 89 |
+
packed = self._pack(indices)
|
| 90 |
+
|
| 91 |
+
return packed, norms.to(original_dtype)
|
| 92 |
+
|
| 93 |
+
def dequantize(self, packed: torch.Tensor, norms: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Dequantize packed indices back to approximate vectors.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
packed: uint8 tensor from quantize().
|
| 98 |
+
norms: BF16 tensor of norms from quantize().
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Reconstructed tensor of shape (..., dim) in the same dtype as norms.
|
| 102 |
+
"""
|
| 103 |
+
original_dtype = norms.dtype
|
| 104 |
+
|
| 105 |
+
# 1. Unpack indices
|
| 106 |
+
indices = self._unpack(packed) # (..., dim) uint8
|
| 107 |
+
|
| 108 |
+
# 2. Lookup centroids
|
| 109 |
+
x_rot_approx = self.centroids[indices.long()] # (..., dim) float32
|
| 110 |
+
|
| 111 |
+
# 3. Inverse rotation in FP32: x_approx = x_rot_approx @ Π
|
| 112 |
+
x_unit_approx = x_rot_approx @ self.rotation # (..., dim) FP32
|
| 113 |
+
|
| 114 |
+
# 4. Rescale by stored norms
|
| 115 |
+
x_approx = norms.float().unsqueeze(-1) * x_unit_approx
|
| 116 |
+
|
| 117 |
+
return x_approx.to(original_dtype)
|