HNTAI / services /ai-service /test_token_limits.py
sachinchandrankallar's picture
Revert "refactor(ai-service): optimize prompts, token counting, and benchmarking - Externalize system prompts to text files for better maintainability. - Integrate tiktoken for faster and more accurate token counting. - Refactor BenchmarkLogger to use asynchronous logging for zero latency impact. - Improve prompt echo removal logic with more robust markers. - Add specialized system instruction for medical document processing. - Update test expectations to reflect precise token counting."
8c76d6f
Raw
History Blame
5.07 kB
"""
Simple test to verify token limit detection works correctly.
"""
import sys
import os
# Set UTF-8 encoding for Windows console
if sys.platform == 'win32':
os.system('chcp 65001 > nul')
sys.path.insert(0, 'src')
from ai_med_extract.utils.model_config import get_model_token_limit
from ai_med_extract.utils.unified_model_manager import count_tokens, check_token_limits, is_token_limit_error
def test_model_token_limits():
"""Test that model token limits are configured correctly"""
print("Testing model token limits...")
# Updated to reflect new 8192 token limit
assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct") == 8192
assert get_model_token_limit("microsoft/Phi-3-mini-128k-instruct") == 131072
assert get_model_token_limit("microsoft/Phi-3-small-8k-instruct") == 8192
assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf") == 8192
assert get_model_token_limit("some-model-128k") == 131072
assert get_model_token_limit("unknown-model") == 4096
print("[PASS] Model token limits working correctly\n")
def test_token_counting():
"""Test token counting estimation"""
print("Testing token counting...")
assert count_tokens("") == 0
small_text = "This is a test of the token counting system. It should estimate tokens based on character count."
tokens = count_tokens(small_text)
assert 20 < tokens < 35, f"Expected ~27 tokens, got {tokens}"
large_text = "Patient visit data. " * 1000
tokens = count_tokens(large_text)
assert 5000 < tokens < 6000, f"Expected ~5,500 tokens, got {tokens}"
print(f"[PASS] Token counting working correctly")
print(f" Small text ({len(small_text)} chars) = {count_tokens(small_text)} tokens")
print(f" Large text ({len(large_text)} chars) = {count_tokens(large_text)} tokens\n")
def test_token_limit_checking():
"""Test token limit validation"""
print("Testing token limit checking...")
model_name = "microsoft/Phi-3-mini-4k-instruct" # Now 8192 tokens
# Small input
small_text = "Short patient summary. " * 10
result = check_token_limits(small_text, model_name, reserve_for_output=2048)
assert result["within_limit"] == True
assert result["max_tokens"] == 8192
assert result["available_for_input"] == 6144 # 8192 - 2048
print(f"[PASS] Small input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
# Large input (should exceed)
large_text = "Patient visit data. " * 2000
result = check_token_limits(large_text, model_name, reserve_for_output=2048)
assert result["within_limit"] == False
print(f"[PASS] Large input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%) - EXCEEDS LIMIT")
# Medium input - ~90% of 6144 = ~5530 tokens
medium_text = "Patient visit data. " * 1000 # ~8192 chars = ~5500 tokens
result = check_token_limits(medium_text, model_name, reserve_for_output=2048)
print(f"[INFO] Medium input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
assert result["within_limit"] == True
assert result["usage_percentage"] > 80, f"Expected >80%, got {result['usage_percentage']:.1f}%"
print(f"[PASS] Medium input - APPROACHING LIMIT\n")
def test_error_detection():
"""Test token limit error pattern detection"""
print("Testing error pattern detection...")
test_cases = [
(Exception("input is too long"), True),
(Exception("maximum context length exceeded"), True),
(Exception("Token limit exceeded"), True),
(IndexError("position index out of range"), True),
(Exception("some other error"), False),
]
for error, expected in test_cases:
result = is_token_limit_error(error)
assert result == expected, f"Failed for: {error}"
status = "[PASS]" if result else "[SKIP]"
print(f" {status} '{str(error)[:40]}...' -> token_limit={result}")
print("[PASS] Error pattern detection working correctly\n")
if __name__ == "__main__":
print("="*60)
print("Token Limit Detection - Verification Tests")
print("="*60 + "\n")
try:
test_model_token_limits()
test_token_counting()
test_token_limit_checking()
test_error_detection()
print("="*60)
print("[SUCCESS] ALL TESTS PASSED")
print("="*60)
print("\nToken limit detection is working correctly!")
print("\nConfiguration:")
print(" - Model limit: 8192 tokens")
print(" - Reserve for output: 2048 tokens")
print(" - Available for input: 6144 tokens")
except AssertionError as e:
print(f"\n[FAILED] TEST FAILED: {e}")
sys.exit(1)
except Exception as e:
print(f"\n[ERROR] {e}")
import traceback
traceback.print_exc()
sys.exit(1)