Spaces:
Paused
Paused
Commit ·
baf854e
1
Parent(s): 733c0c5
feat: Implement unified model management, centralized constants, and error handling for AI medical extraction services.
Browse files- __pycache__/app.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/services/error_handler.py +16 -0
- services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc +0 -0
- services/ai-service/src/ai_med_extract/utils/constants.py +1 -0
- services/ai-service/src/ai_med_extract/utils/model_config.py +59 -0
- services/ai-service/src/ai_med_extract/utils/unified_model_manager.py +175 -1
- services/ai-service/test_token_limits.py +120 -0
__pycache__/app.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/services/error_handler.py
CHANGED
|
@@ -23,6 +23,7 @@ class ErrorCategory(Enum):
|
|
| 23 |
MEMORY = "memory"
|
| 24 |
VALIDATION = "validation"
|
| 25 |
GENERATION = "generation"
|
|
|
|
| 26 |
CACHE = "cache"
|
| 27 |
UNKNOWN = "unknown"
|
| 28 |
|
|
@@ -79,10 +80,16 @@ def categorize_error(error: Exception) -> ErrorCategory:
|
|
| 79 |
return ErrorCategory.MEMORY
|
| 80 |
elif "validation" in error_str or "value" in error_str or isinstance(error, ValueError):
|
| 81 |
return ErrorCategory.VALIDATION
|
|
|
|
|
|
|
|
|
|
| 82 |
# Detect model/generation failures
|
| 83 |
try:
|
| 84 |
from ..utils.unified_model_manager import ModelError # type: ignore
|
| 85 |
if isinstance(error, ModelError):
|
|
|
|
|
|
|
|
|
|
| 86 |
return ErrorCategory.GENERATION
|
| 87 |
except Exception:
|
| 88 |
pass
|
|
@@ -233,6 +240,15 @@ def _get_default_recommendations(category: ErrorCategory, error_str: str) -> lis
|
|
| 233 |
"Check data format and types",
|
| 234 |
"Review API documentation"
|
| 235 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
elif category == ErrorCategory.GENERATION:
|
| 237 |
recommendations = [
|
| 238 |
"Verify model availability and internet access",
|
|
|
|
| 23 |
MEMORY = "memory"
|
| 24 |
VALIDATION = "validation"
|
| 25 |
GENERATION = "generation"
|
| 26 |
+
TOKEN_LIMIT = "token_limit"
|
| 27 |
CACHE = "cache"
|
| 28 |
UNKNOWN = "unknown"
|
| 29 |
|
|
|
|
| 80 |
return ErrorCategory.MEMORY
|
| 81 |
elif "validation" in error_str or "value" in error_str or isinstance(error, ValueError):
|
| 82 |
return ErrorCategory.VALIDATION
|
| 83 |
+
# Detect token limit errors
|
| 84 |
+
elif "token_limit_exceeded" in error_str or "token limit" in error_str or "input is too long" in error_str or "maximum context length" in error_str:
|
| 85 |
+
return ErrorCategory.TOKEN_LIMIT
|
| 86 |
# Detect model/generation failures
|
| 87 |
try:
|
| 88 |
from ..utils.unified_model_manager import ModelError # type: ignore
|
| 89 |
if isinstance(error, ModelError):
|
| 90 |
+
# Check if it's specifically a token limit error
|
| 91 |
+
if hasattr(error, 'error_type') and error.error_type == "token_limit_exceeded":
|
| 92 |
+
return ErrorCategory.TOKEN_LIMIT
|
| 93 |
return ErrorCategory.GENERATION
|
| 94 |
except Exception:
|
| 95 |
pass
|
|
|
|
| 240 |
"Check data format and types",
|
| 241 |
"Review API documentation"
|
| 242 |
]
|
| 243 |
+
elif category == ErrorCategory.TOKEN_LIMIT:
|
| 244 |
+
recommendations = [
|
| 245 |
+
"Reduce the number of patient visits in the request",
|
| 246 |
+
"Use a model with larger context window (e.g., Phi-3-mini-128k-instruct instead of 4k)",
|
| 247 |
+
"Split patient data into multiple requests",
|
| 248 |
+
"Use chunking endpoints for large datasets",
|
| 249 |
+
"Filter visits by date range to reduce data size",
|
| 250 |
+
"Check logs for exact token count and model limits"
|
| 251 |
+
]
|
| 252 |
elif category == ErrorCategory.GENERATION:
|
| 253 |
recommendations = [
|
| 254 |
"Verify model availability and internet access",
|
services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc
CHANGED
|
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc differ
|
|
|
services/ai-service/src/ai_med_extract/utils/constants.py
CHANGED
|
@@ -80,6 +80,7 @@ ERROR_MESSAGES = {
|
|
| 80 |
"model_load_failed": "Failed to load AI model. Please try again or contact support.",
|
| 81 |
"generation_timeout": "Summary generation timed out. Please try again with a simpler request.",
|
| 82 |
"generation_failed": "Summary generation failed. Please try again or contact support.",
|
|
|
|
| 83 |
"cache_error": "Cache operation failed. Continuing with fresh generation."
|
| 84 |
}
|
| 85 |
|
|
|
|
| 80 |
"model_load_failed": "Failed to load AI model. Please try again or contact support.",
|
| 81 |
"generation_timeout": "Summary generation timed out. Please try again with a simpler request.",
|
| 82 |
"generation_failed": "Summary generation failed. Please try again or contact support.",
|
| 83 |
+
"token_limit_exceeded": "Patient data exceeds model's token limit. Please reduce the number of visits or use a model with larger context window.",
|
| 84 |
"cache_error": "Cache operation failed. Continuing with fresh generation."
|
| 85 |
}
|
| 86 |
|
services/ai-service/src/ai_med_extract/utils/model_config.py
CHANGED
|
@@ -172,6 +172,65 @@ QUANTIZATION_CONFIG = {
|
|
| 172 |
"skip_layers": ["embeddings", "lm_head", "shared", "embed_positions"] # Layers to skip quantization
|
| 173 |
}
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def get_default_model(model_type: str, use_spaces_optimized: bool = False) -> str:
|
| 176 |
"""Get the default model for a given type, optimized for T4 Spaces"""
|
| 177 |
# Always use T4-optimized models when on T4 Medium
|
|
|
|
| 172 |
"skip_layers": ["embeddings", "lm_head", "shared", "embed_positions"] # Layers to skip quantization
|
| 173 |
}
|
| 174 |
|
| 175 |
+
# ========== MODEL-SPECIFIC TOKEN LIMITS ==========
|
| 176 |
+
# Maximum context window sizes for different models
|
| 177 |
+
MODEL_TOKEN_LIMITS = {
|
| 178 |
+
# Phi-3 models
|
| 179 |
+
"microsoft/Phi-3-mini-4k-instruct": 4096,
|
| 180 |
+
"microsoft/Phi-3-mini-4k-instruct-gguf": 4096,
|
| 181 |
+
"microsoft/Phi-3-mini-4k-instruct-GGUF": 4096,
|
| 182 |
+
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
| 183 |
+
"microsoft/Phi-3-mini-128k-instruct-gguf": 131072,
|
| 184 |
+
"microsoft/Phi-3-small-8k-instruct": 8192,
|
| 185 |
+
"microsoft/Phi-3-small-128k-instruct": 131072,
|
| 186 |
+
|
| 187 |
+
# OpenVINO models
|
| 188 |
+
"OpenVINO/Phi-3-mini-4k-instruct-fp16-ov": 4096,
|
| 189 |
+
"OpenVINO/Phi-3-mini-128k-instruct-int4-ov": 131072,
|
| 190 |
+
|
| 191 |
+
# DialoGPT and BART models
|
| 192 |
+
"microsoft/DialoGPT-small": 1024,
|
| 193 |
+
"microsoft/DialoGPT-medium": 1024,
|
| 194 |
+
"facebook/bart-base": 1024,
|
| 195 |
+
"facebook/bart-large-cnn": 1024,
|
| 196 |
+
"sshleifer/distilbart-cnn-6-6": 1024,
|
| 197 |
+
|
| 198 |
+
# Default fallback
|
| 199 |
+
"default": 4096
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def get_model_token_limit(model_name: str) -> int:
|
| 203 |
+
"""
|
| 204 |
+
Get the maximum token limit for a specific model.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
model_name: Name of the model
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Maximum token limit for the model
|
| 211 |
+
"""
|
| 212 |
+
# Check exact match first
|
| 213 |
+
if model_name in MODEL_TOKEN_LIMITS:
|
| 214 |
+
return MODEL_TOKEN_LIMITS[model_name]
|
| 215 |
+
|
| 216 |
+
# Check for partial matches (e.g., for GGUF files with full paths)
|
| 217 |
+
model_name_lower = model_name.lower()
|
| 218 |
+
for key, limit in MODEL_TOKEN_LIMITS.items():
|
| 219 |
+
if key.lower() in model_name_lower or model_name_lower in key.lower():
|
| 220 |
+
return limit
|
| 221 |
+
|
| 222 |
+
# Check for common patterns
|
| 223 |
+
if "128k" in model_name_lower:
|
| 224 |
+
return 131072
|
| 225 |
+
elif "8k" in model_name_lower:
|
| 226 |
+
return 8192
|
| 227 |
+
elif "4k" in model_name_lower:
|
| 228 |
+
return 4096
|
| 229 |
+
|
| 230 |
+
# Return default
|
| 231 |
+
return MODEL_TOKEN_LIMITS["default"]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
def get_default_model(model_type: str, use_spaces_optimized: bool = False) -> str:
|
| 235 |
"""Get the default model for a given type, optimized for T4 Spaces"""
|
| 236 |
# Always use T4-optimized models when on T4 Medium
|
services/ai-service/src/ai_med_extract/utils/unified_model_manager.py
CHANGED
|
@@ -25,7 +25,7 @@ import torch
|
|
| 25 |
from .model_config import (
|
| 26 |
get_default_model, get_fallback_model, get_t4_model_kwargs,
|
| 27 |
get_t4_generation_config, is_model_supported_on_t4, detect_model_type,
|
| 28 |
-
IS_T4_MEDIUM
|
| 29 |
)
|
| 30 |
|
| 31 |
# Configure logging
|
|
@@ -76,6 +76,96 @@ class ModelError(Exception):
|
|
| 76 |
self.timestamp = time.time()
|
| 77 |
super().__init__(f"Model {model_name} failed ({error_type}): {details}")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Global unified model manager instance
|
| 80 |
unified_model_manager = None
|
| 81 |
|
|
@@ -229,6 +319,24 @@ class TransformersModel(BaseModel):
|
|
| 229 |
if self._model is None:
|
| 230 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
try:
|
| 233 |
# Get T4-optimized generation config
|
| 234 |
gen_config = get_t4_generation_config(self.model_type)
|
|
@@ -275,6 +383,16 @@ class TransformersModel(BaseModel):
|
|
| 275 |
return generated_text
|
| 276 |
|
| 277 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 279 |
|
| 280 |
class GGUFModel(BaseModel):
|
|
@@ -330,6 +448,24 @@ class GGUFModel(BaseModel):
|
|
| 330 |
if self._model is None:
|
| 331 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
try:
|
| 334 |
# Get T4-optimized generation config
|
| 335 |
gen_config = get_t4_generation_config("gguf")
|
|
@@ -347,6 +483,16 @@ class GGUFModel(BaseModel):
|
|
| 347 |
return result['choices'][0]['text'] if result and 'choices' in result else ""
|
| 348 |
|
| 349 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 351 |
|
| 352 |
class OpenVINOModel(BaseModel):
|
|
@@ -390,6 +536,24 @@ class OpenVINOModel(BaseModel):
|
|
| 390 |
if self._model is None or self._tokenizer is None:
|
| 391 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
try:
|
| 394 |
inputs = self._tokenizer(prompt, return_tensors="pt")
|
| 395 |
if torch.cuda.is_available():
|
|
@@ -413,6 +577,16 @@ class OpenVINOModel(BaseModel):
|
|
| 413 |
return generated_text
|
| 414 |
|
| 415 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 417 |
|
| 418 |
class FallbackModel(BaseModel):
|
|
|
|
| 25 |
from .model_config import (
|
| 26 |
get_default_model, get_fallback_model, get_t4_model_kwargs,
|
| 27 |
get_t4_generation_config, is_model_supported_on_t4, detect_model_type,
|
| 28 |
+
get_model_token_limit, IS_T4_MEDIUM
|
| 29 |
)
|
| 30 |
|
| 31 |
# Configure logging
|
|
|
|
| 76 |
self.timestamp = time.time()
|
| 77 |
super().__init__(f"Model {model_name} failed ({error_type}): {details}")
|
| 78 |
|
| 79 |
+
def count_tokens(text: str, model_name: str = None) -> int:
|
| 80 |
+
"""
|
| 81 |
+
Estimate token count for a given text.
|
| 82 |
+
Uses a simple heuristic: ~4 characters per token for English text.
|
| 83 |
+
This is a conservative estimate that works reasonably well for medical text.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
text: Text to count tokens for
|
| 87 |
+
model_name: Optional model name for model-specific counting
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Estimated token count
|
| 91 |
+
"""
|
| 92 |
+
if not text:
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
# Simple heuristic: ~4 characters per token
|
| 96 |
+
# This is conservative and works well for medical/clinical text
|
| 97 |
+
estimated_tokens = len(text) // 4
|
| 98 |
+
|
| 99 |
+
# Add some overhead for special tokens and formatting
|
| 100 |
+
estimated_tokens = int(estimated_tokens * 1.1)
|
| 101 |
+
|
| 102 |
+
return estimated_tokens
|
| 103 |
+
|
| 104 |
+
def check_token_limits(text: str, model_name: str, reserve_for_output: int = 8192) -> dict:
|
| 105 |
+
"""
|
| 106 |
+
Check if text exceeds model's token limit.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
text: Input text to check
|
| 110 |
+
model_name: Name of the model
|
| 111 |
+
reserve_for_output: Tokens to reserve for model output
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dictionary with check results:
|
| 115 |
+
- within_limit: bool
|
| 116 |
+
- estimated_tokens: int
|
| 117 |
+
- max_tokens: int
|
| 118 |
+
- available_for_input: int
|
| 119 |
+
- usage_percentage: float
|
| 120 |
+
"""
|
| 121 |
+
max_tokens = get_model_token_limit(model_name)
|
| 122 |
+
estimated_tokens = count_tokens(text, model_name)
|
| 123 |
+
available_for_input = max_tokens - reserve_for_output
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"within_limit": estimated_tokens <= available_for_input,
|
| 127 |
+
"estimated_tokens": estimated_tokens,
|
| 128 |
+
"max_tokens": max_tokens,
|
| 129 |
+
"available_for_input": available_for_input,
|
| 130 |
+
"reserve_for_output": reserve_for_output,
|
| 131 |
+
"usage_percentage": (estimated_tokens / available_for_input * 100) if available_for_input > 0 else 0
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
def is_token_limit_error(error: Exception) -> bool:
|
| 135 |
+
"""
|
| 136 |
+
Detect if an error is related to token limits being exceeded.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
error: Exception to check
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
True if error is token-limit related
|
| 143 |
+
"""
|
| 144 |
+
error_str = str(error).lower()
|
| 145 |
+
error_patterns = [
|
| 146 |
+
"input is too long",
|
| 147 |
+
"maximum context length",
|
| 148 |
+
"exceeds the maximum",
|
| 149 |
+
"context_length",
|
| 150 |
+
"too many tokens",
|
| 151 |
+
"input too long",
|
| 152 |
+
"sequence length",
|
| 153 |
+
"max_position_embeddings",
|
| 154 |
+
"position_ids",
|
| 155 |
+
"token limit" # Added for direct token limit messages
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
# Check error message
|
| 159 |
+
for pattern in error_patterns:
|
| 160 |
+
if pattern in error_str:
|
| 161 |
+
return True
|
| 162 |
+
|
| 163 |
+
# Check for IndexError which can indicate token overflow
|
| 164 |
+
if isinstance(error, IndexError) and ("position" in error_str or "index" in error_str):
|
| 165 |
+
return True
|
| 166 |
+
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
# Global unified model manager instance
|
| 170 |
unified_model_manager = None
|
| 171 |
|
|
|
|
| 319 |
if self._model is None:
|
| 320 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 321 |
|
| 322 |
+
# Check token limits before generation
|
| 323 |
+
token_check = check_token_limits(prompt, self.name, config.max_tokens)
|
| 324 |
+
logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
|
| 325 |
+
|
| 326 |
+
if not token_check["within_limit"]:
|
| 327 |
+
error_msg = (
|
| 328 |
+
f"Input exceeds token limit for model {self.name}. "
|
| 329 |
+
f"Estimated tokens: {token_check['estimated_tokens']}, "
|
| 330 |
+
f"Available for input: {token_check['available_for_input']} "
|
| 331 |
+
f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
|
| 332 |
+
f"Please reduce the input size or use a model with larger context window."
|
| 333 |
+
)
|
| 334 |
+
logger.error(error_msg)
|
| 335 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg)
|
| 336 |
+
|
| 337 |
+
if token_check["usage_percentage"] > 80:
|
| 338 |
+
logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
|
| 339 |
+
|
| 340 |
try:
|
| 341 |
# Get T4-optimized generation config
|
| 342 |
gen_config = get_t4_generation_config(self.model_type)
|
|
|
|
| 383 |
return generated_text
|
| 384 |
|
| 385 |
except Exception as e:
|
| 386 |
+
# Check if this is a token limit error
|
| 387 |
+
if is_token_limit_error(e):
|
| 388 |
+
error_msg = (
|
| 389 |
+
f"Token limit exceeded for model {self.name}. "
|
| 390 |
+
f"Input length: ~{token_check['estimated_tokens']} tokens, "
|
| 391 |
+
f"Model limit: {token_check['max_tokens']} tokens. "
|
| 392 |
+
f"Original error: {str(e)}"
|
| 393 |
+
)
|
| 394 |
+
logger.error(error_msg)
|
| 395 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
|
| 396 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 397 |
|
| 398 |
class GGUFModel(BaseModel):
|
|
|
|
| 448 |
if self._model is None:
|
| 449 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 450 |
|
| 451 |
+
# Check token limits before generation
|
| 452 |
+
token_check = check_token_limits(prompt, self.name, config.max_tokens)
|
| 453 |
+
logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
|
| 454 |
+
|
| 455 |
+
if not token_check["within_limit"]:
|
| 456 |
+
error_msg = (
|
| 457 |
+
f"Input exceeds token limit for model {self.name}. "
|
| 458 |
+
f"Estimated tokens: {token_check['estimated_tokens']}, "
|
| 459 |
+
f"Available for input: {token_check['available_for_input']} "
|
| 460 |
+
f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
|
| 461 |
+
f"Please reduce the input size or use a model with larger context window."
|
| 462 |
+
)
|
| 463 |
+
logger.error(error_msg)
|
| 464 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg)
|
| 465 |
+
|
| 466 |
+
if token_check["usage_percentage"] > 80:
|
| 467 |
+
logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
|
| 468 |
+
|
| 469 |
try:
|
| 470 |
# Get T4-optimized generation config
|
| 471 |
gen_config = get_t4_generation_config("gguf")
|
|
|
|
| 483 |
return result['choices'][0]['text'] if result and 'choices' in result else ""
|
| 484 |
|
| 485 |
except Exception as e:
|
| 486 |
+
# Check if this is a token limit error
|
| 487 |
+
if is_token_limit_error(e):
|
| 488 |
+
error_msg = (
|
| 489 |
+
f"Token limit exceeded for model {self.name}. "
|
| 490 |
+
f"Input length: ~{token_check['estimated_tokens']} tokens, "
|
| 491 |
+
f"Model limit: {token_check['max_tokens']} tokens. "
|
| 492 |
+
f"Original error: {str(e)}"
|
| 493 |
+
)
|
| 494 |
+
logger.error(error_msg)
|
| 495 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
|
| 496 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 497 |
|
| 498 |
class OpenVINOModel(BaseModel):
|
|
|
|
| 536 |
if self._model is None or self._tokenizer is None:
|
| 537 |
raise ModelError(self.name, "not_loaded", "Model not loaded")
|
| 538 |
|
| 539 |
+
# Check token limits before generation
|
| 540 |
+
token_check = check_token_limits(prompt, self.name, config.max_tokens)
|
| 541 |
+
logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
|
| 542 |
+
|
| 543 |
+
if not token_check["within_limit"]:
|
| 544 |
+
error_msg = (
|
| 545 |
+
f"Input exceeds token limit for model {self.name}. "
|
| 546 |
+
f"Estimated tokens: {token_check['estimated_tokens']}, "
|
| 547 |
+
f"Available for input: {token_check['available_for_input']} "
|
| 548 |
+
f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
|
| 549 |
+
f"Please reduce the input size or use a model with larger context window."
|
| 550 |
+
)
|
| 551 |
+
logger.error(error_msg)
|
| 552 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg)
|
| 553 |
+
|
| 554 |
+
if token_check["usage_percentage"] > 80:
|
| 555 |
+
logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
|
| 556 |
+
|
| 557 |
try:
|
| 558 |
inputs = self._tokenizer(prompt, return_tensors="pt")
|
| 559 |
if torch.cuda.is_available():
|
|
|
|
| 577 |
return generated_text
|
| 578 |
|
| 579 |
except Exception as e:
|
| 580 |
+
# Check if this is a token limit error
|
| 581 |
+
if is_token_limit_error(e):
|
| 582 |
+
error_msg = (
|
| 583 |
+
f"Token limit exceeded for model {self.name}. "
|
| 584 |
+
f"Input length: ~{token_check['estimated_tokens']} tokens, "
|
| 585 |
+
f"Model limit: {token_check['max_tokens']} tokens. "
|
| 586 |
+
f"Original error: {str(e)}"
|
| 587 |
+
)
|
| 588 |
+
logger.error(error_msg)
|
| 589 |
+
raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
|
| 590 |
raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
|
| 591 |
|
| 592 |
class FallbackModel(BaseModel):
|
services/ai-service/test_token_limits.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple test to verify token limit detection works correctly.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Set UTF-8 encoding for Windows console
|
| 9 |
+
if sys.platform == 'win32':
|
| 10 |
+
os.system('chcp 65001 > nul')
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, 'src')
|
| 13 |
+
|
| 14 |
+
from ai_med_extract.utils.model_config import get_model_token_limit
|
| 15 |
+
from ai_med_extract.utils.unified_model_manager import count_tokens, check_token_limits, is_token_limit_error
|
| 16 |
+
|
| 17 |
+
def test_model_token_limits():
|
| 18 |
+
"""Test that model token limits are configured correctly"""
|
| 19 |
+
print("Testing model token limits...")
|
| 20 |
+
|
| 21 |
+
assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct") == 4096
|
| 22 |
+
assert get_model_token_limit("microsoft/Phi-3-mini-128k-instruct") == 131072
|
| 23 |
+
assert get_model_token_limit("microsoft/Phi-3-small-8k-instruct") == 8192
|
| 24 |
+
assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf") == 4096
|
| 25 |
+
assert get_model_token_limit("some-model-128k") == 131072
|
| 26 |
+
assert get_model_token_limit("unknown-model") == 4096
|
| 27 |
+
|
| 28 |
+
print("[PASS] Model token limits working correctly\n")
|
| 29 |
+
|
| 30 |
+
def test_token_counting():
|
| 31 |
+
"""Test token counting estimation"""
|
| 32 |
+
print("Testing token counting...")
|
| 33 |
+
|
| 34 |
+
assert count_tokens("") == 0
|
| 35 |
+
small_text = "This is a test of the token counting system. It should estimate tokens based on character count."
|
| 36 |
+
tokens = count_tokens(small_text)
|
| 37 |
+
assert 20 < tokens < 35, f"Expected ~27 tokens, got {tokens}"
|
| 38 |
+
|
| 39 |
+
large_text = "Patient visit data. " * 1000
|
| 40 |
+
tokens = count_tokens(large_text)
|
| 41 |
+
assert 5000 < tokens < 6000, f"Expected ~5,500 tokens, got {tokens}"
|
| 42 |
+
|
| 43 |
+
print(f"[PASS] Token counting working correctly")
|
| 44 |
+
print(f" Small text ({len(small_text)} chars) = {count_tokens(small_text)} tokens")
|
| 45 |
+
print(f" Large text ({len(large_text)} chars) = {count_tokens(large_text)} tokens\n")
|
| 46 |
+
|
| 47 |
+
def test_token_limit_checking():
|
| 48 |
+
"""Test token limit validation"""
|
| 49 |
+
print("Testing token limit checking...")
|
| 50 |
+
|
| 51 |
+
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
| 52 |
+
|
| 53 |
+
# Small input
|
| 54 |
+
small_text = "Short patient summary. " * 10
|
| 55 |
+
result = check_token_limits(small_text, model_name, reserve_for_output=2048)
|
| 56 |
+
assert result["within_limit"] == True
|
| 57 |
+
print(f"[PASS] Small input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
|
| 58 |
+
|
| 59 |
+
# Large input
|
| 60 |
+
large_text = "Patient visit data. " * 2000
|
| 61 |
+
result = check_token_limits(large_text, model_name, reserve_for_output=2048)
|
| 62 |
+
assert result["within_limit"] == False
|
| 63 |
+
print(f"[PASS] Large input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%) - EXCEEDS LIMIT")
|
| 64 |
+
|
| 65 |
+
# Medium input - adjust to actually be ~80-90% of limit
|
| 66 |
+
medium_text = "Patient visit data. " * 350 # ~7000 chars = ~1925 tokens (~94% of 2048)
|
| 67 |
+
result = check_token_limits(medium_text, model_name, reserve_for_output=2048)
|
| 68 |
+
print(f"[INFO] Medium input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
|
| 69 |
+
assert result["within_limit"] == True
|
| 70 |
+
assert result["usage_percentage"] > 80, f"Expected >80%, got {result['usage_percentage']:.1f}%"
|
| 71 |
+
print(f"[PASS] Medium input - APPROACHING LIMIT\n")
|
| 72 |
+
|
| 73 |
+
def test_error_detection():
|
| 74 |
+
"""Test token limit error pattern detection"""
|
| 75 |
+
print("Testing error pattern detection...")
|
| 76 |
+
|
| 77 |
+
test_cases = [
|
| 78 |
+
(Exception("input is too long"), True),
|
| 79 |
+
(Exception("maximum context length exceeded"), True),
|
| 80 |
+
(Exception("Token limit exceeded"), True),
|
| 81 |
+
(IndexError("position index out of range"), True),
|
| 82 |
+
(Exception("some other error"), False),
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
for error, expected in test_cases:
|
| 86 |
+
result = is_token_limit_error(error)
|
| 87 |
+
assert result == expected, f"Failed for: {error}"
|
| 88 |
+
status = "[PASS]" if result else "[SKIP]"
|
| 89 |
+
print(f" {status} '{str(error)[:40]}...' -> token_limit={result}")
|
| 90 |
+
|
| 91 |
+
print("[PASS] Error pattern detection working correctly\n")
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
print("="*60)
|
| 95 |
+
print("Token Limit Detection - Verification Tests")
|
| 96 |
+
print("="*60 + "\n")
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
test_model_token_limits()
|
| 100 |
+
test_token_counting()
|
| 101 |
+
test_token_limit_checking()
|
| 102 |
+
test_error_detection()
|
| 103 |
+
|
| 104 |
+
print("="*60)
|
| 105 |
+
print("[SUCCESS] ALL TESTS PASSED")
|
| 106 |
+
print("="*60)
|
| 107 |
+
print("\nToken limit detection is working correctly!")
|
| 108 |
+
print("\nNext steps:")
|
| 109 |
+
print("1. Test with real patient data containing many visits")
|
| 110 |
+
print("2. Verify error messages appear in API responses")
|
| 111 |
+
print("3. Check logs for token diagnostics")
|
| 112 |
+
|
| 113 |
+
except AssertionError as e:
|
| 114 |
+
print(f"\n[FAILED] TEST FAILED: {e}")
|
| 115 |
+
sys.exit(1)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"\n[ERROR] {e}")
|
| 118 |
+
import traceback
|
| 119 |
+
traceback.print_exc()
|
| 120 |
+
sys.exit(1)
|