LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
raw
history blame contribute delete
28.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F # Used for GLU
import math
import numpy as np
# Assuming 'add_coord_dim' is defined in models.utils
from models.utils import add_coord_dim
# --- Basic Utility Modules ---
class Identity(nn.Module):
"""
Identity Module.
Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
in nn.Sequential containers or conditional network parts.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Squeeze(nn.Module):
"""
Squeeze Module.
Removes a specified dimension of size 1 from the input tensor.
Useful for incorporating tensor dimension squeezing within nn.Sequential.
Args:
dim (int): The dimension to squeeze.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.squeeze(self.dim)
# --- Core CTM Component Modules ---
class SynapseUNET(nn.Module):
"""
UNET-style architecture for the Synapse Model (f_theta1 in the paper).
This module implements the connections between neurons in the CTM's latent
space. It processes the combined input (previous post-activation state z^t
and attention output o^t) to produce the pre-activations (a^t) for the
next internal tick (Eq. 1 in the paper).
While a simpler Linear or MLP layer can be used, the paper notes
that this U-Net structure empirically performed better, suggesting benefit
from more flexible synaptic connections[cite: 79, 80]. This implementation
uses `depth` points in linspace and creates `depth-1` down/up blocks.
Args:
in_dims (int): Number of input dimensions (d_model + d_input).
out_dims (int): Number of output dimensions (d_model).
depth (int): Determines structure size; creates `depth-1` down/up blocks.
minimum_width (int): Smallest channel width at the U-Net bottleneck.
dropout (float): Dropout rate applied within down/up projections.
"""
def __init__(self,
out_dims,
depth,
minimum_width=16,
dropout=0.0):
super().__init__()
self.width_out = out_dims
self.n_deep = depth # Store depth just for reference if needed
# Define UNET structure based on depth
# Creates `depth` width values, leading to `depth-1` blocks
widths = np.linspace(out_dims, minimum_width, depth)
# Initial projection layer
self.first_projection = nn.Sequential(
nn.LazyLinear(int(widths[0])), # Project to the first width
nn.LayerNorm(int(widths[0])),
nn.SiLU()
)
# Downward path (encoding layers)
self.down_projections = nn.ModuleList()
self.up_projections = nn.ModuleList()
self.skip_lns = nn.ModuleList()
num_blocks = len(widths) - 1 # Number of down/up blocks created
for i in range(num_blocks):
# Down block: widths[i] -> widths[i+1]
self.down_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i]), int(widths[i+1])),
nn.LayerNorm(int(widths[i+1])),
nn.SiLU()
))
# Up block: widths[i+1] -> widths[i]
# Note: Up blocks are added in order matching down blocks conceptually,
# but applied in reverse order in the forward pass.
self.up_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i+1]), int(widths[i])),
nn.LayerNorm(int(widths[i])),
nn.SiLU()
))
# Skip connection LayerNorm operates on width[i]
self.skip_lns.append(nn.LayerNorm(int(widths[i])))
def forward(self, x):
# Initial projection
out_first = self.first_projection(x)
# Downward path, storing outputs for skip connections
outs_down = [out_first]
for layer in self.down_projections:
outs_down.append(layer(outs_down[-1]))
# outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
# Upward path, starting from the bottleneck output
outs_up = outs_down[-1] # Bottleneck activation
num_blocks = len(self.up_projections) # Should be depth - 1
for i in range(num_blocks):
# Apply up projection in reverse order relative to down blocks
# up_projection[num_blocks - 1 - i] processes deeper features first
up_layer_idx = num_blocks - 1 - i
out_up = self.up_projections[up_layer_idx](outs_up)
# Get corresponding skip connection from downward path
# skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
# This matches the output width of the up_projection[up_layer_idx]
skip_idx = up_layer_idx
skip_connection = outs_down[skip_idx]
# Add skip connection and apply LayerNorm corresponding to this level
# skip_lns index also corresponds to the level = skip_idx
outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
# The final output after all up-projections
return outs_up
class SuperLinear(nn.Module):
"""
SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
This layer is the core component enabling Neuron-Level Models (NLMs),
referred to as g_theta_d in the paper (Eq. 3). It applies N independent
linear transformations (or small MLPs when used sequentially) to corresponding
slices of the input tensor along a specified dimension (typically the neuron
or feature dimension).
How it works for NLMs:
- The input `x` is expected to be the pre-activation history for each neuron,
shaped (batch_size, n_neurons=N, history_length=in_dims).
- This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
`w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
- `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
- For each neuron `n` (from 0 to N-1):
- It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
- Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
- Resulting in `out[:, n, :]` (shape B, out_dims).
- The unique bias `self.b1[:, n, :]` is added.
- The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
This allows each neuron `d` to process its temporal history `A_d^t` using
its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
It's typically used within the `trace_processor` module of the main CTM class.
Args:
in_dims (int): Input dimension (typically `memory_length`).
out_dims (int): Output dimension per neuron.
N (int): Number of independent linear models (typically `d_model`).
T (float): Initial value for learnable temperature/scaling factor applied to output.
do_norm (bool): Apply Layer Normalization to the input history before linear transform.
dropout (float): Dropout rate applied to the input.
"""
def __init__(self,
in_dims,
out_dims,
N,
T=1.0,
do_norm=False,
dropout=0):
super().__init__()
# N is the number of neurons (d_model), in_dims is the history length (memory_length)
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
self.in_dims = in_dims # Corresponds to memory_length
# LayerNorm applied across the history dimension for each neuron independently
self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
self.do_norm = do_norm
# Initialize weights and biases
# w1 shape: (memory_length, out_dims, d_model)
self.register_parameter('w1', nn.Parameter(
torch.empty((in_dims, out_dims, N)).uniform_(
-1/math.sqrt(in_dims + out_dims),
1/math.sqrt(in_dims + out_dims)
), requires_grad=True)
)
# b1 shape: (1, d_model, out_dims)
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
# Learnable temperature/scaler T
self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
where B=batch, N=d_model, in_dims=memory_length.
Returns:
torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
"""
# Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
out = self.dropout(x)
# LayerNorm across the memory_length dimension (dim=-1)
out = self.layernorm(out) # Shape remains (B, N, M)
# Apply N independent linear models using einsum
# einsum('BDM,MHD->BDH', ...)
# x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
# w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
# b1: (1, D=N neurons, H)
# einsum result: (B, D, H)
# Applying bias requires matching shapes, b1 is broadcasted.
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
# Squeeze the output dimension (assumed to be 1 usually) and scale by T
# This matches the original code's structure exactly.
out = out.squeeze(-1) / self.T
return out
# --- Backbone Modules ---
class ParityBackbone(nn.Module):
def __init__(self, n_embeddings, d_embedding):
super(ParityBackbone, self).__init__()
self.embedding = nn.Embedding(n_embeddings, d_embedding)
def forward(self, x):
"""
Maps -1 (negative parity) to 0 and 1 (positive) to 1
"""
x = (x == 1).long()
return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
class QAMNISTOperatorEmbeddings(nn.Module):
def __init__(self, num_operator_types, d_projection):
super(QAMNISTOperatorEmbeddings, self).__init__()
self.embedding = nn.Embedding(num_operator_types, d_projection)
def forward(self, x):
# -1 for plus and -2 for minus
return self.embedding(-x - 1)
class QAMNISTIndexEmbeddings(torch.nn.Module):
def __init__(self, max_seq_length, embedding_dim):
super().__init__()
self.max_seq_length = max_seq_length
self.embedding_dim = embedding_dim
embedding = torch.zeros(max_seq_length, embedding_dim)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
embedding[:, 0::2] = torch.sin(position * div_term)
embedding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('embedding', embedding)
def forward(self, x):
return self.embedding[x]
class ThoughtSteps:
"""
Helper class for managing "thought steps" in the ctm_qamnist pipeline.
Args:
iterations_per_digit (int): Number of iterations for each digit.
iterations_per_question_part (int): Number of iterations for each question part.
total_iterations_for_answering (int): Total number of iterations for answering.
total_iterations_for_digits (int): Total number of iterations for digits.
total_iterations_for_question (int): Total number of iterations for question.
"""
def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.total_iterations_for_digits = total_iterations_for_digits
self.total_iterations_for_question = total_iterations_for_question
self.total_iterations_for_answering = total_iterations_for_answering
self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
def determine_step_type(self, stepi: int):
is_digit_step = stepi < self.total_iterations_for_digits
is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
return is_digit_step, is_question_step, is_answer_step
def determine_answer_step_type(self, stepi: int):
step_within_questions = stepi - self.total_iterations_for_digits
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
is_index_step = True
is_operator_step = False
else:
is_index_step = False
is_operator_step = True
return is_index_step, is_operator_step
class MNISTBackbone(nn.Module):
"""
Simple backbone for MNIST feature extraction.
"""
def __init__(self, d_input):
super(MNISTBackbone, self).__init__()
self.layers = nn.Sequential(
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
def forward(self, x):
return self.layers(x)
class MiniGridBackbone(nn.Module):
def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
super().__init__()
self.object_embedding = nn.Embedding(num_objects, embedding_dim)
self.color_embedding = nn.Embedding(num_colors, embedding_dim)
self.state_embedding = nn.Embedding(num_states, embedding_dim)
self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
self.project_to_d_projection = nn.Sequential(
nn.Linear(embedding_dim * 4, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.Linear(d_input, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
x = x.long()
B, H, W, C = x.size()
object_idx = x[:,:,:, 0]
color_idx = x[:,:,:, 1]
state_idx = x[:,:,:, 2]
obj_embed = self.object_embedding(object_idx)
color_embed = self.color_embedding(color_idx)
state_embed = self.state_embedding(state_idx)
pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
pos_embed = self.position_embedding(pos_idx)
out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
return out
class ClassicControlBackbone(nn.Module):
def __init__(self, d_input):
super().__init__()
self.input_projector = nn.Sequential(
nn.Flatten(),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
return self.input_projector(x)
class ShallowWide(nn.Module):
"""
Simple, wide, shallow convolutional backbone for image feature extraction.
Alternative to ResNet, uses grouped convolutions and GLU activations.
Fixed structure, useful for specific experiments.
"""
def __init__(self):
super(ShallowWide, self).__init__()
# LazyConv2d infers input channels
self.layers = nn.Sequential(
nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048),
# Grouped convolution maintains width but processes groups independently
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048)
)
def forward(self, x):
return self.layers(x)
class PretrainedResNetWrapper(nn.Module):
"""
Wrapper to use standard pre-trained ResNet models from torchvision.
Loads a specified ResNet architecture pre-trained on ImageNet, removes the
final classification layer (fc), average pooling, and optionally later layers
(e.g., layer4), allowing it to be used as a feature extractor backbone.
Args:
resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
"""
def __init__(self, resnet_type, fine_tune=True):
super(PretrainedResNetWrapper, self).__init__()
self.resnet_type = resnet_type
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
if not fine_tune:
for param in self.backbone.parameters():
param.requires_grad = False
# Remove final layers to use as feature extractor
self.backbone.avgpool = Identity()
self.backbone.fc = Identity()
# Keep layer4 by default, user can modify instance if needed
# self.backbone.layer4 = Identity()
def forward(self, x):
# Get features from the modified ResNet
out = self.backbone(x)
# Reshape output to (B, C, H, W) - This is heuristic based on original comment.
# User might need to adjust this based on which layers are kept/removed.
# Infer C based on ResNet type (example values)
nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
# Infer H, W assuming output is flattened C * H * W
num_features = out.shape[-1]
# This calculation assumes nc is correct and feature map is square
wh_squared = num_features / nc
if wh_squared < 0 or not float(wh_squared).is_integer():
print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
# Return potentially flattened features if reshape fails
return out
wh = int(np.sqrt(wh_squared))
return out.reshape(x.size(0), nc, wh, wh)
# --- Positional Encoding Modules ---
class LearnableFourierPositionalEncoding(nn.Module):
"""
Learnable Fourier Feature Positional Encoding.
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
Provides positional information for 2D feature maps.
Args:
d_model (int): The output dimension of the positional encoding (D).
G (int): Positional groups (default 1).
M (int): Dimensionality of input coordinates (default 2 for H, W).
F_dim (int): Dimension of the Fourier features.
H_dim (int): Hidden dimension of the MLP.
gamma (float): Initialization scale for the Fourier projection weights (Wr).
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma=1/2.5,
):
super().__init__()
self.G = G
self.M = M
self.F_dim = F_dim
self.H_dim = H_dim
self.D = d_model
self.gamma = gamma
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
self.mlp = nn.Sequential(
nn.Linear(self.F_dim, self.H_dim, bias=True),
nn.GLU(), # Halves H_dim
nn.Linear(self.H_dim // 2, self.D // self.G),
nn.LayerNorm(self.D // self.G)
)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
def forward(self, x):
"""
Computes positional encodings for the input feature map x.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
"""
B, C, H, W = x.shape
# Creates coordinates based on (H, W) and repeats for batch B.
# Takes x[:,0] assuming channel dim isn't needed for coords.
x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
# Compute Fourier features
projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
cosines = torch.cos(projected)
sines = torch.sin(projected)
F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
# Project features through MLP
Y = self.mlp(F) # (B, H, W, D // G)
# Reshape to (B, D, H, W)
PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
return PEx
class MultiLearnableFourierPositionalEncoding(nn.Module):
"""
Combines multiple LearnableFourierPositionalEncoding modules with different
initialization scales (gamma) via a learnable weighted sum.
Allows the model to learn an optimal combination of positional frequencies.
Args:
d_model (int): Output dimension of the encoding.
G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
gamma_range (list[float]): Min and max gamma values for the linspace.
N (int): Number of parallel embedding modules to create.
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma_range=[1.0, 0.1], # Default range
N=10,
):
super().__init__()
self.embedders = nn.ModuleList()
for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
# Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
# Actual registered name remains 'combination' as in original code
self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
self.N = N
def forward(self, x):
"""
Computes combined positional encoding.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
"""
# Compute embeddings from all modules and stack: (N, B, D, H, W)
pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
# Compute combination weights using softmax
# Use registered parameter name 'combination'
# Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
# Compute weighted sum over the N dimension
combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
return combined_emb
class CustomRotationalEmbedding(nn.Module):
"""
Custom Rotational Positional Embedding.
Generates 2D positional embeddings based on rotating a fixed start vector.
The rotation angle for each grid position is determined primarily by its
horizontal position (width dimension). The resulting rotated vectors are
concatenated and projected.
Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
Args:
d_model (int): Dimensionality of the output embeddings.
"""
def __init__(self, d_model):
super(CustomRotationalEmbedding, self).__init__()
# Learnable 2D start vector
self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
# Projects the 4D concatenated rotated vectors to d_model
# Input size 4 comes from concatenating two 2D rotated vectors
self.projection = nn.Sequential(nn.Linear(4, d_model))
def forward(self, x):
"""
Computes rotational positional embeddings based on input width.
Args:
x (torch.Tensor): Input tensor (used for shape and device),
shape (batch_size, channels, height, width).
Returns:
Output tensor containing positional embeddings,
shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
"""
B, C, H, W = x.shape
device = x.device
# --- Generate rotations based only on Width ---
# Angles derived from width dimension
theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
# Create rotation matrices: Shape (W, 2, 2)
# Use unsqueeze(1) to allow stacking along dim 1
rotation_matrices = torch.stack([
torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
# Rotate the start vector by column angle: Shape (W, 2)
rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
# --- Create Grid Key ---
# Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
# This creates a (W, W, 4) key tensor.
key = torch.cat((
torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
), dim=-1) # Shape (W, W, 4)
# Project the 4D key vector to d_model: Shape (W, W, d_model)
pe_grid = self.projection(key)
# Reshape to (1, d_model, W, W) and then select/resize to target H, W?
# Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
# If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
# Let's return the (1, d_model, W, W) tensor as generated by the original logic.
# If H != W, downstream code must handle the mismatch or this PE needs modification.
if H != W:
# Simple interpolation/cropping could be added, but sticking to original logic:
# Option 1: Interpolate
# pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
# Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
# Sticking to original: return shape (1, d_model, W, W)
pass
return pe
class CustomRotationalEmbedding1D(nn.Module):
def __init__(self, d_model):
super(CustomRotationalEmbedding1D, self).__init__()
self.projection = nn.Linear(2, d_model)
def forward(self, x):
start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
# Create rotation matrices
rotation_matrices = torch.stack([
torch.cat([cos_theta, -sin_theta], dim=1),
torch.cat([sin_theta, cos_theta], dim=1)
], dim=1) # Shape: (height, 2, 2)
# Rotate the start vector
rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
pe = self.projection(rotated_vectors)
pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
return pe.transpose(1, 2) # Transpose for compatibility with other backbones