DiT-MoE-diffusers

Diffusers implementation of DiT-MoE (Diffusion Transformer with Mixture of Experts) for class-conditional ImageNet generation. Each variant folder is self-contained:

  • pipeline.py β€” DiTMoEPipeline
  • scheduler/scheduler_config.json β€” DDIMScheduler (S/B) or DiTMoEFlowMatchScheduler (XL/G)
  • transformer/transformer_dit_moe.py β€” DiTMoETransformer2DModel
  • vae/ β€” AutoencoderKL (stabilityai/sd-vae-ft-mse)

ImageNet class labels

Each variant keeps an English id2label map directly in its own model_index.json (DiT-style).

  • pipe.id2label β€” inspect id β†’ English label correspondence
  • pipe.labels β€” reverse map (English synonym β†’ id), sorted for browsing
  • pipe.get_label_ids("golden retriever")
  • pipe(class_labels="golden retriever", ...) β€” string labels resolved automatically

Available checkpoints

Checkpoint Path Resolution Sampler
DiT-MoE-S/2-8E2A ./DiT-MoE-S-8E2A 256Γ—256 DDIM
DiT-MoE-B/2-8E2A ./DiT-MoE-B-8E2A 256Γ—256 DDIM
DiT-MoE-XL/2-8E2A ./DiT-MoE-XL-8E2A 256Γ—256 RF
DiT-MoE-G/2-16E2A (convert with --rectified-flow --num-experts 16) 256Γ—256 RF

Convert from official weights

conda activate rsgen
cd libs/DiT-MoE-diffusers

python scripts/convert_dit_moe_to_diffusers.py \
  --checkpoint ../../models/feizhengcong/DiT-MoE/dit_moe_s_8E2A.pt \
  --output ../../models/BiliSakura/DiT-MoE-diffusers/DiT-MoE-S-8E2A \
  --model DiT-S/2 \
  --num-experts 8 \
  --num-experts-per-tok 2 \
  --copy-vae ../../models/feizhengcong/DiT-MoE/sd-vae-ft-mse \
  --check-load

Inference

Use torch.bfloat16 on Ampere+ GPUs (default in examples and sample_dit_moe.py).

from pathlib import Path
import torch
from diffusers import DiffusionPipeline

model_dir = Path("./DiT-MoE-S-8E2A").resolve()
pipe = DiffusionPipeline.from_pretrained(
    str(model_dir),
    local_files_only=True,
    custom_pipeline=str(model_dir / "pipeline.py"),
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

print(pipe.id2label[207])
print(pipe.get_label_ids("golden retriever"))

generator = torch.Generator(device="cuda").manual_seed(42)
image = pipe(
    class_labels="golden retriever",
    height=256,
    width=256,
    num_inference_steps=50,
    guidance_scale=4.0,
    generator=generator,
).images[0]
image.save("demo.png")

Citation

@article{FeiDiTMoE2024,
  title={Scaling Diffusion Transformers to 16 Billion Parameters},
  author={Zhengcong Fei and Mingyuan Fan and Changqian Yu and Debang Li and Jusnshi Huang},
  year={2024},
  journal={arXiv preprint arXiv:2407.11633},
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Collection including BiliSakura/DiT-MoE-diffusers

Paper for BiliSakura/DiT-MoE-diffusers