| import sys |
| import os |
| import time |
| import threading |
| from http.server import BaseHTTPRequestHandler, HTTPServer |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from huggingface_hub import HfApi, hf_hub_download |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src", "pretraining")) |
|
|
| from model import AutoRegressiveTransformer |
| from tokenizer import FENTokenizer |
|
|
| HF_PRETRAINED_REPO = "henribonamy/chess-puzzles-pretrained" |
| HF_DATA_REPO = "henribonamy/chess-puzzles-data" |
| PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt" |
| FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt" |
| DATA_PATH = "data/encoded_fens.npy" |
| INDICES_PATH = "data/counter_intuitive_indices.npy" |
|
|
| BATCH_SIZE = 128 |
| LR = 1e-6 |
| EPOCHS = 10 |
| LOG_INTERVAL = 20 |
| CHECKPOINT_INTERVAL = 300 |
|
|
| log_lines: list[str] = [] |
|
|
|
|
| class SubsetDataset(Dataset): |
| """Wraps encoded_fens.npy restricted to a set of indices.""" |
|
|
| def __init__(self, data: np.ndarray, indices: np.ndarray) -> None: |
| """Load subset of encoded FENs at the given indices.""" |
| self.data = data[indices] |
|
|
| def __len__(self) -> int: |
| """Return number of samples.""" |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int) -> torch.Tensor: |
| """Return encoded FEN as a tensor.""" |
| return torch.tensor(self.data[idx], dtype=torch.long) |
|
|
|
|
| def log(msg: str) -> None: |
| """Append message to log buffer and print.""" |
| ts = time.strftime("%H:%M:%S") |
| line = f"[{ts}] {msg}" |
| log_lines.append(line) |
| print(line, flush=True) |
|
|
|
|
| class LogHandler(BaseHTTPRequestHandler): |
| def do_GET(self) -> None: |
| """Serve accumulated log lines.""" |
| body = "\n".join(log_lines).encode() |
| self.send_response(200) |
| self.send_header("Content-Type", "text/plain") |
| self.end_headers() |
| self.wfile.write(body) |
|
|
| def log_message(self, *args) -> None: |
| """Suppress default HTTP access logs.""" |
| pass |
|
|
|
|
| def start_log_server() -> None: |
| """Start HTTP log server on port 7860 in a background thread.""" |
| server = HTTPServer(("0.0.0.0", 7860), LogHandler) |
| threading.Thread(target=server.serve_forever, daemon=True).start() |
| log("Log server started on port 7860") |
|
|
|
|
| def ensure_file(local_path: str, repo_id: str, filename: str, repo_type: str) -> None: |
| """Download file from HF Hub if not already present.""" |
| if os.path.exists(local_path): |
| return |
| os.makedirs(os.path.dirname(local_path) or ".", exist_ok=True) |
| log(f"Downloading {filename} from {repo_id}...") |
| hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=os.path.dirname(local_path)) |
| log(f"{filename} ready.") |
|
|
|
|
| def push_checkpoint(local_path: str) -> None: |
| """Upload fine-tuned checkpoint to HF pretrained repo.""" |
| api = HfApi() |
| api.create_repo(HF_PRETRAINED_REPO, repo_type="model", exist_ok=True) |
| api.upload_file( |
| path_or_fileobj=local_path, |
| path_in_repo="model_checkpoint_finetuned.pt", |
| repo_id=HF_PRETRAINED_REPO, |
| repo_type="model", |
| ) |
| log(f"Pushed fine-tuned checkpoint to {HF_PRETRAINED_REPO}") |
|
|
|
|
| def main() -> None: |
| """Fine-tune pretrained model on counter-intuitive puzzle positions.""" |
| start_log_server() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log(f"Using device: {device}") |
|
|
| ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset") |
| ensure_file(INDICES_PATH, HF_DATA_REPO, "counter_intuitive_indices.npy", "dataset") |
| ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model") |
|
|
| log("Loading data...") |
| all_fens = np.load(DATA_PATH) |
| indices = np.load(INDICES_PATH) |
| log(f"Total encoded FENs: {len(all_fens):,} | Counter-intuitive indices: {len(indices):,}") |
|
|
| dataset = SubsetDataset(all_fens, indices) |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True) |
| log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}") |
|
|
| model = AutoRegressiveTransformer().to(device) |
| ckpt = torch.load(PRETRAINED_CHECKPOINT_PATH, map_location=device) |
| model.load_state_dict(ckpt) |
| log("Loaded pretrained checkpoint.") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-6) |
| scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda")) |
|
|
| total_steps = 0 |
| for epoch in range(EPOCHS): |
| model.train() |
| running_loss = 0.0 |
| t0 = time.time() |
|
|
| for batch_idx, batch in enumerate(dataloader): |
| batch = batch.to(device) |
| inputs = batch[:, :-1] |
| targets = batch[:, 1:] |
|
|
| with torch.cuda.amp.autocast(enabled=(device.type == "cuda")): |
| _, loss = model(inputs, targets) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
|
|
| running_loss += loss.item() |
| total_steps += 1 |
|
|
| if total_steps % LOG_INTERVAL == 0: |
| elapsed = time.time() - t0 |
| avg_loss = running_loss / LOG_INTERVAL |
| pct = 100.0 * batch_idx / len(dataloader) |
| log(f"Epoch {epoch+1} | Step {total_steps} | {pct:.1f}% | Loss: {avg_loss:.4f} | {elapsed:.1f}s/{LOG_INTERVAL}steps") |
| running_loss = 0.0 |
| t0 = time.time() |
|
|
| if total_steps % CHECKPOINT_INTERVAL == 0: |
| os.makedirs("outputs", exist_ok=True) |
| torch.save(model.state_dict(), FINETUNED_CHECKPOINT_PATH) |
| push_checkpoint(FINETUNED_CHECKPOINT_PATH) |
| log(f"Checkpoint saved and pushed at step {total_steps}") |
|
|
| log(f"Epoch {epoch+1} complete.") |
|
|
| os.makedirs("outputs", exist_ok=True) |
| torch.save(model.state_dict(), FINETUNED_CHECKPOINT_PATH) |
| push_checkpoint(FINETUNED_CHECKPOINT_PATH) |
| log("Fine-tuning complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|