TimurHromek commited on
Commit
2ed1fcb
·
verified ·
1 Parent(s): eeca056

Uploaded HROM-M1 model.

Browse files
HROM-M1.py ADDED
@@ -0,0 +1,1392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set parallelism env var *before* importing tokenizers
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F # Added for softmax in MoE
8
+ from torch.utils.data import Dataset, DataLoader
9
+ # Import necessary dataset functions, including concatenate_datasets if needed later
10
+ from datasets import load_dataset, disable_caching, concatenate_datasets
11
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
12
+ import math
13
+ import re
14
+ from datetime import datetime
15
+ from contextlib import nullcontext
16
+ from collections import defaultdict
17
+ import logging
18
+ import random # For shuffling combined data
19
+
20
+ # Disable caching for datasets if needed, helps ensure reprocessing
21
+ # disable_caching()
22
+
23
+ # Setup logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(levelname)s - %(message)s',
27
+ force=True
28
+ )
29
+
30
+ # Configuration
31
+ CONFIG = {
32
+ "dim": 768,
33
+ "n_layers": 8,
34
+ "n_heads": 8,
35
+ "ff_dim": 2048, # This will be the ff_dim for each expert
36
+ "dropout": 0.1,
37
+ "max_seq_len": 512,
38
+ "batch_size": 16,
39
+ "checkpoint_interval": 2000,
40
+ "debug_interval": 400,
41
+ "datasets": ["daily_dialog", "empathetic_dialogues", "blended_skill_talk", "AlekseyKorshuk/persona-chat", "papahawk/conversational-01"], # <-- Added papahawk/conversational-01
42
+ "tokenizer_name": "hrom_moe_tokenizer.json", # Changed tokenizer name for MoE version
43
+ "checkpoint_dir": "checkpoints_moe", # Changed checkpoint dir for MoE version
44
+ "vocab_size": 32000,
45
+ "tokenizer_train_samples_per_dataset": 50000, # Keep lower for faster testing if needed
46
+ "learning_rate": 2e-5,
47
+ "warmup_steps": 1000,
48
+ "max_turns": 8, # For multi-turn datasets, papahawk is treated as 2 turns
49
+ "max_checkpoints": 5,
50
+ "num_epochs": 30,
51
+ "grad_accum_steps": 8,
52
+
53
+ # --- MoE Specific Configuration ---
54
+ "num_experts": 8, # Number of experts in each MoE layer
55
+ "top_k_experts": 2, # Number of experts to route to for each token
56
+ "moe_load_balancing_coeff": 0.01 # Coefficient for the load balancing loss
57
+ }
58
+
59
+ # Ensure top_k is not more than num_experts
60
+ if CONFIG["top_k_experts"] > CONFIG["num_experts"]:
61
+ logging.warning(f"top_k_experts ({CONFIG['top_k_experts']}) > num_experts ({CONFIG['num_experts']}). Setting top_k_experts to num_experts.")
62
+ CONFIG["top_k_experts"] = CONFIG["num_experts"]
63
+
64
+
65
+ # --- Model Definition (HROM, HROMBlock_MoE, HROMAttention, SwiGLU, RoPE, Expert, MoELayer) ---
66
+
67
+ class RotaryEmbedding(nn.Module):
68
+ def __init__(self, dim):
69
+ super().__init__()
70
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
71
+ self.register_buffer("inv_freq", inv_freq)
72
+
73
+ def forward(self, seq_len):
74
+ t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
75
+ freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
76
+ if seq_len == 0:
77
+ return torch.empty((0, self.inv_freq.shape[0] * 2), device=self.inv_freq.device)
78
+ if freqs.shape[0] != seq_len and seq_len > 0:
79
+ freqs = freqs.reshape(seq_len, -1)
80
+ elif seq_len == 0:
81
+ return torch.empty((0, self.inv_freq.shape[0]*2), device=self.inv_freq.device, dtype=self.inv_freq.dtype)
82
+ return torch.cat((freqs, freqs), dim=-1)
83
+
84
+ def rotate_half(x):
85
+ x1, x2 = x.chunk(2, dim=-1)
86
+ return torch.cat((-x2, x1), dim=-1)
87
+
88
+ def apply_rotary_pos_emb(pos, t):
89
+ pos = pos.to(t.device, dtype=t.dtype)
90
+ pos = pos.unsqueeze(0).unsqueeze(1)
91
+ tensor_seq_len = t.shape[2]
92
+ pos_seq_len = pos.shape[2]
93
+
94
+ if pos_seq_len < tensor_seq_len:
95
+ logging.warning(f"RoPE Warning: pos sequence length ({pos_seq_len}) is shorter than tensor sequence length ({tensor_seq_len}). Using truncated tensor length for RoPE.")
96
+ t_rotated = t[:, :, :pos_seq_len, :]
97
+ pos = pos[:, :, :pos_seq_len, :]
98
+ cos_pos = pos.cos()
99
+ sin_pos = pos.sin()
100
+ t_rotated = (t_rotated * cos_pos) + (rotate_half(t_rotated) * sin_pos)
101
+ t_unrotated = t[:, :, pos_seq_len:, :]
102
+ return torch.cat([t_rotated, t_unrotated], dim=2)
103
+ elif pos_seq_len > tensor_seq_len:
104
+ pos = pos[:, :, :tensor_seq_len, :]
105
+
106
+ if pos.shape[-1] != t.shape[-1]:
107
+ logging.error(f"Mismatched dimensions for RoPE: pos ({pos.shape[-1]}) vs t ({t.shape[-1]})")
108
+ raise ValueError("Rotary embedding dimension must match head dimension.")
109
+
110
+ cos_pos = pos.cos()
111
+ sin_pos = pos.sin()
112
+ rotated_t = (t * cos_pos) + (rotate_half(t) * sin_pos)
113
+ return rotated_t
114
+
115
+
116
+ class SwiGLU(nn.Module):
117
+ def forward(self, x):
118
+ x, gate = x.chunk(2, dim=-1)
119
+ return x * nn.functional.gelu(gate) # Changed from F.silu(gate) to F.gelu(gate) to match original code
120
+
121
+ class HROMAttention(nn.Module):
122
+ def __init__(self):
123
+ super().__init__()
124
+ self.dim = CONFIG["dim"]
125
+ self.n_heads = CONFIG["n_heads"]
126
+ self.head_dim = self.dim // self.n_heads
127
+ if self.dim % self.n_heads != 0:
128
+ raise ValueError("dim must be divisible by n_heads")
129
+ self.qkv = nn.Linear(self.dim, 3 * self.dim)
130
+ self.proj = nn.Linear(self.dim, self.dim)
131
+ self.rotary = RotaryEmbedding(self.head_dim)
132
+ self.dropout = nn.Dropout(CONFIG["dropout"])
133
+
134
+ def forward(self, x, mask=None):
135
+ B, T, C = x.shape
136
+ qkv = self.qkv(x)
137
+ qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim)
138
+ q, k, v = qkv.unbind(2)
139
+ q = q.transpose(1, 2)
140
+ k = k.transpose(1, 2)
141
+ v = v.transpose(1, 2)
142
+ pos = self.rotary(T)
143
+ q = apply_rotary_pos_emb(pos, q)
144
+ k = apply_rotary_pos_emb(pos, k)
145
+ attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
146
+ if mask is not None:
147
+ if mask.dim() == 2:
148
+ mask = mask.unsqueeze(1).unsqueeze(2)
149
+ elif mask.dim() == 3:
150
+ mask = mask.unsqueeze(1)
151
+ attn_scores = attn_scores + mask
152
+ attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(dtype=x.dtype)
153
+ attn_probs = self.dropout(attn_probs)
154
+ output = attn_probs @ v
155
+ output = output.transpose(1, 2).reshape(B, T, self.dim)
156
+ return self.proj(output)
157
+
158
+ # --- MoE Components ---
159
+ class Expert(nn.Module):
160
+ """A simple feed-forward network for an expert in the MoE layer."""
161
+ def __init__(self, dim, ff_dim):
162
+ super().__init__()
163
+ self.fc1 = nn.Linear(dim, 2 * ff_dim) # Input to SwiGLU
164
+ self.activation = SwiGLU()
165
+ self.fc2 = nn.Linear(ff_dim, dim) # Output of SwiGLU feeds into this
166
+
167
+ def forward(self, x):
168
+ hidden = self.fc1(x)
169
+ activated_hidden = self.activation(hidden)
170
+ return self.fc2(activated_hidden)
171
+
172
+ class MoELayer(nn.Module):
173
+ """Mixture of Experts layer with top-k gating."""
174
+ def __init__(self, dim, ff_dim, num_experts, top_k, load_balancing_coeff):
175
+ super().__init__()
176
+ self.dim = dim
177
+ self.num_experts = num_experts
178
+ self.top_k = top_k
179
+ self.load_balancing_coeff = load_balancing_coeff
180
+
181
+ self.experts = nn.ModuleList([Expert(dim, ff_dim) for _ in range(num_experts)])
182
+ self.gate = nn.Linear(dim, num_experts)
183
+
184
+ def forward(self, x):
185
+ batch_size, seq_len, dim = x.shape
186
+ x_reshaped = x.reshape(-1, dim) # (B*T, C) or (num_tokens, dim)
187
+ num_tokens = x_reshaped.shape[0]
188
+
189
+ # 1. Gating mechanism
190
+ gate_logits = self.gate(x_reshaped) # (num_tokens, num_experts)
191
+ gate_probs = F.softmax(gate_logits, dim=-1) # (num_tokens, num_experts)
192
+
193
+ # 2. Select Top-K experts
194
+ top_k_gate_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1) # (num_tokens, top_k)
195
+
196
+ # Normalize top_k_gate_probs so they sum to 1 for each token's selected experts
197
+ top_k_weights_norm = top_k_gate_probs / (top_k_gate_probs.sum(dim=-1, keepdim=True) + 1e-6) # (num_tokens, top_k)
198
+
199
+ # 3. Dispatch tokens to experts and combine outputs
200
+ final_output = torch.zeros_like(x_reshaped) # (num_tokens, dim)
201
+
202
+ for i in range(self.num_experts):
203
+ # Find tokens routed to expert i
204
+ token_indices_for_expert_i, position_in_top_k = torch.where(top_k_indices == i)
205
+
206
+ if token_indices_for_expert_i.numel() > 0:
207
+ tokens_for_this_expert = x_reshaped[token_indices_for_expert_i] # (num_tokens_for_expert_i, dim)
208
+ # Get the normalized weights for these tokens for this expert
209
+ weights_for_this_expert = top_k_weights_norm[token_indices_for_expert_i, position_in_top_k] # (num_tokens_for_expert_i)
210
+
211
+ expert_output = self.experts[i](tokens_for_this_expert) # (num_tokens_for_expert_i, dim)
212
+ weighted_expert_output = expert_output * weights_for_this_expert.unsqueeze(-1)
213
+
214
+ # Accumulate weighted outputs
215
+ final_output.index_add_(0, token_indices_for_expert_i, weighted_expert_output.to(final_output.dtype))
216
+
217
+ # 4. Load balancing loss (Mixtral-style)
218
+ chosen_expert_mask = torch.zeros_like(gate_probs, device=x.device) # (num_tokens, num_experts)
219
+ chosen_expert_mask.scatter_(1, top_k_indices, 1) # Mark chosen experts with 1
220
+
221
+ fraction_tokens_per_expert = chosen_expert_mask.mean(dim=0) # (num_experts,)
222
+ mean_router_probs_per_expert = gate_probs.mean(dim=0) # (num_experts,)
223
+
224
+ load_balancing_loss = self.load_balancing_coeff * self.num_experts * \
225
+ torch.sum(fraction_tokens_per_expert * mean_router_probs_per_expert)
226
+
227
+ final_output = final_output.reshape(batch_size, seq_len, dim)
228
+ return final_output, load_balancing_loss
229
+
230
+
231
+ class HROMBlock(nn.Module):
232
+ def __init__(self):
233
+ super().__init__()
234
+ self.attn = HROMAttention()
235
+ self.moe_layer = MoELayer(
236
+ dim=CONFIG["dim"],
237
+ ff_dim=CONFIG["ff_dim"],
238
+ num_experts=CONFIG["num_experts"],
239
+ top_k=CONFIG["top_k_experts"],
240
+ load_balancing_coeff=CONFIG["moe_load_balancing_coeff"]
241
+ )
242
+ self.norm1 = nn.LayerNorm(CONFIG["dim"])
243
+ self.norm2 = nn.LayerNorm(CONFIG["dim"])
244
+ self.dropout = nn.Dropout(CONFIG["dropout"])
245
+
246
+ def forward(self, x, mask=None):
247
+ residual1 = x
248
+ normed_x1 = self.norm1(x)
249
+ attn_output = self.attn(normed_x1, mask)
250
+ x = residual1 + self.dropout(attn_output)
251
+
252
+ residual2 = x
253
+ normed_x2 = self.norm2(x)
254
+ ff_output, moe_aux_loss = self.moe_layer(normed_x2)
255
+ x = residual2 + self.dropout(ff_output)
256
+ return x, moe_aux_loss
257
+
258
+
259
+ class HROM(nn.Module):
260
+ def __init__(self):
261
+ super().__init__()
262
+ self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
263
+ self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
264
+ self.norm = nn.LayerNorm(CONFIG["dim"])
265
+ self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
266
+ self.dropout = nn.Dropout(CONFIG["dropout"])
267
+ self.apply(self._init_weights)
268
+
269
+ def _init_weights(self, module):
270
+ if isinstance(module, nn.Linear):
271
+ torch.nn.init.xavier_uniform_(module.weight)
272
+ if module.bias is not None:
273
+ torch.nn.init.zeros_(module.bias)
274
+ elif isinstance(module, nn.Embedding):
275
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
276
+ elif isinstance(module, nn.LayerNorm):
277
+ torch.nn.init.zeros_(module.bias)
278
+ torch.nn.init.ones_(module.weight)
279
+
280
+ def forward(self, input_ids, attention_mask=None):
281
+ B, T = input_ids.shape
282
+ x = self.embed(input_ids)
283
+ x = self.dropout(x)
284
+
285
+ combined_mask = None
286
+ causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device) * float('-inf'), diagonal=1)
287
+ combined_mask = causal_mask.unsqueeze(0).unsqueeze(1)
288
+
289
+ if attention_mask is not None:
290
+ pad_mask = (1.0 - attention_mask.to(torch.float32)) * torch.finfo(torch.float32).min
291
+ pad_mask = pad_mask.unsqueeze(1).unsqueeze(2)
292
+ combined_mask = combined_mask + pad_mask
293
+ combined_mask = combined_mask.to(dtype=x.dtype)
294
+
295
+ total_moe_aux_loss = 0.0
296
+ for block in self.blocks:
297
+ x, block_moe_aux_loss = block(x, combined_mask)
298
+ total_moe_aux_loss += block_moe_aux_loss
299
+
300
+ x = self.norm(x)
301
+ logits = self.head(x)
302
+
303
+ avg_moe_aux_loss = total_moe_aux_loss / CONFIG["n_layers"] if CONFIG["n_layers"] > 0 else 0.0
304
+ return logits, avg_moe_aux_loss
305
+
306
+
307
+ # --- Tokenizer Training ---
308
+ class TokenizerTrainer:
309
+ def __init__(self):
310
+ self.tokenizer = Tokenizer(models.BPE())
311
+ self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
312
+ self.tokenizer.decoder = decoders.ByteLevel()
313
+ self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
314
+ self.tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
315
+ self.tokenizer_dir = os.path.dirname(self.tokenizer_path)
316
+
317
+ def _clean_text(self, text):
318
+ text = str(text)
319
+ text = re.sub(r'_comma_', ',', text)
320
+ text = re.sub(r'[^\w\s.,!?\'\-:;<>"]', '', text) # Allow < and > for special tokens
321
+ text = re.sub(r'\s+', ' ', text).strip()
322
+ return text
323
+
324
+ def train(self, dataset_names):
325
+ logging.info("Starting tokenizer training...")
326
+ text_samples = []
327
+ samples_per_dataset = CONFIG['tokenizer_train_samples_per_dataset']
328
+
329
+ if "daily_dialog" in dataset_names:
330
+ logging.info(f"Loading daily_dialog for tokenizer training (max {samples_per_dataset} dialogues)...")
331
+ try:
332
+ dd_dataset = load_dataset("daily_dialog", split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
333
+ logging.info("Processing daily_dialog...")
334
+ for entry in dd_dataset:
335
+ formatted_dialogue = []
336
+ dialogue = entry['dialog'][:CONFIG["max_turns"]]
337
+ for i, utterance in enumerate(dialogue):
338
+ role = "<user>" if i % 2 == 0 else "<assistant>"
339
+ cleaned_utterance = self._clean_text(utterance)
340
+ if cleaned_utterance:
341
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
342
+ if formatted_dialogue:
343
+ text_samples.append(" </s> ".join(formatted_dialogue))
344
+ except Exception as e:
345
+ logging.error(f"Failed to load or process daily_dialog for tokenizer: {e}")
346
+
347
+ if "empathetic_dialogues" in dataset_names:
348
+ logging.info(f"Loading empathetic_dialogues for tokenizer training (max {samples_per_dataset} dialogues)...")
349
+ try:
350
+ # empathetic_dialogues is structured with multiple entries per conv_id
351
+ # So we need to fetch more raw entries to get `samples_per_dataset` actual conversations
352
+ ed_dataset = load_dataset("empathetic_dialogues", split=f"train[:{samples_per_dataset * 3}]", trust_remote_code=True) # Fetch more due to grouping
353
+ logging.info("Processing empathetic_dialogues...")
354
+ grouped_by_conv = defaultdict(list)
355
+ for entry in ed_dataset:
356
+ grouped_by_conv[entry['conv_id']].append(entry)
357
+
358
+ processed_conv_count = 0
359
+ for conv_id, entries in grouped_by_conv.items():
360
+ if processed_conv_count >= samples_per_dataset:
361
+ break
362
+ sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
363
+ formatted_dialogue = []
364
+ if sorted_entries[0]['context']:
365
+ cleaned_context = self._clean_text(sorted_entries[0]['context'])
366
+ if cleaned_context:
367
+ formatted_dialogue.append(f"<user> {cleaned_context}")
368
+ last_role = '<user>' if formatted_dialogue else None # Determine based on context
369
+ for entry in sorted_entries:
370
+ cleaned_utterance = self._clean_text(entry['utterance'])
371
+ if cleaned_utterance:
372
+ current_role = '<assistant>' if last_role == '<user>' else '<user>'
373
+ formatted_dialogue.append(f"{current_role} {cleaned_utterance}")
374
+ last_role = current_role
375
+ formatted_dialogue = formatted_dialogue[:CONFIG["max_turns"]]
376
+ if formatted_dialogue:
377
+ text_samples.append(" </s> ".join(formatted_dialogue))
378
+ processed_conv_count += 1
379
+ except Exception as e:
380
+ logging.error(f"Failed to load or process empathetic_dialogues for tokenizer: {e}")
381
+
382
+ if "blended_skill_talk" in dataset_names:
383
+ logging.info(f"Loading blended_skill_talk for tokenizer training (max {samples_per_dataset} dialogues)...")
384
+ try:
385
+ bst_dataset = load_dataset("blended_skill_talk", split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
386
+ logging.info("Processing blended_skill_talk...")
387
+ for entry in bst_dataset:
388
+ formatted_dialogue = []
389
+ # Correctly access turns including free_turker_utterance and guided_turker_utterance
390
+ dialogue_turns_raw = list(entry['previous_utterance']) # Make a mutable copy
391
+ if entry.get('free_turker_utterance'): # This is usually the user's last turn
392
+ dialogue_turns_raw.append(entry['free_turker_utterance'])
393
+ if entry.get('guided_turker_utterance'): # This is usually the system's last turn
394
+ dialogue_turns_raw.append(entry['guided_turker_utterance'])
395
+
396
+ turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]]
397
+ # BST turn structure: User, Bot, User, Bot ...
398
+ # The 'previous_utterance' list alternates.
399
+ # If 'free_turker_utterance' is present, it's a user turn.
400
+ # If 'guided_turker_utterance' is present, it's an agent turn.
401
+ # A common pattern is previous_utterance ends with Agent, then free_turker (User), then guided_turker (Agent).
402
+ # Let's assume simple alternation for the combined list.
403
+ for i, utterance in enumerate(turns_to_process):
404
+ role = "<user>" if i % 2 == 0 else "<assistant>" # This might need adjustment based on exact BST turn structure
405
+ cleaned_utterance = self._clean_text(utterance)
406
+ if cleaned_utterance:
407
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
408
+ if formatted_dialogue:
409
+ text_samples.append(" </s> ".join(formatted_dialogue))
410
+ except Exception as e:
411
+ logging.error(f"Failed to load or process blended_skill_talk for tokenizer: {e}")
412
+
413
+ if "AlekseyKorshuk/persona-chat" in dataset_names:
414
+ pc_dataset_name = "AlekseyKorshuk/persona-chat"
415
+ logging.info(f"Loading {pc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...")
416
+ try:
417
+ pc_dataset = load_dataset(pc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
418
+ logging.info(f"Processing {pc_dataset_name}...")
419
+ for entry in pc_dataset:
420
+ if 'utterances' in entry and entry['utterances']:
421
+ # Get the history from the last utterance object
422
+ history = entry['utterances'][-1]['history']
423
+ history = history[:CONFIG["max_turns"]] # Limit turns
424
+ formatted_dialogue = []
425
+ for i, utterance in enumerate(history):
426
+ role = "<user>" if i % 2 == 0 else "<assistant>" # Assuming alternating roles
427
+ cleaned_utterance = self._clean_text(utterance)
428
+ if cleaned_utterance:
429
+ formatted_dialogue.append(f"{role} {cleaned_utterance}")
430
+ if formatted_dialogue:
431
+ text_samples.append(" </s> ".join(formatted_dialogue))
432
+ else:
433
+ logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry}")
434
+ except Exception as e:
435
+ logging.error(f"Failed to load or process {pc_dataset_name} for tokenizer: {e}")
436
+
437
+ if "papahawk/conversational-01" in dataset_names:
438
+ ph_dataset_name = "papahawk/conversational-01"
439
+ logging.info(f"Loading {ph_dataset_name} for tokenizer training (max {samples_per_dataset} entries)...")
440
+ try:
441
+ ph_dataset = load_dataset(ph_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True)
442
+ logging.info(f"Processing {ph_dataset_name} for tokenizer...")
443
+ for entry in ph_dataset:
444
+ instruction = self._clean_text(entry.get('instruction', ''))
445
+ response = self._clean_text(entry.get('response', ''))
446
+
447
+ formatted_pair = []
448
+ if instruction:
449
+ formatted_pair.append(f"<user> {instruction}")
450
+ if response and instruction: # Only add assistant if there was a user part
451
+ formatted_pair.append(f"<assistant> {response}")
452
+
453
+ if len(formatted_pair) == 2: # Ensure we have a user-assistant pair
454
+ text_samples.append(" </s> ".join(formatted_pair))
455
+ elif len(formatted_pair) == 1: # If only user instruction
456
+ text_samples.append(formatted_pair[0]) # append "<user> instruction"
457
+ except Exception as e:
458
+ logging.error(f"Failed to load or process {ph_dataset_name} for tokenizer: {e}")
459
+
460
+ logging.info(f"Total text samples for tokenizer training: {len(text_samples)}")
461
+ if not text_samples:
462
+ raise ValueError("No text samples collected for tokenizer training. Check dataset loading and paths.")
463
+
464
+ os.makedirs(self.tokenizer_dir, exist_ok=True)
465
+ logging.info(f"Training BPE tokenizer with vocab size {CONFIG['vocab_size']}...")
466
+ trainer = trainers.BpeTrainer(
467
+ vocab_size=CONFIG["vocab_size"],
468
+ special_tokens=self.special_tokens,
469
+ min_frequency=2,
470
+ show_progress=True
471
+ )
472
+ def text_iterator():
473
+ for sample in text_samples:
474
+ yield sample
475
+ self.tokenizer.train_from_iterator(text_iterator(), trainer=trainer, length=len(text_samples))
476
+
477
+ eos_token_id = self.tokenizer.token_to_id("</s>")
478
+ if eos_token_id is None:
479
+ logging.warning("</s> token not found! Using <pad> as fallback for post-processor.")
480
+ eos_token_id = self.tokenizer.token_to_id("<pad>") or 0 # Ensure it's not None
481
+
482
+ self.tokenizer.post_processor = processors.TemplateProcessing(
483
+ single="$A </s>", # This adds </s> to single sequences
484
+ pair="$A </s> $B </s>", # This adds </s> to pairs
485
+ special_tokens=[("</s>", eos_token_id)],
486
+ )
487
+ logging.info(f"Saving tokenizer to {self.tokenizer_path}")
488
+ self.tokenizer.save(self.tokenizer_path)
489
+ logging.info("Tokenizer training complete.")
490
+
491
+ def get_tokenizer(self):
492
+ if not os.path.exists(self.tokenizer_path):
493
+ raise FileNotFoundError(f"Tokenizer file not found at {self.tokenizer_path}. Train tokenizer first.")
494
+ tokenizer = Tokenizer.from_file(self.tokenizer_path)
495
+ # Ensure all special tokens are actually in the tokenizer's vocab
496
+ required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
497
+ for token in required_tokens:
498
+ if tokenizer.token_to_id(token) is None:
499
+ # This is critical, if a special token isn't there, it can't be used.
500
+ raise ValueError(f"Crucial special token '{token}' not found in loaded tokenizer '{self.tokenizer_path}'!")
501
+ return tokenizer
502
+
503
+ # --- Dataset Loading and Processing ---
504
+ class CombinedChatDataset(Dataset):
505
+ def __init__(self, tokenizer):
506
+ self.tokenizer = tokenizer
507
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
508
+ self.eos_id = self.tokenizer.token_to_id("</s>")
509
+ self.bos_id = self.tokenizer.token_to_id("<s>")
510
+ self.user_id = self.tokenizer.token_to_id("<user>")
511
+ self.assistant_id = self.tokenizer.token_to_id("<assistant>")
512
+
513
+ if None in [self.pad_id, self.eos_id, self.bos_id, self.user_id, self.assistant_id]:
514
+ missing = [name for name, val in zip(["pad", "eos", "bos", "user", "assistant"], [self.pad_id, self.eos_id, self.bos_id, self.user_id, self.assistant_id]) if val is None]
515
+ raise ValueError(f"Tokenizer is missing critical special token IDs: {missing}. Tokenizer path: {self.tokenizer.model_path if hasattr(self.tokenizer, 'model_path') else 'N/A'}")
516
+
517
+ self.max_length = CONFIG["max_seq_len"]
518
+ self._clean_text = TokenizerTrainer()._clean_text # Use the same cleaning logic
519
+ self.all_processed_conversations = []
520
+
521
+ if "daily_dialog" in CONFIG["datasets"]:
522
+ logging.info("Loading and processing daily_dialog dataset...")
523
+ try:
524
+ dd_dataset = load_dataset("daily_dialog", split="train", trust_remote_code=True)
525
+ logging.info(f"Processing {len(dd_dataset)} daily_dialog conversations...")
526
+ for entry in dd_dataset:
527
+ conversation = []
528
+ dialogue = entry['dialog'][:CONFIG["max_turns"]]
529
+ if not dialogue: continue
530
+ for i, utterance in enumerate(dialogue):
531
+ role = "<user>" if i % 2 == 0 else "<assistant>"
532
+ cleaned_text = self._clean_text(utterance)
533
+ if cleaned_text:
534
+ conversation.append({'role': role, 'text': cleaned_text})
535
+ if conversation:
536
+ self.all_processed_conversations.append(conversation)
537
+ except Exception as e:
538
+ logging.error(f"Failed to load or process daily_dialog for training: {e}")
539
+
540
+ if "empathetic_dialogues" in CONFIG["datasets"]:
541
+ logging.info("Loading and processing empathetic_dialogues dataset...")
542
+ try:
543
+ ed_dataset = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True)
544
+ logging.info("Grouping empathetic_dialogues by conversation ID...")
545
+ conversations_grouped = defaultdict(list)
546
+ for entry in ed_dataset:
547
+ conversations_grouped[entry['conv_id']].append(entry)
548
+ logging.info(f"Processing {len(conversations_grouped)} empathetic_dialogues conversations...")
549
+ for conv_id, entries in conversations_grouped.items():
550
+ conversation = []
551
+ sorted_entries = sorted(entries, key=lambda x: x['utterance_idx'])
552
+ if sorted_entries[0]['context']:
553
+ context_text = self._clean_text(sorted_entries[0]['context'])
554
+ if context_text:
555
+ conversation.append({'role': '<user>', 'text': context_text})
556
+ last_role = conversation[-1]['role'] if conversation else None
557
+ for entry in sorted_entries:
558
+ text = self._clean_text(entry['utterance'])
559
+ if not text: continue
560
+ current_role = '<assistant>' if last_role == '<user>' else '<user>'
561
+ conversation.append({'role': current_role, 'text': text})
562
+ last_role = current_role
563
+ conversation = conversation[:CONFIG["max_turns"]]
564
+ if conversation:
565
+ self.all_processed_conversations.append(conversation)
566
+ except Exception as e:
567
+ logging.error(f"Failed to load or process empathetic_dialogues for training: {e}")
568
+
569
+ if "blended_skill_talk" in CONFIG["datasets"]:
570
+ logging.info("Loading and processing blended_skill_talk dataset...")
571
+ try:
572
+ bst_dataset = load_dataset("blended_skill_talk", split="train", trust_remote_code=True)
573
+ logging.info(f"Processing {len(bst_dataset)} blended_skill_talk conversations...")
574
+ for entry in bst_dataset:
575
+ conversation = []
576
+ dialogue_turns_raw = list(entry['previous_utterance'])
577
+ if entry.get('free_turker_utterance'):
578
+ dialogue_turns_raw.append(entry['free_turker_utterance'])
579
+ if entry.get('guided_turker_utterance'):
580
+ dialogue_turns_raw.append(entry['guided_turker_utterance'])
581
+ if not dialogue_turns_raw: continue
582
+
583
+ turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]]
584
+ for i, utterance in enumerate(turns_to_process):
585
+ # Simplified role assignment, assuming alternation.
586
+ # For BST, the exact roles might depend on how 'previous_utterance' mixes with 'free_turker' and 'guided_turker'.
587
+ # A common pattern: prev_utterance (alternating), free_turker_utterance (user), guided_turker_utterance (agent).
588
+ # This simplified alternation should be mostly correct for a combined list.
589
+ role = "<user>" if i % 2 == 0 else "<assistant>"
590
+ cleaned_text = self._clean_text(utterance)
591
+ if cleaned_text:
592
+ conversation.append({'role': role, 'text': cleaned_text})
593
+ if conversation:
594
+ self.all_processed_conversations.append(conversation)
595
+ except Exception as e:
596
+ logging.error(f"Failed to load or process blended_skill_talk for training: {e}")
597
+
598
+ if "AlekseyKorshuk/persona-chat" in CONFIG["datasets"]:
599
+ pc_dataset_name = "AlekseyKorshuk/persona-chat"
600
+ logging.info(f"Loading and processing {pc_dataset_name} dataset...")
601
+ try:
602
+ pc_dataset = load_dataset(pc_dataset_name, split="train", trust_remote_code=True)
603
+ logging.info(f"Processing {len(pc_dataset)} {pc_dataset_name} conversations...")
604
+ for entry in pc_dataset:
605
+ conversation = []
606
+ if 'utterances' in entry and entry['utterances']:
607
+ history = entry['utterances'][-1]['history']
608
+ history = history[:CONFIG["max_turns"]] # Limit turns
609
+ for i, utterance in enumerate(history):
610
+ role = "<user>" if i % 2 == 0 else "<assistant>"
611
+ cleaned_text = self._clean_text(utterance)
612
+ if cleaned_text:
613
+ conversation.append({'role': role, 'text': cleaned_text})
614
+ if conversation:
615
+ self.all_processed_conversations.append(conversation)
616
+ else:
617
+ logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry.keys()}")
618
+ except Exception as e:
619
+ logging.error(f"Failed to load or process {pc_dataset_name} for training: {e}")
620
+
621
+ if "papahawk/conversational-01" in CONFIG["datasets"]:
622
+ ph_dataset_name = "papahawk/conversational-01"
623
+ logging.info(f"Loading and processing {ph_dataset_name} dataset...")
624
+ try:
625
+ ph_dataset = load_dataset(ph_dataset_name, split="train", trust_remote_code=True)
626
+ logging.info(f"Processing {len(ph_dataset)} {ph_dataset_name} entries...")
627
+ for entry in ph_dataset:
628
+ instruction = self._clean_text(entry.get('instruction', ''))
629
+ response = self._clean_text(entry.get('response', ''))
630
+
631
+ if instruction and response: # Only process if both instruction and response exist
632
+ # Treat as a two-turn conversation
633
+ conversation = [
634
+ {'role': '<user>', 'text': instruction},
635
+ {'role': '<assistant>', 'text': response}
636
+ ]
637
+ # CONFIG["max_turns"] is not strictly applied here as each entry is 2 turns.
638
+ # If it were a multi-turn format from this dataset, truncation would apply.
639
+ self.all_processed_conversations.append(conversation)
640
+ # else:
641
+ # Optionally log skipped entries if instruction or response is missing
642
+ # logging.debug(f"Skipping entry from {ph_dataset_name} due to missing instruction or response.")
643
+ except Exception as e:
644
+ logging.error(f"Failed to load or process {ph_dataset_name} for training: {e}")
645
+
646
+
647
+ logging.info(f"Total processed conversations from all datasets: {len(self.all_processed_conversations)}")
648
+ if not self.all_processed_conversations:
649
+ raise ValueError("No processed conversations were created from any dataset. Check dataset paths and processing logic.")
650
+ logging.info("Shuffling combined dataset...")
651
+ random.shuffle(self.all_processed_conversations)
652
+
653
+ def __len__(self):
654
+ return len(self.all_processed_conversations)
655
+
656
+ def __getitem__(self, idx):
657
+ conversation = self.all_processed_conversations[idx]
658
+ formatted_ids = [self.bos_id] # Start with BOS
659
+
660
+ for turn in conversation:
661
+ role_id = self.user_id if turn['role'] == '<user>' else self.assistant_id
662
+ try:
663
+ # Encode utterance without adding any special tokens (like BOS/EOS) automatically by tokenizer.encode
664
+ utterance_ids = self.tokenizer.encode(turn['text'], add_special_tokens=False).ids
665
+ except Exception as e:
666
+ logging.error(f"Error encoding text at index {idx}, turn '{turn}': {e}")
667
+ utterance_ids = [] # Skip problematic turn
668
+
669
+ # Check space: current length + role_id + utterance_ids + eos_id
670
+ if len(formatted_ids) + 1 + len(utterance_ids) + 1 > self.max_length:
671
+ # If only role + EOS can fit, add them and break
672
+ if len(formatted_ids) + 1 + 1 <= self.max_length:
673
+ formatted_ids.append(role_id)
674
+ formatted_ids.append(self.eos_id) # Add EOS if this is the last possible token
675
+ break # Sequence is full
676
+
677
+ formatted_ids.append(role_id)
678
+ formatted_ids.extend(utterance_ids)
679
+ formatted_ids.append(self.eos_id) # Add EOS after each turn
680
+
681
+ # Truncate if still too long (e.g. if last utterance was very long)
682
+ if len(formatted_ids) > self.max_length:
683
+ formatted_ids = formatted_ids[:self.max_length]
684
+ # Ensure last token is not a role_id if truncated abruptly
685
+ if formatted_ids and (formatted_ids[-1] == self.user_id or formatted_ids[-1] == self.assistant_id):
686
+ formatted_ids.pop() # Remove trailing role_id
687
+ # Re-check length and ensure it ends with EOS if possible and space allows
688
+ if formatted_ids and formatted_ids[-1] != self.eos_id:
689
+ if len(formatted_ids) == self.max_length: # If full, replace last token with EOS
690
+ formatted_ids[-1] = self.eos_id
691
+ # elif len(formatted_ids) < self.max_length: # If space, append EOS
692
+ # formatted_ids.append(self.eos_id) # This case is less likely due to above logic
693
+
694
+ # Ensure sequence has at least BOS and one other token before slicing for input/label
695
+ if len(formatted_ids) < 2: # e.g., only [bos_id] or [bos_id, eos_id] after truncation
696
+ logging.warning(f"Sequence at index {idx} is too short after processing (<2 tokens): {formatted_ids}. Skipping.")
697
+ # Try to return a minimal valid item to avoid None, or handle None in collate_fn
698
+ # For now, let collate_fn handle potential Nones.
699
+ return None
700
+
701
+ input_ids = formatted_ids[:-1]
702
+ labels = formatted_ids[1:]
703
+
704
+ if len(input_ids) == 0: # Should be caught by len(formatted_ids) < 2
705
+ logging.warning(f"Sequence at index {idx} resulted in empty input_ids after slicing. Skipping.")
706
+ return None
707
+
708
+ return {"input_ids": input_ids, "labels": labels}
709
+
710
+ @staticmethod
711
+ def collate_fn(batch):
712
+ batch = [item for item in batch if item is not None] # Filter out None items
713
+ if not batch:
714
+ # If all items in batch were None, return None or an empty dict
715
+ # to be handled by the training loop
716
+ logging.warning("Collate_fn received an entirely empty batch after filtering Nones.")
717
+ return None # Or: return {"input_ids": torch.empty(0), "labels": torch.empty(0), "attention_mask": torch.empty(0)}
718
+
719
+ max_len = 0
720
+ for item in batch:
721
+ if "input_ids" in item and len(item["input_ids"]) > max_len:
722
+ max_len = len(item["input_ids"])
723
+
724
+ if max_len == 0: # If all valid items had empty input_ids (should not happen with __getitem__ checks)
725
+ logging.warning("Collate_fn: max_len is 0 after processing batch items.")
726
+ return None
727
+
728
+ try:
729
+ # It's better to pass pad_id or get it from a global config/tokenizer instance
730
+ # rather than reloading from file in every collate_fn call.
731
+ # For simplicity, keeping current structure but flagging as potential optimization.
732
+ tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"])
733
+ # This can be slow if called frequently. Consider passing tokenizer/pad_id.
734
+ tokenizer = Tokenizer.from_file(tokenizer_path)
735
+ pad_id = tokenizer.token_to_id("<pad>")
736
+ if pad_id is None: raise ValueError("<pad> token not found in tokenizer for collate_fn")
737
+ except Exception as e:
738
+ logging.error(f"Collate Error: Failed to load tokenizer or get pad_id ('{CONFIG['tokenizer_name']}'): {e}. Using pad_id=0 as fallback.")
739
+ pad_id = 0 # Fallback pad_id
740
+
741
+ inputs, labels, masks = [], [], []
742
+ for item in batch:
743
+ input_len = len(item["input_ids"])
744
+ pad_len = max_len - input_len
745
+
746
+ inputs.append(item["input_ids"] + [pad_id] * pad_len)
747
+ labels.append(item["labels"] + [pad_id] * pad_len) # Use pad_id for labels too for CrossEntropyLoss ignore_index
748
+ masks.append([1] * input_len + [0] * pad_len)
749
+
750
+ return {
751
+ "input_ids": torch.tensor(inputs, dtype=torch.long),
752
+ "labels": torch.tensor(labels, dtype=torch.long),
753
+ "attention_mask": torch.tensor(masks, dtype=torch.long)
754
+ }
755
+
756
+ # --- Trainer, Safety Manager, Checkpoint Manager ---
757
+
758
+ class HROMTrainer:
759
+ def __init__(self, model, tokenizer):
760
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
761
+ logging.info(f"Using device: {self.device}")
762
+ self.model = model.to(self.device)
763
+
764
+ self.use_amp = (self.device.type == "cuda" and hasattr(torch.cuda.amp, "GradScaler"))
765
+ self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
766
+ logging.info(f"Automatic Mixed Precision (AMP): {'Enabled' if self.use_amp else 'Disabled'}")
767
+
768
+ self.optimizer = torch.optim.AdamW(
769
+ self.model.parameters(),
770
+ lr=CONFIG["learning_rate"],
771
+ betas=(0.9, 0.95),
772
+ weight_decay=0.1,
773
+ fused= (self.device.type == "cuda") # fused=True can be faster on CUDA
774
+ )
775
+ self.tokenizer = tokenizer # Store the tokenizer instance
776
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
777
+ if self.pad_id is None:
778
+ # This should ideally not happen if tokenizer loading is robust
779
+ self.pad_id = CONFIG.get("pad_token_id", 0) # Fallback from global config if available
780
+ logging.warning(f"<pad> token ID not found in provided tokenizer, using fallback ID: {self.pad_id}")
781
+
782
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
783
+ self.base_lr = CONFIG["learning_rate"]
784
+ self.warmup_steps = CONFIG["warmup_steps"]
785
+
786
+ def _adjust_learning_rate(self, step):
787
+ if self.warmup_steps > 0 and step < self.warmup_steps:
788
+ lr = self.base_lr * (step + 1) / self.warmup_steps
789
+ else:
790
+ # Optional: Add cosine decay after warmup
791
+ # progress = (step - self.warmup_steps) / max(1, total_steps - self.warmup_steps)
792
+ # lr = self.base_lr * (0.5 * (1.0 + math.cos(math.pi * progress)))
793
+ lr = self.base_lr # For now, constant after warmup
794
+ for param_group in self.optimizer.param_groups:
795
+ param_group['lr'] = lr
796
+ return lr
797
+
798
+ def train_step(self, batch):
799
+ if self.use_amp:
800
+ amp_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
801
+ # Ensure nullcontext is properly used if AMP is disabled
802
+ autocast_context = torch.cuda.amp.autocast(dtype=amp_dtype, enabled=self.use_amp) if self.use_amp else nullcontext()
803
+
804
+ with autocast_context:
805
+ input_ids = batch["input_ids"].to(self.device)
806
+ attention_mask = batch["attention_mask"].to(self.device)
807
+ labels = batch["labels"].to(self.device)
808
+
809
+ outputs, moe_aux_loss = self.model(input_ids, attention_mask=attention_mask)
810
+
811
+ logits_flat = outputs.view(-1, outputs.size(-1))
812
+ labels_flat = labels.view(-1)
813
+
814
+ # Ensure logits are float32 for CrossEntropyLoss if using AMP with float16
815
+ # For bfloat16, this might not be strictly necessary but doesn't hurt.
816
+ main_loss = self.criterion(logits_flat.float(), labels_flat)
817
+ total_loss = main_loss + moe_aux_loss
818
+
819
+ # Scale loss for gradient accumulation
820
+ scaled_loss = total_loss / CONFIG["grad_accum_steps"]
821
+
822
+ if self.use_amp and self.scaler:
823
+ self.scaler.scale(scaled_loss).backward()
824
+ else:
825
+ scaled_loss.backward()
826
+
827
+ return main_loss.item(), moe_aux_loss.item() # Return unscaled losses for logging
828
+
829
+ def clip_and_step(self, current_optimizer_step): # Renamed step to current_optimizer_step for clarity
830
+ current_lr = self._adjust_learning_rate(current_optimizer_step) # Pass optimizer step for LR scheduling
831
+ if self.use_amp and self.scaler:
832
+ self.scaler.unscale_(self.optimizer) # Unscale before clipping
833
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
834
+ self.scaler.step(self.optimizer)
835
+ self.scaler.update()
836
+ else:
837
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
838
+ self.optimizer.step()
839
+ self.optimizer.zero_grad(set_to_none=True) # More memory efficient
840
+ return current_lr
841
+
842
+
843
+ class SafetyManager:
844
+ def __init__(self, model, tokenizer):
845
+ self.model = model
846
+ self.tokenizer = tokenizer
847
+ self.bad_words = ["kill", "murder", "suicide", "hate", "abuse", "violence", "illegal", "harm", "die", "attack", "rape", "molest", "exploit", "terror"]
848
+ self.bad_word_ids = []
849
+ logging.info("Initializing safety manager...")
850
+ for word in self.bad_words:
851
+ # Try encoding with and without leading space as tokenization can vary
852
+ ids_with_space = tokenizer.encode(f" {word}", add_special_tokens=False).ids
853
+ if ids_with_space:
854
+ self.bad_word_ids.append(ids_with_space)
855
+ logging.debug(f"Encoded bad word ' {word}' to IDs: {ids_with_space}")
856
+
857
+ ids_no_space = tokenizer.encode(word, add_special_tokens=False).ids
858
+ if ids_no_space and ids_no_space != ids_with_space: # Avoid duplicates if space makes no difference
859
+ self.bad_word_ids.append(ids_no_space)
860
+ logging.debug(f"Encoded bad word '{word}' to IDs: {ids_no_space}")
861
+
862
+ if not ids_with_space and not ids_no_space:
863
+ logging.warning(f"Could not encode bad word '{word}' - skipping.")
864
+
865
+ # Get critical token IDs
866
+ self.eos_id = self.tokenizer.token_to_id("</s>")
867
+ self.bos_id = self.tokenizer.token_to_id("<s>")
868
+ self.user_id = self.tokenizer.token_to_id("<user>")
869
+ self.assistant_id = self.tokenizer.token_to_id("<assistant>")
870
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
871
+
872
+ # Log errors if critical tokens are missing
873
+ if self.eos_id is None: logging.error("</s> token ID not found in SafetyManager!"); self.eos_id = 0 # Fallback
874
+ if self.bos_id is None: logging.error("<s> token ID not found in SafetyManager!"); self.bos_id = 0 # Fallback
875
+ if self.user_id is None: logging.error("<user> token ID not found in SafetyManager!")
876
+ if self.assistant_id is None: logging.error("<assistant> token ID not found in SafetyManager!")
877
+ if self.pad_id is None: logging.error("<pad> token ID not found in SafetyManager!"); self.pad_id = 0 # Fallback
878
+
879
+ def contains_sequence(self, tokens, seq):
880
+ if not seq or not tokens or len(tokens) < len(seq): return False
881
+ seq_len = len(seq)
882
+ for i in range(len(tokens) - seq_len + 1):
883
+ if tokens[i : i + seq_len] == seq: return True
884
+ return False
885
+
886
+ def content_filter(self, text_ids):
887
+ if not isinstance(text_ids, list):
888
+ logging.warning(f"Content filter received non-list input: {type(text_ids)}")
889
+ return True # Default to safe if input is unexpected
890
+ for bad_ids in self.bad_word_ids:
891
+ if self.contains_sequence(text_ids, bad_ids):
892
+ try:
893
+ detected_word = self.tokenizer.decode(bad_ids)
894
+ except Exception:
895
+ detected_word = "unknown (decoding error)"
896
+ logging.warning(f"Unsafe content detected: Found sequence for '{detected_word}' (IDs: {bad_ids}). Blocking generation.")
897
+ return False # Unsafe
898
+ return True # Safe
899
+
900
+ def generate_safely(self, prompt, max_new_tokens=50, temperature=0.7, top_k=50):
901
+ self.model.eval()
902
+ device = next(self.model.parameters()).device
903
+
904
+ # Ensure prompt starts with BOS, add user/assistant roles correctly
905
+ # Example prompt: "<user> Hello there! </s>"
906
+ # Or simply "Hello there!" -> will be wrapped.
907
+
908
+ # Tokenize the input prompt
909
+ # Remove <s> if already present, as we add it.
910
+ if prompt.startswith(self.tokenizer.decode([self.bos_id])):
911
+ prompt = prompt[len(self.tokenizer.decode([self.bos_id])):].strip()
912
+
913
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False).ids
914
+ input_ids = [self.bos_id] + prompt_ids
915
+
916
+ # Ensure the prompt ends with an assistant token to cue the model for response
917
+ if self.assistant_id is not None:
918
+ if not input_ids or input_ids[-1] != self.assistant_id: # Check if last token is already assistant
919
+ # Also check if it ends with user_id, if so, add assistant
920
+ if input_ids and input_ids[-1] == self.user_id:
921
+ input_ids.append(self.assistant_id)
922
+ elif input_ids and input_ids[-1] == self.eos_id: # e.g. <user> text </s>
923
+ input_ids.append(self.assistant_id)
924
+ elif not input_ids: # Empty prompt
925
+ input_ids.extend([self.user_id, self.eos_id, self.assistant_id]) # Default to user -> assistant
926
+ else: # General case, append assistant
927
+ input_ids.append(self.assistant_id)
928
+ else:
929
+ logging.error("Assistant token ID is None. Cannot properly cue model for generation.")
930
+ return "Error: Assistant token not found."
931
+
932
+ generated_ids = list(input_ids) # Start with the prepared input_ids
933
+ logging.debug(f"Starting safe generation with initial IDs: {generated_ids} (decoded: '{self.tokenizer.decode(generated_ids)}')")
934
+
935
+ with torch.no_grad():
936
+ for step in range(max_new_tokens):
937
+ # Prepare input tensor, ensuring it fits max_seq_len
938
+ current_input_ids_trimmed = generated_ids[-CONFIG["max_seq_len"]:]
939
+ current_input_tensor = torch.tensor([current_input_ids_trimmed], device=device)
940
+ attention_mask = torch.ones_like(current_input_tensor, device=device)
941
+
942
+ try:
943
+ outputs, _ = self.model(current_input_tensor, attention_mask=attention_mask) # Ignore aux_loss
944
+ next_token_logits = outputs[:, -1, :] # Logits for the last token
945
+ except Exception as e:
946
+ logging.error(f"Model forward pass failed during generation: {e}", exc_info=True)
947
+ break # Stop generation on error
948
+
949
+ # Apply temperature
950
+ if temperature > 0 and temperature != 1.0: # Avoid division by zero or no-op
951
+ next_token_logits = next_token_logits / temperature
952
+
953
+ # Apply top-k filtering
954
+ if top_k > 0 and top_k < next_token_logits.size(-1):
955
+ v, _ = torch.topk(next_token_logits, top_k, dim=-1)
956
+ # Create a mask for tokens not in top-k
957
+ # Ensure threshold is taken from the correct dimension if batch_size > 1 (here B=1)
958
+ threshold_val = v[:, -1].unsqueeze(-1) # Get the k-th largest logit value
959
+ # Set logits not in top-k to -inf
960
+ next_token_logits = next_token_logits.masked_fill(next_token_logits < threshold_val, -float('Inf'))
961
+
962
+ # Get probabilities and sample
963
+ probs = torch.softmax(next_token_logits, dim=-1)
964
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
965
+ logging.warning(f"NaN/Inf detected in probabilities at step {step}. Using uniform distribution as fallback.")
966
+ # Fallback to uniform distribution over the vocabulary
967
+ probs = torch.ones_like(probs) / probs.size(-1)
968
+
969
+ next_token_id = torch.multinomial(probs, num_samples=1).item()
970
+
971
+ # --- Safety Check BEFORE appending the token ---
972
+ # Check the potential sequence if this token were added
973
+ potential_sequence_for_check = generated_ids[len(input_ids):] + [next_token_id] # Check only the generated part
974
+ if not self.content_filter(potential_sequence_for_check):
975
+ logging.warning(f"Unsafe token ID {next_token_id} ('{self.tokenizer.decode([next_token_id])}') blocked PRE-APPEND. Stopping generation.")
976
+ # Optionally, try to sample a different token or end generation.
977
+ # For now, just stop.
978
+ break
979
+
980
+ generated_ids.append(next_token_id)
981
+
982
+ if next_token_id == self.eos_id:
983
+ logging.debug(f"EOS token ({self.eos_id}) generated at step {step+1}. Stopping.")
984
+ break
985
+ if step == max_new_tokens - 1: # Max length reached
986
+ logging.debug("Max new tokens reached.")
987
+ if generated_ids[-1] != self.eos_id and self.eos_id is not None:
988
+ generated_ids.append(self.eos_id) # Append EOS if not already there
989
+ self.model.train() # Set model back to training mode
990
+
991
+ # Extract only the generated response part (after the initial input_ids)
992
+ response_ids = generated_ids[len(input_ids):]
993
+ # Decode, skipping special tokens like <s>, </s>, <user>, <assistant> in the final output string
994
+ decoded_text = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip()
995
+ return decoded_text
996
+
997
+ def debug_generation(self, prompt="<user> Tell me about your hobbies. </s>"): # Ensure prompt ends with </s> for consistency
998
+ logging.info(f"\n--- Debug Generation & Safety Check ---")
999
+ # Standardize prompt format slightly for consistency in logging/testing
1000
+ if not prompt.strip().startswith("<user>") and not prompt.strip().startswith("<assistant>"):
1001
+ prompt = f"<user> {prompt.strip()}" # Default to user prompt
1002
+ if not prompt.strip().endswith("</s>"):
1003
+ prompt = f"{prompt.strip()} </s>"
1004
+
1005
+ # The generate_safely method handles BOS and final assistant cueing.
1006
+ generated_response = self.generate_safely(prompt, max_new_tokens=60, temperature=0.7, top_k=50)
1007
+ logging.info(f"Prompt Sent: '{prompt}'")
1008
+ logging.info(f"Generated Response: '{generated_response}'")
1009
+ logging.info(f"--- End Debug Generation ---\n")
1010
+
1011
+
1012
+ class CheckpointManager:
1013
+ def __init__(self):
1014
+ self.checkpoint_dir = CONFIG["checkpoint_dir"]
1015
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
1016
+ logging.info(f"Checkpoint directory set to: {self.checkpoint_dir}")
1017
+
1018
+ def save(self, model, optimizer, step_info): # step_info can be int or string like "epochX_stepY"
1019
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1020
+ # Extract base name for MoE checkpoints, e.g., "moe" from "checkpoints_moe"
1021
+ prefix_base = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
1022
+ step_str = str(step_info).replace(" ", "_") # Sanitize step_info for filename
1023
+
1024
+ filename = f"hrom_{prefix_base}_step{step_str}_{timestamp}.pt"
1025
+ path = os.path.join(self.checkpoint_dir, filename)
1026
+ state = {
1027
+ "model": model.state_dict(),
1028
+ "optimizer": optimizer.state_dict(),
1029
+ "step_info": step_info, # Store the original step_info
1030
+ "config": CONFIG # Save current config for reference
1031
+ }
1032
+ logging.info(f"Saving checkpoint to {path}...")
1033
+ try:
1034
+ torch.save(state, path)
1035
+ logging.info(f"Checkpoint saved successfully: {filename}")
1036
+ self._cleanup_old_checkpoints()
1037
+ except Exception as e:
1038
+ logging.error(f"Failed to save checkpoint '{path}': {e}", exc_info=True)
1039
+
1040
+ def _parse_step_from_filename(self, filename_part):
1041
+ # Tries to extract a numerical step from strings like "12000" or "epoch3_step12000"
1042
+ match_epoch_step = re.search(r'epoch\d+_step(\d+)', filename_part)
1043
+ if match_epoch_step:
1044
+ return int(match_epoch_step.group(1))
1045
+ match_step = re.search(r'(\d+)', filename_part)
1046
+ if match_step:
1047
+ return int(match_step.group(1))
1048
+ return 0 # Fallback if no numeric step found
1049
+
1050
+ def _cleanup_old_checkpoints(self):
1051
+ max_checkpoints = CONFIG.get("max_checkpoints", 5)
1052
+ if max_checkpoints <= 0: return # Disabled
1053
+
1054
+ try:
1055
+ prefix_base = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
1056
+ # Regex to match checkpoint filenames: hrom_moe_step(digits or epochX_stepY or final_stepZ)_timestamp.pt
1057
+ pattern_str = rf"hrom_{re.escape(prefix_base)}_step([\w\d_]+)_(\d{{8}}_\d{{6}})\.pt"
1058
+ pattern = re.compile(pattern_str)
1059
+
1060
+ checkpoints = []
1061
+ for f_name in os.listdir(self.checkpoint_dir):
1062
+ match = pattern.match(f_name)
1063
+ if match:
1064
+ filepath = os.path.join(self.checkpoint_dir, f_name)
1065
+ # Use file modification time for sorting actual save time
1066
+ # Step info can be used for tie-breaking or specific logic if needed
1067
+ checkpoints.append((filepath, os.path.getmtime(filepath)))
1068
+
1069
+ # Sort by modification time (oldest first)
1070
+ checkpoints.sort(key=lambda x: x[1])
1071
+
1072
+ num_to_delete = len(checkpoints) - max_checkpoints
1073
+ if num_to_delete > 0:
1074
+ logging.info(f"Max checkpoints ({max_checkpoints}) reached. Deleting {num_to_delete} oldest ones.")
1075
+ for i in range(num_to_delete):
1076
+ file_to_remove, _ = checkpoints[i]
1077
+ try:
1078
+ os.remove(file_to_remove)
1079
+ logging.info(f"Removed old checkpoint: {file_to_remove}")
1080
+ except OSError as e:
1081
+ logging.error(f"Error removing old checkpoint {file_to_remove}: {e}")
1082
+ except Exception as e:
1083
+ logging.error(f"Error during checkpoint cleanup: {e}", exc_info=True)
1084
+
1085
+ def load_latest(self, model, optimizer):
1086
+ try:
1087
+ prefix_base = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "")
1088
+ pattern_str = rf"hrom_{re.escape(prefix_base)}_step([\w\d_]+)_(\d{{8}}_\d{{6}})\.pt"
1089
+ pattern = re.compile(pattern_str)
1090
+
1091
+ checkpoints = []
1092
+ for f_name in os.listdir(self.checkpoint_dir):
1093
+ match = pattern.match(f_name)
1094
+ if match:
1095
+ filepath = os.path.join(self.checkpoint_dir, f_name)
1096
+ # Use modification time to find the truly latest saved file
1097
+ checkpoints.append((filepath, os.path.getmtime(filepath), match.group(1))) # path, mtime, step_info_str
1098
+
1099
+ if not checkpoints:
1100
+ logging.info(f"No valid checkpoints found in '{self.checkpoint_dir}' matching pattern. Starting fresh.")
1101
+ return 0 # No checkpoint to load, start from step 0
1102
+
1103
+ # Sort by modification time (newest first)
1104
+ checkpoints.sort(key=lambda x: x[1], reverse=True)
1105
+ latest_checkpoint_path, _, latest_step_info_str = checkpoints[0]
1106
+
1107
+ logging.info(f"Loading latest checkpoint from: {latest_checkpoint_path}")
1108
+ map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1109
+ checkpoint = torch.load(latest_checkpoint_path, map_location=map_location)
1110
+
1111
+ loaded_config = checkpoint.get("config", {})
1112
+ critical_keys = ["dim", "n_layers", "n_heads", "ff_dim", "vocab_size", "max_seq_len",
1113
+ "tokenizer_name", "num_experts", "top_k_experts"]
1114
+ if loaded_config:
1115
+ mismatched_keys = []
1116
+ for key in critical_keys:
1117
+ loaded_val = loaded_config.get(key)
1118
+ current_val = CONFIG.get(key)
1119
+ if loaded_val != current_val: # Handles cases where key might be missing in one
1120
+ mismatched_keys.append((key, loaded_val, current_val))
1121
+ if mismatched_keys:
1122
+ logging.warning("--- CONFIG MISMATCH DETECTED (Loading Checkpoint) ---")
1123
+ for key, loaded_val, current_val in mismatched_keys:
1124
+ logging.warning(f" - {key}: Checkpoint='{loaded_val}', Current='{current_val}'")
1125
+ logging.warning("Proceeding with loading, but this may impact model performance or cause errors if critical arch params changed.")
1126
+ else:
1127
+ logging.warning("Checkpoint does not contain configuration info. Cannot check for mismatches.")
1128
+
1129
+ try:
1130
+ model.load_state_dict(checkpoint['model'], strict=True)
1131
+ except RuntimeError as e:
1132
+ logging.error(f"Failed to load model state_dict: {e}. This often happens if model architecture changed or vocab_size is different. Starting fresh.")
1133
+ return 0 # Cannot recover model state, start fresh
1134
+
1135
+ try:
1136
+ optimizer.load_state_dict(checkpoint['optimizer'])
1137
+ # Move optimizer states to current device
1138
+ for state_val in optimizer.state.values(): # state is a defaultdict
1139
+ for k, v in state_val.items():
1140
+ if isinstance(v, torch.Tensor):
1141
+ try:
1142
+ state_val[k] = v.to(map_location)
1143
+ except Exception as e_opt_move:
1144
+ logging.error(f"Failed to move optimizer tensor '{k}' to device: {e_opt_move}")
1145
+ except Exception as e:
1146
+ logging.warning(f"Could not load optimizer state_dict: {e}. Optimizer state will be reset.")
1147
+ # Reset optimizer state if loading fails
1148
+ optimizer.state = defaultdict(dict)
1149
+
1150
+
1151
+ # Determine starting optimizer step
1152
+ # The 'step_info' could be an int (optimizer step) or a string (e.g., "epochX_stepY", "final_stepZ")
1153
+ step_info_loaded = checkpoint.get('step_info', 0)
1154
+ start_optimizer_step = 0
1155
+ if isinstance(step_info_loaded, int):
1156
+ start_optimizer_step = step_info_loaded + 1 # Resume from next step
1157
+ elif isinstance(step_info_loaded, str):
1158
+ # Try to parse numeric step from string for continuation
1159
+ # e.g., "epoch2_step10000" -> 10000, "final_step20000" -> 20000
1160
+ parsed_step = self._parse_step_from_filename(step_info_loaded)
1161
+ start_optimizer_step = parsed_step + 1 if parsed_step > 0 else 0 # If parsing fails, might start from 0 or a previous point
1162
+ if parsed_step == 0 and "epoch" in step_info_loaded.lower(): # If it was an epoch save but couldn't parse step, log it
1163
+ logging.warning(f"Loaded epoch checkpoint '{step_info_loaded}' but could not parse specific optimizer step. Optimizer step count might reset or be inaccurate for LR scheduling if not careful.")
1164
+
1165
+ logging.info(f"Checkpoint loaded. Resuming from (or after) saved info '{step_info_loaded}'. Effective next optimizer_step: {start_optimizer_step}.")
1166
+ return start_optimizer_step
1167
+
1168
+ except FileNotFoundError:
1169
+ logging.info(f"No checkpoint directory or files at '{self.checkpoint_dir}'. Starting fresh.")
1170
+ return 0
1171
+ except Exception as e:
1172
+ logging.error(f"Error loading checkpoint: {e}. Starting fresh.", exc_info=True)
1173
+ return 0
1174
+
1175
+
1176
+ # --- Training Function ---
1177
+
1178
+ def train():
1179
+ logging.info("Starting HROM-MoE training process...")
1180
+ logging.info(f"Initial Configuration: {CONFIG}")
1181
+
1182
+ tokenizer_trainer = TokenizerTrainer()
1183
+ tokenizer_path = tokenizer_trainer.tokenizer_path
1184
+ if not os.path.exists(tokenizer_path):
1185
+ logging.info(f"Tokenizer '{CONFIG['tokenizer_name']}' not found at '{tokenizer_path}'. Training new tokenizer...")
1186
+ try:
1187
+ # Pass only unique dataset names to tokenizer trainer
1188
+ tokenizer_datasets = list(set(CONFIG["datasets"]))
1189
+ tokenizer_trainer.train(tokenizer_datasets)
1190
+ except Exception as e:
1191
+ logging.error(f"Critical error during tokenizer training: {e}", exc_info=True)
1192
+ return # Cannot proceed without a tokenizer
1193
+ else:
1194
+ logging.info(f"Loading existing tokenizer from {tokenizer_path}")
1195
+
1196
+ try:
1197
+ tokenizer = tokenizer_trainer.get_tokenizer()
1198
+ # Update global config with actual token IDs from the loaded tokenizer
1199
+ CONFIG['pad_token_id'] = tokenizer.token_to_id("<pad>")
1200
+ CONFIG['bos_token_id'] = tokenizer.token_to_id("<s>")
1201
+ CONFIG['eos_token_id'] = tokenizer.token_to_id("</s>")
1202
+ # Check if all critical tokens were loaded
1203
+ if None in [CONFIG['pad_token_id'], CONFIG['bos_token_id'], CONFIG['eos_token_id'],
1204
+ tokenizer.token_to_id("<user>"), tokenizer.token_to_id("<assistant>")]:
1205
+ raise ValueError("One or more critical special tokens (<pad>, <s>, </s>, <user>, <assistant>) are missing from the tokenizer after loading.")
1206
+ logging.info(f"Tokenizer loaded. Vocab size: {tokenizer.get_vocab_size()}. PAD ID: {CONFIG['pad_token_id']}, BOS ID: {CONFIG['bos_token_id']}, EOS ID: {CONFIG['eos_token_id']}")
1207
+ except (FileNotFoundError, ValueError) as e:
1208
+ logging.error(f"Critical error loading tokenizer: {e}. Cannot continue.", exc_info=True)
1209
+ return
1210
+
1211
+ logging.info("Initializing HROM-MoE model...")
1212
+ if CONFIG['vocab_size'] != tokenizer.get_vocab_size():
1213
+ logging.warning(
1214
+ f"CONFIG vocab_size ({CONFIG['vocab_size']}) differs from tokenizer vocab_size ({tokenizer.get_vocab_size()}). "
1215
+ f"Updating CONFIG vocab_size to match tokenizer: {tokenizer.get_vocab_size()}."
1216
+ )
1217
+ CONFIG['vocab_size'] = tokenizer.get_vocab_size()
1218
+ model = HROM()
1219
+
1220
+ total_params = sum(p.numel() for p in model.parameters())
1221
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1222
+ logging.info(f"HROM-MoE Model initialized. Total params: {total_params:,} ({total_params/1e6:.2f}M)")
1223
+ logging.info(f"Trainable params: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
1224
+
1225
+ logging.info("Setting up combined dataset and dataloader...")
1226
+ try:
1227
+ logging.info("Pre-checking specified datasets for availability...")
1228
+ for ds_name in CONFIG["datasets"]:
1229
+ logging.info(f"Attempting to quick-load '{ds_name}' to check cache/availability...")
1230
+ try:
1231
+ # Load a tiny slice to check if dataset is accessible and caches are populated
1232
+ _ = load_dataset(ds_name, split="train[:1%]", download_mode="reuse_cache_if_exists", trust_remote_code=True)
1233
+ logging.info(f"Successfully quick-loaded '{ds_name}'.")
1234
+ except Exception as e:
1235
+ logging.warning(f"Could not pre-check/quick-load dataset '{ds_name}': {e}. Full load will proceed but might take time or fail.")
1236
+
1237
+ dataset = CombinedChatDataset(tokenizer) # Pass tokenizer instance
1238
+ if len(dataset) == 0:
1239
+ logging.error("Dataset is empty after processing all specified sources. Cannot train.")
1240
+ return
1241
+
1242
+ # Determine num_workers carefully
1243
+ cpu_count = os.cpu_count()
1244
+ num_workers = 0 # Default to 0 for main process loading
1245
+ if cpu_count and cpu_count > 1:
1246
+ if torch.cuda.is_available(): # More workers if GPU is bottlenecked by CPU
1247
+ num_workers = min(4, cpu_count // 2)
1248
+ else: # Fewer if CPU is doing both compute and loading
1249
+ num_workers = min(2, cpu_count // 2)
1250
+ num_workers = max(0, num_workers) # Ensure non-negative
1251
+
1252
+ logging.info(f"Using num_workers: {num_workers} for DataLoader.")
1253
+
1254
+ dataloader = DataLoader(
1255
+ dataset,
1256
+ batch_size=CONFIG["batch_size"],
1257
+ collate_fn=CombinedChatDataset.collate_fn, # Static method, no instance needed
1258
+ shuffle=True,
1259
+ num_workers=num_workers,
1260
+ pin_memory=torch.cuda.is_available(), # Pin memory if using CUDA
1261
+ prefetch_factor=2 if num_workers > 0 else None, # Prefetch if using multiple workers
1262
+ drop_last=False # Process all data, even if last batch is smaller
1263
+ )
1264
+ except Exception as e:
1265
+ logging.error(f"Failed to initialize dataset/dataloader: {e}", exc_info=True)
1266
+ return
1267
+
1268
+ logging.info("Initializing Trainer, Checkpoint Manager, and Safety Manager...")
1269
+ trainer_obj = HROMTrainer(model, tokenizer) # Pass tokenizer instance
1270
+ checkpoint_manager = CheckpointManager()
1271
+ safety = SafetyManager(model, tokenizer) # Pass tokenizer instance
1272
+
1273
+ start_optimizer_step = checkpoint_manager.load_latest(model, trainer_obj.optimizer)
1274
+ model.to(trainer_obj.device) # Ensure model is on the correct device after loading state
1275
+
1276
+ logging.info(f"Starting/Resuming training from optimizer step {start_optimizer_step}")
1277
+ optimizer_step = start_optimizer_step
1278
+
1279
+ accum_main_loss_for_log = 0.0
1280
+ accum_aux_loss_for_log = 0.0
1281
+
1282
+ # Estimate current batch step and epoch based on loaded optimizer_step
1283
+ # This is an approximation if dataloader length varies or if not all epochs are full
1284
+ batches_per_epoch_est = len(dataloader) if len(dataloader) > 0 else 1 # Avoid division by zero
1285
+ current_total_batch_steps = optimizer_step * CONFIG["grad_accum_steps"]
1286
+ start_epoch = current_total_batch_steps // batches_per_epoch_est if batches_per_epoch_est > 0 else 0
1287
+
1288
+ try:
1289
+ if len(dataloader) == 0: raise ValueError("DataLoader has zero length, cannot train.")
1290
+ total_optimizer_steps_estimate = (len(dataloader) * CONFIG["num_epochs"]) // CONFIG["grad_accum_steps"]
1291
+ logging.info(f"Dataset size: {len(dataset)} samples, Batches per epoch: {len(dataloader)}")
1292
+ logging.info(f"Gradient Accumulation Steps: {CONFIG['grad_accum_steps']}, Effective Batch Size: {CONFIG['batch_size'] * CONFIG['grad_accum_steps']}")
1293
+ logging.info(f"Target Epochs: {CONFIG['num_epochs']}, Estimated Total Optimizer Steps: {total_optimizer_steps_estimate}")
1294
+ except Exception as e:
1295
+ logging.warning(f"Could not fully estimate training steps due to: {e}")
1296
+
1297
+ model.train() # Ensure model is in training mode
1298
+ for epoch in range(start_epoch, CONFIG["num_epochs"]):
1299
+ logging.info(f"--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} (Optimizer step: {optimizer_step}) ---")
1300
+ epoch_main_loss_sum = 0.0
1301
+ epoch_aux_loss_sum = 0.0
1302
+ epoch_batches_processed = 0 # Batches processed within this epoch execution
1303
+
1304
+ # Skip batches if resuming mid-epoch (simplified: we restart epoch if start_epoch > 0)
1305
+ # More precise mid-epoch resumption would require saving/loading dataloader state or batch index.
1306
+ # For now, if start_epoch > 0, previous epochs are considered complete.
1307
+
1308
+ for i, batch in enumerate(dataloader):
1309
+ if batch is None: # Should be handled by collate_fn returning None for entirely bad batches
1310
+ logging.warning(f"Skipping None batch at index {i} in epoch {epoch+1}. This might indicate data issues.")
1311
+ continue
1312
+ if not batch["input_ids"].numel(): # Check if batch tensors are empty
1313
+ logging.warning(f"Skipping batch with empty tensors at index {i} in epoch {epoch+1}.")
1314
+ continue
1315
+
1316
+ main_loss_val, aux_loss_val = trainer_obj.train_step(batch)
1317
+
1318
+ valid_loss = True
1319
+ if main_loss_val is None or math.isnan(main_loss_val) or math.isinf(main_loss_val):
1320
+ logging.error(f"NaN/Inf main loss detected: {main_loss_val}. Aux: {aux_loss_val}. Optimizer Step {optimizer_step}. Stopping training.")
1321
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"error_main_loss_nan_inf_step{optimizer_step}")
1322
+ valid_loss = False
1323
+ if aux_loss_val is None or math.isnan(aux_loss_val) or math.isinf(aux_loss_val):
1324
+ logging.warning(f"NaN/Inf auxiliary loss detected: {aux_loss_val}. Main: {main_loss_val}. Optimizer Step {optimizer_step}. This is problematic.")
1325
+ # If main_loss was also bad, we're already stopping.
1326
+ # If only aux_loss is bad, the total_loss will be NaN/Inf, potentially corrupting gradients.
1327
+ # Consider setting aux_loss_val = 0.0 if this becomes a frequent issue and main_loss is fine.
1328
+ if not valid_loss: return # Already stopping from main_loss error
1329
+
1330
+ if not valid_loss: return # Stop training if critical loss error
1331
+
1332
+ accum_main_loss_for_log += main_loss_val
1333
+ accum_aux_loss_for_log += aux_loss_val
1334
+ epoch_main_loss_sum += main_loss_val
1335
+ epoch_aux_loss_sum += aux_loss_val
1336
+ epoch_batches_processed += 1
1337
+ current_total_batch_steps += 1 # This tracks raw batches processed by train_step
1338
+
1339
+ if current_total_batch_steps % CONFIG["grad_accum_steps"] == 0:
1340
+ current_lr = trainer_obj.clip_and_step(optimizer_step) # Perform optimizer step
1341
+
1342
+ avg_main_loss_accum = accum_main_loss_for_log / CONFIG["grad_accum_steps"]
1343
+ avg_aux_loss_accum = accum_aux_loss_for_log / CONFIG["grad_accum_steps"]
1344
+ accum_main_loss_for_log = 0.0 # Reset accumulators
1345
+ accum_aux_loss_for_log = 0.0
1346
+
1347
+ if optimizer_step % CONFIG["debug_interval"] == 0:
1348
+ logging.info(
1349
+ f"E {epoch+1} | OptSt {optimizer_step} | TotalBatchSt {current_total_batch_steps} | "
1350
+ f"AvgMainL: {avg_main_loss_accum:.4f} | AvgAuxL: {avg_aux_loss_accum:.4f} | LR: {current_lr:.2e}"
1351
+ )
1352
+ # Perform debug generation less frequently to save time
1353
+ if optimizer_step > 0 and optimizer_step % (CONFIG["debug_interval"] * 10) == 0: # e.g., every 5 * 400 = 2000 opt steps
1354
+ safety.debug_generation("<user> Hi there! How are you doing today? </s>")
1355
+
1356
+ if optimizer_step > 0 and optimizer_step % CONFIG["checkpoint_interval"] == 0:
1357
+ logging.info(f"Checkpoint interval reached at optimizer step {optimizer_step}.")
1358
+ checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step) # Save with optimizer_step
1359
+ # Optionally run debug generation after checkpointing
1360
+ # safety.debug_generation("<user> What is the capital of France? </s>")
1361
+
1362
+ optimizer_step += 1 # Increment optimizer_step *after* an optimization step
1363
+
1364
+ avg_epoch_main_loss = epoch_main_loss_sum / epoch_batches_processed if epoch_batches_processed > 0 else 0
1365
+ avg_epoch_aux_loss = epoch_aux_loss_sum / epoch_batches_processed if epoch_batches_processed > 0 else 0
1366
+ logging.info(
1367
+ f"--- Finished Epoch {epoch+1}/{CONFIG['num_epochs']} --- "
1368
+ f"Avg Epoch MainL: {avg_epoch_main_loss:.4f} | Avg Epoch AuxL: {avg_epoch_aux_loss:.4f} | "
1369
+ f"Optimizer Steps this epoch: {optimizer_step - (start_optimizer_step if epoch == start_epoch else current_epoch_start_opt_step)} (approx)"
1370
+ )
1371
+ # Save checkpoint at the end of each epoch
1372
+ # Use a string that includes epoch and current optimizer step
1373
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"epoch{epoch+1}_step{optimizer_step}")
1374
+ safety.debug_generation("<user> That was an interesting epoch. What did you learn? </s>")
1375
+
1376
+ # For next epoch's calculation of "Optimizer Steps this epoch"
1377
+ current_epoch_start_opt_step = optimizer_step
1378
+
1379
+
1380
+ logging.info(f"Training finished after {CONFIG['num_epochs']} target epochs. Final optimizer step: {optimizer_step}.")
1381
+ logging.info("Saving final model state...")
1382
+ checkpoint_manager.save(model, trainer_obj.optimizer, f"final_step{optimizer_step}")
1383
+ safety.debug_generation("<user> The training is complete. How do you feel? </s>")
1384
+
1385
+
1386
+ if __name__ == "__main__":
1387
+ # For reproducibility, consider setting random seeds early
1388
+ # random.seed(42)
1389
+ # torch.manual_seed(42)
1390
+ # if torch.cuda.is_available():
1391
+ # torch.cuda.manual_seed_all(42)
1392
+ train()
checkpoints_moe/hrom_moe_stepfinal_step12108_20250604_201629.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d10a16df935a88f88010ff01eff76e078a51140bc4838b205c8342ce545f494
3
+ size 4445992437
tokenizer/hrom_moe_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff