import sys import torch from safetensors.torch import load_file, save_file # Key mapping rules: (source_substring, target_substring) KEY_REMAP = [ (".qkv.", ".attn.qkv."), (".proj.", ".attn.proj."), (".mlp.0.", ".mlp.fc1."), (".mlp.3.", ".mlp.fc2."), ] def remap_key(key: str) -> str: key = key.replace("tag_head.", "head.") for src, dst in KEY_REMAP: if src in key: return key.replace(src, dst) return key def convert_to_timm_bgr( input_path: str, output_path: str, bgr_key: str = "patch_embed.weight", ) -> None: """Convert a safetensors checkpoint to timm layout with BGR channel reorder.""" weights = load_file(input_path) print(f"Converting: {input_path}") converted: dict[str, torch.Tensor] = {} for key, tensor in weights.items(): new_key = remap_key(key) if bgr_key in key: # patch_embed → patch_embed.proj + RGB→BGR channel reorder new_key = new_key.replace("patch_embed.weight", "patch_embed.proj.weight") tensor = tensor[:, [2, 1, 0], :, :] elif "patch_embed.bias" in key: new_key = new_key.replace("patch_embed.bias", "patch_embed.proj.bias") converted[new_key] = tensor save_file(converted, output_path) print(f"Saved: {output_path}") if __name__ == "__main__": if len(sys.argv) != 3: print(f"Usage: python {sys.argv[0]} ", file=sys.stderr) sys.exit(1) src, dst = sys.argv[1], sys.argv[2] convert_to_timm_bgr(src, dst)