| |
| """Parallel episode export for Qwen3-Omni train/validation datasets.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import concurrent.futures |
| import json |
| import subprocess |
| import sys |
| import time |
| from collections import Counter |
| from pathlib import Path |
|
|
| from qwen3_omni_dataset_utils import build_messages, label_counts, load_jsonl, write_jsonl |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| workspace_default = Path(__file__).resolve().parents[2] |
| parser = argparse.ArgumentParser(description="Export Qwen3-Omni JSON-QA records with per-episode workers.") |
| parser.add_argument("--workspace", type=Path, default=workspace_default) |
| parser.add_argument("--manifest", type=Path, required=True) |
| parser.add_argument("--run-id", default="xperience10m_qwen3_parallel_export") |
| parser.add_argument("--output-dir", type=Path) |
| parser.add_argument("--cache-dir", type=Path, default=workspace_default / "outputs/omni_exploration/feature_cache") |
| parser.add_argument("--num-workers", type=int, default=8) |
| parser.add_argument("--max-windows-per-episode", type=int, default=32) |
| parser.add_argument("--max-video-frames", type=int, default=16) |
| parser.add_argument("--audio-source", default="fisheye_cam0") |
| parser.add_argument("--audio-sample-rate", type=int, default=16000) |
| parser.add_argument("--audio-band-count", type=int, default=16) |
| parser.add_argument("--render-media", action=argparse.BooleanOptionalAction, default=True) |
| parser.add_argument("--force-rebuild-cache", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def shard_episodes(episodes: list[dict], workers: int) -> list[list[dict]]: |
| workers = max(1, min(workers, len(episodes))) |
| shards = [[] for _ in range(workers)] |
| for split in ("train", "val", "test", "unspecified"): |
| split_eps = [ep for ep in episodes if ep.get("split", "unspecified") == split] |
| for idx, episode in enumerate(split_eps): |
| shards[idx % workers].append(episode) |
| return [shard for shard in shards if shard] |
|
|
|
|
| def write_shard_manifest(base_payload: dict, episodes: list[dict], path: Path, shard_index: int) -> None: |
| split_counts = Counter(ep.get("split", "unspecified") for ep in episodes) |
| summary = dict(base_payload.get("summary", {})) |
| summary.update({ |
| "parallel_shard_index": shard_index, |
| "num_episodes": len(episodes), |
| "split_counts": dict(split_counts), |
| }) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps({"summary": summary, "episodes": episodes}, indent=2), encoding="utf-8") |
|
|
|
|
| def run_shard(args: argparse.Namespace, shard_manifest: Path, shard_output: Path, shard_index: int) -> dict: |
| script = Path(__file__).with_name("export_qwen3_omni_action_dataset.py") |
| cmd = [ |
| sys.executable, |
| str(script), |
| "--workspace", |
| str(args.workspace), |
| "--manifest", |
| str(shard_manifest), |
| "--run-id", |
| f"{args.run_id}_shard_{shard_index:02d}", |
| "--output-dir", |
| str(shard_output), |
| "--cache-dir", |
| str(args.cache_dir), |
| "--max-windows-per-episode", |
| str(args.max_windows_per_episode), |
| "--max-video-frames", |
| str(args.max_video_frames), |
| "--audio-source", |
| args.audio_source, |
| "--audio-sample-rate", |
| str(args.audio_sample_rate), |
| "--audio-band-count", |
| str(args.audio_band_count), |
| "--allow-empty", |
| ] |
| if not args.render_media: |
| cmd.append("--no-render-media") |
| if args.force_rebuild_cache: |
| cmd.append("--force-rebuild-cache") |
|
|
| log_path = shard_output / "export.log" |
| shard_output.mkdir(parents=True, exist_ok=True) |
| started = time.time() |
| with log_path.open("w", encoding="utf-8") as log: |
| log.write(" ".join(cmd) + "\n") |
| log.flush() |
| subprocess.run(cmd, check=True, stdout=log, stderr=subprocess.STDOUT) |
| return { |
| "shard_index": shard_index, |
| "manifest": str(shard_manifest), |
| "output_dir": str(shard_output), |
| "dataset_jsonl": str(shard_output / "dataset.jsonl"), |
| "seconds": round(time.time() - started, 3), |
| } |
|
|
|
|
| def merge_shards(args: argparse.Namespace, shard_results: list[dict], output_dir: Path) -> dict: |
| records = [] |
| shard_manifests = [] |
| available_modalities = [] |
| feature_manifests = [] |
| skipped_episodes = [] |
| for shard in sorted(shard_results, key=lambda row: row["shard_index"]): |
| shard_records = load_jsonl(Path(shard["dataset_jsonl"])) |
| for record in shard_records: |
| record["parallel_export_shard"] = shard["shard_index"] |
| records.extend(shard_records) |
| manifest_path = Path(shard["output_dir"]) / "dataset_manifest.json" |
| if manifest_path.exists(): |
| payload = json.loads(manifest_path.read_text(encoding="utf-8")) |
| shard_manifests.append(payload) |
| available_modalities.extend(payload.get("available_modalities", [])) |
| for skipped in payload.get("skipped_episodes", []): |
| skipped_episodes.append({"shard_index": shard["shard_index"], **skipped}) |
| feature_manifests.append({ |
| "shard_index": shard["shard_index"], |
| "feature_manifest": payload.get("feature_manifest", []), |
| }) |
|
|
| action_options = sorted({record["answer_json"]["action"] for record in records if record["answer_json"]["action"] != "unknown"}) |
| subtask_options = sorted({record["answer_json"]["subtask"] for record in records if record["answer_json"]["subtask"] != "unknown"}) |
| for record in records: |
| record["action_options"] = action_options |
| record["subtask_options"] = subtask_options |
| record["label_options"] = action_options |
| record["messages"] = build_messages(record, action_options, include_answer=True) |
|
|
| dataset_path = output_dir / "dataset.jsonl" |
| write_jsonl(dataset_path, records) |
| dataset_manifest = { |
| "run_id": args.run_id, |
| "dataset_path": str(dataset_path), |
| "num_samples": len(records), |
| "num_episodes": len({record["episode_id"] for record in records}), |
| "split_counts": dict(Counter(record["split"] for record in records)), |
| "label_counts": label_counts(records), |
| "action_options": action_options, |
| "subtask_options": subtask_options, |
| "parallel_export": { |
| "num_workers": args.num_workers, |
| "shards": shard_results, |
| }, |
| "clip_policy": { |
| "max_windows_per_episode": args.max_windows_per_episode, |
| "max_video_frames": args.max_video_frames, |
| "audio_span": "same_as_video_context", |
| "mosaic": "2x3 multi-camera grid", |
| }, |
| "feature_manifest": feature_manifests, |
| "available_modalities": available_modalities, |
| "skipped_episodes": skipped_episodes, |
| "notes": [ |
| "Shard media and sensor-feature paths remain in shard output directories.", |
| "Assistant answers are strict JSON for episode understanding, not robot-control policies.", |
| "Merged label options are recomputed globally across all shards.", |
| "Episodes with no labeled windows under the configured label rule are skipped and reported.", |
| ], |
| } |
| (output_dir / "dataset_manifest.json").write_text(json.dumps(dataset_manifest, indent=2), encoding="utf-8") |
| return dataset_manifest |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| args.workspace = args.workspace.expanduser().resolve() |
| args.manifest = args.manifest.expanduser().resolve() |
| args.cache_dir = args.cache_dir.expanduser().resolve() |
| if args.output_dir is None: |
| args.output_dir = args.workspace / "results" / "omni_finetune" / args.run_id |
| args.output_dir = args.output_dir.expanduser().resolve() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| payload = json.loads(args.manifest.read_text(encoding="utf-8")) |
| episodes = payload.get("episodes", []) |
| if not episodes: |
| raise ValueError(f"No episodes found in manifest: {args.manifest}") |
|
|
| shards = shard_episodes(episodes, args.num_workers) |
| shard_root = args.output_dir / "shards" |
| shard_jobs = [] |
| for shard_index, shard in enumerate(shards): |
| shard_manifest = shard_root / f"manifest_shard_{shard_index:02d}.json" |
| shard_output = shard_root / f"shard_{shard_index:02d}" |
| write_shard_manifest(payload, shard, shard_manifest, shard_index) |
| shard_jobs.append((shard_manifest, shard_output, shard_index)) |
|
|
| started = time.time() |
| results = [] |
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(shard_jobs)) as pool: |
| futures = [ |
| pool.submit(run_shard, args, shard_manifest, shard_output, shard_index) |
| for shard_manifest, shard_output, shard_index in shard_jobs |
| ] |
| for future in concurrent.futures.as_completed(futures): |
| result = future.result() |
| results.append(result) |
| print(json.dumps({"event": "shard_done", **result}, sort_keys=True), flush=True) |
|
|
| dataset_manifest = merge_shards(args, results, args.output_dir) |
| dataset_manifest["parallel_export"]["seconds"] = round(time.time() - started, 3) |
| (args.output_dir / "dataset_manifest.json").write_text(json.dumps(dataset_manifest, indent=2), encoding="utf-8") |
| print(json.dumps(dataset_manifest, indent=2)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|