MohsinEli's picture
Add PIRD app + trained checkpoint
bd743a9
Raw
History Blame
1.58 kB
"""PIRD model. Stream B = a transformer encoder (DeBERTa-v3-small by default) with mean pooling and
an MLP head. Designed to also accept concatenated Stream A/C feature vectors (`extra`) for the later
fusion version; in PIRD-lite `n_extra=0` and only the encoder is used."""
from __future__ import annotations
import torch
import torch.nn as nn
from transformers import AutoModel
class PIRDModel(nn.Module):
def __init__(self, encoder_name: str = "roberta-base",
n_extra: int = 0, hidden: int = 256, dropout: float = 0.2):
super().__init__()
self.encoder_name = encoder_name
self.n_extra = n_extra
self.encoder = AutoModel.from_pretrained(encoder_name)
enc_dim = self.encoder.config.hidden_size
self.head = nn.Sequential(
nn.Linear(enc_dim + n_extra, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def encode(self, input_ids, attention_mask):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
h = out.last_hidden_state # [B, T, H]
mask = attention_mask.unsqueeze(-1).float()
return (h * mask).sum(1) / mask.sum(1).clamp_min(1e-6) # mean pooling
def forward(self, input_ids, attention_mask, extra=None):
pooled = self.encode(input_ids, attention_mask)
if self.n_extra and extra is not None:
pooled = torch.cat([pooled, extra], dim=-1)
return self.head(pooled).squeeze(-1) # raw logit; higher = more AI