""" TerraMind tokenizer - diffusers-style ModelMixin/ConfigMixin. Usage: from terramind_tokenizer import TerraMindTokenizer # Load from local or Hub (diffusers-style) tokenizer = TerraMindTokenizer.from_pretrained("BiliSakura/TerraMind-1.0-Tokenizer-S1RTC") # or: tokenizer = TerraMindTokenizer.from_pretrained("./path/to/model") # Tokenize S1RTC input [B, 2, 256, 256] tokens = tokenizer.tokenize(x) quant, code_loss, tokens = tokenizer.encode(x) """ import inspect import os import sys from pathlib import Path from typing import Optional, Tuple, Union import torch # Ensure repo root is in path (for _terramind_tokenizer) when loaded via trust_remote_code. # Diffusers copies this file to a cache; we need the original repo to import _terramind_tokenizer. def _get_repo_root() -> Optional[Path]: _here = Path(__file__).resolve().parent # If _terramind_tokenizer exists next to this file, we're in the repo if (_here / "_terramind_tokenizer").exists(): return _here # Otherwise we're in the diffusers cache; find repo from caller's pretrained_model_name_or_path for fi in inspect.stack(): for name in ("pretrained_model_name_or_path", "pretrained_model_or_path"): if name in fi.frame.f_locals: p = fi.frame.f_locals[name] if p: path = Path(str(p)).resolve() if path.is_dir() and (path / "_terramind_tokenizer").exists(): return path if path.is_file(): parent = path.parent if (parent / "_terramind_tokenizer").exists(): return parent # Fallback: for Hub IDs (e.g. "BiliSakura/TerraMind-1.0-Tokenizer-S1RTC"), use snapshot_download for fi in inspect.stack(): for name in ("pretrained_model_name_or_path", "pretrained_model_or_path"): if name in fi.frame.f_locals: p = fi.frame.f_locals[name] if p and isinstance(p, str) and "/" in p and not Path(p).exists(): try: from huggingface_hub import snapshot_download return Path(snapshot_download(p)) except Exception: pass return None _repo_root = _get_repo_root() if _repo_root is not None and str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) from diffusers import ConfigMixin, ModelMixin from diffusers.configuration_utils import register_to_config from safetensors.torch import load_file from _terramind_tokenizer.tokenizer_register import build_vqvae class TerraMindTokenizer(ModelMixin, ConfigMixin): """ TerraMind VQ tokenizer in native diffusers style. Subclasses ModelMixin and ConfigMixin for standard from_pretrained/from_config, save_pretrained, and config handling. Uses ViT encoder + FSQ quantization. """ config_name = "config.json" @register_to_config def __init__( self, model_type: str = "divae", image_size: int = 256, n_channels: int = 12, encoder_type: str = "vit_b_enc", decoder_type: Optional[str] = None, prediction_type: Optional[str] = None, post_mlp: bool = True, patch_size: int = 16, patch_size_dec: Optional[int] = None, quant_type: str = "fsq", codebook_size: str = "8-8-8-6-5", latent_dim: int = 5, clip_sample: Optional[bool] = None, auto_map: Optional[dict] = None, # for diffusers AutoModel, ignored **kwargs, ): super().__init__() # Map config keys to build_vqvae (enc_type/dec_type) enc_type = encoder_type dec_type = decoder_type model = build_vqvae( model_type=model_type, ckpt_path=None, enc_type=enc_type, dec_type=dec_type, image_size=image_size, n_channels=n_channels, post_mlp=post_mlp, patch_size=patch_size, quant_type=quant_type, codebook_size=codebook_size, latent_dim=latent_dim, **{k: v for k, v in kwargs.items() if k not in ("_class_name", "model_type", "auto_map")}, ) # Expose submodules so state_dict keys match saved checkpoint self.encoder = model.encoder self.quant_proj = model.quant_proj self.quantize = model.quantize self.cls_emb = getattr(model, "cls_emb", None) self.undo_std = getattr(model, "undo_std", False) def prepare_input(self, x: torch.Tensor) -> torch.Tensor: """Preprocess input before encoding.""" if self.undo_std: raise NotImplementedError("undo_std is not supported in slim tokenizer.") if self.cls_emb is not None: from einops import rearrange x = rearrange(self.cls_emb(x), "b h w c -> b c h w") return x def encode( self, x: torch.Tensor, return_dict: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.LongTensor], dict]: """ Encode input to quantized latent and token indices. Args: x: Input tensor [B, C, H, W] (e.g. S1RTC [B, 2, 256, 256]) return_dict: If True, return dict with keys quant, code_loss, tokens Returns: (quant, code_loss, tokens) or dict with those keys """ x = self.prepare_input(x) h = self.encoder(x) h = self.quant_proj(h) quant, code_loss, tokens = self.quantize(h) if return_dict: return {"quant": quant, "code_loss": code_loss, "tokens": tokens} return quant, code_loss, tokens def tokenize(self, x: torch.Tensor) -> torch.LongTensor: """Tokenize input to discrete indices [B, H_Q, W_Q].""" _, _, tokens = self.encode(x) return tokens def forward(self, sample: torch.Tensor) -> torch.LongTensor: """Forward pass (diffusers-style) returns token indices.""" return self.tokenize(sample) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], *, torch_dtype: Optional[torch.dtype] = None, device: Optional[Union[str, torch.device]] = None, **kwargs, ): """ Load from pretrained path or Hub (diffusers-style). Supports both model.safetensors (legacy) and diffusion_pytorch_model.safetensors. """ path = Path(pretrained_model_name_or_path) if not path.exists(): # Resolve from Hub via diffusers/huggingface_hub from huggingface_hub import snapshot_download path = Path(snapshot_download(str(pretrained_model_name_or_path))) config_path = path / "config.json" if not config_path.exists(): raise FileNotFoundError(f"config.json not found at {config_path}") # Load config (diffusers-style) config, unused = cls.load_config(str(path), return_unused_kwargs=True, **kwargs) config = {k: v for k, v in config.items() if k not in ("_class_name",)} model = cls.from_config(config, **unused) # Load weights: try standard diffusers name first, then model.safetensors for weights_name in ("diffusion_pytorch_model.safetensors", "model.safetensors"): weights_path = path / weights_name if weights_path.exists(): state_dict = load_file(str(weights_path)) model.load_state_dict(state_dict, strict=False) break else: raise FileNotFoundError( f"No weights found at {path}. Expected diffusion_pytorch_model.safetensors or model.safetensors" ) if torch_dtype is not None: model = model.to(torch_dtype) if device is not None: model = model.to(device) model.eval() return model