""" One-shot Colab training script. Run this single file in a Colab cell: !git clone https://github.com/mehular0ra/dental-aligner-claw4s.git %cd dental-aligner-claw4s !python run_colab.py Does everything: installs deps, stubs openenv, starts server, trains GRPO. """ import subprocess import sys import os import time # ============================================================ # Step 1: Install dependencies # ============================================================ print("=" * 60) print("STEP 1: Installing dependencies") print("=" * 60) packages = [ "unsloth", "trl==0.16.1", "wandb", "fastapi", "uvicorn", "pydantic", "scipy", "numpy", "Pillow", "matplotlib", "requests", "trimesh", ] for pkg in packages: subprocess.run( [ sys.executable, "-m", "pip", "install", "-q", "--no-deps" if pkg == "trl==0.16.1" else "", pkg, ], capture_output=True, ) print("Dependencies installed.") # ============================================================ # Step 2: Stub openenv-core # ============================================================ print("\n" + "=" * 60) print("STEP 2: Stubbing openenv-core") print("=" * 60) import types try: import openenv print("openenv-core found.") except ImportError: openenv = types.ModuleType("openenv") openenv.core = types.ModuleType("openenv.core") openenv.core.env_server = types.ModuleType("openenv.core.env_server") class _Base: pass imod = types.ModuleType("openenv.core.env_server.interfaces") imod.Environment = _Base tmod = types.ModuleType("openenv.core.env_server.types") tmod.Action = _Base tmod.Observation = _Base tmod.State = _Base openenv.core.env_server.interfaces = imod openenv.core.env_server.types = tmod def _create_app(env, **kw): from fastapi import FastAPI return FastAPI() openenv.core.env_server.create_fastapi_app = _create_app for k, v in { "openenv": openenv, "openenv.core": openenv.core, "openenv.core.env_server": openenv.core.env_server, "openenv.core.env_server.interfaces": imod, "openenv.core.env_server.types": tmod, }.items(): sys.modules[k] = v print("openenv-core stubbed.") # ============================================================ # Step 3: Start server (in-process, background thread) # ============================================================ print("\n" + "=" * 60) print("STEP 3: Starting dental aligner server") print("=" * 60) import threading import uvicorn def _run_server(): uvicorn.run("server.app:app", host="0.0.0.0", port=7860, log_level="warning") server_thread = threading.Thread(target=_run_server, daemon=True) server_thread.start() print("Server starting in background thread...") time.sleep(5) # Health check import json import math import urllib.request SERVER = "http://localhost:7860" def post(endpoint, data): req = urllib.request.Request( f"{SERVER}{endpoint}", data=json.dumps(data).encode(), headers={"Content-Type": "application/json"}, ) return json.loads(urllib.request.urlopen(req, timeout=30).read().decode(), strict=False) for attempt in range(10): try: resp = json.loads(urllib.request.urlopen(f"{SERVER}/health", timeout=5).read()) print(f"Server ready: {resp}") break except Exception as e: if attempt < 9: print(f" Attempt {attempt + 1}/10: waiting...") time.sleep(3) else: print(f"FATAL: Server failed to start after 10 attempts.") sys.exit(1) # ============================================================ # Step 4: Quick benchmark # ============================================================ print("\n" + "=" * 60) print("STEP 4: Running benchmark") print("=" * 60) from benchmarks import run_benchmarks results = run_benchmarks(quick=True) # ============================================================ # Step 5: Generate training prompts # ============================================================ print("\n" + "=" * 60) print("STEP 5: Generating training prompts") print("=" * 60) N_EPISODES = 20 TOOTH_IDS = [ 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27, 31, 32, 33, 34, 35, 36, 37, 41, 42, 43, 44, 45, 46, 47, ] def make_prompt(obs, stage=0): progress = obs.get("per_tooth_progress", [0] * 28) mean_p = sum(progress) / max(len(progress), 1) min_p = min(progress) if progress else 0 return f"""You are an orthodontic treatment planning AI. Plan aligner stage {stage + 1}/24. STATE: Stage {stage}/24, Progress: {mean_p:.0%}, Worst: {min_p:.0%}, Violations: {obs.get("cumulative_violations", 0)} RULES: Max 0.25mm/2deg per tooth per stage. Move incisors BEFORE molars. Output JSON: {{"strategy": "...", "tooth_groups": [{{"teeth": [...], "fraction": 0.0-1.0}}]}}""" difficulties = [ {"n_perturbed_teeth": 6, "translation_magnitude": 2.0}, {"n_perturbed_teeth": 12, "translation_magnitude": 3.5}, {"n_perturbed_teeth": 18, "translation_magnitude": 5.0}, {"n_perturbed_teeth": 10, "translation_magnitude": 3.0, "jitter_probability": 0.2}, ] prompts = [] for i in range(N_EPISODES): seed = i * 7 + 42 diff = difficulties[i % len(difficulties)] obs = post( "/reset_stepwise", { "task_id": "task_easy", "seed": seed, "source": "synthetic", "episode_id": f"prompt_{i}", "difficulty_params": diff, }, ) prompts.append({"prompt": make_prompt(obs), "seed": seed, "difficulty_params": diff}) print(f"Generated {len(prompts)} training prompts") # ============================================================ # Step 6: Define reward functions # ============================================================ print("\n" + "=" * 60) print("STEP 6: Setting up reward functions") print("=" * 60) def run_episode(completion_text, seed, difficulty_params): tooth_fractions = {} try: js = completion_text[completion_text.find("{") : completion_text.rfind("}") + 1] plan = json.loads(js) for g in plan.get("tooth_groups", []): for tid in g.get("teeth", []): tooth_fractions[tid] = float(g.get("fraction", 0.5)) except: pass eid = f"grpo_{seed}_{hash(completion_text) % 99999}" obs = post( "/reset_stepwise", { "task_id": "task_easy", "seed": seed, "source": "synthetic", "episode_id": eid, "difficulty_params": difficulty_params, }, ) init, tgt = obs["current_config"], obs["target_config"] final = None for s in range(1, 25): a_def = s / 25.0 poses = [] for i in range(28): f = tooth_fractions.get(TOOTH_IDS[i], a_def) f = max(0.0, min(1.0, f)) q = [init[i][j] * (1 - f) + tgt[i][j] * f for j in range(4)] qn = math.sqrt(sum(x * x for x in q)) q = [x / max(qn, 1e-10) for x in q] t = [init[i][4 + j] * (1 - f) + tgt[i][4 + j] * f for j in range(3)] poses.append(q + t) final = post("/step_stepwise", {"episode_id": eid, "poses": poses}) return final def reward_terminal(completions, **kwargs): rewards = [] for i, comp in enumerate(completions): try: text = comp[0]["content"] if isinstance(comp, list) else str(comp) idx = i % len(prompts) obs = run_episode(text, prompts[idx]["seed"], prompts[idx]["difficulty_params"]) rewards.append(float(obs.get("terminal_reward", 0.0))) except: rewards.append(0.0) return rewards def reward_occlusion(completions, **kwargs): rewards = [] for i, comp in enumerate(completions): try: text = comp[0]["content"] if isinstance(comp, list) else str(comp) idx = i % len(prompts) obs = run_episode(text, prompts[idx]["seed"], prompts[idx]["difficulty_params"]) rewards.append(float(obs.get("reward_breakdown", {}).get("occlusion_composite", 0.0))) except: rewards.append(0.0) return rewards # Quick test test = json.dumps({"strategy": "test", "tooth_groups": [{"teeth": [11, 21], "fraction": 0.5}]}) test_r = reward_terminal([test]) print(f"Test reward: {test_r}") # ============================================================ # Step 7: wandb (optional) # ============================================================ print("\n" + "=" * 60) print("STEP 7: wandb setup") print("=" * 60) USE_WANDB = False try: import wandb key = os.environ.get("WANDB_API_KEY", "") if key: wandb.login(key=key) USE_WANDB = True print("wandb logged in from env var.") else: os.environ["WANDB_DISABLED"] = "true" print("No WANDB_API_KEY set. Training without wandb.") except ImportError: os.environ["WANDB_DISABLED"] = "true" print("wandb not installed. Training without it.") # ============================================================ # Step 8: Load model with Unsloth # ============================================================ print("\n" + "=" * 60) print("STEP 8: Loading model with Unsloth") print("=" * 60) import torch print(f"CUDA: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") from unsloth import FastLanguageModel MODEL_NAME = "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=1024, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=16, lora_dropout=0, use_gradient_checkpointing="unsloth", ) print(f"Model loaded: {MODEL_NAME}") # ============================================================ # Step 9: GRPO Training # ============================================================ print("\n" + "=" * 60) print("STEP 9: GRPO TRAINING") print("=" * 60) from trl import GRPOConfig, GRPOTrainer train_dataset = [{"prompt": [{"role": "user", "content": p["prompt"]}]} for p in prompts] config = GRPOConfig( output_dir="./dental_grpo_output", learning_rate=5e-6, per_device_train_batch_size=2, num_generations=4, max_prompt_length=512, max_completion_length=512, num_train_epochs=1, save_steps=5, logging_steps=1, report_to="wandb" if USE_WANDB else "none", run_name="dental-grpo-unsloth" if USE_WANDB else None, bf16=True, gradient_accumulation_steps=2, ) trainer = GRPOTrainer( model=model, reward_funcs=[reward_terminal, reward_occlusion], args=config, train_dataset=train_dataset, processing_class=tokenizer, ) print( f"Training {len(train_dataset)} episodes, batch={config.per_device_train_batch_size}, gens={config.num_generations}" ) print("Starting...") trainer.train() # ============================================================ # Step 10: Save & Evaluate # ============================================================ print("\n" + "=" * 60) print("STEP 10: Saving model") print("=" * 60) model.save_pretrained("./dental_grpo_output/lora") tokenizer.save_pretrained("./dental_grpo_output/lora") print("LoRA adapters saved to ./dental_grpo_output/lora") # Evaluate FastLanguageModel.for_inference(model) test_obs = post( "/reset_stepwise", {"task_id": "task_easy", "seed": 999, "source": "synthetic", "episode_id": "eval_final"}, ) test_prompt = make_prompt(test_obs) inputs = tokenizer([test_prompt], return_tensors="pt").to("cuda") outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.7, do_sample=True) completion = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True) print(f"\nModel output:\n{completion[:400]}") result = run_episode(completion, 999, {"n_perturbed_teeth": 10, "translation_magnitude": 3.0}) print(f"\nTrained model reward: {result.get('terminal_reward', 0):.4f}") print(f"SLERP baseline: ~0.87") print(f"Occlusion: {result.get('reward_breakdown', {}).get('occlusion_composite', 0):.4f}") print("\n" + "=" * 60) print("DONE! Training complete.") if USE_WANDB: print(f"wandb: {wandb.run.get_url()}") print("=" * 60)