#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import shutil from pathlib import Path def link_or_copy(src: Path, dst: Path) -> str: if dst.exists() or dst.is_symlink(): dst.unlink() try: os.link(src, dst) return "hardlink" except OSError: try: dst.symlink_to(src) return "symlink" except OSError: shutil.copy2(src, dst) return "copy" def main() -> None: parser = argparse.ArgumentParser( description="Build an Agents-A1 HF snapshot with the MTPLX MTP sidecar indexed." ) parser.add_argument("--source", type=Path, required=True) parser.add_argument("--donor-index", type=Path, required=True) parser.add_argument("--mtp-sidecar", type=Path, required=True) parser.add_argument("--output", type=Path, required=True) parser.add_argument("--donor-repo", default="wang-yang/Agents-A1-MTPLX-Q4") args = parser.parse_args() source = args.source.resolve() output = args.output.resolve() donor_index_path = args.donor_index.resolve() mtp_sidecar = args.mtp_sidecar.resolve() if not (source / "config.json").exists(): raise FileNotFoundError(source / "config.json") if not (source / "model.safetensors.index.json").exists(): raise FileNotFoundError(source / "model.safetensors.index.json") if not donor_index_path.exists(): raise FileNotFoundError(donor_index_path) if not mtp_sidecar.exists(): raise FileNotFoundError(mtp_sidecar) output.mkdir(parents=True, exist_ok=True) linked: dict[str, str] = {} skip_names = {"config.json", "model.safetensors.index.json", "mtp.safetensors"} for item in sorted(source.iterdir()): if item.name in skip_names: continue dst = output / item.name if item.is_dir(): if dst.exists(): shutil.rmtree(dst) shutil.copytree(item, dst, symlinks=True) linked[item.name] = "copytree" elif item.is_file() or item.is_symlink(): linked[item.name] = link_or_copy(item.resolve(), dst) mtp_link_method = link_or_copy(mtp_sidecar, output / "mtp.safetensors") config = json.loads((source / "config.json").read_text()) config["mtp_num_hidden_layers"] = 1 text_config = config.setdefault("text_config", {}) text_config["mtp_num_hidden_layers"] = 1 config["_mtp_graft_provenance"] = { "donor_repo": args.donor_repo, "sidecar_file": "mtp.safetensors", "method": "HF safetensors sidecar indexed before llama.cpp Qwen3.5-MoE conversion", } (output / "config.json").write_text(json.dumps(config, indent=2, sort_keys=True) + "\n") base_index = json.loads((source / "model.safetensors.index.json").read_text()) donor_index = json.loads(donor_index_path.read_text()) base_weight_map = dict(base_index.get("weight_map", {})) donor_weight_map = donor_index.get("weight_map", {}) mtp_keys = sorted(k for k in donor_weight_map if k.startswith("mtp.")) if not mtp_keys: raise ValueError(f"No mtp.* keys found in donor index {donor_index_path}") overlapping = sorted(k for k in mtp_keys if k in base_weight_map) if overlapping: raise ValueError(f"MTP keys already present in base index, first overlap: {overlapping[:5]}") for key in mtp_keys: base_weight_map[key] = "mtp.safetensors" metadata = dict(base_index.get("metadata", {})) total_size = metadata.get("total_size") if isinstance(total_size, int): metadata["total_size"] = total_size + mtp_sidecar.stat().st_size metadata["mtp_sidecar_total_size"] = mtp_sidecar.stat().st_size metadata["mtp_sidecar_keys"] = len(mtp_keys) metadata["mtp_sidecar_repo"] = args.donor_repo merged_index = { "metadata": metadata, "weight_map": dict(sorted(base_weight_map.items())), } (output / "model.safetensors.index.json").write_text( json.dumps(merged_index, indent=2, sort_keys=True) + "\n" ) report = { "source": str(source), "output": str(output), "donor_repo": args.donor_repo, "donor_index": str(donor_index_path), "mtp_sidecar": str(mtp_sidecar), "mtp_sidecar_link_method": mtp_link_method, "source_entries_linked": linked, "base_weight_count": len(base_index.get("weight_map", {})), "mtp_weight_count": len(mtp_keys), "merged_weight_count": len(base_weight_map), "config_mtp_num_hidden_layers": config.get("mtp_num_hidden_layers"), "text_config_mtp_num_hidden_layers": text_config.get("mtp_num_hidden_layers"), } (output / "mtp_snapshot_report.json").write_text( json.dumps(report, indent=2, sort_keys=True) + "\n" ) print(json.dumps(report, indent=2, sort_keys=True)) if __name__ == "__main__": main()