#!/usr/bin/env python3 """Evaluate Qwen3-Omni on target-backed retrieval probes. This runner covers model-friendly retrieval tasks whose targets can be formed from the staged 128-episode JSON export without inventing labels. It currently implements Task 08, language grounding, as text-query-to-video-window retrieval: the query is derived from the held-out window's action/subtask/object labels, and Qwen ranks shuffled candidate mosaic video windows. """ from __future__ import annotations import argparse import csv import hashlib import json import re import time from collections import OrderedDict from pathlib import Path from typing import Any import numpy as np import torch from eval_qwen3_omni_lora import load_model_processor, move_inputs from qwen3_omni_dataset_utils import has_empty_audio_items, is_empty_audio_exception, load_jsonl TASK_SPECS: OrderedDict[str, dict[str, Any]] = OrderedDict( [ ( "caption_grounding", { "task_number": 8, "label": "Language Grounding", "family": "retrieval", "metric_key": "caption_grounding_mrr", "prediction_key": "ranked_candidates", }, ), ( "cross_modal_retrieval", { "task_number": 9, "label": "Cross-Modal Retrieval", "family": "retrieval", "metric_key": "cross_modal_retrieval_mrr", "prediction_key": "ranked_candidates", }, ), ] ) MOTION_POSE_QUERY_BLOCKS: OrderedDict[str, tuple[int, int]] = OrderedDict( [ ("hand_left_joints", (0, 441)), ("hand_right_joints", (441, 882)), ("body_joints", (882, 1974)), ("body_contacts", (1974, 2121)), ("camera_translation", (2121, 2142)), ("camera_rotation_matrix", (2142, 2205)), ("imu_accel_gyro", (2205, 2247)), ] ) SYSTEM_PROMPT = ( "You are an embodied episode-understanding model for Ropedia/Xperience-10M. " "Return exactly one compact valid JSON object and no markdown, prose, code " "fences, explanations, or repeated text." ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--dataset-jsonl", type=Path, required=True) parser.add_argument("--run-id", default="qwen3_retrieval_task_probes") parser.add_argument("--output-dir", type=Path) parser.add_argument("--model-id", required=True) parser.add_argument("--adapter-dir", type=Path, required=True) parser.add_argument("--eval-split", default="test") parser.add_argument("--tasks", default="caption_grounding") parser.add_argument("--candidate-count", type=int, default=4) parser.add_argument("--sample-limit", type=int, default=0) parser.add_argument("--sample-offset", type=int, default=0) parser.add_argument("--sample-stride", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=64) parser.add_argument("--device-map", default="auto") parser.add_argument("--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16", "float32"]) parser.add_argument("--local-files-only", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--use-audio-in-video", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--resume", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--progress-jsonl", type=Path) return parser.parse_args() def write_json(path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: for row in rows: handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n") def append_jsonl(path: Path, row: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as handle: handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n") def write_csv(path: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore", lineterminator="\n") writer.writeheader() writer.writerows(rows) def read_jsonl_if_exists(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] rows: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: rows.append(json.loads(line)) except json.JSONDecodeError: continue return rows def normalize_text(value: Any) -> str: return " ".join(str(value or "").strip().strip("`'\". ").split()) def answer(sample: dict[str, Any]) -> dict[str, Any]: payload = sample.get("answer_json") return payload if isinstance(payload, dict) else {} def row_start(sample: dict[str, Any]) -> int: window = sample.get("center_window") if isinstance(sample.get("center_window"), dict) else {} return int(window.get("start_frame", 0) or 0) def row_end(sample: dict[str, Any]) -> int: window = sample.get("center_window") if isinstance(sample.get("center_window"), dict) else {} return int(window.get("end_frame", row_start(sample)) or row_start(sample)) def media_video_path(sample: dict[str, Any]) -> str | None: media = sample.get("media") if isinstance(sample.get("media"), dict) else {} return media.get("mosaic_video_path") or sample.get("primary_video_path") def parse_json_object(text: str) -> dict[str, Any]: raw = str(text or "").strip() if raw.startswith("```"): raw = raw.strip("`").strip() if raw.lower().startswith("json"): raw = raw[4:].strip() try: payload = json.loads(raw) except json.JSONDecodeError: start = raw.find("{") end = raw.rfind("}") if start < 0 or end <= start: return {} try: payload = json.loads(raw[start : end + 1]) except json.JSONDecodeError: return {} return payload if isinstance(payload, dict) else {} def select_tasks(spec: str) -> list[str]: if spec.strip().lower() == "all": return list(TASK_SPECS) tasks = [item.strip() for item in spec.split(",") if item.strip()] unknown = [task for task in tasks if task not in TASK_SPECS] if unknown: raise ValueError(f"unknown tasks: {unknown}") return tasks def select_eval_indices(samples: list[dict[str, Any]], args: argparse.Namespace) -> list[int]: indices = [ idx for idx, sample in enumerate(samples) if sample.get("split") == args.eval_split and media_video_path(sample) and answer(sample) ] if args.sample_stride < 1: raise ValueError("--sample-stride must be >= 1") if args.sample_offset < 0 or args.sample_offset >= args.sample_stride: raise ValueError("--sample-offset must satisfy 0 <= offset < stride") if args.sample_stride > 1: indices = [idx for local_idx, idx in enumerate(indices) if local_idx % args.sample_stride == args.sample_offset] if args.sample_limit > 0: indices = indices[: args.sample_limit] return indices def prediction_id(task_id: str, sample: dict[str, Any]) -> str: return f"{task_id}::{sample.get('id')}" def stable_score(*parts: Any) -> str: return hashlib.sha1("::".join(str(part) for part in parts).encode("utf-8")).hexdigest() def build_candidate_indices( samples: list[dict[str, Any]], eval_pool: list[int], sample_idx: int, task_id: str, candidate_count: int, ) -> list[int]: if candidate_count < 2 or candidate_count > 8: raise ValueError("--candidate-count must be between 2 and 8") sample = samples[sample_idx] true_action = normalize_text(answer(sample).get("action")).casefold() true_episode = sample.get("episode_id") negatives = [ idx for idx in eval_pool if idx != sample_idx and media_video_path(samples[idx]) and samples[idx].get("episode_id") != true_episode and normalize_text(answer(samples[idx]).get("action")).casefold() != true_action ] if len(negatives) < candidate_count - 1: negatives = [idx for idx in eval_pool if idx != sample_idx and media_video_path(samples[idx])] negatives.sort(key=lambda idx: stable_score(task_id, sample.get("id"), samples[idx].get("id"))) selected = [sample_idx] + negatives[: candidate_count - 1] selected.sort(key=lambda idx: stable_score(task_id, "order", sample.get("id"), samples[idx].get("id"))) return selected def query_text(sample: dict[str, Any]) -> str: payload = answer(sample) objects = payload.get("objects") if isinstance(payload.get("objects"), list) else [] object_text = ", ".join(normalize_text(item) for item in objects[:8] if normalize_text(item)) return "\n".join( [ f"Action: {normalize_text(payload.get('action')) or 'unknown'}", f"Procedure step: {normalize_text(payload.get('subtask')) or 'unknown'}", f"Relevant objects: {object_text or 'unknown'}", ] ) class SensorFeatureCache: def __init__(self) -> None: self._features_by_path: dict[str, np.ndarray] = {} def get(self, path_text: str, index: int) -> np.ndarray: path = str(path_text) if path not in self._features_by_path: data = np.load(path, allow_pickle=False) self._features_by_path[path] = np.asarray(data["features"], dtype=np.float32) features = self._features_by_path[path] if index < 0 or index >= features.shape[0]: raise IndexError(f"sensor feature index {index} out of range for {path}") return features[index] def has_sensor_feature(sample: dict[str, Any]) -> bool: return bool(sample.get("sensor_feature_path")) and sample.get("sensor_feature_index") is not None def summarize_vector_block(values: np.ndarray) -> dict[str, float]: finite = values[np.isfinite(values)] if finite.size == 0: return {"mean": 0.0, "std": 0.0, "mean_abs": 0.0, "l2": 0.0, "max_abs": 0.0} return { "mean": float(np.mean(finite)), "std": float(np.std(finite)), "mean_abs": float(np.mean(np.abs(finite))), "l2": float(np.linalg.norm(finite)), "max_abs": float(np.max(np.abs(finite))), } def sensor_query_text(sample: dict[str, Any], cache: SensorFeatureCache) -> str: vector = cache.get(str(sample.get("sensor_feature_path")), int(sample.get("sensor_feature_index"))) lines = [ "Sensor/motion query for the current 20-frame window.", "Only motion capture, body contact, camera pose, and IMU blocks are summarized.", "The target is the candidate depth/video window synchronized with this sensor window.", f"Window frames: {row_start(sample)}-{row_end(sample)}", ] for name, (start, end) in MOTION_POSE_QUERY_BLOCKS.items(): if end > vector.shape[0]: continue stats = summarize_vector_block(vector[start:end]) lines.append( ( f"{name}: mean={stats['mean']:.5g}, std={stats['std']:.5g}, " f"mean_abs={stats['mean_abs']:.5g}, l2={stats['l2']:.5g}, " f"max_abs={stats['max_abs']:.5g}" ) ) return "\n".join(lines) def build_messages( samples: list[dict[str, Any]], sample_idx: int, candidate_indices: list[int], task_id: str, spec: dict[str, Any], sensor_cache: SensorFeatureCache | None = None, ) -> tuple[list[dict[str, Any]], str, list[dict[str, Any]]]: letters = [chr(ord("A") + pos) for pos in range(len(candidate_indices))] true_letter = letters[candidate_indices.index(sample_idx)] candidate_records: list[dict[str, Any]] = [] if task_id == "cross_modal_retrieval": if sensor_cache is None: raise ValueError("cross_modal_retrieval requires a sensor feature cache") task_instruction = "Rank the candidate video windows by which one is synchronized with the sensor/motion query." query = sensor_query_text(samples[sample_idx], sensor_cache) query_header = "Sensor/motion query:" else: task_instruction = "Rank the candidate video windows by how well they match the text query." query = query_text(samples[sample_idx]) query_header = "Text query:" content: list[dict[str, Any]] = [ { "type": "text", "text": "\n".join( [ f"Task {spec['task_number']}: {spec['label']}", task_instruction, "Return JSON only with this schema:", '{"ranked_candidates":["","", "..."]}', "Use each candidate letter at most once.", "", query_header, query, ] ), } ] for letter, idx in zip(letters, candidate_indices): sample = samples[idx] candidate_records.append( { "letter": letter, "id": sample.get("id"), "episode_id": sample.get("episode_id"), "start_frame": row_start(sample), "end_frame": row_end(sample), "is_target": idx == sample_idx, } ) content.append({"type": "text", "text": f"Candidate {letter} video window:"}) content.append({"type": "video", "video": media_video_path(sample)}) return ( [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, {"role": "user", "content": content}, ], true_letter, candidate_records, ) def generate_messages(model, processor, messages: list[dict[str, Any]], args: argparse.Namespace) -> str: from qwen_omni_utils import process_mm_info for include_audio in (False,): text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) audios, images, videos = process_mm_info(messages, use_audio_in_video=args.use_audio_in_video) if include_audio and has_empty_audio_items(audios): continue try: inputs = processor( text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=args.use_audio_in_video, ) break except RuntimeError as exc: if include_audio and is_empty_audio_exception(exc): continue raise else: raise RuntimeError("Unable to prepare retrieval prompt.") inputs = move_inputs(inputs, model) with torch.no_grad(): generated = model.generate( **inputs, thinker_return_dict_in_generate=True, use_audio_in_video=args.use_audio_in_video, return_audio=False, max_new_tokens=args.max_new_tokens, ) text_ids = generated[0] if isinstance(generated, tuple) else generated sequences = text_ids.sequences if hasattr(text_ids, "sequences") else text_ids output_ids = sequences[:, inputs["input_ids"].shape[1] :] decoded = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return decoded[0] if decoded else "" def extract_ranking(raw: str, valid_letters: list[str]) -> list[str]: payload = parse_json_object(raw) value = payload.get("ranked_candidates") or payload.get("ranking") or payload.get("candidates") letters: list[str] = [] if isinstance(value, list): source = " ".join(str(item) for item in value) else: source = str(value or raw) valid = set(valid_letters) for match in re.findall(r"\b[A-H]\b", source.upper()): if match in valid and match not in letters: letters.append(match) for letter in valid_letters: if letter not in letters: letters.append(letter) return letters def score_retrieval(rows: list[dict[str, Any]]) -> dict[str, float]: reciprocal_ranks = [] top1 = 0 for row in rows: ranking = row.get("predicted_ranking") or [] true_letter = row.get("true_letter") rank = ranking.index(true_letter) + 1 if true_letter in ranking else len(ranking) + 1 reciprocal_ranks.append(1.0 / rank) top1 += int(bool(ranking) and ranking[0] == true_letter) mrr = sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0.0 return { "num_samples": len(rows), "mrr": mrr, "caption_grounding_mrr": mrr, "cross_modal_retrieval_mrr": mrr, "top1_accuracy": top1 / len(rows) if rows else 0.0, } def score_task(task_id: str, spec: dict[str, Any], rows: list[dict[str, Any]], output_dir: Path, args: argparse.Namespace) -> dict[str, Any]: task_dir = output_dir / task_id task_dir.mkdir(parents=True, exist_ok=True) write_jsonl(task_dir / "predictions.jsonl", rows) write_csv( task_dir / "predictions.csv", [ { "id": row["id"], "episode_id": row["episode_id"], "split": row["split"], "start_frame": row["start_frame"], "end_frame": row["end_frame"], "true_letter": row["true_letter"], "predicted_ranking": json.dumps(row["predicted_ranking"], ensure_ascii=False), "reciprocal_rank": row["reciprocal_rank"], "top1_correct": row["top1_correct"], "raw_prediction": row["raw_prediction"], } for row in rows ], [ "id", "episode_id", "split", "start_frame", "end_frame", "true_letter", "predicted_ranking", "reciprocal_rank", "top1_correct", "raw_prediction", ], ) metrics = score_retrieval(rows) primary_score = metrics[spec["metric_key"]] if task_id == "cross_modal_retrieval": score_policy = ( "GPU-backed Qwen3-Omni v6 sensor-to-video retrieval probe. The query is a compact " "summary of held-out motion-capture, body-contact, camera-pose, and IMU feature blocks; " "candidates are shuffled staged mosaic video windows, and the score is MRR of the " "synchronized true window. No action/subtask/object labels are included in the query." ) else: score_policy = ( "GPU-backed Qwen3-Omni v6 text-to-video retrieval probe. The text query is built " "from held-out action/subtask/object labels, candidates are shuffled staged mosaic " "video windows, and the score is MRR of the true window. This does not score tasks " "whose numeric/raw targets are absent from the export." ) metrics.update( { "title": f"Qwen3-Omni v6 {spec['label']}", "status": "pass", "run_id": args.run_id, "task_id": task_id, "task_number": spec["task_number"], "task_label": spec["label"], "metric_key": spec["metric_key"], "primary_metric": spec["metric_key"], "primary_score": primary_score, "model_id": args.model_id, "adapter_dir": str(args.adapter_dir), "dataset_jsonl": str(args.dataset_jsonl), "eval_split": args.eval_split, "candidate_count": args.candidate_count, "sample_offset": args.sample_offset, "sample_stride": args.sample_stride, "scope": "held_out_test_qwen3_retrieval_task_probe", "score_policy": score_policy, } ) write_json(task_dir / "metrics.json", metrics) return metrics def main() -> int: args = parse_args() if args.output_dir is None: args.output_dir = Path(__file__).resolve().parents[2] / "results/omni_finetune" / args.run_id args.output_dir.mkdir(parents=True, exist_ok=True) args.progress_jsonl = args.progress_jsonl or args.output_dir / "progress.jsonl" selected_tasks = select_tasks(args.tasks) samples = load_jsonl(args.dataset_jsonl) eval_pool = [idx for idx, sample in enumerate(samples) if sample.get("split") == args.eval_split and media_video_path(sample)] eval_indices = select_eval_indices(samples, args) if "cross_modal_retrieval" in selected_tasks: eval_indices = [idx for idx in eval_indices if has_sensor_feature(samples[idx])] eval_pool = [idx for idx in eval_pool if has_sensor_feature(samples[idx])] if not eval_indices: raise ValueError("No evaluation samples with retrieval candidates selected.") append_jsonl( args.progress_jsonl, { "event": "eval_start", "timestamp": time.time(), "run_id": args.run_id, "tasks": selected_tasks, "num_eval_samples": len(eval_indices), "sample_offset": args.sample_offset, "sample_stride": args.sample_stride, "candidate_count": args.candidate_count, }, ) model, processor = load_model_processor(args) sensor_cache = SensorFeatureCache() if "cross_modal_retrieval" in selected_tasks else None partial_by_task = { task_id: { row.get("prediction_id"): row for row in read_jsonl_if_exists(args.output_dir / task_id / "predictions.partial.jsonl") if row.get("prediction_id") } for task_id in selected_tasks } for task_id in selected_tasks: spec = TASK_SPECS[task_id] partial_path = args.output_dir / task_id / "predictions.partial.jsonl" for local_pos, sample_idx in enumerate(eval_indices, start=1): sample = samples[sample_idx] pred_id = prediction_id(task_id, sample) if pred_id in partial_by_task[task_id]: continue started = time.time() candidate_indices = build_candidate_indices(samples, eval_pool, sample_idx, task_id, args.candidate_count) messages, true_letter, candidate_records = build_messages( samples, sample_idx, candidate_indices, task_id, spec, sensor_cache=sensor_cache, ) raw = generate_messages(model, processor, messages, args) valid_letters = [record["letter"] for record in candidate_records] ranking = extract_ranking(raw, valid_letters) rank = ranking.index(true_letter) + 1 if true_letter in ranking else len(ranking) + 1 row = { "prediction_id": pred_id, "id": sample.get("id"), "task_id": task_id, "task_label": spec["label"], "split": sample.get("split"), "episode_id": sample.get("episode_id"), "start_frame": row_start(sample), "end_frame": row_end(sample), "query_text": sensor_query_text(sample, sensor_cache) if task_id == "cross_modal_retrieval" else query_text(sample), "candidates": candidate_records, "true_letter": true_letter, "predicted_ranking": ranking, "reciprocal_rank": 1.0 / rank, "top1_correct": int(bool(ranking) and ranking[0] == true_letter), "raw_prediction": raw, } partial_by_task[task_id][pred_id] = row append_jsonl(partial_path, row) append_jsonl( args.progress_jsonl, { "event": "sample_done", "timestamp": time.time(), "task_id": task_id, "sample_index": local_pos, "num_eval_samples": len(eval_indices), "completed_samples_for_task": len(partial_by_task[task_id]), "sample_id": sample.get("id"), "seconds": round(time.time() - started, 3), }, ) task_metrics = {} for task_id in selected_tasks: rows = [partial_by_task[task_id][prediction_id(task_id, samples[idx])] for idx in eval_indices] task_metrics[task_id] = score_task(task_id, TASK_SPECS[task_id], rows, args.output_dir, args) summary = { "title": "Qwen3-Omni v6 Retrieval Task Probes", "status": "pass", "run_id": args.run_id, "model_id": args.model_id, "adapter_dir": str(args.adapter_dir), "dataset_jsonl": str(args.dataset_jsonl), "eval_split": args.eval_split, "candidate_count": args.candidate_count, "sample_offset": args.sample_offset, "sample_stride": args.sample_stride, "tasks": { task_id: { "task_number": metrics["task_number"], "task_label": metrics["task_label"], "metric_key": metrics["metric_key"], "primary_score": metrics["primary_score"], "num_samples": metrics["num_samples"], "metrics_json": str(args.output_dir / task_id / "metrics.json"), } for task_id, metrics in task_metrics.items() }, } write_json(args.output_dir / "summary.json", summary) report_lines = [ "# Qwen3-Omni v6 Retrieval Task Probes", "", f"- Run ID: `{args.run_id}`", f"- Dataset: `{args.dataset_jsonl}`", f"- Candidate count: `{args.candidate_count}`", f"- Shard: offset `{args.sample_offset}` / stride `{args.sample_stride}`", "", "| Task | Metric | Score | Samples |", "| --- | --- | ---: | ---: |", ] for task_id, metrics in task_metrics.items(): report_lines.append( f"| {metrics['task_label']} | {metrics['metric_key']} | {metrics['primary_score']:.6f} | {metrics['num_samples']} |" ) (args.output_dir / "RUN_REPORT.md").write_text("\n".join(report_lines) + "\n", encoding="utf-8") append_jsonl(args.progress_jsonl, {"event": "eval_complete", "timestamp": time.time(), "run_id": args.run_id}) print(json.dumps(summary, indent=2, sort_keys=True)) return 0 if __name__ == "__main__": raise SystemExit(main())