| from datetime import datetime
|
| import os
|
| from pyexpat import model
|
| from httpx import get
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| import random
|
| import tiktoken
|
| import ai_extras as AI_ex
|
| from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset, random_split
|
| import numpy as np
|
| import torch._dynamo
|
| import bitsandbytes as bnb
|
| import torch.utils.checkpoint as cp
|
| from safetensors.torch import save_model as sf_save_model
|
|
|
| torch.backends.cuda.enable_flash_sdp(True)
|
|
|
| torch._dynamo.config.suppress_errors = False
|
| torch._dynamo.config.verbose = True
|
|
|
|
|
|
|
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
| torch.set_float32_matmul_precision('medium')
|
| torch.backends.cuda.enable_mem_efficient_sdp = True
|
| torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp = True
|
|
|
| enc = tiktoken.get_encoding("gpt2")
|
|
|
| def save_model(model, filename):
|
|
|
| temp_filename = filename + ".tmp"
|
|
|
|
|
|
|
| sf_save_model(model, temp_filename)
|
| print("saving model...")
|
|
|
| try:
|
| if os.path.exists(filename):
|
| old_file = filename + ".old"
|
| if os.path.exists(old_file):
|
| os.remove(old_file)
|
| os.rename(filename, old_file)
|
|
|
| os.rename(temp_filename, filename)
|
|
|
| except OSError as e:
|
| print(f"Windows I/O Lock: {e}. Checkpoint kept at {temp_filename}")
|
|
|
| def get_optimizer(self, lr):
|
|
|
| decay_params = []
|
| no_decay_params = []
|
| alpha_32bit_params = []
|
|
|
| for name, param in self.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
|
|
|
|
|
|
| if any(x in name for x in ['alpha', 'mem_gate']):
|
| alpha_32bit_params.append(param)
|
|
|
|
|
| elif 'bias' in name or 'norm' in name.lower() or 'embed' in name or 'layerO1' in name or 'engram' in name:
|
| no_decay_params.append(param)
|
|
|
|
|
| else:
|
| decay_params.append(param)
|
|
|
| optim_groups = [
|
|
|
| {'params': decay_params, 'weight_decay': 0.04},
|
|
|
| {'params': no_decay_params, 'weight_decay': 0.0},
|
|
|
| {'params': alpha_32bit_params, 'weight_decay': 0.0, 'optim_bits': 32}
|
| ]
|
|
|
|
|
| optimizer = bnb.optim.AdamW8bit(
|
| optim_groups,
|
| lr=lr,
|
| betas=(0.9, 0.95)
|
| )
|
| return optimizer
|
|
|
| class biggerbrain(nn.Module):
|
| def __init__(self, device, sequence_length=640):
|
| super(biggerbrain, self).__init__()
|
| self.dim1 = 768
|
| self.ffndim = int(self.dim1 * 2.7)
|
| self.T_heads = 8
|
| self.P_heads = 12
|
| self.kv_heads = 4
|
| self.sequencelength = sequence_length
|
| self.device = device
|
| self.local_seq_len = 256
|
| self.MLA_dim = 384
|
| self.debugprints = False
|
|
|
| self.rope = AI_ex.RoPE(self.dim1 // self.T_heads, max_seq_len=self.sequencelength + 20)
|
|
|
| self.embed = nn.Embedding(enc.max_token_value + 1, self.dim1)
|
| self.Engram = AI_ex.engram(self.dim1, self.kv_heads, 3, memory_size=12192, bottleneck=self.MLA_dim)
|
|
|
| self.layerO1 = nn.Linear(self.dim1, enc.max_token_value + 1, bias=False)
|
|
|
| self.mem_gate = nn.Linear(self.dim1, 1, bias=True)
|
| nn.init.zeros_(self.mem_gate[-1].weight if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.weight)
|
| nn.init.constant_(self.mem_gate[-1].bias if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.bias, -1.0)
|
|
|
|
|
| self.alpha_pre = nn.Parameter(torch.tensor(0.5))
|
| self.alpha_loop = nn.Parameter(torch.tensor(0.5))
|
| self.alpha_post = nn.Parameter(torch.tensor(0.5))
|
| self.alpha_mem = nn.Parameter(torch.tensor(0.5))
|
|
|
|
|
| self.layerPreA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.preffn1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=3, max_iter=1)
|
| self.normPre1 = AI_ex.RMSNorm(self.dim1)
|
|
|
|
|
| self.layerPreA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.preffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))
|
| self.normPre2 = AI_ex.RMSNorm(self.dim1)
|
|
|
|
|
| self.layerPostA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.postffn1 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))
|
| self.normPost1 = AI_ex.RMSNorm(self.dim1)
|
|
|
|
|
| self.layerPostA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.postffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))
|
| self.normPost2 = AI_ex.RMSNorm(self.dim1)
|
|
|
|
|
| self.layerA1 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.layerA2 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.layerA3 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.layerA4 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
| self.layerA5 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
|
|
|
| self.layerMi1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
|
| self.layerMi2 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
|
| self.layerMi3 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
|
| self.layerMi4 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
|
| self.layerMi5 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
|
|
|
| self.norm0 = AI_ex.RMSNorm(self.dim1)
|
| self.norm1 = AI_ex.RMSNorm(self.dim1)
|
| self.norm2 = AI_ex.RMSNorm(self.dim1)
|
| self.norm3 = AI_ex.RMSNorm(self.dim1)
|
|
|
| self.normM1 = AI_ex.RMSNorm(self.dim1)
|
| self.normM2 = AI_ex.RMSNorm(self.dim1)
|
| self.normM3 = AI_ex.RMSNorm(self.dim1)
|
| self.normM4 = AI_ex.RMSNorm(self.dim1)
|
| self.normM5 = AI_ex.RMSNorm(self.dim1)
|
|
|
| self.scratchpad_gate = nn.Linear(self.dim1, 1, bias=True)
|
|
|
| self.Mnorm = AI_ex.RMSNorm(self.dim1)
|
| self.M1 = AI_ex.custom_mem(self.dim1, self.T_heads)
|
| self.MM2 = AI_ex.GatedResidual(self.dim1)
|
|
|
| nn.init.zeros_(self.scratchpad_gate.weight)
|
| nn.init.constant_(self.scratchpad_gate.bias, -1.0)
|
|
|
| self.layerO1.weight = self.embed.weight
|
| self.apply(self._init_weights)
|
|
|
| def pick_word(self, output, k=50, temperature=0.8,
|
| prev_tokens=None, rep_penalty=1.1):
|
| logits = output / (temperature + 1e-8)
|
|
|
|
|
|
|
|
|
| BAD_TOKENS = [
|
|
|
|
|
|
|
| ]
|
| for tok in BAD_TOKENS:
|
| if tok < logits.size(-1):
|
| logits[0, tok] = float('-inf')
|
|
|
|
|
| if prev_tokens is not None:
|
|
|
|
|
| flat_tokens = []
|
| for t in prev_tokens[-64:]:
|
| if isinstance(t, torch.Tensor):
|
| flat_tokens.append(t.item())
|
| elif isinstance(t, list):
|
| flat_tokens.extend(t)
|
| else:
|
| flat_tokens.append(t)
|
|
|
| for tok in set(flat_tokens):
|
| if tok < logits.size(-1):
|
| if logits[0, tok] > 0:
|
| logits[0, tok] /= rep_penalty
|
| else:
|
| logits[0, tok] *= rep_penalty
|
|
|
|
|
| v, _ = torch.topk(logits, min(k, logits.size(-1)))
|
|
|
| logits[logits < v[..., -1].view(-1, 1)] = float('-inf')
|
|
|
|
|
| probabilities = torch.softmax(logits, dim=-1)
|
| word_id = torch.multinomial(probabilities, num_samples=1)
|
|
|
| return word_id
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
| module.weight.data.normal_(mean=0.0, std=0.01)
|
| if isinstance(module, nn.Linear) and module.bias is not None:
|
| module.bias.data.zero_()
|
|
|
| def trainingloop(self, data, epochs=50, lr=3e-4, batchsize=32,
|
| accumulation_steps=4,
|
| subset_fraction=1.0, warmup_steps=2000):
|
| try:
|
| self.train()
|
| batchloss = 1000
|
|
|
| dataset_size = len(data)
|
| all_indices = np.arange(dataset_size)
|
| np.random.shuffle(all_indices)
|
|
|
|
|
| subset_size = int(dataset_size * subset_fraction)
|
| ITER_OPTIONS = [1, 2, 3]
|
| ITER_WEIGHTS = [0.75, 0.15, 0.1]
|
| chunks = [
|
| all_indices[i * subset_size : (i + 1) * subset_size]
|
| for i in range(int(1.0 / subset_fraction))
|
| ]
|
|
|
| optimizer = get_optimizer(self, lr)
|
|
|
| criterion = nn.CrossEntropyLoss(ignore_index=-1)
|
| warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| optimizer, 0.01, end_factor=1.0, total_iters=warmup_steps
|
| )
|
|
|
| decay_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| optimizer, start_factor=1.0, end_factor=0.001
|
| )
|
|
|
|
|
| scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| optimizer,
|
| schedulers=[warmup_scheduler, decay_scheduler],
|
| milestones=[warmup_steps]
|
| )
|
|
|
| best_loss = 1000.0
|
|
|
|
|
| if os.path.exists("checkpoint_full.pth"):
|
| checkpoint = torch.load("checkpoint_full.pth", weights_only=False, map_location='cpu')
|
| self.load_state_dict(checkpoint['model'])
|
| optimizer.load_state_dict(checkpoint['optimizer'])
|
| scheduler.load_state_dict(checkpoint['scheduler'])
|
| start_epoch = checkpoint['epoch']
|
|
|
| batchloss = checkpoint['batchloss']
|
| print(f"Loaded checkpoint for training. {scheduler.get_last_lr()}")
|
| print(f"Restored batchloss: {batchloss:.4f}")
|
|
|
| for epoch in range(epochs):
|
| epoch_loss = 0.0
|
| batches_run = 0
|
| optimizer.zero_grad(set_to_none=True)
|
|
|
| chunk_idx = epoch % len(chunks)
|
| sampler = SubsetRandomSampler(chunks[chunk_idx])
|
|
|
| loader = DataLoader(
|
| data,
|
| batch_size=batchsize,
|
| sampler=sampler,
|
| num_workers=0,
|
| pin_memory=True,
|
| drop_last=True
|
| )
|
| if epoch % len(chunks) == 0 and epoch > 0:
|
| np.random.shuffle(all_indices)
|
| chunks = [
|
| all_indices[i * subset_size : (i + 1) * subset_size]
|
| for i in range(int(1.0 / subset_fraction))
|
| ]
|
| print(f"Reshuffled data chunks for next cycle")
|
|
|
| for i, (batch_inputs, batch_targets) in enumerate(loader):
|
| if i == 0:
|
| print(f"Epoch {epoch} | Chunk {chunk_idx+1}/{len(chunks)} | "
|
| f"Starting training...")
|
| if (i + 1) % (accumulation_steps * 1000) == 0:
|
| std = self.embed.weight.std().item()
|
| print(f" Embed std: {std:.4f} (healthy = ~0.02, bad = <0.005 OR > 0.05)")
|
| if std < 0.005 or std > 0.05:
|
| print(" WARNING: embedding problems detected!")
|
| if std < 0.005:
|
| print("Embeddings are collapsing. Warning! Attempt to pause.")
|
| if input("Press Enter to continue...").lower() == "stop":
|
| print("Stopping training loop. Saving model...")
|
| save_model(self, "stopped_model.safetensors")
|
| return
|
| elif std > 0.05:
|
| print("Embeddings are too large. Warning! Attempt to pause.")
|
| if input("Press Enter to continue...").lower() == "stop":
|
| print("Stopping training loop. Saving model...")
|
| save_model(self, "stopped_model.safetensors")
|
| return
|
|
|
| batch_inputs = batch_inputs.to(self.device, non_blocking=True)
|
| batch_targets = batch_targets.to(self.device, non_blocking=True)
|
|
|
| with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
| num_iter = random.choices(ITER_OPTIONS, weights=ITER_WEIGHTS, k=1)[0]
|
| logits = self.forward_training(batch_inputs, iters =num_iter)
|
| lm_loss = criterion(
|
| logits.view(-1, enc.max_token_value + 1),
|
| batch_targets.reshape(-1)
|
| )
|
| lm_loss = lm_loss / accumulation_steps
|
| lm_loss.backward()
|
|
|
| if (i + 1) % accumulation_steps == 0:
|
|
|
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0, foreach=True)
|
| optimizer.step()
|
| scheduler.step()
|
| optimizer.zero_grad(set_to_none=True)
|
| if (i + 1) % (accumulation_steps * 4) == 0:
|
| if batchloss > lm_loss.detach().item():
|
| batchloss = lm_loss.detach().item()
|
| print(f"New best loss: {batchloss * accumulation_steps:.4f}. Saving model...")
|
| save_model(self, f"best_model.safetensors")
|
|
|
| if (i + 1) % (accumulation_steps * 100) == 0:
|
| print(f"Epoch {epoch} | Batch {i + 1 // accumulation_steps} | loss={lm_loss.item() * accumulation_steps:.4f}")
|
| torch.save({
|
| 'model': self.state_dict(),
|
| 'optimizer': optimizer.state_dict(),
|
| 'scheduler': scheduler.state_dict(),
|
| 'epoch': epoch,
|
| 'chunk_idx': chunk_idx,
|
|
|
| 'batchloss': batchloss,
|
| }, "checkpoint_full.pth")
|
| print(f"Saved state.")
|
|
|
| if (i + 1) % (accumulation_steps * 50000) == 0:
|
| save_model(self, "model.safetensors")
|
|
|
| avg_loss = (epoch_loss / batches_run)
|
| timestamp = datetime.now().strftime("%H:%M:%S")
|
| print(f"[{timestamp}] Epoch {epoch:3d} | "
|
| f"Loss: {avg_loss:.4f} | "
|
| f"Chunk: {chunk_idx+1}/{len(chunks)}")
|
| except KeyboardInterrupt:
|
| print("\n[!] Manual Interrupt detected. Saving safety checkpoint...")
|
| save_model(model, "emergency_checkpoint.safetensors")
|
| print("Done. Safe to exit.")
|
|
|
|
|
| def forward_training(self, input_ids, iters=3):
|
|
|
| current_input_ids = input_ids[:, -self.sequencelength:]
|
| curr_seq_len = current_input_ids.size(1)
|
| input_ids = current_input_ids
|
|
|
| w = current_input_ids.long()
|
| w = self.embed(w)
|
|
|
|
|
| y = self.norm0(w)
|
| o = y
|
| attn_out = self.layerPreA1(y, rope=self.rope)
|
| ffn_out = self.preffn1(attn_out, w, w, 1)
|
| combined = attn_out + ffn_out
|
|
|
| y = o + (self.alpha_pre * combined)
|
|
|
|
|
| y = self.norm1(y)
|
| o = y
|
| attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)
|
| ffn_out = self.preffn2(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined)
|
|
|
|
|
| mem = self.Engram(y, input_ids)
|
| gate = torch.sigmoid(self.mem_gate(y))
|
| y = y + (gate * mem)
|
| z = y
|
| linguistic_anchor = y
|
|
|
| for j in range(iters):
|
| y_prev = y
|
| y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y
|
| y = self.Mnorm(y)
|
|
|
| iter_tensor = torch.tensor(j, device=self.device)
|
|
|
| y = self.normM1(y)
|
| o = y
|
| attn_out = self.layerA1(y, rope=self.rope)
|
| moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor)
|
| combined = moe_out + attn_out
|
|
|
| y = o + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM2(y)
|
| o = y
|
| g = y[:, -self.local_seq_len:]
|
| attn_out = self.layerA2(g, rope=self.rope)
|
| moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM3(y)
|
| o = y
|
| attn_out = self.layerA3(y, rope=self.rope)
|
| moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM4(y)
|
| o = y
|
| g = y[:, -self.local_seq_len:]
|
| attn_out = self.layerA4(g, rope=self.rope)
|
| moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM5(y)
|
| o = y
|
| attn_out = self.layerA5(y, rope=self.rope)
|
| moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor)
|
| combined = attn_out + moe_out
|
| y = o + (self.alpha_loop * combined)
|
|
|
| z = y
|
|
|
| y = ((self.scratchpad_gate(y)) * y) + y_prev
|
|
|
|
|
| z = self.norm2(z)
|
| o = z
|
| attn_out = self.layerPostA1(z, rope=self.rope)
|
| ffn_out = self.postffn1(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| z = o + (self.alpha_post * combined)
|
|
|
|
|
| z = self.norm3(z)
|
| o = z
|
| attn_out = self.layerPostA2(z, rope=self.rope)
|
| ffn_out = self.postffn2(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| z = o + (self.alpha_post * combined)
|
|
|
|
|
| logits = self.layerO1(z)
|
|
|
|
|
| return logits
|
|
|
| def forward(self, input_ids, max_tokens ,iters=3, top_k=10, temperature=1.0, rep_penalty=1.5):
|
| generated_tokens = []
|
|
|
|
|
| for _ in range(max_tokens):
|
| current_input_ids = input_ids[:, -self.sequencelength:]
|
| curr_seq_len = current_input_ids.size(1)
|
| input_ids = current_input_ids
|
|
|
| w = current_input_ids.long().to(self.device)
|
| w = self.embed(w)
|
|
|
|
|
| y = self.norm0(w)
|
| o = y
|
| attn_out = self.layerPreA1(y, rope=self.rope)
|
| ffn_out = self.preffn1(attn_out, w, w, 1)
|
| combined = attn_out + ffn_out
|
|
|
| y = o + (self.alpha_pre * combined)
|
|
|
|
|
| y = self.norm1(y)
|
| o = y
|
| attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)
|
| ffn_out = self.preffn2(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined)
|
|
|
|
|
| mem = self.Engram(y, input_ids)
|
| gate = torch.sigmoid(self.mem_gate(y))
|
| y = y + (gate * mem)
|
| z = y
|
| linguistic_anchor = y
|
|
|
| for j in range(iters):
|
| y_prev = y
|
| y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y
|
| y = self.Mnorm(y)
|
|
|
| iter_tensor = torch.tensor(j, device=self.device)
|
|
|
| y = self.normM1(y)
|
| o = y
|
| attn_out = self.layerA1(y, rope=self.rope)
|
| moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor)
|
| combined = moe_out + attn_out
|
|
|
| y = o + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM2(y)
|
| o = y
|
| g = y[:, -self.local_seq_len:]
|
| attn_out = self.layerA2(g, rope=self.rope)
|
| moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM3(y)
|
| o = y
|
| attn_out = self.layerA3(y, rope=self.rope)
|
| moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM4(y)
|
| o = y
|
| g = y[:, -self.local_seq_len:]
|
| attn_out = self.layerA4(g, rope=self.rope)
|
| moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
|
| combined = attn_out + moe_out
|
|
|
| y = o.clone()
|
| y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
|
|
|
|
|
| y = self.normM5(y)
|
| o = y
|
| attn_out = self.layerA5(y, rope=self.rope)
|
| moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor)
|
| combined = attn_out + moe_out
|
| y = o + (self.alpha_loop * combined)
|
| z = y
|
| y = ((self.scratchpad_gate(y)) * y) + y_prev
|
|
|
|
|
| z = self.norm2(z)
|
| o = z
|
| attn_out = self.layerPostA1(z, rope=self.rope)
|
| ffn_out = self.postffn1(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| z = o + (self.alpha_post * combined)
|
|
|
|
|
| z = self.norm3(z)
|
| o = z
|
| attn_out = self.layerPostA2(z, rope=self.rope)
|
| ffn_out = self.postffn2(attn_out)
|
| combined = attn_out + ffn_out
|
|
|
| z = o + (self.alpha_post * combined)
|
|
|
|
|
| logits = self.layerO1(z)
|
| token1 = self.pick_word(logits[:, -1, :], k=top_k, temperature=temperature, prev_tokens=generated_tokens, rep_penalty=rep_penalty)
|
| token1 = token1
|
| generated_tokens.append(token1)
|
| input_ids = torch.cat([input_ids, token1], dim=1)
|
| if token1.item() == enc.eot_token:
|
| break
|
|
|
| if enc.decode([token1.item()]) == "<|endoftext|>":
|
| break
|
|
|
|
|
| return generated_tokens
|
|
|
|
|
|
|
| def initmodel(device):
|
| model = biggerbrain(device).to(device).to(torch.bfloat16)
|
| if device != 'cpu':
|
| model = torch.compile(model, dynamic=True)
|
| return model
|
|
|
| def think(prompt, model, max_length=100, iter=3, top_k=10, temperature=1.0, raw=False, rep_penalty=1.5):
|
| model = model._orig_mod if hasattr(model, '_orig_mod') else model
|
| if raw:
|
| formatted = prompt
|
| else:
|
| formatted = f"User: \n\n{prompt}\n\nAssistant:"
|
| input_ids = torch.tensor([enc.encode(formatted, allowed_special={'<|endoftext|>'})]).to(model.device)
|
| model.eval()
|
| with torch.no_grad():
|
| generated_tokens = model.forward(
|
| input_ids,
|
| max_tokens=max_length,
|
| iters=iter,
|
| top_k=top_k,
|
| temperature=temperature,
|
| rep_penalty=rep_penalty
|
| )
|
|
|
| return enc.decode(generated_tokens)
|
|
|
| def print_parameter_breakdown(model):
|
|
|
| base_model = model._orig_mod if hasattr(model, '_orig_mod') else model
|
| total_params = 0
|
| print("--- Parameter Breakdown ---")
|
| for name, param in base_model.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
| params = param.numel()
|
| total_params += params
|
|
|
| if params > 1_000_000:
|
| print(f"MASSIVE LAYER DETECTED -> {name}: {params:,} parameters")
|
| else:
|
| print(f"{name}: {params:,} parameters")
|
|
|
| print(f"---------------------------\nTrue Total: {total_params:,}")
|
|
|