Agents-A1-GGUF / scripts /build_agents_a1_mtp_snapshot.py
LordNeel's picture
Add Q4 MTP GGUF variants
7d16500 verified
Raw
History Blame Contribute Delete
4.95 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import shutil
from pathlib import Path
def link_or_copy(src: Path, dst: Path) -> str:
if dst.exists() or dst.is_symlink():
dst.unlink()
try:
os.link(src, dst)
return "hardlink"
except OSError:
try:
dst.symlink_to(src)
return "symlink"
except OSError:
shutil.copy2(src, dst)
return "copy"
def main() -> None:
parser = argparse.ArgumentParser(
description="Build an Agents-A1 HF snapshot with the MTPLX MTP sidecar indexed."
)
parser.add_argument("--source", type=Path, required=True)
parser.add_argument("--donor-index", type=Path, required=True)
parser.add_argument("--mtp-sidecar", type=Path, required=True)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--donor-repo", default="wang-yang/Agents-A1-MTPLX-Q4")
args = parser.parse_args()
source = args.source.resolve()
output = args.output.resolve()
donor_index_path = args.donor_index.resolve()
mtp_sidecar = args.mtp_sidecar.resolve()
if not (source / "config.json").exists():
raise FileNotFoundError(source / "config.json")
if not (source / "model.safetensors.index.json").exists():
raise FileNotFoundError(source / "model.safetensors.index.json")
if not donor_index_path.exists():
raise FileNotFoundError(donor_index_path)
if not mtp_sidecar.exists():
raise FileNotFoundError(mtp_sidecar)
output.mkdir(parents=True, exist_ok=True)
linked: dict[str, str] = {}
skip_names = {"config.json", "model.safetensors.index.json", "mtp.safetensors"}
for item in sorted(source.iterdir()):
if item.name in skip_names:
continue
dst = output / item.name
if item.is_dir():
if dst.exists():
shutil.rmtree(dst)
shutil.copytree(item, dst, symlinks=True)
linked[item.name] = "copytree"
elif item.is_file() or item.is_symlink():
linked[item.name] = link_or_copy(item.resolve(), dst)
mtp_link_method = link_or_copy(mtp_sidecar, output / "mtp.safetensors")
config = json.loads((source / "config.json").read_text())
config["mtp_num_hidden_layers"] = 1
text_config = config.setdefault("text_config", {})
text_config["mtp_num_hidden_layers"] = 1
config["_mtp_graft_provenance"] = {
"donor_repo": args.donor_repo,
"sidecar_file": "mtp.safetensors",
"method": "HF safetensors sidecar indexed before llama.cpp Qwen3.5-MoE conversion",
}
(output / "config.json").write_text(json.dumps(config, indent=2, sort_keys=True) + "\n")
base_index = json.loads((source / "model.safetensors.index.json").read_text())
donor_index = json.loads(donor_index_path.read_text())
base_weight_map = dict(base_index.get("weight_map", {}))
donor_weight_map = donor_index.get("weight_map", {})
mtp_keys = sorted(k for k in donor_weight_map if k.startswith("mtp."))
if not mtp_keys:
raise ValueError(f"No mtp.* keys found in donor index {donor_index_path}")
overlapping = sorted(k for k in mtp_keys if k in base_weight_map)
if overlapping:
raise ValueError(f"MTP keys already present in base index, first overlap: {overlapping[:5]}")
for key in mtp_keys:
base_weight_map[key] = "mtp.safetensors"
metadata = dict(base_index.get("metadata", {}))
total_size = metadata.get("total_size")
if isinstance(total_size, int):
metadata["total_size"] = total_size + mtp_sidecar.stat().st_size
metadata["mtp_sidecar_total_size"] = mtp_sidecar.stat().st_size
metadata["mtp_sidecar_keys"] = len(mtp_keys)
metadata["mtp_sidecar_repo"] = args.donor_repo
merged_index = {
"metadata": metadata,
"weight_map": dict(sorted(base_weight_map.items())),
}
(output / "model.safetensors.index.json").write_text(
json.dumps(merged_index, indent=2, sort_keys=True) + "\n"
)
report = {
"source": str(source),
"output": str(output),
"donor_repo": args.donor_repo,
"donor_index": str(donor_index_path),
"mtp_sidecar": str(mtp_sidecar),
"mtp_sidecar_link_method": mtp_link_method,
"source_entries_linked": linked,
"base_weight_count": len(base_index.get("weight_map", {})),
"mtp_weight_count": len(mtp_keys),
"merged_weight_count": len(base_weight_map),
"config_mtp_num_hidden_layers": config.get("mtp_num_hidden_layers"),
"text_config_mtp_num_hidden_layers": text_config.get("mtp_num_hidden_layers"),
}
(output / "mtp_snapshot_report.json").write_text(
json.dumps(report, indent=2, sort_keys=True) + "\n"
)
print(json.dumps(report, indent=2, sort_keys=True))
if __name__ == "__main__":
main()