Robotics
LeRobot
Safetensors
xvla
xvla-4dof-search-tracking / fix_dataset_and_train.sh
kavinrajkrupsurge's picture
Upload fix_dataset_and_train.sh with huggingface_hub
efad9c3 verified
Raw
History Blame Contribute Delete
5.39 kB
#!/bin/bash
# Complete Dataset Fix and Training Script
# Run: bash /workspace/fix_dataset_and_train.sh
set -e
echo "========================================="
echo "X-VLA Dataset Fix & Training Script"
echo "========================================="
# Step 1: Clear caches and output
echo ""
echo "Step 1: Clearing caches..."
rm -rf ~/.cache/huggingface/datasets/
rm -rf /workspace/outputs/xvla_4dof
echo "✅ Caches cleared"
# Step 2: Ensure all dataset metadata is correct
echo ""
echo "Step 2: Fixing dataset metadata..."
python3 << 'PYTHON_SCRIPT'
import json
import pandas as pd
from pathlib import Path
TASKS = [
'search for a person in the room by scanning the room and stop when you find',
'track the moving person and then stop when the person stops, start when the person starts moving'
]
cache_dir = Path("/root/.cache/huggingface/lerobot/kavinrajkrupsurge/xvla-4dof-tracking-dataset")
# 1. Fix tasks.parquet - task strings must be the INDEX
tasks_df = pd.DataFrame({"task_index": [0, 1]}, index=TASKS)
tasks_df.index.name = "task"
for path in [
Path("/workspace/lerobot_dataset/meta/tasks.parquet"),
cache_dir / "meta/tasks.parquet"
]:
path.parent.mkdir(parents=True, exist_ok=True)
tasks_df.to_parquet(path)
print("✅ Fixed tasks.parquet (task strings as index)")
# 2. Fix info.json
info_path = cache_dir / "meta/info.json"
with open(info_path) as f:
info = json.load(f)
info["features"]["task"] = {"dtype": "string", "shape": [1], "names": None}
with open(info_path, 'w') as f:
json.dump(info, f, indent=2)
print("✅ Fixed info.json")
# 3. Fix stats.json
stats_path = cache_dir / "meta/stats.json"
with open(stats_path) as f:
stats = json.load(f)
if "observation.images.laptop" in stats:
del stats["observation.images.laptop"]
with open(stats_path, 'w') as f:
json.dump(stats, f, indent=2)
print("✅ Fixed stats.json")
# 4. Ensure parquet files have task column
chunk_dir = cache_dir / "data/chunk-000"
TASK_MAP = {0: TASKS[0], 1: TASKS[1]}
for pq_file in sorted(chunk_dir.glob("*.parquet")):
df = pd.read_parquet(pq_file)
if 'task' not in df.columns:
df['task'] = df['task_index'].apply(lambda x: TASK_MAP[int(x)])
df.to_parquet(pq_file, index=False)
print(f" Added task to {pq_file.name}")
else:
print(f" ✅ {pq_file.name} has task column")
# 5. Fix episodes metadata
dataset_path = Path("/workspace/lerobot_dataset/dataset.json")
with open(dataset_path) as f:
dataset = json.load(f)
episodes_data = []
current_episode = 0
start_index = 0
for i, item in enumerate(dataset):
if item['episode_index'] != current_episode:
episodes_data.append({
'episode_index': current_episode,
'tasks': [TASKS[0]] if current_episode < 25 else [TASKS[1]],
'length': i - start_index,
'dataset_from_index': start_index,
'dataset_to_index': i
})
start_index = i
current_episode = item['episode_index']
episodes_data.append({
'episode_index': current_episode,
'tasks': [TASKS[0]] if current_episode < 25 else [TASKS[1]],
'length': len(dataset) - start_index,
'dataset_from_index': start_index,
'dataset_to_index': len(dataset)
})
df = pd.DataFrame(episodes_data)
output_dir = cache_dir / "meta/episodes/chunk-000"
output_dir.mkdir(parents=True, exist_ok=True)
df.to_parquet(output_dir / "file-000.parquet", index=False)
print(f"✅ Fixed episodes metadata ({len(df)} episodes)")
print("\n✅ All dataset fixes complete!")
PYTHON_SCRIPT
# Step 3: Start Training
echo ""
echo "========================================="
echo "Step 3: Starting X-VLA Training..."
echo "========================================="
echo "Model: 879M parameters"
echo "Dataset: 50 episodes, 9799 frames"
echo "Steps: 3000"
echo "========================================="
cd /workspace/lerobot
source /root/anaconda3/etc/profile.d/conda.sh
conda activate lerobot
python3 -m lerobot.scripts.lerobot_train \
--dataset.repo_id="kavinrajkrupsurge/xvla-4dof-tracking-dataset" \
--dataset.use_imagenet_stats=false \
--output_dir="/workspace/outputs/xvla_4dof" \
--job_name="xvla_4dof_search_tracking" \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="kavinrajkrupsurge/xvla-4dof-search-tracking" \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--policy.action_mode="auto" \
--policy.max_action_dim=20 \
--policy.use_proprio=true \
--policy.max_state_dim=4 \
--policy.len_soft_prompts=32 \
--policy.num_domains=30 \
--policy.chunk_size=32 \
--policy.n_action_steps=32 \
--optimizer.type="xvla-adamw" \
--optimizer.lr=5e-4 \
--optimizer.betas='[0.9,0.99]' \
--optimizer.weight_decay=0.01 \
--optimizer.grad_clip_norm=10.0 \
--optimizer.soft_prompt_lr_scale=1.0 \
--batch_size=16 \
--num_workers=4 \
--eval.n_episodes=5 \
--eval.batch_size=1 \
--save_freq=1000 \
--eval_freq=2000 \
--log_freq=50 \
--wandb.enable=false \
--seed=42 \
--rename_map='{"observation.images.laptop": "observation.images.image"}'
echo ""
echo "========================================="
echo "✅ Training completed!"
echo "========================================="