madebyollin commited on
Commit
c6edc9d
·
verified ·
1 Parent(s): 159713f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -3
README.md CHANGED
@@ -1,3 +1,73 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # 🍰 Tiny AutoEncoder for FLUX.2
6
+
7
+ [TAEF2](https://github.com/madebyollin/taesd) is very tiny autoencoder which uses the same "latent API" as FLUX.2's VAE.
8
+ FLUX.2 is useful for real-time previewing of the FLUX.2 generation process, as well as general resource-constrained encoding/decoding.
9
+
10
+ This repo contains `.safetensors` versions of the TAEF2 weights.
11
+
12
+ ## Using in 🧨 diffusers
13
+
14
+ **NOTE**: Unlike TAEF1, TAEF2 isn't officially integrated into Diffusers yet. So for now you'll want some wrapper code:
15
+
16
+ ```bash
17
+ !wget -nc -nv https://raw.githubusercontent.com/madebyollin/taesd/refs/heads/main/taesd.py -O taesd.py
18
+ !wget -nc -nv https://huggingface.co/madebyollin/taef2/resolve/main/diffusion_pytorch_model.safetensors -O taef2.safetensors
19
+ ```
20
+
21
+ ```python
22
+ # Construction
23
+ from taesd import TAESD
24
+ import torch
25
+ import safetensors.torch as stt
26
+ from diffusers.utils.accelerate_utils import apply_forward_hook
27
+
28
+ class DotDict(dict):
29
+ __getattr__ = dict.__getitem__
30
+ __setattr__ = dict.__setitem__
31
+
32
+ class DiffusersTAEF2Wrapper(torch.nn.Module):
33
+ def __init__(self, taesd):
34
+ super().__init__()
35
+ self.dtype = torch.bfloat16
36
+ self.taesd = TAESD(encoder_path=None, decoder_path=None, latent_channels=32, arch_variant="flux_2").to(self.dtype)
37
+ self.taesd.load_state_dict(stt.load_file("taef2.safetensors"))
38
+ self.bn = torch.nn.BatchNorm2d(128, affine=False, eps=0.0) # default bn
39
+ self.config = DotDict(batch_norm_eps=self.bn.eps)
40
+
41
+ @apply_forward_hook
42
+ def encode(self, x):
43
+ return self.taesd.encoder(x.to(self.dtype).mul(0.5).add_(0.5)).to(x.dtype)
44
+
45
+ @apply_forward_hook
46
+ def decode(self, x, return_dict=True):
47
+ x = self.taesd.decoder(x.to(self.dtype)).mul(2).sub_(1).clamp_(-1, 1).to(x.dtype)
48
+ return dict(sample=x) if return_dict else x,
49
+
50
+ taef2_diffusers = DiffusersTAEF2Wrapper().eval().requires_grad_(False)
51
+
52
+ # Usage
53
+ from diffusers import Flux2KleinPipeline
54
+
55
+ device = "cuda"
56
+ dtype = torch.bfloat16
57
+
58
+ pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype)
59
+ pipe.vae = taef2_diffusers
60
+ pipe.enable_sequential_cpu_offload() # pipe.enable_model_cpu_offload() # pipe = pipe.to(device)
61
+
62
+ prompt = "A slice of delicious New York-style berry cheesecake"
63
+ image = pipe(
64
+ prompt=prompt,
65
+ height=1024,
66
+ width=1024,
67
+ guidance_scale=1.0,
68
+ num_inference_steps=4,
69
+ generator=torch.Generator(device="cpu").manual_seed(0)
70
+ ).images[0]
71
+ image.save("flux-klein.png")
72
+ image
73
+ ```