| |
| import torch |
| import torch.nn as nn |
| import math |
| from .positional_encodings import PositionalEncoding |
| from .multihead_attention import MultiHeadAttention |
| from .encoding_layers import position_wide_feed_forward, Residual_layer |
| from .masking_for_attention import mask |
| from .embeddings import Embeddings |
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, dimension_for_model, num_of_heads, dim_feedforward, dropout = 0.1): |
| ''' |
| dimension_for_model: the dimension desired for the model specified at the embeddings layer |
| num_of_heads: the number of heads for the multi-head-attention structure to keep track of |
| dim_feedforward: the dimension of the positional feed forward structure |
| dropout: structure for removing model dependencies during training, improving robustness |
| ''' |
| super().__init__() |
| |
| self.self_attn = MultiHeadAttention(dimension_for_model, num_of_heads, dropout) |
| self.norm1 = nn.LayerNorm(dimension_for_model) |
| self.dropout1 = nn.Dropout(dropout) |
| |
| self.ffn = position_wide_feed_forward(dimension_for_model, dim_feedforward, dropout) |
| self.norm2 = nn.LayerNorm(dimension_for_model) |
| self.dropout2 = nn.Dropout(dropout) |
|
|
| def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor: |
| |
| _src = src |
| attn_output, _ = self.self_attn(src, src, src, mask=src_mask) |
| src = self.norm1(_src + self.dropout1(attn_output)) |
| |
| _src = src |
| ff_output = self.ffn(src) |
| src = self.norm2(_src + self.dropout2(ff_output)) |
| return src |
|
|
|
|
| class Encoder(nn.Module): |
| """ |
| Stacked Transformer encoder: |
| - embedding + positional encoding |
| - N encoder layers |
| - final layer norm |
| """ |
| def __init__(self, vocab_size, dimension_of_model, num_of_heads, num_layers, dim_feedforward = 2048, dropout = 0.1, max_len = 5000, num_of_roles=2, max_turns=16): |
| super().__init__() |
| |
| self.embed = Embeddings(vocab_size, dimension_for_model=dimension_of_model, num_of_roles=num_of_roles, max_turns=max_turns) |
| |
| self.pe = PositionalEncoding(dimension_of_model, dropout=dropout, max_len=max_len) |
| |
| self.layers = nn.ModuleList([ |
| EncoderLayer(dimension_of_model, num_of_heads, dim_feedforward, dropout) |
| for _ in range(num_layers) |
| ]) |
| |
| self.norm = nn.LayerNorm(dimension_of_model) |
|
|
| def forward(self, src_ids, roles, turns, src_mask = None) -> torch.Tensor: |
| """ |
| Args: |
| src_ids: [batch_size x seq_len] input token indices |
| roles: [batch_size x seq_len] role ids |
| turns: [batch_size x seq_len] turn ids |
| src_mask: [batch_size, 1, 1, seq_len] mask to prevent attending to padding tokens |
| """ |
| |
| x = self.embed(src_ids, roles, turns) |
| |
| x = self.pe(x) |
| |
| for layer in self.layers: |
| x = layer(x, src_mask) |
| |
| return self.norm(x) |
| |
| def load_state_dict(self, state_dict, strict=True): |
| """ |
| Custom state dict loading to handle backward compatibility with old model format |
| """ |
| |
| if 'encoder.embed.weight' in state_dict: |
| |
| old_embed_weight = state_dict['encoder.embed.weight'] |
| |
| |
| state_dict['encoder.embed.lut.weight'] = old_embed_weight |
| state_dict['encoder.embed.lut_roles.weight'] = torch.zeros_like(old_embed_weight) |
| state_dict['encoder.embed.lut_turns.weight'] = torch.zeros_like(old_embed_weight) |
| state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1)) |
| state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1)) |
| |
| |
| del state_dict['encoder.embed.weight'] |
| |
| return super().load_state_dict(state_dict, strict=strict) |
| |
|
|