import logging from typing import Any, Dict, List import traceback import torch from transformers import AutoTokenizer, AutoModelForCausalLM import torch.nn.functional as F logging.basicConfig(level=logging.DEBUG) class EndpointHandler: def __init__(self, path: str = ""): logging.info(f"CUDA STATUS: {torch.cuda.is_available()}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Starting tokenizer initialization from {path}") self.tokenizer = AutoTokenizer.from_pretrained(path) logging.info(f"Starting model initialization from {path}") self.model = AutoModelForCausalLM.from_pretrained( path, device_map="auto", torch_dtype=torch.bfloat16, ) logging.info(f"Model loaded from {path}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: try: inputs = data.get("inputs", "") messages = [ { "role": "user", "content": f"You are a Malay language spelling corrector. I will give you some text written in messy Rumi (shortened or mistyped). Rewrite it in correct Malay Rumi spelling.\n{inputs}", }, ] input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(self.device) outputs = self.model.generate( input_ids=input_ids, max_new_tokens=128, use_cache=True, output_scores=True, return_dict_in_generate=True, ) assert not isinstance(outputs, torch.LongTensor) assert outputs.scores generated_ids = outputs.sequences[0][len(input_ids[0]) :] confidence_scores = [ torch.max(F.softmax(score[0], dim=-1)).item() for score in outputs.scores ] return [ { "token": t, "confidence": c, } for t, c in zip( self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True ), confidence_scores, ) ] except Exception as e: logging.error(f"Error during inference: {e}") logging.error(f"traceback = {traceback.format_exc()}") return [{"error": f"Error during inference: {str(e)}"}]