from datetime import datetime import os from pyexpat import model from httpx import get os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments=True" import torch import torch.nn as nn import torch.optim as optim import random import tiktoken import ai_extras as AI_ex from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset, random_split import numpy as np import torch._dynamo import bitsandbytes as bnb import torch.utils.checkpoint as cp from safetensors.torch import save_model as sf_save_model torch.backends.cuda.enable_flash_sdp(True) torch._dynamo.config.suppress_errors = False # make errors visible torch._dynamo.config.verbose = True # print detailed compilation logs #torch._inductor.config.triton.cudagraphs = True torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True torch.set_float32_matmul_precision('medium') torch.backends.cuda.enable_mem_efficient_sdp = True torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp = True enc = tiktoken.get_encoding("gpt2") def save_model(model, filename): # 1. Create a temporary name for the Windows lock workaround temp_filename = filename + ".tmp" # 2. Use the specialized save_model function # This automatically handles shared tensors like embed/layerO1 sf_save_model(model, temp_filename) print("saving model...") # 3. The Atomic Swap (to beat the Windows 1224 error) try: if os.path.exists(filename): old_file = filename + ".old" if os.path.exists(old_file): os.remove(old_file) os.rename(filename, old_file) os.rename(temp_filename, filename) # Optional: print("Model saved successfully with weight-tying support.") except OSError as e: print(f"Windows I/O Lock: {e}. Checkpoint kept at {temp_filename}") def get_optimizer(self, lr): # Separate parameters into three distinct groups decay_params = [] no_decay_params = [] alpha_32bit_params = [] # Special group for high-precision scalars for name, param in self.named_parameters(): if not param.requires_grad: continue # 1. High-precision scalars (No decay + Force 32-bit) # We catch anything with 'alpha' or the memory gate here if any(x in name for x in ['alpha', 'mem_gate']): alpha_32bit_params.append(param) # 2. No decay for biases, LayerNorm/RMSNorm weights, and Embeddings elif 'bias' in name or 'norm' in name.lower() or 'embed' in name or 'layerO1' in name or 'engram' in name: no_decay_params.append(param) # 3. Standard weights with decay else: decay_params.append(param) optim_groups = [ # Standard weights (8-bit) {'params': decay_params, 'weight_decay': 0.04}, # Biases/Norms (8-bit) {'params': no_decay_params, 'weight_decay': 0.0}, # Alphas/Scalars (Forced 32-bit) {'params': alpha_32bit_params, 'weight_decay': 0.0, 'optim_bits': 32} ] # Initialize 8-bit AdamW - the optim_bits:32 in the group will override this for alphas optimizer = bnb.optim.AdamW8bit( optim_groups, lr=lr, betas=(0.9, 0.95) ) return optimizer class biggerbrain(nn.Module): def __init__(self, device, sequence_length=640): super(biggerbrain, self).__init__() self.dim1 = 768 self.ffndim = int(self.dim1 * 2.7) self.T_heads = 8 self.P_heads = 12 self.kv_heads = 4 self.sequencelength = sequence_length self.device = device self.local_seq_len = 256 self.MLA_dim = 384 self.debugprints = False self.rope = AI_ex.RoPE(self.dim1 // self.T_heads, max_seq_len=self.sequencelength + 20) # a bit of extra rope for dependiencies that exceed training context self.embed = nn.Embedding(enc.max_token_value + 1, self.dim1) self.Engram = AI_ex.engram(self.dim1, self.kv_heads, 3, memory_size=12192, bottleneck=self.MLA_dim) self.layerO1 = nn.Linear(self.dim1, enc.max_token_value + 1, bias=False) self.mem_gate = nn.Linear(self.dim1, 1, bias=True) nn.init.zeros_(self.mem_gate[-1].weight if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.weight) nn.init.constant_(self.mem_gate[-1].bias if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.bias, -1.0) #block weighting parameters self.alpha_pre = nn.Parameter(torch.tensor(0.5)) self.alpha_loop = nn.Parameter(torch.tensor(0.5)) self.alpha_post = nn.Parameter(torch.tensor(0.5)) self.alpha_mem = nn.Parameter(torch.tensor(0.5)) #Pre block1 self.layerPreA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim) self.preffn1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=3, max_iter=1) self.normPre1 = AI_ex.RMSNorm(self.dim1) #Pre block2 self.layerPreA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim) self.preffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))#Half length local attention. self.normPre2 = AI_ex.RMSNorm(self.dim1) #Post block1 self.layerPostA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim) self.postffn1 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1)) self.normPost1 = AI_ex.RMSNorm(self.dim1) #Post block2 self.layerPostA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim) self.postffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1)) self.normPost2 = AI_ex.RMSNorm(self.dim1) #Middle attention self.layerA1 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim) self.layerA2 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)#Half length local attention. self.layerA3 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim) self.layerA4 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)#Half length local attention. self.layerA5 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim) self.layerMi1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2) self.layerMi2 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)#Half length local attention. self.layerMi3 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2) self.layerMi4 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)#Half length local attention. self.layerMi5 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2) self.norm0 = AI_ex.RMSNorm(self.dim1) self.norm1 = AI_ex.RMSNorm(self.dim1) self.norm2 = AI_ex.RMSNorm(self.dim1) self.norm3 = AI_ex.RMSNorm(self.dim1) self.normM1 = AI_ex.RMSNorm(self.dim1) # Layer normalization for stabilizing. self.normM2 = AI_ex.RMSNorm(self.dim1) self.normM3 = AI_ex.RMSNorm(self.dim1) self.normM4 = AI_ex.RMSNorm(self.dim1) self.normM5 = AI_ex.RMSNorm(self.dim1) self.scratchpad_gate = nn.Linear(self.dim1, 1, bias=True) self.Mnorm = AI_ex.RMSNorm(self.dim1) self.M1 = AI_ex.custom_mem(self.dim1, self.T_heads) # for memory integration self.MM2 = AI_ex.GatedResidual(self.dim1) # for linguistic anchor pull nn.init.zeros_(self.scratchpad_gate.weight) nn.init.constant_(self.scratchpad_gate.bias, -1.0) self.layerO1.weight = self.embed.weight # tied weights self.apply(self._init_weights) def pick_word(self, output, k=50, temperature=0.8, prev_tokens=None, rep_penalty=1.1): logits = output / (temperature + 1e-8) # Permanently blacklist chat-format artifact tokens # These have high probability from training data contamination # but should never appear in normal prose output BAD_TOKENS = [ #25, # ":" single colon #3712, # "::" double colon #1058, # ":" alternate encoding ] for tok in BAD_TOKENS: if tok < logits.size(-1): logits[0, tok] = float('-inf') # Repetition penalty if prev_tokens is not None: # This list comprehension ensures everything is a flat integer # even if it was passed in as [tensor(1), tensor(2)] or [[1], [2]] flat_tokens = [] for t in prev_tokens[-64:]: if isinstance(t, torch.Tensor): flat_tokens.append(t.item()) elif isinstance(t, list): flat_tokens.extend(t) else: flat_tokens.append(t) for tok in set(flat_tokens): if tok < logits.size(-1): if logits[0, tok] > 0: logits[0, tok] /= rep_penalty else: logits[0, tok] *= rep_penalty # Top-K filtering v, _ = torch.topk(logits, min(k, logits.size(-1))) # Use .view() to ensure the shapes match perfectly logits[logits < v[..., -1].view(-1, 1)] = float('-inf') # Sample probabilities = torch.softmax(logits, dim=-1) word_id = torch.multinomial(probabilities, num_samples=1) return word_id def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): # Standard GPT-2 initialization module.weight.data.normal_(mean=0.0, std=0.01) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def trainingloop(self, data, epochs=50, lr=3e-4, batchsize=32, accumulation_steps=4, subset_fraction=1.0, warmup_steps=2000): try: self.train() batchloss = 1000 # Build full index list once dataset_size = len(data) all_indices = np.arange(dataset_size) np.random.shuffle(all_indices) # shuffle once upfront #self.to(dtype=torch.bfloat16) # How many samples per epoch subset_size = int(dataset_size * subset_fraction) ITER_OPTIONS = [1, 2, 3] ITER_WEIGHTS = [0.75, 0.15, 0.1] chunks = [ all_indices[i * subset_size : (i + 1) * subset_size] for i in range(int(1.0 / subset_fraction)) ] optimizer = get_optimizer(self, lr) criterion = nn.CrossEntropyLoss(ignore_index=-1) warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, 0.01, end_factor=1.0, total_iters=warmup_steps ) decay_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1.0, end_factor=0.001 ) # The Unified Scheduler scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] ) best_loss = 1000.0 #self.forward_training = torch.compile(self.forward_training, dynamic=True, backend ='inductor')#, options=['shape_padding':True] model#self if os.path.exists("checkpoint_full.pth"): checkpoint = torch.load("checkpoint_full.pth", weights_only=False, map_location='cpu') self.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) start_epoch = checkpoint['epoch'] #all_indices = checkpoint['all_indices'] batchloss = checkpoint['batchloss'] print(f"Loaded checkpoint for training. {scheduler.get_last_lr()}") print(f"Restored batchloss: {batchloss:.4f}") for epoch in range(epochs): epoch_loss = 0.0 batches_run = 0 optimizer.zero_grad(set_to_none=True) chunk_idx = epoch % len(chunks) sampler = SubsetRandomSampler(chunks[chunk_idx]) loader = DataLoader( data, batch_size=batchsize, sampler=sampler, num_workers=0, pin_memory=True, drop_last=True ) if epoch % len(chunks) == 0 and epoch > 0: np.random.shuffle(all_indices) chunks = [ all_indices[i * subset_size : (i + 1) * subset_size] for i in range(int(1.0 / subset_fraction)) ] print(f"Reshuffled data chunks for next cycle") for i, (batch_inputs, batch_targets) in enumerate(loader): if i == 0: print(f"Epoch {epoch} | Chunk {chunk_idx+1}/{len(chunks)} | " f"Starting training...") if (i + 1) % (accumulation_steps * 1000) == 0: std = self.embed.weight.std().item() print(f" Embed std: {std:.4f} (healthy = ~0.02, bad = <0.005 OR > 0.05)") if std < 0.005 or std > 0.05: print(" WARNING: embedding problems detected!") if std < 0.005: print("Embeddings are collapsing. Warning! Attempt to pause.") if input("Press Enter to continue...").lower() == "stop": print("Stopping training loop. Saving model...") save_model(self, "stopped_model.safetensors") return elif std > 0.05: print("Embeddings are too large. Warning! Attempt to pause.") if input("Press Enter to continue...").lower() == "stop": print("Stopping training loop. Saving model...") save_model(self, "stopped_model.safetensors") return batch_inputs = batch_inputs.to(self.device, non_blocking=True) batch_targets = batch_targets.to(self.device, non_blocking=True) with torch.amp.autocast('cuda', dtype=torch.bfloat16): num_iter = random.choices(ITER_OPTIONS, weights=ITER_WEIGHTS, k=1)[0] logits = self.forward_training(batch_inputs, iters =num_iter) lm_loss = criterion( logits.view(-1, enc.max_token_value + 1), batch_targets.reshape(-1) ) lm_loss = lm_loss / accumulation_steps lm_loss.backward() if (i + 1) % accumulation_steps == 0: #print(f"Alpha_pre Grad: {self.alpha_pre.grad}") torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0, foreach=True) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) if (i + 1) % (accumulation_steps * 4) == 0: if batchloss > lm_loss.detach().item(): batchloss = lm_loss.detach().item() print(f"New best loss: {batchloss * accumulation_steps:.4f}. Saving model...") save_model(self, f"best_model.safetensors") if (i + 1) % (accumulation_steps * 100) == 0: print(f"Epoch {epoch} | Batch {i + 1 // accumulation_steps} | loss={lm_loss.item() * accumulation_steps:.4f}") torch.save({ 'model': self.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch, 'chunk_idx': chunk_idx, #'all_indices': all_indices, # save the shuffle order 'batchloss': batchloss, }, "checkpoint_full.pth") print(f"Saved state.") if (i + 1) % (accumulation_steps * 50000) == 0: save_model(self, "model.safetensors") avg_loss = (epoch_loss / batches_run) timestamp = datetime.now().strftime("%H:%M:%S") print(f"[{timestamp}] Epoch {epoch:3d} | " f"Loss: {avg_loss:.4f} | " f"Chunk: {chunk_idx+1}/{len(chunks)}") except KeyboardInterrupt: print("\n[!] Manual Interrupt detected. Saving safety checkpoint...") save_model(model, "emergency_checkpoint.safetensors") print("Done. Safe to exit.") def forward_training(self, input_ids, iters=3): current_input_ids = input_ids[:, -self.sequencelength:]#:, curr_seq_len = current_input_ids.size(1)# input_ids = current_input_ids w = current_input_ids.long() w = self.embed(w) #Pre block 1 y = self.norm0(w) o = y attn_out = self.layerPreA1(y, rope=self.rope)#, _ ffn_out = self.preffn1(attn_out, w, w, 1) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) # sqrt(2) y = o + (self.alpha_pre * combined) #Pre block 2 y = self.norm1(y) o = y attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)#, _ ffn_out = self.preffn2(attn_out) combined = attn_out + ffn_out #combined = attn_out * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined) #Engram Block mem = self.Engram(y, input_ids)#ENgram for mem gate = torch.sigmoid(self.mem_gate(y)) y = y + (gate * mem) # make it decide when to use Engram. z = y linguistic_anchor = y for j in range(iters): y_prev = y#.detach() y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y y = self.Mnorm(y) iter_tensor = torch.tensor(j, device=self.device) # ---- BLOCK 1 ---- y = self.normM1(y) o = y attn_out = self.layerA1(y, rope=self.rope) moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor) combined = moe_out + attn_out #combined = combined * (1 / 1.41421356237) y = o + (self.alpha_loop * combined) # ---- BLOCK 2 ---- y = self.normM2(y) o = y g = y[:, -self.local_seq_len:] attn_out = self.layerA2(g, rope=self.rope) moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined) # ---- BLOCK 3 ---- y = self.normM3(y) o = y attn_out = self.layerA3(y, rope=self.rope) moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o + (self.alpha_loop * combined) # ---- BLOCK 4 ---- y = self.normM4(y) o = y g = y[:, -self.local_seq_len:] attn_out = self.layerA4(g, rope=self.rope) moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined) # ---- BLOCK 5 ---- y = self.normM5(y) o = y attn_out = self.layerA5(y, rope=self.rope) moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor) combined = attn_out + moe_out y = o + (self.alpha_loop * combined) z = y y = ((self.scratchpad_gate(y)) * y) + y_prev # gated scratchpad residual #post block 1 z = self.norm2(z) o = z attn_out = self.layerPostA1(z, rope=self.rope) ffn_out = self.postffn1(attn_out) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) z = o + (self.alpha_post * combined) #post block 2 z = self.norm3(z) o = z attn_out = self.layerPostA2(z, rope=self.rope) ffn_out = self.postffn2(attn_out) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) z = o + (self.alpha_post * combined) #weight tied output layer with embedding logits = self.layerO1(z) #if self.debugprints: #print(f"Logit Max: {logits.max().item()}, Logit Min: {logits.min().item()}") return logits def forward(self, input_ids, max_tokens ,iters=3, top_k=10, temperature=1.0, rep_penalty=1.5): generated_tokens = [] for _ in range(max_tokens): current_input_ids = input_ids[:, -self.sequencelength:]#:, curr_seq_len = current_input_ids.size(1)# input_ids = current_input_ids w = current_input_ids.long().to(self.device) w = self.embed(w) #Pre block 1 y = self.norm0(w) o = y attn_out = self.layerPreA1(y, rope=self.rope)#, _ ffn_out = self.preffn1(attn_out, w, w, 1) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) # sqrt(2) y = o + (self.alpha_pre * combined) #Pre block 2 y = self.norm1(y) o = y attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)#, _ ffn_out = self.preffn2(attn_out) combined = attn_out + ffn_out #combined = attn_out * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined) #Engram Block mem = self.Engram(y, input_ids)#ENgram for mem gate = torch.sigmoid(self.mem_gate(y)) y = y + (gate * mem) # make it decide when to use Engram. z = y linguistic_anchor = y for j in range(iters): y_prev = y#.detach() y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y y = self.Mnorm(y) iter_tensor = torch.tensor(j, device=self.device) # ---- BLOCK 1 ---- y = self.normM1(y) o = y attn_out = self.layerA1(y, rope=self.rope) moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor) combined = moe_out + attn_out #combined = combined * (1 / 1.41421356237) y = o + (self.alpha_loop * combined) # ---- BLOCK 2 ---- y = self.normM2(y) o = y g = y[:, -self.local_seq_len:] attn_out = self.layerA2(g, rope=self.rope) moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined) # ---- BLOCK 3 ---- y = self.normM3(y) o = y attn_out = self.layerA3(y, rope=self.rope) moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o + (self.alpha_loop * combined) # ---- BLOCK 4 ---- y = self.normM4(y) o = y g = y[:, -self.local_seq_len:] attn_out = self.layerA4(g, rope=self.rope) moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor) combined = attn_out + moe_out #combined = combined * (1 / 1.41421356237) y = o.clone() # Keep the first half as is y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined) # ---- BLOCK 5 ---- y = self.normM5(y) o = y attn_out = self.layerA5(y, rope=self.rope) moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor) combined = attn_out + moe_out y = o + (self.alpha_loop * combined) z = y y = ((self.scratchpad_gate(y)) * y) + y_prev # gated scratchpad residual #post block 1 z = self.norm2(z) o = z attn_out = self.layerPostA1(z, rope=self.rope) ffn_out = self.postffn1(attn_out) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) z = o + (self.alpha_post * combined) #post block 2 z = self.norm3(z) o = z attn_out = self.layerPostA2(z, rope=self.rope) ffn_out = self.postffn2(attn_out) combined = attn_out + ffn_out #combined = combined * (1 / 1.41421356237) z = o + (self.alpha_post * combined) #weight tied output layer with embedding logits = self.layerO1(z) token1 = self.pick_word(logits[:, -1, :], k=top_k, temperature=temperature, prev_tokens=generated_tokens, rep_penalty=rep_penalty) token1 = token1 generated_tokens.append(token1)#.item() input_ids = torch.cat([input_ids, token1], dim=1) if token1.item() == enc.eot_token: break #return generated_tokens if enc.decode([token1.item()]) == "<|endoftext|>": break #return generated_tokens return generated_tokens #Setup model ##This returns the compiled model, moved to BF16 and to the specified device. def initmodel(device): model = biggerbrain(device).to(device).to(torch.bfloat16) if device != 'cpu': model = torch.compile(model, dynamic=True) return model def think(prompt, model, max_length=100, iter=3, top_k=10, temperature=1.0, raw=False, rep_penalty=1.5): model = model._orig_mod if hasattr(model, '_orig_mod') else model if raw: formatted = prompt else: formatted = f"User: \n\n{prompt}\n\nAssistant:" input_ids = torch.tensor([enc.encode(formatted, allowed_special={'<|endoftext|>'})]).to(model.device)#, device=model.device model.eval() with torch.no_grad(): generated_tokens = model.forward( input_ids, max_tokens=max_length, iters=iter, top_k=top_k, temperature=temperature, rep_penalty=rep_penalty ) #print("Output:", enc.decode(generated_tokens)) return enc.decode(generated_tokens) def print_parameter_breakdown(model): # Use _orig_mod if the model is compiled base_model = model._orig_mod if hasattr(model, '_orig_mod') else model total_params = 0 print("--- Parameter Breakdown ---") for name, param in base_model.named_parameters(): if not param.requires_grad: continue params = param.numel() total_params += params # Only print layers that have a suspiciously high count (e.g., > 1 million) if params > 1_000_000: print(f"MASSIVE LAYER DETECTED -> {name}: {params:,} parameters") else: print(f"{name}: {params:,} parameters") print(f"---------------------------\nTrue Total: {total_params:,}")