File size: 4,470 Bytes
e904452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
"""Merge Qwen3 future-task probe shards into one result package."""

from __future__ import annotations

import argparse
import json
import shutil
from pathlib import Path
from typing import Any

from eval_qwen3_omni_future_task_probes import TASK_SPECS, score_task, write_json, write_jsonl


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--run-id", required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--shard-dir", type=Path, nargs="+", required=True)
    return parser.parse_args()


def read_jsonl(path: Path) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    if not path.exists():
        return rows
    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def read_json(path: Path) -> dict[str, Any]:
    return json.loads(path.read_text(encoding="utf-8")) if path.exists() else {}


def fake_args(run_id: str, first_metrics: dict[str, Any]) -> argparse.Namespace:
    return argparse.Namespace(
        run_id=run_id,
        model_id=first_metrics.get("model_id"),
        adapter_dir=Path(first_metrics.get("adapter_dir", "")),
        dataset_jsonl=Path(first_metrics.get("dataset_jsonl", "")),
        eval_split=first_metrics.get("eval_split", "test"),
        future_frames=int(first_metrics.get("future_frames", 100) or 100),
        sample_offset=0,
        sample_stride=1,
    )


def main() -> int:
    args = parse_args()
    args.output_dir.mkdir(parents=True, exist_ok=True)
    task_metrics: dict[str, dict[str, Any]] = {}
    first_metrics: dict[str, Any] | None = None

    for task_id, spec in TASK_SPECS.items():
        rows_by_id: dict[str, dict[str, Any]] = {}
        for shard_dir in args.shard_dir:
            for row in read_jsonl(shard_dir / task_id / "predictions.jsonl"):
                key = str(row.get("prediction_id") or f"{task_id}::{row.get('id')}")
                rows_by_id.setdefault(key, row)
            shard_metrics = read_json(shard_dir / task_id / "metrics.json")
            if shard_metrics and first_metrics is None:
                first_metrics = shard_metrics
        if not rows_by_id:
            continue
        ordered_rows = sorted(rows_by_id.values(), key=lambda row: (str(row.get("episode_id")), int(row.get("start_frame", 0)), str(row.get("id"))))
        task_dir = args.output_dir / task_id
        task_dir.mkdir(parents=True, exist_ok=True)
        write_jsonl(task_dir / "predictions.jsonl", ordered_rows)
        metrics = score_task(task_id, spec, ordered_rows, args.output_dir, fake_args(args.run_id, first_metrics or {}))
        task_metrics[task_id] = metrics

    for shard_dir in args.shard_dir:
        if (shard_dir / "progress.jsonl").exists():
            shutil.copy2(shard_dir / "progress.jsonl", args.output_dir / f"{shard_dir.name}.progress.jsonl")

    summary = {
        "title": "Qwen3-Omni v6 Future Task Probes",
        "status": "pass",
        "run_id": args.run_id,
        "shard_dirs": [str(path) for path in args.shard_dir],
        "tasks": {
            task_id: {
                "task_number": metrics["task_number"],
                "task_label": metrics["task_label"],
                "metric_key": metrics["metric_key"],
                "primary_score": metrics["primary_score"],
                "num_samples": metrics["num_samples"],
                "metrics_json": str(args.output_dir / task_id / "metrics.json"),
            }
            for task_id, metrics in task_metrics.items()
        },
    }
    write_json(args.output_dir / "summary.json", summary)
    report = [
        "# Qwen3-Omni v6 Future Task Probes",
        "",
        f"- Run ID: `{args.run_id}`",
        f"- Shards: `{len(args.shard_dir)}`",
        "",
        "| Task | Metric | Score | Samples |",
        "| --- | --- | ---: | ---: |",
    ]
    for metrics in task_metrics.values():
        report.append(
            f"| {metrics['task_label']} | {metrics['metric_key']} | {metrics['primary_score']:.6f} | {metrics['num_samples']} |"
        )
    (args.output_dir / "RUN_REPORT.md").write_text("\n".join(report) + "\n", encoding="utf-8")
    print(json.dumps(summary, indent=2, sort_keys=True))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())