| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import numpy as np |
|
|
| |
| from models.utils import add_coord_dim |
|
|
| |
|
|
| class Identity(nn.Module): |
| """ |
| Identity Module. |
| |
| Returns the input tensor unchanged. Useful as a placeholder or a no-op layer |
| in nn.Sequential containers or conditional network parts. |
| """ |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x): |
| return x |
|
|
|
|
| class Squeeze(nn.Module): |
| """ |
| Squeeze Module. |
| |
| Removes a specified dimension of size 1 from the input tensor. |
| Useful for incorporating tensor dimension squeezing within nn.Sequential. |
| |
| Args: |
| dim (int): The dimension to squeeze. |
| """ |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| return x.squeeze(self.dim) |
|
|
| |
|
|
| class SynapseUNET(nn.Module): |
| """ |
| UNET-style architecture for the Synapse Model (f_theta1 in the paper). |
| |
| This module implements the connections between neurons in the CTM's latent |
| space. It processes the combined input (previous post-activation state z^t |
| and attention output o^t) to produce the pre-activations (a^t) for the |
| next internal tick (Eq. 1 in the paper). |
| |
| While a simpler Linear or MLP layer can be used, the paper notes |
| that this U-Net structure empirically performed better, suggesting benefit |
| from more flexible synaptic connections[cite: 79, 80]. This implementation |
| uses `depth` points in linspace and creates `depth-1` down/up blocks. |
| |
| Args: |
| in_dims (int): Number of input dimensions (d_model + d_input). |
| out_dims (int): Number of output dimensions (d_model). |
| depth (int): Determines structure size; creates `depth-1` down/up blocks. |
| minimum_width (int): Smallest channel width at the U-Net bottleneck. |
| dropout (float): Dropout rate applied within down/up projections. |
| """ |
| def __init__(self, |
| out_dims, |
| depth, |
| minimum_width=16, |
| dropout=0.0): |
| super().__init__() |
| self.width_out = out_dims |
| self.n_deep = depth |
|
|
| |
| |
| widths = np.linspace(out_dims, minimum_width, depth) |
|
|
| |
| self.first_projection = nn.Sequential( |
| nn.LazyLinear(int(widths[0])), |
| nn.LayerNorm(int(widths[0])), |
| nn.SiLU() |
| ) |
|
|
| |
| self.down_projections = nn.ModuleList() |
| self.up_projections = nn.ModuleList() |
| self.skip_lns = nn.ModuleList() |
| num_blocks = len(widths) - 1 |
|
|
| for i in range(num_blocks): |
| |
| self.down_projections.append(nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(int(widths[i]), int(widths[i+1])), |
| nn.LayerNorm(int(widths[i+1])), |
| nn.SiLU() |
| )) |
| |
| |
| |
| self.up_projections.append(nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(int(widths[i+1]), int(widths[i])), |
| nn.LayerNorm(int(widths[i])), |
| nn.SiLU() |
| )) |
| |
| self.skip_lns.append(nn.LayerNorm(int(widths[i]))) |
|
|
| def forward(self, x): |
| |
| out_first = self.first_projection(x) |
|
|
| |
| outs_down = [out_first] |
| for layer in self.down_projections: |
| outs_down.append(layer(outs_down[-1])) |
| |
|
|
| |
| outs_up = outs_down[-1] |
| num_blocks = len(self.up_projections) |
|
|
| for i in range(num_blocks): |
| |
| |
| up_layer_idx = num_blocks - 1 - i |
| out_up = self.up_projections[up_layer_idx](outs_up) |
|
|
| |
| |
| |
| skip_idx = up_layer_idx |
| skip_connection = outs_down[skip_idx] |
|
|
| |
| |
| outs_up = self.skip_lns[skip_idx](out_up + skip_connection) |
|
|
| |
| return outs_up |
|
|
|
|
| class SuperLinear(nn.Module): |
| """ |
| SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM. |
| |
| This layer is the core component enabling Neuron-Level Models (NLMs), |
| referred to as g_theta_d in the paper (Eq. 3). It applies N independent |
| linear transformations (or small MLPs when used sequentially) to corresponding |
| slices of the input tensor along a specified dimension (typically the neuron |
| or feature dimension). |
| |
| How it works for NLMs: |
| - The input `x` is expected to be the pre-activation history for each neuron, |
| shaped (batch_size, n_neurons=N, history_length=in_dims). |
| - This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons. |
| `w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims). |
| - `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix |
| multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`): |
| - For each neuron `n` (from 0 to N-1): |
| - It takes the neuron's history `x[:, n, :]` (shape B, in_dims). |
| - Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims). |
| - Resulting in `out[:, n, :]` (shape B, out_dims). |
| - The unique bias `self.b1[:, n, :]` is added. |
| - The result is squeezed on the last dim (if out_dims=1) and scaled by `T`. |
| |
| This allows each neuron `d` to process its temporal history `A_d^t` using |
| its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`, |
| enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85]. |
| It's typically used within the `trace_processor` module of the main CTM class. |
| |
| Args: |
| in_dims (int): Input dimension (typically `memory_length`). |
| out_dims (int): Output dimension per neuron. |
| N (int): Number of independent linear models (typically `d_model`). |
| T (float): Initial value for learnable temperature/scaling factor applied to output. |
| do_norm (bool): Apply Layer Normalization to the input history before linear transform. |
| dropout (float): Dropout rate applied to the input. |
| """ |
| def __init__(self, |
| in_dims, |
| out_dims, |
| N, |
| T=1.0, |
| do_norm=False, |
| dropout=0): |
| super().__init__() |
| |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity() |
| self.in_dims = in_dims |
| |
| self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity() |
| self.do_norm = do_norm |
|
|
| |
| |
| self.register_parameter('w1', nn.Parameter( |
| torch.empty((in_dims, out_dims, N)).uniform_( |
| -1/math.sqrt(in_dims + out_dims), |
| 1/math.sqrt(in_dims + out_dims) |
| ), requires_grad=True) |
| ) |
| |
| self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True)) |
| |
| self.register_parameter('T', nn.Parameter(torch.Tensor([T]))) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x (torch.Tensor): Input tensor, expected shape (B, N, in_dims) |
| where B=batch, N=d_model, in_dims=memory_length. |
| Returns: |
| torch.Tensor: Output tensor, shape (B, N) after squeeze(-1). |
| """ |
| |
| out = self.dropout(x) |
| |
| out = self.layernorm(out) |
|
|
| |
| |
| |
| |
| |
| |
| |
| out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1 |
|
|
| |
| |
| out = out.squeeze(-1) / self.T |
| return out |
|
|
|
|
| |
|
|
| class ParityBackbone(nn.Module): |
| def __init__(self, n_embeddings, d_embedding): |
| super(ParityBackbone, self).__init__() |
| self.embedding = nn.Embedding(n_embeddings, d_embedding) |
|
|
| def forward(self, x): |
| """ |
| Maps -1 (negative parity) to 0 and 1 (positive) to 1 |
| """ |
| x = (x == 1).long() |
| return self.embedding(x.long()).transpose(1, 2) |
|
|
| class QAMNISTOperatorEmbeddings(nn.Module): |
| def __init__(self, num_operator_types, d_projection): |
| super(QAMNISTOperatorEmbeddings, self).__init__() |
| self.embedding = nn.Embedding(num_operator_types, d_projection) |
|
|
| def forward(self, x): |
| |
| return self.embedding(-x - 1) |
|
|
| class QAMNISTIndexEmbeddings(torch.nn.Module): |
| def __init__(self, max_seq_length, embedding_dim): |
| super().__init__() |
| self.max_seq_length = max_seq_length |
| self.embedding_dim = embedding_dim |
|
|
| embedding = torch.zeros(max_seq_length, embedding_dim) |
| position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim)) |
| |
| embedding[:, 0::2] = torch.sin(position * div_term) |
| embedding[:, 1::2] = torch.cos(position * div_term) |
| |
| self.register_buffer('embedding', embedding) |
|
|
| def forward(self, x): |
| return self.embedding[x] |
| |
| class ThoughtSteps: |
| """ |
| Helper class for managing "thought steps" in the ctm_qamnist pipeline. |
| |
| Args: |
| iterations_per_digit (int): Number of iterations for each digit. |
| iterations_per_question_part (int): Number of iterations for each question part. |
| total_iterations_for_answering (int): Total number of iterations for answering. |
| total_iterations_for_digits (int): Total number of iterations for digits. |
| total_iterations_for_question (int): Total number of iterations for question. |
| """ |
| def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question): |
| self.iterations_per_digit = iterations_per_digit |
| self.iterations_per_question_part = iterations_per_question_part |
| self.total_iterations_for_digits = total_iterations_for_digits |
| self.total_iterations_for_question = total_iterations_for_question |
| self.total_iterations_for_answering = total_iterations_for_answering |
| self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering |
|
|
| def determine_step_type(self, stepi: int): |
| is_digit_step = stepi < self.total_iterations_for_digits |
| is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question |
| is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question |
| return is_digit_step, is_question_step, is_answer_step |
|
|
| def determine_answer_step_type(self, stepi: int): |
| step_within_questions = stepi - self.total_iterations_for_digits |
| if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part: |
| is_index_step = True |
| is_operator_step = False |
| else: |
| is_index_step = False |
| is_operator_step = True |
| return is_index_step, is_operator_step |
|
|
| class MNISTBackbone(nn.Module): |
| """ |
| Simple backbone for MNIST feature extraction. |
| """ |
| def __init__(self, d_input): |
| super(MNISTBackbone, self).__init__() |
| self.layers = nn.Sequential( |
| nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(d_input), |
| nn.ReLU(), |
| nn.MaxPool2d(2, 2), |
| nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(d_input), |
| nn.ReLU(), |
| nn.MaxPool2d(2, 2), |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| class MiniGridBackbone(nn.Module): |
| def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8): |
| super().__init__() |
| self.object_embedding = nn.Embedding(num_objects, embedding_dim) |
| self.color_embedding = nn.Embedding(num_colors, embedding_dim) |
| self.state_embedding = nn.Embedding(num_states, embedding_dim) |
| |
| self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim) |
|
|
| self.project_to_d_projection = nn.Sequential( |
| nn.Linear(embedding_dim * 4, d_input * 2), |
| nn.GLU(), |
| nn.LayerNorm(d_input), |
| nn.Linear(d_input, d_input * 2), |
| nn.GLU(), |
| nn.LayerNorm(d_input) |
| ) |
|
|
| def forward(self, x): |
| x = x.long() |
| B, H, W, C = x.size() |
|
|
| object_idx = x[:,:,:, 0] |
| color_idx = x[:,:,:, 1] |
| state_idx = x[:,:,:, 2] |
|
|
| obj_embed = self.object_embedding(object_idx) |
| color_embed = self.color_embedding(color_idx) |
| state_embed = self.state_embedding(state_idx) |
| |
| pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1) |
| pos_embed = self.position_embedding(pos_idx) |
|
|
| out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1)) |
| return out |
|
|
| class ClassicControlBackbone(nn.Module): |
| def __init__(self, d_input): |
| super().__init__() |
| self.input_projector = nn.Sequential( |
| nn.Flatten(), |
| nn.LazyLinear(d_input * 2), |
| nn.GLU(), |
| nn.LayerNorm(d_input), |
| nn.LazyLinear(d_input * 2), |
| nn.GLU(), |
| nn.LayerNorm(d_input) |
| ) |
|
|
| def forward(self, x): |
| return self.input_projector(x) |
|
|
|
|
| class ShallowWide(nn.Module): |
| """ |
| Simple, wide, shallow convolutional backbone for image feature extraction. |
| |
| Alternative to ResNet, uses grouped convolutions and GLU activations. |
| Fixed structure, useful for specific experiments. |
| """ |
| def __init__(self): |
| super(ShallowWide, self).__init__() |
| |
| self.layers = nn.Sequential( |
| nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), |
| nn.GLU(dim=1), |
| nn.BatchNorm2d(2048), |
| |
| nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32), |
| nn.GLU(dim=1), |
| nn.BatchNorm2d(2048) |
| ) |
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| class PretrainedResNetWrapper(nn.Module): |
| """ |
| Wrapper to use standard pre-trained ResNet models from torchvision. |
| |
| Loads a specified ResNet architecture pre-trained on ImageNet, removes the |
| final classification layer (fc), average pooling, and optionally later layers |
| (e.g., layer4), allowing it to be used as a feature extractor backbone. |
| |
| Args: |
| resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50'). |
| fine_tune (bool): If False, freezes the weights of the pre-trained backbone. |
| """ |
| def __init__(self, resnet_type, fine_tune=True): |
| super(PretrainedResNetWrapper, self).__init__() |
| self.resnet_type = resnet_type |
| self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True) |
|
|
| if not fine_tune: |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.backbone.avgpool = Identity() |
| self.backbone.fc = Identity() |
| |
| |
|
|
| def forward(self, x): |
| |
| out = self.backbone(x) |
|
|
| |
| |
| |
| nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 |
| |
| num_features = out.shape[-1] |
| |
| wh_squared = num_features / nc |
| if wh_squared < 0 or not float(wh_squared).is_integer(): |
| print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}") |
| |
| return out |
| wh = int(np.sqrt(wh_squared)) |
|
|
| return out.reshape(x.size(0), nc, wh, wh) |
|
|
| |
|
|
| class LearnableFourierPositionalEncoding(nn.Module): |
| """ |
| Learnable Fourier Feature Positional Encoding. |
| |
| Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional |
| Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf). |
| Provides positional information for 2D feature maps. |
| |
| Args: |
| d_model (int): The output dimension of the positional encoding (D). |
| G (int): Positional groups (default 1). |
| M (int): Dimensionality of input coordinates (default 2 for H, W). |
| F_dim (int): Dimension of the Fourier features. |
| H_dim (int): Hidden dimension of the MLP. |
| gamma (float): Initialization scale for the Fourier projection weights (Wr). |
| """ |
| def __init__(self, d_model, |
| G=1, M=2, |
| F_dim=256, |
| H_dim=128, |
| gamma=1/2.5, |
| ): |
| super().__init__() |
| self.G = G |
| self.M = M |
| self.F_dim = F_dim |
| self.H_dim = H_dim |
| self.D = d_model |
| self.gamma = gamma |
|
|
| self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.F_dim, self.H_dim, bias=True), |
| nn.GLU(), |
| nn.Linear(self.H_dim // 2, self.D // self.G), |
| nn.LayerNorm(self.D // self.G) |
| ) |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) |
|
|
| def forward(self, x): |
| """ |
| Computes positional encodings for the input feature map x. |
| |
| Args: |
| x (torch.Tensor): Input feature map, shape (B, C, H, W). |
| |
| Returns: |
| torch.Tensor: Positional encoding tensor, shape (B, D, H, W). |
| """ |
| B, C, H, W = x.shape |
| |
| |
| x_coord = add_coord_dim(x[:,0]) |
|
|
| |
| projected = self.Wr(x_coord) |
| cosines = torch.cos(projected) |
| sines = torch.sin(projected) |
| F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) |
|
|
| |
| Y = self.mlp(F) |
|
|
| |
| PEx = Y.permute(0, 3, 1, 2) |
| return PEx |
|
|
|
|
| class MultiLearnableFourierPositionalEncoding(nn.Module): |
| """ |
| Combines multiple LearnableFourierPositionalEncoding modules with different |
| initialization scales (gamma) via a learnable weighted sum. |
| |
| Allows the model to learn an optimal combination of positional frequencies. |
| |
| Args: |
| d_model (int): Output dimension of the encoding. |
| G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding. |
| gamma_range (list[float]): Min and max gamma values for the linspace. |
| N (int): Number of parallel embedding modules to create. |
| """ |
| def __init__(self, d_model, |
| G=1, M=2, |
| F_dim=256, |
| H_dim=128, |
| gamma_range=[1.0, 0.1], |
| N=10, |
| ): |
| super().__init__() |
| self.embedders = nn.ModuleList() |
| for gamma in np.linspace(gamma_range[0], gamma_range[1], N): |
| self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma)) |
|
|
| |
| |
| self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True)) |
| self.N = N |
|
|
|
|
| def forward(self, x): |
| """ |
| Computes combined positional encoding. |
| |
| Args: |
| x (torch.Tensor): Input feature map, shape (B, C, H, W). |
| |
| Returns: |
| torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W). |
| """ |
| |
| pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0) |
|
|
| |
| |
| |
| weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1) |
|
|
| |
| combined_emb = (pos_embs * weights).sum(0) |
| return combined_emb |
|
|
|
|
| class CustomRotationalEmbedding(nn.Module): |
| """ |
| Custom Rotational Positional Embedding. |
| |
| Generates 2D positional embeddings based on rotating a fixed start vector. |
| The rotation angle for each grid position is determined primarily by its |
| horizontal position (width dimension). The resulting rotated vectors are |
| concatenated and projected. |
| |
| Note: The current implementation derives angles only from the width dimension (`x.size(-1)`). |
| |
| Args: |
| d_model (int): Dimensionality of the output embeddings. |
| """ |
| def __init__(self, d_model): |
| super(CustomRotationalEmbedding, self).__init__() |
| |
| self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True)) |
| |
| |
| self.projection = nn.Sequential(nn.Linear(4, d_model)) |
|
|
| def forward(self, x): |
| """ |
| Computes rotational positional embeddings based on input width. |
| |
| Args: |
| x (torch.Tensor): Input tensor (used for shape and device), |
| shape (batch_size, channels, height, width). |
| Returns: |
| Output tensor containing positional embeddings, |
| shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all. |
| """ |
| B, C, H, W = x.shape |
| device = x.device |
|
|
| |
| |
| theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) |
| cos_theta = torch.cos(theta_rad) |
| sin_theta = torch.sin(theta_rad) |
|
|
| |
| |
| rotation_matrices = torch.stack([ |
| torch.stack([cos_theta, -sin_theta], dim=-1), |
| torch.stack([sin_theta, cos_theta], dim=-1) |
| ], dim=1) |
|
|
| |
| rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector) |
|
|
| |
| |
| |
| key = torch.cat(( |
| torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), |
| torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) |
| ), dim=-1) |
|
|
| |
| pe_grid = self.projection(key) |
|
|
| |
| |
| pe = pe_grid.permute(2, 0, 1).unsqueeze(0) |
|
|
| |
| |
| |
| if H != W: |
| |
| |
| |
| |
| |
| pass |
|
|
| return pe |
|
|
| class CustomRotationalEmbedding1D(nn.Module): |
| def __init__(self, d_model): |
| super(CustomRotationalEmbedding1D, self).__init__() |
| self.projection = nn.Linear(2, d_model) |
|
|
| def forward(self, x): |
| start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float) |
| theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device)) |
| cos_theta = torch.cos(theta_rad) |
| sin_theta = torch.sin(theta_rad) |
| cos_theta = cos_theta.unsqueeze(1) |
| sin_theta = sin_theta.unsqueeze(1) |
|
|
| |
| rotation_matrices = torch.stack([ |
| torch.cat([cos_theta, -sin_theta], dim=1), |
| torch.cat([sin_theta, cos_theta], dim=1) |
| ], dim=1) |
|
|
| |
| rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector) |
|
|
| pe = self.projection(rotated_vectors) |
| pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0) |
| return pe.transpose(1, 2) |
| |
|
|