ropedia-xperience-10m-task-baselines / scripts /omni /export_cosmos3_camera_pose_targets.py
cy0307's picture
Publish Ropedia Xperience-10M task baseline cards
3a10443 verified
Raw
History Blame
10.2 kB
#!/usr/bin/env python3
"""Augment exported Xperience windows with Cosmos3 camera-pose action targets.
This does not invent robot-control labels. It converts frame-aligned SLAM poses
from `annotation.hdf5` into the Cosmos3-supported `camera_pose` action domain:
9D per-transition vectors with translation delta, rotation delta as a rotation
vector, and absolute displacement from the window start. The target is a
continuous egocentric-motion proxy suitable for a first Cosmos3 action-packer
smoke run; it is intentionally separate from the semantic JSON QA target.
"""
from __future__ import annotations
import argparse
import json
import math
from collections import Counter
from pathlib import Path
from typing import Any
import h5py
import numpy as np
from qwen3_omni_dataset_utils import load_jsonl, write_jsonl
RAW_ACTION_DIM = 9
DOMAIN_NAME = "camera_pose"
def parse_args() -> argparse.Namespace:
workspace_default = Path(__file__).resolve().parents[2]
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--dataset-jsonl", type=Path, required=True)
parser.add_argument("--output-jsonl", type=Path, required=True)
parser.add_argument("--output-manifest", type=Path, required=True)
parser.add_argument("--chunk-size", type=int, default=8)
parser.add_argument("--resolution-tier", type=int, default=480, choices=[256, 480, 704, 720])
parser.add_argument("--view-point", default="ego_view")
parser.add_argument("--max-records", type=int, default=0)
parser.add_argument("--strict", action="store_true")
return parser.parse_args()
def read_pose_cache(annotation_path: Path) -> dict[str, np.ndarray]:
with h5py.File(annotation_path, "r") as h5:
slam = h5["slam"]
trans = np.asarray(slam["trans_xyz"], dtype=np.float64)
quat = np.asarray(slam["quat_wxyz"], dtype=np.float64)
frame_numbers = np.asarray(h5["video"]["frame_number"], dtype=np.int64)
return {"trans": trans, "quat": normalize_quat_array(quat), "frame_numbers": frame_numbers}
def normalize_quat_array(quat: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(quat, axis=-1, keepdims=True)
norm[norm <= 1e-12] = 1.0
quat = quat / norm
# Keep quaternion sign continuous enough for simple deltas.
for idx in range(1, len(quat)):
if np.dot(quat[idx - 1], quat[idx]) < 0:
quat[idx] *= -1.0
return quat
def quat_inverse(q: np.ndarray) -> np.ndarray:
return np.asarray([q[0], -q[1], -q[2], -q[3]], dtype=np.float64) / max(float(np.dot(q, q)), 1e-12)
def quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray:
aw, ax, ay, az = a
bw, bx, by, bz = b
return np.asarray(
[
aw * bw - ax * bx - ay * by - az * bz,
aw * bx + ax * bw + ay * bz - az * by,
aw * by - ax * bz + ay * bw + az * bx,
aw * bz + ax * by - ay * bx + az * bw,
],
dtype=np.float64,
)
def quat_to_rotvec(q: np.ndarray) -> np.ndarray:
q = q / max(float(np.linalg.norm(q)), 1e-12)
if q[0] < 0:
q = -q
w = float(np.clip(q[0], -1.0, 1.0))
xyz = q[1:]
sin_half = float(np.linalg.norm(xyz))
if sin_half < 1e-8:
return 2.0 * xyz
angle = 2.0 * math.atan2(sin_half, w)
if angle > math.pi:
angle -= 2.0 * math.pi
return xyz / sin_half * angle
def nearest_index(frame_numbers: np.ndarray, frame: int) -> int:
if frame <= int(frame_numbers[0]):
return 0
if frame >= int(frame_numbers[-1]):
return len(frame_numbers) - 1
return int(np.searchsorted(frame_numbers, frame, side="left"))
def sampled_frame_pairs(start_frame: int, end_frame: int, chunk_size: int) -> list[tuple[int, int]]:
if chunk_size < 1:
raise ValueError("chunk_size must be >= 1")
if end_frame <= start_frame:
end_frame = start_frame + chunk_size
points = np.linspace(start_frame, end_frame, chunk_size + 1)
frames = [int(round(value)) for value in points]
pairs: list[tuple[int, int]] = []
for left, right in zip(frames[:-1], frames[1:]):
if right <= left:
right = left + 1
pairs.append((left, right))
return pairs
def camera_pose_actions(pose: dict[str, np.ndarray], start_frame: int, end_frame: int, chunk_size: int) -> list[list[float]]:
trans = pose["trans"]
quat = pose["quat"]
frame_numbers = pose["frame_numbers"]
start_idx = nearest_index(frame_numbers, start_frame)
origin = trans[start_idx]
rows: list[list[float]] = []
for left_frame, right_frame in sampled_frame_pairs(start_frame, end_frame, chunk_size):
li = nearest_index(frame_numbers, left_frame)
ri = nearest_index(frame_numbers, right_frame)
delta_t = trans[ri] - trans[li]
delta_q = quat_multiply(quat[ri], quat_inverse(quat[li]))
delta_r = quat_to_rotvec(delta_q)
displacement = trans[ri] - origin
row = np.concatenate([delta_t, delta_r, displacement]).astype(np.float32)
if row.shape[0] != RAW_ACTION_DIM:
raise AssertionError(row.shape)
rows.append([float(value) for value in row])
return rows
def media_condition(row: dict[str, Any]) -> dict[str, Any]:
media = row.get("media") if isinstance(row.get("media"), dict) else {}
return {
"mosaic_video_path": media.get("mosaic_video_path"),
"video_paths": media.get("video_paths") if isinstance(media.get("video_paths"), list) else [],
"context_start_frame": media.get("context_start_frame"),
"context_end_frame": media.get("context_end_frame"),
}
def augment_rows(rows: list[dict[str, Any]], args: argparse.Namespace) -> tuple[list[dict[str, Any]], dict[str, Any]]:
pose_cache: dict[str, dict[str, np.ndarray]] = {}
counters = Counter()
issues: list[dict[str, Any]] = []
augmented: list[dict[str, Any]] = []
selected = rows[: args.max_records] if args.max_records > 0 else rows
for idx, row in enumerate(selected):
counters["rows_seen"] += 1
episode_path_raw = row.get("episode_path")
window = row.get("center_window") if isinstance(row.get("center_window"), dict) else {}
if not episode_path_raw or "start_frame" not in window or "end_frame" not in window:
counters["rows_skipped_missing_source_fields"] += 1
issues.append({"row_index": idx, "id": row.get("id"), "reason": "missing episode_path or center_window"})
if args.strict:
raise ValueError(issues[-1])
continue
annotation_path = Path(str(episode_path_raw)) / "annotation.hdf5"
if not annotation_path.exists():
counters["rows_skipped_missing_annotation"] += 1
issues.append({"row_index": idx, "id": row.get("id"), "reason": f"missing {annotation_path}"})
if args.strict:
raise FileNotFoundError(annotation_path)
continue
key = str(annotation_path)
if key not in pose_cache:
pose_cache[key] = read_pose_cache(annotation_path)
start_frame = int(window["start_frame"])
end_frame = int(window["end_frame"])
try:
raw_actions = camera_pose_actions(pose_cache[key], start_frame, end_frame, args.chunk_size)
except Exception as exc:
counters["rows_skipped_action_build_error"] += 1
issues.append({"row_index": idx, "id": row.get("id"), "reason": repr(exc)})
if args.strict:
raise
continue
copied = dict(row)
copied["cosmos_action_target"] = {
"mode": "forward_dynamics",
"domain_name": DOMAIN_NAME,
"chunk_size": args.chunk_size,
"raw_action_dim": RAW_ACTION_DIM,
"raw_actions": raw_actions,
"resolution_tier": args.resolution_tier,
"view_point": args.view_point,
"source": {
"kind": "slam_camera_pose_delta_proxy_v1",
"annotation_hdf5": str(annotation_path),
"frame_range": {"start_frame": start_frame, "end_frame": end_frame},
"fields": [
"slam/trans_xyz delta",
"slam/quat_wxyz delta as rotation vector",
"slam/trans_xyz displacement from window start",
],
"units": "translation in annotation coordinate units; rotation in radians",
},
"conditioning": media_condition(row),
}
augmented.append(copied)
counters["rows_augmented"] += 1
manifest = {
"status": "pass" if counters["rows_augmented"] else "fail",
"input_dataset_jsonl": str(args.dataset_jsonl),
"output_jsonl": str(args.output_jsonl),
"domain_name": DOMAIN_NAME,
"raw_action_dim": RAW_ACTION_DIM,
"chunk_size": args.chunk_size,
"resolution_tier": args.resolution_tier,
"view_point": args.view_point,
"target_kind": "slam_camera_pose_delta_proxy_v1",
"counts": dict(counters),
"episode_annotation_files_read": len(pose_cache),
"issues": issues[:100],
"limitations": [
"This is an egocentric camera-motion proxy, not a robot gripper or human hand-control action.",
"Use it for Cosmos3 action-packer and one-episode overfit smoke tests before claiming model-quality improvement.",
"Fit any normalization on train episodes only before a full publishable Cosmos adapter run.",
],
}
return augmented, manifest
def main() -> int:
args = parse_args()
rows = load_jsonl(args.dataset_jsonl)
augmented, manifest = augment_rows(rows, args)
args.output_jsonl.parent.mkdir(parents=True, exist_ok=True)
args.output_manifest.parent.mkdir(parents=True, exist_ok=True)
write_jsonl(args.output_jsonl, augmented)
args.output_manifest.write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
print(json.dumps(manifest, indent=2, ensure_ascii=False))
return 0 if manifest["status"] == "pass" else 1
if __name__ == "__main__":
raise SystemExit(main())