File size: 4,470 Bytes
e904452 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | #!/usr/bin/env python3
"""Merge Qwen3 future-task probe shards into one result package."""
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
from typing import Any
from eval_qwen3_omni_future_task_probes import TASK_SPECS, score_task, write_json, write_jsonl
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--run-id", required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--shard-dir", type=Path, nargs="+", required=True)
return parser.parse_args()
def read_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
if not path.exists():
return rows
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def read_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8")) if path.exists() else {}
def fake_args(run_id: str, first_metrics: dict[str, Any]) -> argparse.Namespace:
return argparse.Namespace(
run_id=run_id,
model_id=first_metrics.get("model_id"),
adapter_dir=Path(first_metrics.get("adapter_dir", "")),
dataset_jsonl=Path(first_metrics.get("dataset_jsonl", "")),
eval_split=first_metrics.get("eval_split", "test"),
future_frames=int(first_metrics.get("future_frames", 100) or 100),
sample_offset=0,
sample_stride=1,
)
def main() -> int:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
task_metrics: dict[str, dict[str, Any]] = {}
first_metrics: dict[str, Any] | None = None
for task_id, spec in TASK_SPECS.items():
rows_by_id: dict[str, dict[str, Any]] = {}
for shard_dir in args.shard_dir:
for row in read_jsonl(shard_dir / task_id / "predictions.jsonl"):
key = str(row.get("prediction_id") or f"{task_id}::{row.get('id')}")
rows_by_id.setdefault(key, row)
shard_metrics = read_json(shard_dir / task_id / "metrics.json")
if shard_metrics and first_metrics is None:
first_metrics = shard_metrics
if not rows_by_id:
continue
ordered_rows = sorted(rows_by_id.values(), key=lambda row: (str(row.get("episode_id")), int(row.get("start_frame", 0)), str(row.get("id"))))
task_dir = args.output_dir / task_id
task_dir.mkdir(parents=True, exist_ok=True)
write_jsonl(task_dir / "predictions.jsonl", ordered_rows)
metrics = score_task(task_id, spec, ordered_rows, args.output_dir, fake_args(args.run_id, first_metrics or {}))
task_metrics[task_id] = metrics
for shard_dir in args.shard_dir:
if (shard_dir / "progress.jsonl").exists():
shutil.copy2(shard_dir / "progress.jsonl", args.output_dir / f"{shard_dir.name}.progress.jsonl")
summary = {
"title": "Qwen3-Omni v6 Future Task Probes",
"status": "pass",
"run_id": args.run_id,
"shard_dirs": [str(path) for path in args.shard_dir],
"tasks": {
task_id: {
"task_number": metrics["task_number"],
"task_label": metrics["task_label"],
"metric_key": metrics["metric_key"],
"primary_score": metrics["primary_score"],
"num_samples": metrics["num_samples"],
"metrics_json": str(args.output_dir / task_id / "metrics.json"),
}
for task_id, metrics in task_metrics.items()
},
}
write_json(args.output_dir / "summary.json", summary)
report = [
"# Qwen3-Omni v6 Future Task Probes",
"",
f"- Run ID: `{args.run_id}`",
f"- Shards: `{len(args.shard_dir)}`",
"",
"| Task | Metric | Score | Samples |",
"| --- | --- | ---: | ---: |",
]
for metrics in task_metrics.values():
report.append(
f"| {metrics['task_label']} | {metrics['metric_key']} | {metrics['primary_score']:.6f} | {metrics['num_samples']} |"
)
(args.output_dir / "RUN_REPORT.md").write_text("\n".join(report) + "\n", encoding="utf-8")
print(json.dumps(summary, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())
|