import torch def clean_model_output(text): import re text = re.sub(r".*?", "", text, flags=re.DOTALL) text = re.sub(r"^(assistant|user)\s*\n", "", text, flags=re.MULTILINE) text = re.sub(r"\n{2,}", "\n", text) return text.strip() def is_unsafe_prompt(model, tokenizer, system_prompt=None, user_prompt=None, max_new_token=10): if not system_prompt or not user_prompt: return False try: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_prompt}) input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=max_new_token, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) generated = output_ids[0][input_ids.shape[-1]:] response = tokenizer.decode(generated, skip_special_tokens=True) response = clean_model_output(response) return "yes" in response.lower() except Exception: return False