| import torch |
| import torch.nn as nn |
| import math |
|
|
| |
| from built_transformer.embeddings import Embeddings |
| from built_transformer.encoder import Encoder, EncoderLayer |
| from built_transformer.decoders import Decoder, DecoderLayer |
| from built_transformer.positional_encodings import PositionalEncoding |
| from built_transformer.slot_classifier import SlotClassifier |
|
|
| class TransformerChatbot(nn.Module): |
| """ |
| Unified Transformer-based chatbot model that combines: |
| - Joint token/role/turn embeddings |
| - Encoder-decoder architecture with attention |
| - Slot-filling classification |
| - Generation capabilities |
| """ |
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int = 512, |
| num_heads: int = 8, |
| d_ff: int = 2048, |
| num_encoder_layers: int = 6, |
| num_decoder_layers: int = 6, |
| num_roles: int = 2, |
| max_turns: int = 16, |
| num_slots: int = 4, |
| dropout: float = 0.1, |
| max_len: int = 5000 |
| ): |
| super().__init__() |
| |
| |
| self.embed = Embeddings( |
| char=vocab_size, |
| dimension_for_model=d_model, |
| num_of_roles=num_roles, |
| max_turns=max_turns |
| ) |
| |
| |
| self.pos_enc = PositionalEncoding(d_model, dropout, max_len) |
| |
| |
| self.encoder = Encoder( |
| vocab_size=vocab_size, |
| dimension_of_model=d_model, |
| num_of_heads=num_heads, |
| num_layers=num_encoder_layers, |
| dim_feedforward=d_ff, |
| dropout=dropout, |
| max_len=max_len, |
| num_of_roles=num_roles, |
| max_turns=max_turns |
| ) |
| |
| |
| self.decoder = Decoder( |
| vocab_size=vocab_size, |
| dimension_for_model=d_model, |
| num_layers=num_decoder_layers, |
| num_of_heads=num_heads, |
| dim_feedforward=d_ff, |
| dropout=dropout, |
| max_len=max_len |
| ) |
| |
| |
| self.out_proj = nn.Linear(d_model, vocab_size) |
| self.slot_classifier = SlotClassifier(d_model, num_slots) |
| |
| |
| self._init_parameters() |
| |
| def _init_parameters(self): |
| |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
| |
| def load_state_dict(self, state_dict, strict=True): |
| |
| if 'encoder.embed.weight' in state_dict: |
| |
| old_embed_weight = state_dict['encoder.embed.weight'] |
| |
| |
| state_dict['encoder.embed.lut.weight'] = old_embed_weight |
| |
| state_dict['encoder.embed.lut_roles.weight'] = torch.zeros(2, old_embed_weight.size(1)) |
| state_dict['encoder.embed.lut_turns.weight'] = torch.zeros(16, old_embed_weight.size(1)) |
| state_dict['encoder.embed.norm.weight'] = torch.ones(old_embed_weight.size(1)) |
| state_dict['encoder.embed.norm.bias'] = torch.zeros(old_embed_weight.size(1)) |
| |
| |
| del state_dict['encoder.embed.weight'] |
| |
| return super().load_state_dict(state_dict, strict=strict) |
| |
| def encode(self, src_tokens, src_roles, src_turns, src_mask=None): |
| """ |
| Encode source sequences with role and turn information. |
| Args: |
| src_tokens: [B, S] token IDs |
| src_roles: [B, S] role IDs |
| src_turns: [B, S] turn IDs |
| src_mask: padding mask [B, 1, 1, S] |
| Returns: |
| enc_out: [B, S, d_model] |
| """ |
| |
| return self.encoder(src_tokens, src_roles, src_turns, src_mask) |
| |
| def decode( |
| self, |
| tgt_tokens, |
| enc_out, |
| tgt_roles, |
| tgt_turns, |
| src_mask=None, |
| tgt_mask=None |
| ): |
| """ |
| Decode target sequences with encoder context. |
| Args: |
| tgt_tokens: [B, T] target token IDs |
| enc_out: [B, S, d_model] encoder output |
| tgt_roles: [B, T] target role IDs |
| tgt_turns: [B, T] target turn IDs |
| src_mask: [B, 1, 1, S] source mask |
| tgt_mask: [B, 1, T, T] target mask |
| Returns: |
| logits: [B, T, vocab_size] |
| """ |
| |
| y = self.embed(tgt_tokens, tgt_roles, tgt_turns) |
| y = self.pos_enc(y) |
| |
| |
| dec_out = self.decoder(tgt_tokens, enc_out, tgt_mask, src_mask) |
| return self.out_proj(dec_out) |
| |
| def forward( |
| self, |
| src_tokens, |
| tgt_tokens, |
| src_roles, |
| tgt_roles, |
| src_turns, |
| tgt_turns, |
| src_mask=None, |
| tgt_mask=None |
| ): |
| """ |
| Full forward pass combining encoding, decoding, and slot classification. |
| Args: |
| src_tokens: [B, S] source token IDs |
| tgt_tokens: [B, T] target token IDs |
| src_roles: [B, S] source role IDs |
| tgt_roles: [B, T] target role IDs |
| src_turns: [B, S] source turn IDs |
| tgt_turns: [B, T] target turn IDs |
| src_mask: [B, 1, 1, S] source mask |
| tgt_mask: [B, 1, T, T] target mask |
| Returns: |
| gen_logits: [B, T, vocab_size] generation logits |
| slot_logits: [B, num_slots] slot classification logits |
| """ |
| |
| enc_out = self.encode(src_tokens, src_roles, src_turns, src_mask) |
| |
| |
| gen_logits = self.decode( |
| tgt_tokens, |
| enc_out, |
| tgt_roles, |
| tgt_turns, |
| src_mask, |
| tgt_mask |
| ) |
| |
| |
| cls_rep = enc_out[:, 0, :] |
| slot_logits = self.slot_classifier(cls_rep) |
| |
| return gen_logits, slot_logits |
|
|