chess-puzzles-training / finetune_hf.py
henribonamy's picture
Upload finetune_hf.py with huggingface_hub
5ddb674 verified
raw
history blame
6.17 kB
import sys
import os
import time
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from huggingface_hub import HfApi, hf_hub_download
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src", "pretraining"))
from model import AutoRegressiveTransformer
from tokenizer import FENTokenizer
HF_PRETRAINED_REPO = "henribonamy/chess-puzzles-pretrained"
HF_DATA_REPO = "henribonamy/chess-puzzles-data"
PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt"
FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt"
DATA_PATH = "data/encoded_fens.npy"
INDICES_PATH = "data/counter_intuitive_indices.npy"
BATCH_SIZE = 128
LR = 1e-6
EPOCHS = 10
LOG_INTERVAL = 20
CHECKPOINT_INTERVAL = 300
log_lines: list[str] = []
class SubsetDataset(Dataset):
"""Wraps encoded_fens.npy restricted to a set of indices."""
def __init__(self, data: np.ndarray, indices: np.ndarray) -> None:
"""Load subset of encoded FENs at the given indices."""
self.data = data[indices]
def __len__(self) -> int:
"""Return number of samples."""
return len(self.data)
def __getitem__(self, idx: int) -> torch.Tensor:
"""Return encoded FEN as a tensor."""
return torch.tensor(self.data[idx], dtype=torch.long)
def log(msg: str) -> None:
"""Append message to log buffer and print."""
ts = time.strftime("%H:%M:%S")
line = f"[{ts}] {msg}"
log_lines.append(line)
print(line, flush=True)
class LogHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
"""Serve accumulated log lines."""
body = "\n".join(log_lines).encode()
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
self.wfile.write(body)
def log_message(self, *args) -> None:
"""Suppress default HTTP access logs."""
pass
def start_log_server() -> None:
"""Start HTTP log server on port 7860 in a background thread."""
server = HTTPServer(("0.0.0.0", 7860), LogHandler)
threading.Thread(target=server.serve_forever, daemon=True).start()
log("Log server started on port 7860")
def ensure_file(local_path: str, repo_id: str, filename: str, repo_type: str) -> None:
"""Download file from HF Hub if not already present."""
if os.path.exists(local_path):
return
os.makedirs(os.path.dirname(local_path) or ".", exist_ok=True)
log(f"Downloading {filename} from {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=os.path.dirname(local_path))
log(f"{filename} ready.")
def push_checkpoint(local_path: str) -> None:
"""Upload fine-tuned checkpoint to HF pretrained repo."""
api = HfApi()
api.create_repo(HF_PRETRAINED_REPO, repo_type="model", exist_ok=True)
api.upload_file(
path_or_fileobj=local_path,
path_in_repo="model_checkpoint_finetuned.pt",
repo_id=HF_PRETRAINED_REPO,
repo_type="model",
)
log(f"Pushed fine-tuned checkpoint to {HF_PRETRAINED_REPO}")
def main() -> None:
"""Fine-tune pretrained model on counter-intuitive puzzle positions."""
start_log_server()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Using device: {device}")
ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset")
ensure_file(INDICES_PATH, HF_DATA_REPO, "counter_intuitive_indices.npy", "dataset")
ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model")
log("Loading data...")
all_fens = np.load(DATA_PATH)
indices = np.load(INDICES_PATH)
log(f"Total encoded FENs: {len(all_fens):,} | Counter-intuitive indices: {len(indices):,}")
dataset = SubsetDataset(all_fens, indices)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}")
model = AutoRegressiveTransformer().to(device)
ckpt = torch.load(PRETRAINED_CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt)
log("Loaded pretrained checkpoint.")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-6)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
total_steps = 0
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
t0 = time.time()
for batch_idx, batch in enumerate(dataloader):
batch = batch.to(device)
inputs = batch[:, :-1]
targets = batch[:, 1:]
with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
_, loss = model(inputs, targets)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
running_loss += loss.item()
total_steps += 1
if total_steps % LOG_INTERVAL == 0:
elapsed = time.time() - t0
avg_loss = running_loss / LOG_INTERVAL
pct = 100.0 * batch_idx / len(dataloader)
log(f"Epoch {epoch+1} | Step {total_steps} | {pct:.1f}% | Loss: {avg_loss:.4f} | {elapsed:.1f}s/{LOG_INTERVAL}steps")
running_loss = 0.0
t0 = time.time()
if total_steps % CHECKPOINT_INTERVAL == 0:
os.makedirs("outputs", exist_ok=True)
torch.save(model.state_dict(), FINETUNED_CHECKPOINT_PATH)
push_checkpoint(FINETUNED_CHECKPOINT_PATH)
log(f"Checkpoint saved and pushed at step {total_steps}")
log(f"Epoch {epoch+1} complete.")
os.makedirs("outputs", exist_ok=True)
torch.save(model.state_dict(), FINETUNED_CHECKPOINT_PATH)
push_checkpoint(FINETUNED_CHECKPOINT_PATH)
log("Fine-tuning complete.")
if __name__ == "__main__":
main()