| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Optional |
| from torch import Tensor |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| |
| class BlaserConfig(PretrainedConfig): |
| model_type = "blaser" |
|
|
| def __init__( |
| self, |
| embedding_dim=1024, |
| output_dim=1, |
| hidden_dims=None, |
| dropout=0.1, |
| activation="TANH", |
| input_form="COMET", |
| norm_emb=True, |
| output_act=False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.embedding_dim = embedding_dim |
| self.output_dim = output_dim |
| self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536] |
| self.dropout = dropout |
| self.activation = activation |
| self.input_form = input_form |
| self.norm_emb = norm_emb |
| self.output_act = output_act |
|
|
|
|
| |
| ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU} |
|
|
|
|
| class BlaserCore(nn.Module): |
| def __init__( |
| self, |
| embedding_dim: int, |
| output_dim: int, |
| hidden_dims: List[int], |
| dropout: float, |
| activation: str, |
| input_form: str, |
| norm_emb: bool, |
| output_act: bool, |
| ): |
| super().__init__() |
| self.input_form = input_form |
| self.norm_emb = norm_emb |
|
|
| if input_form == "COMET": |
| embedding_dim *= 6 |
| elif input_form == "QE": |
| embedding_dim *= 4 |
| else: |
| raise ValueError(f"Unrecognized input_form: {input_form}") |
| if activation not in ACTIVATIONS: |
| raise ValueError(f"Unrecognized activation: {activation}") |
|
|
| modules: List[nn.Module] = [] |
| if hidden_dims: |
| if dropout > 0: |
| modules.append(nn.Dropout(p=dropout)) |
| nprev = embedding_dim |
| for h in hidden_dims: |
| modules.append(nn.Linear(nprev, h)) |
| modules.append(ACTIVATIONS[activation]()) |
| if dropout > 0: |
| modules.append(nn.Dropout(p=dropout)) |
| nprev = h |
| modules.append(nn.Linear(nprev, output_dim)) |
| if output_act: |
| modules.append(nn.Tanh()) |
| else: |
| modules.append(nn.Linear(embedding_dim, output_dim)) |
|
|
| self.mlp = nn.Sequential(*modules) |
|
|
| def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]: |
| return F.normalize(emb) if (emb is not None and self.norm_emb) else emb |
|
|
| def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor: |
| if self.input_form == "COMET": |
| if ref is None: |
| raise ValueError("COMET input_form requires reference embedding") |
| return torch.cat( |
| [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], |
| dim=-1, |
| ) |
| elif self.input_form == "QE": |
| return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) |
|
|
|
|
| |
| class BlaserModel(PreTrainedModel): |
| config_class = BlaserConfig |
|
|
| def __init__(self, config: BlaserConfig): |
| super().__init__(config) |
| |
| core = BlaserCore( |
| embedding_dim=config.embedding_dim, |
| output_dim=config.output_dim, |
| hidden_dims=config.hidden_dims, |
| dropout=config.dropout, |
| activation=config.activation, |
| input_form=config.input_form, |
| norm_emb=config.norm_emb, |
| output_act=config.output_act, |
| ) |
| self.mlp = core.mlp |
| self.input_form = core.input_form |
| self.norm_emb = core.norm_emb |
|
|
| |
| self.post_init() |
|
|
| def forward(self, src, mt, ref=None): |
| |
| src = F.normalize(src) if self.norm_emb else src |
| mt = F.normalize(mt) if self.norm_emb else mt |
| ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref |
|
|
| if self.input_form == "COMET": |
| if ref is None: |
| raise ValueError("COMET input_form requires reference embedding") |
| proc = torch.cat( |
| [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)], |
| dim=-1, |
| ) |
| else: |
| proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1) |
|
|
| return self.mlp(proc) |