""" Load ModernBERT-NLI from HuggingFace base model + task heads. Usage: from load_model import load_modernbert_nli model, tokenizer = load_modernbert_nli("path/to/task_heads.pt") # NLI classification logits = model(**tokenizer(premise, hypothesis, return_tensors="pt"), mode="nli") # With abstention nli_logits, abstention_logits = model(**inputs, mode="abstention") """ import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class ModernBERTWithNLI(nn.Module): """ModernBERT with NLI and abstention heads.""" def __init__(self, base_model_name: str = "answerdotai/ModernBERT-large"): super().__init__() # Load base encoder from HuggingFace self.encoder = AutoModel.from_pretrained(base_model_name) hidden_size = self.encoder.config.hidden_size # 1024 for large # NLI head (split for abstention access) self.nli_hidden = nn.Sequential( nn.Linear(hidden_size, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.1), ) self.nli_output = nn.Linear(512, 3) # Abstention head: takes [nli_hidden, nli_logits] self.abstention_head = nn.Sequential( nn.Linear(512 + 3, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(0.1), nn.Linear(128, 2), ) # Freeze encoder by default for param in self.encoder.parameters(): param.requires_grad = False def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, mode: str = "nli", ): """ Forward pass with multiple modes. Args: input_ids: Token IDs attention_mask: Attention mask mode: One of "embed", "late_interaction", "nli", "abstention" Returns: Depends on mode: - "embed": (batch, hidden_size) CLS embeddings - "late_interaction": (batch, seq_len, hidden_size) all token embeddings - "nli": (batch, 3) NLI logits - "abstention": tuple of (nli_logits, abstention_logits) """ outputs = self.encoder(input_ids, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state if mode == "embed": return hidden_states[:, 0] # CLS token elif mode == "late_interaction": return hidden_states # All tokens elif mode == "nli": cls_hidden = hidden_states[:, 0] nli_hidden = self.nli_hidden(cls_hidden) return self.nli_output(nli_hidden) elif mode == "abstention": cls_hidden = hidden_states[:, 0] nli_hidden = self.nli_hidden(cls_hidden) nli_logits = self.nli_output(nli_hidden) # Concat hidden and logits for abstention abstention_input = torch.cat([nli_hidden, nli_logits], dim=-1) abstention_logits = self.abstention_head(abstention_input) return nli_logits, abstention_logits else: raise ValueError(f"Unknown mode: {mode}") def load_modernbert_nli( task_heads_path: str, base_model: str = "answerdotai/ModernBERT-large", device: str = "cuda" if torch.cuda.is_available() else "cpu", ): """ Load ModernBERT-NLI model. Args: task_heads_path: Path to task_heads.pt file base_model: HuggingFace model ID for base encoder device: Device to load model on Returns: (model, tokenizer) tuple """ # Create model (downloads base from HuggingFace if needed) model = ModernBERTWithNLI(base_model) # Load task heads task_heads = torch.load(task_heads_path, map_location=device) model.load_state_dict(task_heads, strict=False) model = model.to(device) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model) return model, tokenizer # Convenience functions def predict_nli(model, tokenizer, premise: str, hypothesis: str, device: str = "cuda"): """Predict NLI label for a premise-hypothesis pair.""" inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs, mode="nli") probs = torch.softmax(logits, dim=-1)[0] pred = probs.argmax().item() labels = ["entailment", "neutral", "contradiction"] return { "label": labels[pred], "confidence": probs[pred].item(), "probs": {l: p.item() for l, p in zip(labels, probs)} } def predict_with_abstention( model, tokenizer, premise: str, hypothesis: str, device: str = "cuda", threshold: float = 0.5 ): """Predict NLI with abstention flag.""" inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): nli_logits, abstention_logits = model(**inputs, mode="abstention") nli_probs = torch.softmax(nli_logits, dim=-1)[0] abstention_probs = torch.softmax(abstention_logits, dim=-1)[0] pred = nli_probs.argmax().item() labels = ["entailment", "neutral", "contradiction"] should_abstain = abstention_probs[1].item() > threshold return { "label": labels[pred], "confidence": nli_probs[pred].item(), "abstain": should_abstain, "uncertainty": abstention_probs[1].item(), "probs": {l: p.item() for l, p in zip(labels, nli_probs)} } if __name__ == "__main__": # Example usage model, tokenizer = load_modernbert_nli("task_heads.pt") examples = [ ("A man is playing guitar.", "A person is making music."), ("The cat is sleeping.", "The cat is running outside."), ("A woman walks down the street.", "She is going to work."), ] print("NLI Predictions with Abstention:\n") for premise, hypothesis in examples: result = predict_with_abstention(model, tokenizer, premise, hypothesis) status = "ABSTAIN" if result["abstain"] else "CONFIDENT" print(f"P: {premise}") print(f"H: {hypothesis}") print(f"-> {result['label']} ({result['confidence']:.1%}) [{status}]\n")