| from smolagents import Tool |
| from typing import Any, Optional |
|
|
| class SimpleTool(Tool): |
| name = "classify_topic" |
| description = "This tool classifies whether the given Vietnamese text is related to the specified topic." |
| inputs = {"text":{"type":"string","description":"The Vietnamese text to be classified."},"topic":{"type":"string","description":"The string representing the topic to be checked."}} |
| output_type = "boolean" |
|
|
| def forward(self, text: str, topic: str) -> bool: |
| """ |
| This tool classifies whether the given Vietnamese text is related to the specified topic. |
| |
| Args: |
| text: The Vietnamese text to be classified. |
| topic: The string representing the topic to be checked. |
| |
| Returns: |
| bool: True if the text is related to the topic; False otherwise. |
| """ |
| from transformers import pipeline |
| import torch |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| classifier = pipeline( |
| "zero-shot-classification", |
| model="vicgalle/xlm-roberta-large-xnli-anli", |
| device=device, |
| trust_remote_code=True, |
| ) |
|
|
| candidate_labels = [topic, f"không liên quan {topic}"] |
| result = classifier(text, candidate_labels) |
| predicted_label = result["labels"][0] |
|
|
| return predicted_label == topic |