| """PIRD model. Stream B = a transformer encoder (DeBERTa-v3-small by default) with mean pooling and |
| an MLP head. Designed to also accept concatenated Stream A/C feature vectors (`extra`) for the later |
| fusion version; in PIRD-lite `n_extra=0` and only the encoder is used.""" |
| from __future__ import annotations |
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
|
|
|
|
| class PIRDModel(nn.Module): |
| def __init__(self, encoder_name: str = "roberta-base", |
| n_extra: int = 0, hidden: int = 256, dropout: float = 0.2): |
| super().__init__() |
| self.encoder_name = encoder_name |
| self.n_extra = n_extra |
| self.encoder = AutoModel.from_pretrained(encoder_name) |
| enc_dim = self.encoder.config.hidden_size |
| self.head = nn.Sequential( |
| nn.Linear(enc_dim + n_extra, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, 1), |
| ) |
|
|
| def encode(self, input_ids, attention_mask): |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| h = out.last_hidden_state |
| mask = attention_mask.unsqueeze(-1).float() |
| return (h * mask).sum(1) / mask.sum(1).clamp_min(1e-6) |
|
|
| def forward(self, input_ids, attention_mask, extra=None): |
| pooled = self.encode(input_ids, attention_mask) |
| if self.n_extra and extra is not None: |
| pooled = torch.cat([pooled, extra], dim=-1) |
| return self.head(pooled).squeeze(-1) |
|
|