| from typing import Optional, Tuple |
| import torch |
| from torch import nn |
| from .configuration_italia import ItaliaConfig |
| from transformers.models.gpt_neox import modeling_gpt_neox |
|
|
| |
| class GPTNeoXLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.use_parallel_residual = config.use_parallel_residual |
| self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.post_attention_dropout = nn.Dropout(config.hidden_dropout) |
| self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) |
| self.attention = modeling_gpt_neox.GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config) |
| self.mlp = modeling_gpt_neox.GPTNeoXMLP(config) |
|
|
| def forward( |
| self, |
| hidden_states: Optional[torch.FloatTensor], |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = False, |
| layer_past: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| ): |
| |
| attention_layer_outputs = self.attention( |
| self.input_layernorm(hidden_states), |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| layer_past=layer_past, |
| head_mask=head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| attn_output = attention_layer_outputs[0] |
| attn_output = self.post_attention_dropout(attn_output) |
| outputs = attention_layer_outputs[1:] |
|
|
| |
| |
| mlp_output = self.mlp(self.input_layernorm(hidden_states)) |
| mlp_output = self.post_mlp_dropout(mlp_output) |
| hidden_states = mlp_output + attn_output + hidden_states |
|
|
| if use_cache: |
| outputs = (hidden_states,) + outputs |
| else: |
| outputs = (hidden_states,) + outputs[1:] |
|
|
| return outputs |
|
|
| modeling_gpt_neox.GPTNeoXLayer = GPTNeoXLayer |
|
|
| from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, GPTNeoXModel |
|
|
| class ItaliaForCausalLM(GPTNeoXForCausalLM): |
|
|
|
|
| config_class = ItaliaConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.gpt_neox = GPTNeoXModel(config) |
| self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=True) |
|
|
| |
| self.post_init() |
|
|