henribonamy commited on
Commit
5ddb674
·
verified ·
1 Parent(s): 04a61e1

Upload finetune_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. finetune_hf.py +9 -9
finetune_hf.py CHANGED
@@ -19,13 +19,13 @@ HF_DATA_REPO = "henribonamy/chess-puzzles-data"
19
  PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt"
20
  FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt"
21
  DATA_PATH = "data/encoded_fens.npy"
22
- HIGH_RATED_INDICES_PATH = "data/high_rated_indices.npy"
23
 
24
  BATCH_SIZE = 128
25
  LR = 1e-6
26
- EPOCHS = 1
27
- LOG_INTERVAL = 100
28
- CHECKPOINT_INTERVAL = 5000
29
 
30
  log_lines: list[str] = []
31
 
@@ -99,22 +99,22 @@ def push_checkpoint(local_path: str) -> None:
99
 
100
 
101
  def main() -> None:
102
- """Fine-tune pretrained model on high-rated puzzle positions."""
103
  start_log_server()
104
 
105
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
  log(f"Using device: {device}")
107
 
108
  ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset")
109
- ensure_file(HIGH_RATED_INDICES_PATH, HF_DATA_REPO, "high_rated_indices.npy", "dataset")
110
  ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model")
111
 
112
  log("Loading data...")
113
  all_fens = np.load(DATA_PATH)
114
- high_rated_indices = np.load(HIGH_RATED_INDICES_PATH)
115
- log(f"Total encoded FENs: {len(all_fens):,} | High-rated indices: {len(high_rated_indices):,}")
116
 
117
- dataset = SubsetDataset(all_fens, high_rated_indices)
118
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
119
  log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}")
120
 
 
19
  PRETRAINED_CHECKPOINT_PATH = "outputs/model_checkpoint.pt"
20
  FINETUNED_CHECKPOINT_PATH = "outputs/model_checkpoint_finetuned.pt"
21
  DATA_PATH = "data/encoded_fens.npy"
22
+ INDICES_PATH = "data/counter_intuitive_indices.npy"
23
 
24
  BATCH_SIZE = 128
25
  LR = 1e-6
26
+ EPOCHS = 10
27
+ LOG_INTERVAL = 20
28
+ CHECKPOINT_INTERVAL = 300
29
 
30
  log_lines: list[str] = []
31
 
 
99
 
100
 
101
  def main() -> None:
102
+ """Fine-tune pretrained model on counter-intuitive puzzle positions."""
103
  start_log_server()
104
 
105
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
  log(f"Using device: {device}")
107
 
108
  ensure_file(DATA_PATH, HF_DATA_REPO, "encoded_fens.npy", "dataset")
109
+ ensure_file(INDICES_PATH, HF_DATA_REPO, "counter_intuitive_indices.npy", "dataset")
110
  ensure_file(PRETRAINED_CHECKPOINT_PATH, HF_PRETRAINED_REPO, "model_checkpoint.pt", "model")
111
 
112
  log("Loading data...")
113
  all_fens = np.load(DATA_PATH)
114
+ indices = np.load(INDICES_PATH)
115
+ log(f"Total encoded FENs: {len(all_fens):,} | Counter-intuitive indices: {len(indices):,}")
116
 
117
+ dataset = SubsetDataset(all_fens, indices)
118
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
119
  log(f"Dataset size: {len(dataset):,} | Batches per epoch: {len(dataloader):,}")
120