Upload finetune_hf.py with huggingface_hub
Browse files- finetune_hf.py +9 -9
finetune_hf.py
CHANGED
|
@@ -19,13 +19,13 @@ HF_DATA_REPO = "henribonamy/chess-puzzles-data"
|
|
| 19 |
PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt"
|
| 20 |
FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt"
|
| 21 |
DATA_PATH = "data/encoded_fens.npy"
|
| 22 |
-
|
| 23 |
|
| 24 |
BATCH_SIZE = 128
|
| 25 |
LR = 1e-6
|
| 26 |
-
EPOCHS =
|
| 27 |
-
LOG_INTERVAL =
|
| 28 |
-
CHECKPOINT_INTERVAL =
|
| 29 |
|
| 30 |
log_lines: list[str] = []
|
| 31 |
|
|
@@ -99,22 +99,22 @@ def push_checkpoint(local_path: str) -> None:
|
|
| 99 |
|
| 100 |
|
| 101 |
def main() -> None:
|
| 102 |
-
"""Fine-tune pretrained model on
|
| 103 |
start_log_server()
|
| 104 |
|
| 105 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 106 |
log(f"Using device: {device}")
|
| 107 |
|
| 108 |
ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset")
|
| 109 |
-
ensure_file(
|
| 110 |
ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model")
|
| 111 |
|
| 112 |
log("Loading data...")
|
| 113 |
all_fens = np.load(DATA_PATH)
|
| 114 |
-
|
| 115 |
-
log(f"Total encoded FENs: {len(all_fens):,} |
|
| 116 |
|
| 117 |
-
dataset = SubsetDataset(all_fens,
|
| 118 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
|
| 119 |
log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}")
|
| 120 |
|
|
|
|
| 19 |
PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt"
|
| 20 |
FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt"
|
| 21 |
DATA_PATH = "data/encoded_fens.npy"
|
| 22 |
+
INDICES_PATH = "data/counter_intuitive_indices.npy"
|
| 23 |
|
| 24 |
BATCH_SIZE = 128
|
| 25 |
LR = 1e-6
|
| 26 |
+
EPOCHS = 10
|
| 27 |
+
LOG_INTERVAL = 20
|
| 28 |
+
CHECKPOINT_INTERVAL = 300
|
| 29 |
|
| 30 |
log_lines: list[str] = []
|
| 31 |
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def main() -> None:
|
| 102 |
+
"""Fine-tune pretrained model on counter-intuitive puzzle positions."""
|
| 103 |
start_log_server()
|
| 104 |
|
| 105 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 106 |
log(f"Using device: {device}")
|
| 107 |
|
| 108 |
ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset")
|
| 109 |
+
ensure_file(INDICES_PATH, HF_DATA_REPO, "counter_intuitive_indices.npy", "dataset")
|
| 110 |
ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model")
|
| 111 |
|
| 112 |
log("Loading data...")
|
| 113 |
all_fens = np.load(DATA_PATH)
|
| 114 |
+
indices = np.load(INDICES_PATH)
|
| 115 |
+
log(f"Total encoded FENs: {len(all_fens):,} | Counter-intuitive indices: {len(indices):,}")
|
| 116 |
|
| 117 |
+
dataset = SubsetDataset(all_fens, indices)
|
| 118 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
|
| 119 |
log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}")
|
| 120 |
|