Biggerbrain2_136m / biggerbrain.py
Skull18500's picture
Upload 2 files
e00bb21 verified
from datetime import datetime
import os
from pyexpat import model
from httpx import get
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments=True"
import torch
import torch.nn as nn
import torch.optim as optim
import random
import tiktoken
import ai_extras as AI_ex
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset, random_split
import numpy as np
import torch._dynamo
import bitsandbytes as bnb
import torch.utils.checkpoint as cp
from safetensors.torch import save_model as sf_save_model
torch.backends.cuda.enable_flash_sdp(True)
torch._dynamo.config.suppress_errors = False # make errors visible
torch._dynamo.config.verbose = True # print detailed compilation logs
#torch._inductor.config.triton.cudagraphs = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.enable_mem_efficient_sdp = True
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp = True
enc = tiktoken.get_encoding("gpt2")
def save_model(model, filename):
# 1. Create a temporary name for the Windows lock workaround
temp_filename = filename + ".tmp"
# 2. Use the specialized save_model function
# This automatically handles shared tensors like embed/layerO1
sf_save_model(model, temp_filename)
print("saving model...")
# 3. The Atomic Swap (to beat the Windows 1224 error)
try:
if os.path.exists(filename):
old_file = filename + ".old"
if os.path.exists(old_file):
os.remove(old_file)
os.rename(filename, old_file)
os.rename(temp_filename, filename)
# Optional: print("Model saved successfully with weight-tying support.")
except OSError as e:
print(f"Windows I/O Lock: {e}. Checkpoint kept at {temp_filename}")
def get_optimizer(self, lr):
# Separate parameters into three distinct groups
decay_params = []
no_decay_params = []
alpha_32bit_params = [] # Special group for high-precision scalars
for name, param in self.named_parameters():
if not param.requires_grad:
continue
# 1. High-precision scalars (No decay + Force 32-bit)
# We catch anything with 'alpha' or the memory gate here
if any(x in name for x in ['alpha', 'mem_gate']):
alpha_32bit_params.append(param)
# 2. No decay for biases, LayerNorm/RMSNorm weights, and Embeddings
elif 'bias' in name or 'norm' in name.lower() or 'embed' in name or 'layerO1' in name or 'engram' in name:
no_decay_params.append(param)
# 3. Standard weights with decay
else:
decay_params.append(param)
optim_groups = [
# Standard weights (8-bit)
{'params': decay_params, 'weight_decay': 0.04},
# Biases/Norms (8-bit)
{'params': no_decay_params, 'weight_decay': 0.0},
# Alphas/Scalars (Forced 32-bit)
{'params': alpha_32bit_params, 'weight_decay': 0.0, 'optim_bits': 32}
]
# Initialize 8-bit AdamW - the optim_bits:32 in the group will override this for alphas
optimizer = bnb.optim.AdamW8bit(
optim_groups,
lr=lr,
betas=(0.9, 0.95)
)
return optimizer
class biggerbrain(nn.Module):
def __init__(self, device, sequence_length=640):
super(biggerbrain, self).__init__()
self.dim1 = 768
self.ffndim = int(self.dim1 * 2.7)
self.T_heads = 8
self.P_heads = 12
self.kv_heads = 4
self.sequencelength = sequence_length
self.device = device
self.local_seq_len = 256
self.MLA_dim = 384
self.debugprints = False
self.rope = AI_ex.RoPE(self.dim1 // self.T_heads, max_seq_len=self.sequencelength + 20) # a bit of extra rope for dependiencies that exceed training context
self.embed = nn.Embedding(enc.max_token_value + 1, self.dim1)
self.Engram = AI_ex.engram(self.dim1, self.kv_heads, 3, memory_size=12192, bottleneck=self.MLA_dim)
self.layerO1 = nn.Linear(self.dim1, enc.max_token_value + 1, bias=False)
self.mem_gate = nn.Linear(self.dim1, 1, bias=True)
nn.init.zeros_(self.mem_gate[-1].weight if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.weight)
nn.init.constant_(self.mem_gate[-1].bias if isinstance(self.mem_gate, nn.Sequential) else self.mem_gate.bias, -1.0)
#block weighting parameters
self.alpha_pre = nn.Parameter(torch.tensor(0.5))
self.alpha_loop = nn.Parameter(torch.tensor(0.5))
self.alpha_post = nn.Parameter(torch.tensor(0.5))
self.alpha_mem = nn.Parameter(torch.tensor(0.5))
#Pre block1
self.layerPreA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.preffn1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=3, max_iter=1)
self.normPre1 = AI_ex.RMSNorm(self.dim1)
#Pre block2
self.layerPreA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.preffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))#Half length local attention.
self.normPre2 = AI_ex.RMSNorm(self.dim1)
#Post block1
self.layerPostA1 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.postffn1 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))
self.normPost1 = AI_ex.RMSNorm(self.dim1)
#Post block2
self.layerPostA2 = AI_ex.RoPEAttention(self.dim1, self.P_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.postffn2 = nn.Sequential(nn.Linear(self.dim1, self.ffndim * 2), AI_ex.SwiGLU(), nn.Linear(self.ffndim, self.dim1))
self.normPost2 = AI_ex.RMSNorm(self.dim1)
#Middle attention
self.layerA1 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.layerA2 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)#Half length local attention.
self.layerA3 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.layerA4 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)#Half length local attention.
self.layerA5 = AI_ex.RoPEAttention(self.dim1, self.T_heads, self.kv_heads, bottleneck=self.MLA_dim)
self.layerMi1 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
self.layerMi2 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)#Half length local attention.
self.layerMi3 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
self.layerMi4 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)#Half length local attention.
self.layerMi5 = AI_ex.MoLLayer(self.dim1, self.ffndim, n_experts=2)
self.norm0 = AI_ex.RMSNorm(self.dim1)
self.norm1 = AI_ex.RMSNorm(self.dim1)
self.norm2 = AI_ex.RMSNorm(self.dim1)
self.norm3 = AI_ex.RMSNorm(self.dim1)
self.normM1 = AI_ex.RMSNorm(self.dim1) # Layer normalization for stabilizing.
self.normM2 = AI_ex.RMSNorm(self.dim1)
self.normM3 = AI_ex.RMSNorm(self.dim1)
self.normM4 = AI_ex.RMSNorm(self.dim1)
self.normM5 = AI_ex.RMSNorm(self.dim1)
self.scratchpad_gate = nn.Linear(self.dim1, 1, bias=True)
self.Mnorm = AI_ex.RMSNorm(self.dim1)
self.M1 = AI_ex.custom_mem(self.dim1, self.T_heads) # for memory integration
self.MM2 = AI_ex.GatedResidual(self.dim1) # for linguistic anchor pull
nn.init.zeros_(self.scratchpad_gate.weight)
nn.init.constant_(self.scratchpad_gate.bias, -1.0)
self.layerO1.weight = self.embed.weight # tied weights
self.apply(self._init_weights)
def pick_word(self, output, k=50, temperature=0.8,
prev_tokens=None, rep_penalty=1.1):
logits = output / (temperature + 1e-8)
# Permanently blacklist chat-format artifact tokens
# These have high probability from training data contamination
# but should never appear in normal prose output
BAD_TOKENS = [
#25, # ":" single colon
#3712, # "::" double colon
#1058, # ":" alternate encoding
]
for tok in BAD_TOKENS:
if tok < logits.size(-1):
logits[0, tok] = float('-inf')
# Repetition penalty
if prev_tokens is not None:
# This list comprehension ensures everything is a flat integer
# even if it was passed in as [tensor(1), tensor(2)] or [[1], [2]]
flat_tokens = []
for t in prev_tokens[-64:]:
if isinstance(t, torch.Tensor):
flat_tokens.append(t.item())
elif isinstance(t, list):
flat_tokens.extend(t)
else:
flat_tokens.append(t)
for tok in set(flat_tokens):
if tok < logits.size(-1):
if logits[0, tok] > 0:
logits[0, tok] /= rep_penalty
else:
logits[0, tok] *= rep_penalty
# Top-K filtering
v, _ = torch.topk(logits, min(k, logits.size(-1)))
# Use .view() to ensure the shapes match perfectly
logits[logits < v[..., -1].view(-1, 1)] = float('-inf')
# Sample
probabilities = torch.softmax(logits, dim=-1)
word_id = torch.multinomial(probabilities, num_samples=1)
return word_id
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Standard GPT-2 initialization
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def trainingloop(self, data, epochs=50, lr=3e-4, batchsize=32,
accumulation_steps=4,
subset_fraction=1.0, warmup_steps=2000):
try:
self.train()
batchloss = 1000
# Build full index list once
dataset_size = len(data)
all_indices = np.arange(dataset_size)
np.random.shuffle(all_indices) # shuffle once upfront
#self.to(dtype=torch.bfloat16)
# How many samples per epoch
subset_size = int(dataset_size * subset_fraction)
ITER_OPTIONS = [1, 2, 3]
ITER_WEIGHTS = [0.75, 0.15, 0.1]
chunks = [
all_indices[i * subset_size : (i + 1) * subset_size]
for i in range(int(1.0 / subset_fraction))
]
optimizer = get_optimizer(self, lr)
criterion = nn.CrossEntropyLoss(ignore_index=-1)
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, 0.01, end_factor=1.0, total_iters=warmup_steps
)
decay_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1.0, end_factor=0.001
)
# The Unified Scheduler
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[warmup_steps]
)
best_loss = 1000.0
#self.forward_training = torch.compile(self.forward_training, dynamic=True, backend ='inductor')#, options=['shape_padding':True] model#self
if os.path.exists("checkpoint_full.pth"):
checkpoint = torch.load("checkpoint_full.pth", weights_only=False, map_location='cpu')
self.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
start_epoch = checkpoint['epoch']
#all_indices = checkpoint['all_indices']
batchloss = checkpoint['batchloss']
print(f"Loaded checkpoint for training. {scheduler.get_last_lr()}")
print(f"Restored batchloss: {batchloss:.4f}")
for epoch in range(epochs):
epoch_loss = 0.0
batches_run = 0
optimizer.zero_grad(set_to_none=True)
chunk_idx = epoch % len(chunks)
sampler = SubsetRandomSampler(chunks[chunk_idx])
loader = DataLoader(
data,
batch_size=batchsize,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=True
)
if epoch % len(chunks) == 0 and epoch > 0:
np.random.shuffle(all_indices)
chunks = [
all_indices[i * subset_size : (i + 1) * subset_size]
for i in range(int(1.0 / subset_fraction))
]
print(f"Reshuffled data chunks for next cycle")
for i, (batch_inputs, batch_targets) in enumerate(loader):
if i == 0:
print(f"Epoch {epoch} | Chunk {chunk_idx+1}/{len(chunks)} | "
f"Starting training...")
if (i + 1) % (accumulation_steps * 1000) == 0:
std = self.embed.weight.std().item()
print(f" Embed std: {std:.4f} (healthy = ~0.02, bad = <0.005 OR > 0.05)")
if std < 0.005 or std > 0.05:
print(" WARNING: embedding problems detected!")
if std < 0.005:
print("Embeddings are collapsing. Warning! Attempt to pause.")
if input("Press Enter to continue...").lower() == "stop":
print("Stopping training loop. Saving model...")
save_model(self, "stopped_model.safetensors")
return
elif std > 0.05:
print("Embeddings are too large. Warning! Attempt to pause.")
if input("Press Enter to continue...").lower() == "stop":
print("Stopping training loop. Saving model...")
save_model(self, "stopped_model.safetensors")
return
batch_inputs = batch_inputs.to(self.device, non_blocking=True)
batch_targets = batch_targets.to(self.device, non_blocking=True)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
num_iter = random.choices(ITER_OPTIONS, weights=ITER_WEIGHTS, k=1)[0]
logits = self.forward_training(batch_inputs, iters =num_iter)
lm_loss = criterion(
logits.view(-1, enc.max_token_value + 1),
batch_targets.reshape(-1)
)
lm_loss = lm_loss / accumulation_steps
lm_loss.backward()
if (i + 1) % accumulation_steps == 0:
#print(f"Alpha_pre Grad: {self.alpha_pre.grad}")
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0, foreach=True)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
if (i + 1) % (accumulation_steps * 4) == 0:
if batchloss > lm_loss.detach().item():
batchloss = lm_loss.detach().item()
print(f"New best loss: {batchloss * accumulation_steps:.4f}. Saving model...")
save_model(self, f"best_model.safetensors")
if (i + 1) % (accumulation_steps * 100) == 0:
print(f"Epoch {epoch} | Batch {i + 1 // accumulation_steps} | loss={lm_loss.item() * accumulation_steps:.4f}")
torch.save({
'model': self.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
'chunk_idx': chunk_idx,
#'all_indices': all_indices, # save the shuffle order
'batchloss': batchloss,
}, "checkpoint_full.pth")
print(f"Saved state.")
if (i + 1) % (accumulation_steps * 50000) == 0:
save_model(self, "model.safetensors")
avg_loss = (epoch_loss / batches_run)
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"[{timestamp}] Epoch {epoch:3d} | "
f"Loss: {avg_loss:.4f} | "
f"Chunk: {chunk_idx+1}/{len(chunks)}")
except KeyboardInterrupt:
print("\n[!] Manual Interrupt detected. Saving safety checkpoint...")
save_model(model, "emergency_checkpoint.safetensors")
print("Done. Safe to exit.")
def forward_training(self, input_ids, iters=3):
current_input_ids = input_ids[:, -self.sequencelength:]#:,
curr_seq_len = current_input_ids.size(1)#
input_ids = current_input_ids
w = current_input_ids.long()
w = self.embed(w)
#Pre block 1
y = self.norm0(w)
o = y
attn_out = self.layerPreA1(y, rope=self.rope)#, _
ffn_out = self.preffn1(attn_out, w, w, 1)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237) # sqrt(2)
y = o + (self.alpha_pre * combined)
#Pre block 2
y = self.norm1(y)
o = y
attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)#, _
ffn_out = self.preffn2(attn_out)
combined = attn_out + ffn_out
#combined = attn_out * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined)
#Engram Block
mem = self.Engram(y, input_ids)#ENgram for mem
gate = torch.sigmoid(self.mem_gate(y))
y = y + (gate * mem) # make it decide when to use Engram.
z = y
linguistic_anchor = y
for j in range(iters):
y_prev = y#.detach()
y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y
y = self.Mnorm(y)
iter_tensor = torch.tensor(j, device=self.device)
# ---- BLOCK 1 ----
y = self.normM1(y)
o = y
attn_out = self.layerA1(y, rope=self.rope)
moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor)
combined = moe_out + attn_out
#combined = combined * (1 / 1.41421356237)
y = o + (self.alpha_loop * combined)
# ---- BLOCK 2 ----
y = self.normM2(y)
o = y
g = y[:, -self.local_seq_len:]
attn_out = self.layerA2(g, rope=self.rope)
moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
# ---- BLOCK 3 ----
y = self.normM3(y)
o = y
attn_out = self.layerA3(y, rope=self.rope)
moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o + (self.alpha_loop * combined)
# ---- BLOCK 4 ----
y = self.normM4(y)
o = y
g = y[:, -self.local_seq_len:]
attn_out = self.layerA4(g, rope=self.rope)
moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
# ---- BLOCK 5 ----
y = self.normM5(y)
o = y
attn_out = self.layerA5(y, rope=self.rope)
moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor)
combined = attn_out + moe_out
y = o + (self.alpha_loop * combined)
z = y
y = ((self.scratchpad_gate(y)) * y) + y_prev # gated scratchpad residual
#post block 1
z = self.norm2(z)
o = z
attn_out = self.layerPostA1(z, rope=self.rope)
ffn_out = self.postffn1(attn_out)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237)
z = o + (self.alpha_post * combined)
#post block 2
z = self.norm3(z)
o = z
attn_out = self.layerPostA2(z, rope=self.rope)
ffn_out = self.postffn2(attn_out)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237)
z = o + (self.alpha_post * combined)
#weight tied output layer with embedding
logits = self.layerO1(z)
#if self.debugprints:
#print(f"Logit Max: {logits.max().item()}, Logit Min: {logits.min().item()}")
return logits
def forward(self, input_ids, max_tokens ,iters=3, top_k=10, temperature=1.0, rep_penalty=1.5):
generated_tokens = []
for _ in range(max_tokens):
current_input_ids = input_ids[:, -self.sequencelength:]#:,
curr_seq_len = current_input_ids.size(1)#
input_ids = current_input_ids
w = current_input_ids.long().to(self.device)
w = self.embed(w)
#Pre block 1
y = self.norm0(w)
o = y
attn_out = self.layerPreA1(y, rope=self.rope)#, _
ffn_out = self.preffn1(attn_out, w, w, 1)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237) # sqrt(2)
y = o + (self.alpha_pre * combined)
#Pre block 2
y = self.norm1(y)
o = y
attn_out = self.layerPreA2(y[:, -self.local_seq_len:], rope=self.rope)#, _
ffn_out = self.preffn2(attn_out)
combined = attn_out + ffn_out
#combined = attn_out * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_pre * combined)
#Engram Block
mem = self.Engram(y, input_ids)#ENgram for mem
gate = torch.sigmoid(self.mem_gate(y))
y = y + (gate * mem) # make it decide when to use Engram.
z = y
linguistic_anchor = y
for j in range(iters):
y_prev = y#.detach()
y = (self.M1(y, z, linguistic_anchor) * self.alpha_mem) + y
y = self.Mnorm(y)
iter_tensor = torch.tensor(j, device=self.device)
# ---- BLOCK 1 ----
y = self.normM1(y)
o = y
attn_out = self.layerA1(y, rope=self.rope)
moe_out = self.layerMi1(y, y_prev, linguistic_anchor, iter_tensor)
combined = moe_out + attn_out
#combined = combined * (1 / 1.41421356237)
y = o + (self.alpha_loop * combined)
# ---- BLOCK 2 ----
y = self.normM2(y)
o = y
g = y[:, -self.local_seq_len:]
attn_out = self.layerA2(g, rope=self.rope)
moe_out = self.layerMi2(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
# ---- BLOCK 3 ----
y = self.normM3(y)
o = y
attn_out = self.layerA3(y, rope=self.rope)
moe_out = self.layerMi3(attn_out, y_prev, linguistic_anchor, iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o + (self.alpha_loop * combined)
# ---- BLOCK 4 ----
y = self.normM4(y)
o = y
g = y[:, -self.local_seq_len:]
attn_out = self.layerA4(g, rope=self.rope)
moe_out = self.layerMi4(attn_out, y_prev[:, -self.local_seq_len:], linguistic_anchor[:, -self.local_seq_len:], iter_tensor)
combined = attn_out + moe_out
#combined = combined * (1 / 1.41421356237)
y = o.clone() # Keep the first half as is
y[:, -self.local_seq_len:] = o[:, -self.local_seq_len:] + (self.alpha_loop * combined)
# ---- BLOCK 5 ----
y = self.normM5(y)
o = y
attn_out = self.layerA5(y, rope=self.rope)
moe_out = self.layerMi5(attn_out, y_prev, linguistic_anchor, iter_tensor)
combined = attn_out + moe_out
y = o + (self.alpha_loop * combined)
z = y
y = ((self.scratchpad_gate(y)) * y) + y_prev # gated scratchpad residual
#post block 1
z = self.norm2(z)
o = z
attn_out = self.layerPostA1(z, rope=self.rope)
ffn_out = self.postffn1(attn_out)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237)
z = o + (self.alpha_post * combined)
#post block 2
z = self.norm3(z)
o = z
attn_out = self.layerPostA2(z, rope=self.rope)
ffn_out = self.postffn2(attn_out)
combined = attn_out + ffn_out
#combined = combined * (1 / 1.41421356237)
z = o + (self.alpha_post * combined)
#weight tied output layer with embedding
logits = self.layerO1(z)
token1 = self.pick_word(logits[:, -1, :], k=top_k, temperature=temperature, prev_tokens=generated_tokens, rep_penalty=rep_penalty)
token1 = token1
generated_tokens.append(token1)#.item()
input_ids = torch.cat([input_ids, token1], dim=1)
if token1.item() == enc.eot_token:
break
#return generated_tokens
if enc.decode([token1.item()]) == "<|endoftext|>":
break
#return generated_tokens
return generated_tokens
#Setup model
##This returns the compiled model, moved to BF16 and to the specified device.
def initmodel(device):
model = biggerbrain(device).to(device).to(torch.bfloat16)
if device != 'cpu':
model = torch.compile(model, dynamic=True)
return model
def think(prompt, model, max_length=100, iter=3, top_k=10, temperature=1.0, raw=False, rep_penalty=1.5):
model = model._orig_mod if hasattr(model, '_orig_mod') else model
if raw:
formatted = prompt
else:
formatted = f"User: \n\n{prompt}\n\nAssistant:"
input_ids = torch.tensor([enc.encode(formatted, allowed_special={'<|endoftext|>'})]).to(model.device)#, device=model.device
model.eval()
with torch.no_grad():
generated_tokens = model.forward(
input_ids,
max_tokens=max_length,
iters=iter,
top_k=top_k,
temperature=temperature,
rep_penalty=rep_penalty
)
#print("Output:", enc.decode(generated_tokens))
return enc.decode(generated_tokens)
def print_parameter_breakdown(model):
# Use _orig_mod if the model is compiled
base_model = model._orig_mod if hasattr(model, '_orig_mod') else model
total_params = 0
print("--- Parameter Breakdown ---")
for name, param in base_model.named_parameters():
if not param.requires_grad:
continue
params = param.numel()
total_params += params
# Only print layers that have a suspiciously high count (e.g., > 1 million)
if params > 1_000_000:
print(f"MASSIVE LAYER DETECTED -> {name}: {params:,} parameters")
else:
print(f"{name}: {params:,} parameters")
print(f"---------------------------\nTrue Total: {total_params:,}")