Instructions to use kavinrajkrupsurge/xvla-4dof-search-tracking with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use kavinrajkrupsurge/xvla-4dof-search-tracking with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| # Complete Dataset Fix and Training Script | |
| # Run: bash /workspace/fix_dataset_and_train.sh | |
| set -e | |
| echo "=========================================" | |
| echo "X-VLA Dataset Fix & Training Script" | |
| echo "=========================================" | |
| # Step 1: Clear caches and output | |
| echo "" | |
| echo "Step 1: Clearing caches..." | |
| rm -rf ~/.cache/huggingface/datasets/ | |
| rm -rf /workspace/outputs/xvla_4dof | |
| echo "✅ Caches cleared" | |
| # Step 2: Ensure all dataset metadata is correct | |
| echo "" | |
| echo "Step 2: Fixing dataset metadata..." | |
| python3 << 'PYTHON_SCRIPT' | |
| import json | |
| import pandas as pd | |
| from pathlib import Path | |
| TASKS = [ | |
| 'search for a person in the room by scanning the room and stop when you find', | |
| 'track the moving person and then stop when the person stops, start when the person starts moving' | |
| ] | |
| cache_dir = Path("/root/.cache/huggingface/lerobot/kavinrajkrupsurge/xvla-4dof-tracking-dataset") | |
| # 1. Fix tasks.parquet - task strings must be the INDEX | |
| tasks_df = pd.DataFrame({"task_index": [0, 1]}, index=TASKS) | |
| tasks_df.index.name = "task" | |
| for path in [ | |
| Path("/workspace/lerobot_dataset/meta/tasks.parquet"), | |
| cache_dir / "meta/tasks.parquet" | |
| ]: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| tasks_df.to_parquet(path) | |
| print("✅ Fixed tasks.parquet (task strings as index)") | |
| # 2. Fix info.json | |
| info_path = cache_dir / "meta/info.json" | |
| with open(info_path) as f: | |
| info = json.load(f) | |
| info["features"]["task"] = {"dtype": "string", "shape": [1], "names": None} | |
| with open(info_path, 'w') as f: | |
| json.dump(info, f, indent=2) | |
| print("✅ Fixed info.json") | |
| # 3. Fix stats.json | |
| stats_path = cache_dir / "meta/stats.json" | |
| with open(stats_path) as f: | |
| stats = json.load(f) | |
| if "observation.images.laptop" in stats: | |
| del stats["observation.images.laptop"] | |
| with open(stats_path, 'w') as f: | |
| json.dump(stats, f, indent=2) | |
| print("✅ Fixed stats.json") | |
| # 4. Ensure parquet files have task column | |
| chunk_dir = cache_dir / "data/chunk-000" | |
| TASK_MAP = {0: TASKS[0], 1: TASKS[1]} | |
| for pq_file in sorted(chunk_dir.glob("*.parquet")): | |
| df = pd.read_parquet(pq_file) | |
| if 'task' not in df.columns: | |
| df['task'] = df['task_index'].apply(lambda x: TASK_MAP[int(x)]) | |
| df.to_parquet(pq_file, index=False) | |
| print(f" Added task to {pq_file.name}") | |
| else: | |
| print(f" ✅ {pq_file.name} has task column") | |
| # 5. Fix episodes metadata | |
| dataset_path = Path("/workspace/lerobot_dataset/dataset.json") | |
| with open(dataset_path) as f: | |
| dataset = json.load(f) | |
| episodes_data = [] | |
| current_episode = 0 | |
| start_index = 0 | |
| for i, item in enumerate(dataset): | |
| if item['episode_index'] != current_episode: | |
| episodes_data.append({ | |
| 'episode_index': current_episode, | |
| 'tasks': [TASKS[0]] if current_episode < 25 else [TASKS[1]], | |
| 'length': i - start_index, | |
| 'dataset_from_index': start_index, | |
| 'dataset_to_index': i | |
| }) | |
| start_index = i | |
| current_episode = item['episode_index'] | |
| episodes_data.append({ | |
| 'episode_index': current_episode, | |
| 'tasks': [TASKS[0]] if current_episode < 25 else [TASKS[1]], | |
| 'length': len(dataset) - start_index, | |
| 'dataset_from_index': start_index, | |
| 'dataset_to_index': len(dataset) | |
| }) | |
| df = pd.DataFrame(episodes_data) | |
| output_dir = cache_dir / "meta/episodes/chunk-000" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| df.to_parquet(output_dir / "file-000.parquet", index=False) | |
| print(f"✅ Fixed episodes metadata ({len(df)} episodes)") | |
| print("\n✅ All dataset fixes complete!") | |
| PYTHON_SCRIPT | |
| # Step 3: Start Training | |
| echo "" | |
| echo "=========================================" | |
| echo "Step 3: Starting X-VLA Training..." | |
| echo "=========================================" | |
| echo "Model: 879M parameters" | |
| echo "Dataset: 50 episodes, 9799 frames" | |
| echo "Steps: 3000" | |
| echo "=========================================" | |
| cd /workspace/lerobot | |
| source /root/anaconda3/etc/profile.d/conda.sh | |
| conda activate lerobot | |
| python3 -m lerobot.scripts.lerobot_train \ | |
| --dataset.repo_id="kavinrajkrupsurge/xvla-4dof-tracking-dataset" \ | |
| --dataset.use_imagenet_stats=false \ | |
| --output_dir="/workspace/outputs/xvla_4dof" \ | |
| --job_name="xvla_4dof_search_tracking" \ | |
| --policy.path="lerobot/xvla-base" \ | |
| --policy.repo_id="kavinrajkrupsurge/xvla-4dof-search-tracking" \ | |
| --policy.dtype=bfloat16 \ | |
| --steps=3000 \ | |
| --policy.device=cuda \ | |
| --policy.freeze_vision_encoder=false \ | |
| --policy.freeze_language_encoder=false \ | |
| --policy.train_policy_transformer=true \ | |
| --policy.train_soft_prompts=true \ | |
| --policy.action_mode="auto" \ | |
| --policy.max_action_dim=20 \ | |
| --policy.use_proprio=true \ | |
| --policy.max_state_dim=4 \ | |
| --policy.len_soft_prompts=32 \ | |
| --policy.num_domains=30 \ | |
| --policy.chunk_size=32 \ | |
| --policy.n_action_steps=32 \ | |
| --optimizer.type="xvla-adamw" \ | |
| --optimizer.lr=5e-4 \ | |
| --optimizer.betas='[0.9,0.99]' \ | |
| --optimizer.weight_decay=0.01 \ | |
| --optimizer.grad_clip_norm=10.0 \ | |
| --optimizer.soft_prompt_lr_scale=1.0 \ | |
| --batch_size=16 \ | |
| --num_workers=4 \ | |
| --eval.n_episodes=5 \ | |
| --eval.batch_size=1 \ | |
| --save_freq=1000 \ | |
| --eval_freq=2000 \ | |
| --log_freq=50 \ | |
| --wandb.enable=false \ | |
| --seed=42 \ | |
| --rename_map='{"observation.images.laptop": "observation.images.image"}' | |
| echo "" | |
| echo "=========================================" | |
| echo "✅ Training completed!" | |
| echo "=========================================" | |