ropedia-xperience-10m-task-baselines / scripts /omni /merge_qwen3_omni_retrieval_task_probe_shards.py
| #!/usr/bin/env python3 | |
| """Merge Qwen3 retrieval-task probe shards into one result package.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any | |
| from eval_qwen3_omni_retrieval_task_probes import TASK_SPECS, score_task, write_json, write_jsonl | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--run-id", required=True) | |
| parser.add_argument("--output-dir", type=Path, required=True) | |
| parser.add_argument("--shard-dir", type=Path, nargs="+", required=True) | |
| return parser.parse_args() | |
| def read_jsonl(path: Path) -> list[dict[str, Any]]: | |
| rows: list[dict[str, Any]] = [] | |
| if not path.exists(): | |
| return rows | |
| with path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if line: | |
| rows.append(json.loads(line)) | |
| return rows | |
| def read_json(path: Path) -> dict[str, Any]: | |
| return json.loads(path.read_text(encoding="utf-8")) if path.exists() else {} | |
| def fake_args(run_id: str, first_metrics: dict[str, Any]) -> argparse.Namespace: | |
| return argparse.Namespace( | |
| run_id=run_id, | |
| model_id=first_metrics.get("model_id"), | |
| adapter_dir=Path(first_metrics.get("adapter_dir", "")), | |
| dataset_jsonl=Path(first_metrics.get("dataset_jsonl", "")), | |
| eval_split=first_metrics.get("eval_split", "test"), | |
| candidate_count=int(first_metrics.get("candidate_count", 4) or 4), | |
| future_frames=int(first_metrics.get("future_frames", 100) or 100), | |
| sample_offset=0, | |
| sample_stride=1, | |
| ) | |
| def main() -> int: | |
| args = parse_args() | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| task_metrics: dict[str, dict[str, Any]] = {} | |
| first_metrics: dict[str, Any] | None = None | |
| duplicate_predictions: list[dict[str, Any]] = [] | |
| for task_id, spec in TASK_SPECS.items(): | |
| rows_by_id: dict[str, dict[str, Any]] = {} | |
| row_sources: dict[str, str] = {} | |
| for shard_dir in args.shard_dir: | |
| for row in read_jsonl(shard_dir / task_id / "predictions.jsonl"): | |
| key = str(row.get("prediction_id") or f"{task_id}::{row.get('id')}") | |
| if key in rows_by_id: | |
| duplicate_predictions.append( | |
| { | |
| "task_id": task_id, | |
| "prediction_id": key, | |
| "kept_shard": row_sources.get(key), | |
| "duplicate_shard": str(shard_dir), | |
| "conflict": rows_by_id[key] != row, | |
| } | |
| ) | |
| continue | |
| rows_by_id[key] = row | |
| row_sources[key] = str(shard_dir) | |
| shard_metrics = read_json(shard_dir / task_id / "metrics.json") | |
| if shard_metrics and first_metrics is None: | |
| first_metrics = shard_metrics | |
| if not rows_by_id: | |
| continue | |
| ordered_rows = sorted( | |
| rows_by_id.values(), | |
| key=lambda row: (str(row.get("episode_id")), int(row.get("start_frame", 0)), str(row.get("id"))), | |
| ) | |
| task_dir = args.output_dir / task_id | |
| task_dir.mkdir(parents=True, exist_ok=True) | |
| write_jsonl(task_dir / "predictions.jsonl", ordered_rows) | |
| metrics = score_task(task_id, spec, ordered_rows, args.output_dir, fake_args(args.run_id, first_metrics or {})) | |
| task_metrics[task_id] = metrics | |
| for shard_dir in args.shard_dir: | |
| if (shard_dir / "progress.jsonl").exists(): | |
| shutil.copy2(shard_dir / "progress.jsonl", args.output_dir / f"{shard_dir.name}.progress.jsonl") | |
| summary = { | |
| "title": "Qwen3-Omni v6 Retrieval Task Probes", | |
| "status": "pass", | |
| "run_id": args.run_id, | |
| "shard_dirs": [str(path) for path in args.shard_dir], | |
| "duplicate_prediction_count": len(duplicate_predictions), | |
| "duplicate_prediction_conflict_count": sum(1 for row in duplicate_predictions if row["conflict"]), | |
| "duplicate_predictions": duplicate_predictions[:50], | |
| "tasks": { | |
| task_id: { | |
| "task_number": metrics["task_number"], | |
| "task_label": metrics["task_label"], | |
| "metric_key": metrics["metric_key"], | |
| "primary_score": metrics["primary_score"], | |
| "num_samples": metrics["num_samples"], | |
| "metrics_json": str(args.output_dir / task_id / "metrics.json"), | |
| } | |
| for task_id, metrics in task_metrics.items() | |
| }, | |
| } | |
| write_json(args.output_dir / "summary.json", summary) | |
| report = [ | |
| "# Qwen3-Omni v6 Retrieval Task Probes", | |
| "", | |
| f"- Run ID: `{args.run_id}`", | |
| f"- Shards: `{len(args.shard_dir)}`", | |
| "", | |
| "| Task | Metric | Score | Samples |", | |
| "| --- | --- | ---: | ---: |", | |
| ] | |
| for metrics in task_metrics.values(): | |
| report.append( | |
| f"| {metrics['task_label']} | {metrics['metric_key']} | {metrics['primary_score']:.6f} | {metrics['num_samples']} |" | |
| ) | |
| (args.output_dir / "RUN_REPORT.md").write_text("\n".join(report) + "\n", encoding="utf-8") | |
| print(json.dumps(summary, indent=2, sort_keys=True)) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |