Text Classification
Transformers
Safetensors
modernbert
Generated from Trainer
text-embeddings-inference
Instructions to use param-bharat/ModernBERT-large-nli-scorer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use param-bharat/ModernBERT-large-nli-scorer with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="param-bharat/ModernBERT-large-nli-scorer")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("param-bharat/ModernBERT-large-nli-scorer") model = AutoModelForSequenceClassification.from_pretrained("param-bharat/ModernBERT-large-nli-scorer") - Notebooks
- Google Colab
- Kaggle
| from pydantic import BaseModel, ConfigDict | |
| from transformers import ( | |
| AutoTokenizer, | |
| PreTrainedTokenizerFast, | |
| PreTrainedTokenizer, | |
| BatchEncoding, | |
| ) | |
| from transformers import Pipeline | |
| class NLIInstruction(BaseModel): | |
| tokenizer: AutoTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizer | |
| instruction: str | |
| hypothesis: str | |
| Prompt: str | None = None | |
| Completion: str | None = None | |
| Context: str | None = None | |
| ChatHistory: list[dict[str, str]] | None = None | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| def format_chat_history(self, chat_history: list[dict[str, str]]) -> str: | |
| return "\n".join( | |
| [ | |
| f"### Background\n{message['role']}: {message['content']}" | |
| for message in chat_history | |
| ] | |
| ) | |
| def premise(self) -> str: | |
| base_template = "## Premise\n" | |
| if self.Context: | |
| base_template += f"### Context\n{self.Context}\n" | |
| if self.ChatHistory: | |
| base_template += self.format_chat_history(self.ChatHistory) | |
| if self.Prompt: | |
| base_template += f"### Prompt\n{self.Prompt}\n" | |
| if self.Completion: | |
| base_template += f"### Completion\n{self.Completion}\n" | |
| return base_template | |
| def as_str(self): | |
| return f"{self.instruction}\n{self.premise}\n{self.hypothesis}" | |
| def as_model_inputs(self) -> dict[str, list[int]]: | |
| instruction_ids = self.tokenizer( | |
| self.instruction, add_special_tokens=False | |
| ).input_ids | |
| premise_ids = self.tokenizer(self.premise, add_special_tokens=False).input_ids | |
| hypothesis_ids = self.tokenizer( | |
| self.hypothesis, add_special_tokens=False | |
| ).input_ids | |
| premise_length = self.tokenizer.model_max_length - len( | |
| instruction_ids + hypothesis_ids | |
| ) | |
| premise_ids = premise_ids[:premise_length] | |
| input_ids = ( | |
| [self.tokenizer.cls_token_id] | |
| + instruction_ids | |
| + [self.tokenizer.sep_token_id] | |
| + premise_ids | |
| + [self.tokenizer.sep_token_id] | |
| + hypothesis_ids | |
| + [self.tokenizer.sep_token_id] | |
| ) | |
| attention_mask = [1] * len(input_ids) | |
| return BatchEncoding( | |
| data={"input_ids": input_ids, "attention_mask": attention_mask}, | |
| tensor_type="pt", | |
| prepend_batch_axis=True, | |
| ) | |
| class DetectionResponseFactualConsistency(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is factually consistent with the context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is factually consistent with the context.""" | |
| ) | |
| class DetectionContextPromptInjection(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the prompt contains a prompt injection attack.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe prompt contains a prompt injection attack.""" | |
| ) | |
| class DetectionContextSourceCode(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the given context contains any source code or code snippets.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe given context contains source code or code blocks.""" | |
| ) | |
| class DetectionContextToxicity(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains some toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" | |
| ) | |
| class DetectionContextToxicityDisability(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any toxic content related to disability.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains some toxic content related to disability.""" | |
| ) | |
| class DetectionContextToxicityGender(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any toxic content related to gender.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains some toxic content related to gender.""" | |
| ) | |
| class DetectionContextToxicityIdentityHate(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any toxic content related to identity hate.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains some toxic content related to identity hate.""" | |
| ) | |
| class DetectionContextToxicityInsult(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any insulting content.""" | |
| ) | |
| hypothesis: str = """## Hypothesis\nThe context contains some insulting content.""" | |
| class DetectionContextToxicityObscene(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any obscene content.""" | |
| ) | |
| hypothesis: str = """## Hypothesis\nThe context contains some obscene content.""" | |
| class DetectionContextToxicityRace(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any racist content.""" | |
| ) | |
| hypothesis: str = """## Hypothesis\nThe context contains some racist content.""" | |
| class DetectionContextToxicityReligion(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any toxic content related to religion.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains some toxic content related to religion.""" | |
| ) | |
| class DetectionContextToxicityViolence(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains any violent content.""" | |
| ) | |
| hypothesis: str = """## Hypothesis\nThe context contains some violent content.""" | |
| class QualityContextDocumentRelevance(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| class QualityContextDocumentUtilization(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context was utilized in the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context was utilized in the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| class QualityContextSentenceRelevance(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| Sentence: str | |
| def premise(self) -> str: | |
| return super().premise + f"\n### Sentence\n{self.Sentence}\n" | |
| class QualityContextSentenceUtilization(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" | |
| ) | |
| Sentence: str | |
| def premise(self) -> str: | |
| return super().premise + f"\n### Sentence\n{self.Sentence}\n" | |
| class QualityResponseAdherence(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion adheres to the context when answering the question in the given prompt.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion adheres to the context when answering the question in the given prompt.""" | |
| ) | |
| class QualityResponseAttribution(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion attributes the context when answering the question in the given prompt.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion attributes the context when answering the question in the given prompt.""" | |
| ) | |
| class QualityResponseCoherence(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is coherent and for the given context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is coherent and for the given context.""" | |
| ) | |
| class QualityResponseComplexity(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is complex and contains multiple steps to answer the question.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is complex and contains multiple steps to answer the question.""" | |
| ) | |
| class QualityResponseCorrectness(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is correct with respect to the given prompt and context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is correct with respect to the given prompt and context.""" | |
| ) | |
| class QualityResponseHelpfulness(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is helpful with respect to the given prompt and context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is helpful with respect to the given prompt and context.""" | |
| ) | |
| class QualityResponseInstructionFollowing(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion follows the instructions provided in the given prompt.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion follows the instructions provided in the given prompt.""" | |
| ) | |
| class QualityResponseRelevance(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is relevant to the given prompt and context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is relevant to the given prompt and context.""" | |
| ) | |
| class QualityResponseVerbosity(NLIInstruction): | |
| instruction: str = ( | |
| """## Task\nDetermine if the completion is too verbose with respect to the given prompt and context.""" | |
| ) | |
| hypothesis: str = ( | |
| """## Hypothesis\nThe completion is too verbose with respect to the given prompt and context.""" | |
| ) | |
| TASK_CLASSES = { | |
| "Detection/Hallucination/Factual Consistency": DetectionResponseFactualConsistency, | |
| "Detection/Prompt Injection": DetectionContextPromptInjection, | |
| "Detection/Source Code": DetectionContextSourceCode, | |
| "Detection/Toxicity/Disability": DetectionContextToxicityDisability, | |
| "Detection/Toxicity/Gender": DetectionContextToxicityGender, | |
| "Detection/Toxicity/Identity Hate": DetectionContextToxicityIdentityHate, | |
| "Detection/Toxicity/Insult": DetectionContextToxicityInsult, | |
| "Detection/Toxicity/Obscene": DetectionContextToxicityObscene, | |
| "Detection/Toxicity/Race": DetectionContextToxicityRace, | |
| "Detection/Toxicity/Religion": DetectionContextToxicityReligion, | |
| "Detection/Toxicity/Toxicity": DetectionContextToxicity, | |
| "Detection/Toxicity/Toxic": DetectionContextToxicity, | |
| "Detection/Toxicity/Violence": DetectionContextToxicityViolence, | |
| "Quality/Context/Document Relevance": QualityContextDocumentRelevance, | |
| "Quality/Context/Document Utilization": QualityContextDocumentUtilization, | |
| "Quality/Context/Sentence Relevance": QualityContextSentenceRelevance, | |
| "Quality/Context/Sentence Utilization": QualityContextSentenceUtilization, | |
| "Quality/Response/Adherence": QualityResponseAdherence, | |
| "Quality/Response/Attribution": QualityResponseAttribution, | |
| "Quality/Response/Coherence": QualityResponseCoherence, | |
| "Quality/Response/Complexity": QualityResponseComplexity, | |
| "Quality/Response/Correctness": QualityResponseCorrectness, | |
| "Quality/Response/Helpfulness": QualityResponseHelpfulness, | |
| "Quality/Response/Instruction Following": QualityResponseInstructionFollowing, | |
| "Quality/Response/Relevance": QualityResponseRelevance, | |
| "Quality/Response/Verbosity": QualityResponseVerbosity, | |
| } | |
| TASK_THRESHOLDS = { | |
| "Detection/Hallucination/Factual Consistency": 0.5895, | |
| "Detection/Prompt Injection": 0.4147, | |
| "Detection/Source Code": 0.4001, | |
| "Detection/Toxicity/Disability": 0.5547, | |
| "Detection/Toxicity/Gender": 0.4007, | |
| "Detection/Toxicity/Identity Hate": 0.5502, | |
| "Detection/Toxicity/Insult": 0.4913, | |
| "Detection/Toxicity/Obscene": 0.448, | |
| "Detection/Toxicity/Race": 0.5983, | |
| "Detection/Toxicity/Religion": 0.4594, | |
| "Detection/Toxicity/Toxic": 0.5034, | |
| "Detection/Toxicity/Violence": 0.4031, | |
| "Quality/Context/Document Relevance": 0.5809, | |
| "Quality/Context/Document Utilization": 0.4005, | |
| "Quality/Context/Sentence Relevance": 0.6003, | |
| "Quality/Context/Sentence Utilization": 0.5417, | |
| "Quality/Response/Adherence": 0.59, | |
| "Quality/Response/Attribution": 0.5304, | |
| "Quality/Response/Coherence": 0.6891, | |
| "Quality/Response/Complexity": 0.7235, | |
| "Quality/Response/Correctness": 0.6535, | |
| "Quality/Response/Helpfulness": 0.4445, | |
| "Quality/Response/Instruction Following": 0.5323, | |
| "Quality/Response/Relevance": 0.4011, | |
| "Quality/Response/Verbosity": 0.4243, | |
| } | |
| class NLIScorer(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| postprocess_kwargs = {} | |
| if "task_type" in kwargs: | |
| preprocess_kwargs["task_type"] = kwargs["task_type"] | |
| postprocess_kwargs["task_type"] = kwargs["task_type"] | |
| return preprocess_kwargs, {}, postprocess_kwargs | |
| def preprocess(self, inputs, task_type): | |
| TaskClass = TASK_CLASSES[task_type] | |
| task_class = TaskClass(tokenizer=self.tokenizer, **inputs) | |
| return task_class.as_model_inputs | |
| def _forward(self, model_inputs): | |
| outputs = self.model(**model_inputs) | |
| return outputs | |
| def postprocess(self, model_outputs, task_type): | |
| threshold = TASK_THRESHOLDS[task_type] | |
| pos_scores = model_outputs["logits"].softmax(-1)[0][1] | |
| best_class = int(pos_scores > threshold) | |
| if best_class == 1: | |
| score = pos_scores | |
| else: | |
| score = 1 - pos_scores | |
| return {"score": score.item(), "label": best_class} | |