ropedia-xperience-10m-task-baselines / scripts /omni /parallel_export_qwen3_omni_action_dataset.py
cy0307's picture
Update final Qwen model scripts
627e5d7 verified
Raw
History Blame
9.54 kB
#!/usr/bin/env python3
"""Parallel episode export for Qwen3-Omni train/validation datasets."""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import subprocess
import sys
import time
from collections import Counter
from pathlib import Path
from qwen3_omni_dataset_utils import build_messages, label_counts, load_jsonl, write_jsonl
def parse_args() -> argparse.Namespace:
workspace_default = Path(__file__).resolve().parents[2]
parser = argparse.ArgumentParser(description="Export Qwen3-Omni JSON-QA records with per-episode workers.")
parser.add_argument("--workspace", type=Path, default=workspace_default)
parser.add_argument("--manifest", type=Path, required=True)
parser.add_argument("--run-id", default="xperience10m_qwen3_parallel_export")
parser.add_argument("--output-dir", type=Path)
parser.add_argument("--cache-dir", type=Path, default=workspace_default / "outputs/omni_exploration/feature_cache")
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--max-windows-per-episode", type=int, default=32)
parser.add_argument("--max-video-frames", type=int, default=16)
parser.add_argument("--audio-source", default="fisheye_cam0")
parser.add_argument("--audio-sample-rate", type=int, default=16000)
parser.add_argument("--audio-band-count", type=int, default=16)
parser.add_argument("--render-media", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--force-rebuild-cache", action="store_true")
return parser.parse_args()
def shard_episodes(episodes: list[dict], workers: int) -> list[list[dict]]:
workers = max(1, min(workers, len(episodes)))
shards = [[] for _ in range(workers)]
for split in ("train", "val", "test", "unspecified"):
split_eps = [ep for ep in episodes if ep.get("split", "unspecified") == split]
for idx, episode in enumerate(split_eps):
shards[idx % workers].append(episode)
return [shard for shard in shards if shard]
def write_shard_manifest(base_payload: dict, episodes: list[dict], path: Path, shard_index: int) -> None:
split_counts = Counter(ep.get("split", "unspecified") for ep in episodes)
summary = dict(base_payload.get("summary", {}))
summary.update({
"parallel_shard_index": shard_index,
"num_episodes": len(episodes),
"split_counts": dict(split_counts),
})
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps({"summary": summary, "episodes": episodes}, indent=2), encoding="utf-8")
def run_shard(args: argparse.Namespace, shard_manifest: Path, shard_output: Path, shard_index: int) -> dict:
script = Path(__file__).with_name("export_qwen3_omni_action_dataset.py")
cmd = [
sys.executable,
str(script),
"--workspace",
str(args.workspace),
"--manifest",
str(shard_manifest),
"--run-id",
f"{args.run_id}_shard_{shard_index:02d}",
"--output-dir",
str(shard_output),
"--cache-dir",
str(args.cache_dir),
"--max-windows-per-episode",
str(args.max_windows_per_episode),
"--max-video-frames",
str(args.max_video_frames),
"--audio-source",
args.audio_source,
"--audio-sample-rate",
str(args.audio_sample_rate),
"--audio-band-count",
str(args.audio_band_count),
"--allow-empty",
]
if not args.render_media:
cmd.append("--no-render-media")
if args.force_rebuild_cache:
cmd.append("--force-rebuild-cache")
log_path = shard_output / "export.log"
shard_output.mkdir(parents=True, exist_ok=True)
started = time.time()
with log_path.open("w", encoding="utf-8") as log:
log.write(" ".join(cmd) + "\n")
log.flush()
subprocess.run(cmd, check=True, stdout=log, stderr=subprocess.STDOUT)
return {
"shard_index": shard_index,
"manifest": str(shard_manifest),
"output_dir": str(shard_output),
"dataset_jsonl": str(shard_output / "dataset.jsonl"),
"seconds": round(time.time() - started, 3),
}
def merge_shards(args: argparse.Namespace, shard_results: list[dict], output_dir: Path) -> dict:
records = []
shard_manifests = []
available_modalities = []
feature_manifests = []
skipped_episodes = []
for shard in sorted(shard_results, key=lambda row: row["shard_index"]):
shard_records = load_jsonl(Path(shard["dataset_jsonl"]))
for record in shard_records:
record["parallel_export_shard"] = shard["shard_index"]
records.extend(shard_records)
manifest_path = Path(shard["output_dir"]) / "dataset_manifest.json"
if manifest_path.exists():
payload = json.loads(manifest_path.read_text(encoding="utf-8"))
shard_manifests.append(payload)
available_modalities.extend(payload.get("available_modalities", []))
for skipped in payload.get("skipped_episodes", []):
skipped_episodes.append({"shard_index": shard["shard_index"], **skipped})
feature_manifests.append({
"shard_index": shard["shard_index"],
"feature_manifest": payload.get("feature_manifest", []),
})
action_options = sorted({record["answer_json"]["action"] for record in records if record["answer_json"]["action"] != "unknown"})
subtask_options = sorted({record["answer_json"]["subtask"] for record in records if record["answer_json"]["subtask"] != "unknown"})
for record in records:
record["action_options"] = action_options
record["subtask_options"] = subtask_options
record["label_options"] = action_options
record["messages"] = build_messages(record, action_options, include_answer=True)
dataset_path = output_dir / "dataset.jsonl"
write_jsonl(dataset_path, records)
dataset_manifest = {
"run_id": args.run_id,
"dataset_path": str(dataset_path),
"num_samples": len(records),
"num_episodes": len({record["episode_id"] for record in records}),
"split_counts": dict(Counter(record["split"] for record in records)),
"label_counts": label_counts(records),
"action_options": action_options,
"subtask_options": subtask_options,
"parallel_export": {
"num_workers": args.num_workers,
"shards": shard_results,
},
"clip_policy": {
"max_windows_per_episode": args.max_windows_per_episode,
"max_video_frames": args.max_video_frames,
"audio_span": "same_as_video_context",
"mosaic": "2x3 multi-camera grid",
},
"feature_manifest": feature_manifests,
"available_modalities": available_modalities,
"skipped_episodes": skipped_episodes,
"notes": [
"Shard media and sensor-feature paths remain in shard output directories.",
"Assistant answers are strict JSON for episode understanding, not robot-control policies.",
"Merged label options are recomputed globally across all shards.",
"Episodes with no labeled windows under the configured label rule are skipped and reported.",
],
}
(output_dir / "dataset_manifest.json").write_text(json.dumps(dataset_manifest, indent=2), encoding="utf-8")
return dataset_manifest
def main() -> int:
args = parse_args()
args.workspace = args.workspace.expanduser().resolve()
args.manifest = args.manifest.expanduser().resolve()
args.cache_dir = args.cache_dir.expanduser().resolve()
if args.output_dir is None:
args.output_dir = args.workspace / "results" / "omni_finetune" / args.run_id
args.output_dir = args.output_dir.expanduser().resolve()
args.output_dir.mkdir(parents=True, exist_ok=True)
payload = json.loads(args.manifest.read_text(encoding="utf-8"))
episodes = payload.get("episodes", [])
if not episodes:
raise ValueError(f"No episodes found in manifest: {args.manifest}")
shards = shard_episodes(episodes, args.num_workers)
shard_root = args.output_dir / "shards"
shard_jobs = []
for shard_index, shard in enumerate(shards):
shard_manifest = shard_root / f"manifest_shard_{shard_index:02d}.json"
shard_output = shard_root / f"shard_{shard_index:02d}"
write_shard_manifest(payload, shard, shard_manifest, shard_index)
shard_jobs.append((shard_manifest, shard_output, shard_index))
started = time.time()
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(shard_jobs)) as pool:
futures = [
pool.submit(run_shard, args, shard_manifest, shard_output, shard_index)
for shard_manifest, shard_output, shard_index in shard_jobs
]
for future in concurrent.futures.as_completed(futures):
result = future.result()
results.append(result)
print(json.dumps({"event": "shard_done", **result}, sort_keys=True), flush=True)
dataset_manifest = merge_shards(args, results, args.output_dir)
dataset_manifest["parallel_export"]["seconds"] = round(time.time() - started, 3)
(args.output_dir / "dataset_manifest.json").write_text(json.dumps(dataset_manifest, indent=2), encoding="utf-8")
print(json.dumps(dataset_manifest, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())