# W1-4B dLLM Base
Minimal, torch-only inference package for the **W1-4B diffusion language model**.
- Inference-only -- no training code
- Pure PyTorch -- no `transformer_engine` dependency
- Ships with `gidd`, `jump`, and `standard` samplers
- Uses a single preconverted `.safetensors` checkpoint
## Quick Start
```bash
# 1. Clone the repo
git clone https://huggingface.co/WhaletechAI/W1-4B-dLLM-Base
cd W1-4B-dLLM-Base
# 2. Install dependencies
pip install -r requirements.txt
# 3. Run inference
python -m whale4b.sample \
--checkpoint whale3.7Bdiffusion.safetensors \
--prompt "The future of AI is" \
--sampler gidd \
--steps 64 \
--max-new-tokens 128 \
--device cuda \
--dtype bf16
```
## Install
```bash
pip install -r requirements.txt
```
## Get Weights
The checkpoint is included when you `git clone` the repo (via Git LFS).
To download just the weights without cloning:
```bash
pip install huggingface_hub
huggingface-cli download WhaletechAI/W1-4B-dLLM-Base \
whale3.7Bdiffusion.safetensors \
--repo-type model \
--local-dir .
```
## Run
```bash
python -m whale4b.sample \
--checkpoint whale3.7Bdiffusion.safetensors \
--prompt "The future of AI is" \
--sampler gidd \
--steps 64 \
--max-new-tokens 128 \
--device cuda \
--dtype bf16
```
`jump` sampler:
```bash
python -m whale4b.sample \
--checkpoint whale3.7Bdiffusion.safetensors \
--prompt "The future of AI is" \
--sampler jump \
--steps 64 \
--max-new-tokens 128 \
--device cuda \
--dtype bf16
```
## Python API
```python
from whale4b import generate
text = generate(
checkpoint="whale3.7Bdiffusion.safetensors",
prompt="The future of AI is",
sampler="gidd",
steps=64,
max_new_tokens=128,
device="cuda",
dtype="bf16",
)
print(text)
```
## Parameters Guide
### Device & Precision
| Flag | Recommended | Notes |
|------|-------------|-------|
| `--device cuda` | Default if you have a GPU | Strongly recommended |
| `--device cpu` | Fallback | Works, but very slow |
| `--dtype bf16` | Default for CUDA | Best speed/quality tradeoff |
| `--dtype fp16` | CUDA fallback | Use if your GPU does not support bf16 |
| `--dtype fp32` | CPU or debug | Full precision, slowest |
### Key Arguments
| Argument | Default | Description |
|----------|---------|-------------|
| `--sampler` | `standard` | Sampling algorithm: `gidd`, `jump`, or `standard` |
| `--steps` | `64` | Number of diffusion steps (more = higher quality, slower) |
| `--max-new-tokens` | `256` | Maximum tokens to generate |
| `--temperature` | `0.0` | Sampling temperature (`0.0` = greedy) |
| `--top-k` | `0` | Top-k filtering (`0` = disabled) |
### Limits
- Maximum context length: **4096 tokens** (`max_seq_len` in config)
- Only single-file `.safetensors` checkpoints are supported; sharded `safetensors` are not supported yet
## Notes
- `whale-tokenizer/` is bundled in this repo and required for generation.
- The model code intentionally excludes training code and TE-specific paths.
## License
Apache 2.0 -- see [LICENSE](LICENSE).