sachinchandrankallar commited on
Commit
baf854e
·
1 Parent(s): 733c0c5

feat: Implement unified model management, centralized constants, and error handling for AI medical extraction services.

Browse files
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/__pycache__/inference_service.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/__pycache__/phi_scrubber_service.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/agents/__pycache__/patient_summary_agent.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/agents/__pycache__/summarizer.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/services/error_handler.py CHANGED
@@ -23,6 +23,7 @@ class ErrorCategory(Enum):
23
  MEMORY = "memory"
24
  VALIDATION = "validation"
25
  GENERATION = "generation"
 
26
  CACHE = "cache"
27
  UNKNOWN = "unknown"
28
 
@@ -79,10 +80,16 @@ def categorize_error(error: Exception) -> ErrorCategory:
79
  return ErrorCategory.MEMORY
80
  elif "validation" in error_str or "value" in error_str or isinstance(error, ValueError):
81
  return ErrorCategory.VALIDATION
 
 
 
82
  # Detect model/generation failures
83
  try:
84
  from ..utils.unified_model_manager import ModelError # type: ignore
85
  if isinstance(error, ModelError):
 
 
 
86
  return ErrorCategory.GENERATION
87
  except Exception:
88
  pass
@@ -233,6 +240,15 @@ def _get_default_recommendations(category: ErrorCategory, error_str: str) -> lis
233
  "Check data format and types",
234
  "Review API documentation"
235
  ]
 
 
 
 
 
 
 
 
 
236
  elif category == ErrorCategory.GENERATION:
237
  recommendations = [
238
  "Verify model availability and internet access",
 
23
  MEMORY = "memory"
24
  VALIDATION = "validation"
25
  GENERATION = "generation"
26
+ TOKEN_LIMIT = "token_limit"
27
  CACHE = "cache"
28
  UNKNOWN = "unknown"
29
 
 
80
  return ErrorCategory.MEMORY
81
  elif "validation" in error_str or "value" in error_str or isinstance(error, ValueError):
82
  return ErrorCategory.VALIDATION
83
+ # Detect token limit errors
84
+ elif "token_limit_exceeded" in error_str or "token limit" in error_str or "input is too long" in error_str or "maximum context length" in error_str:
85
+ return ErrorCategory.TOKEN_LIMIT
86
  # Detect model/generation failures
87
  try:
88
  from ..utils.unified_model_manager import ModelError # type: ignore
89
  if isinstance(error, ModelError):
90
+ # Check if it's specifically a token limit error
91
+ if hasattr(error, 'error_type') and error.error_type == "token_limit_exceeded":
92
+ return ErrorCategory.TOKEN_LIMIT
93
  return ErrorCategory.GENERATION
94
  except Exception:
95
  pass
 
240
  "Check data format and types",
241
  "Review API documentation"
242
  ]
243
+ elif category == ErrorCategory.TOKEN_LIMIT:
244
+ recommendations = [
245
+ "Reduce the number of patient visits in the request",
246
+ "Use a model with larger context window (e.g., Phi-3-mini-128k-instruct instead of 4k)",
247
+ "Split patient data into multiple requests",
248
+ "Use chunking endpoints for large datasets",
249
+ "Filter visits by date range to reduce data size",
250
+ "Check logs for exact token count and model limits"
251
+ ]
252
  elif category == ErrorCategory.GENERATION:
253
  recommendations = [
254
  "Verify model availability and internet access",
services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/model_config.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/openvino_summarizer_utils.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc CHANGED
Binary files a/services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc and b/services/ai-service/src/ai_med_extract/utils/__pycache__/performance_monitor.cpython-311.pyc differ
 
services/ai-service/src/ai_med_extract/utils/constants.py CHANGED
@@ -80,6 +80,7 @@ ERROR_MESSAGES = {
80
  "model_load_failed": "Failed to load AI model. Please try again or contact support.",
81
  "generation_timeout": "Summary generation timed out. Please try again with a simpler request.",
82
  "generation_failed": "Summary generation failed. Please try again or contact support.",
 
83
  "cache_error": "Cache operation failed. Continuing with fresh generation."
84
  }
85
 
 
80
  "model_load_failed": "Failed to load AI model. Please try again or contact support.",
81
  "generation_timeout": "Summary generation timed out. Please try again with a simpler request.",
82
  "generation_failed": "Summary generation failed. Please try again or contact support.",
83
+ "token_limit_exceeded": "Patient data exceeds model's token limit. Please reduce the number of visits or use a model with larger context window.",
84
  "cache_error": "Cache operation failed. Continuing with fresh generation."
85
  }
86
 
services/ai-service/src/ai_med_extract/utils/model_config.py CHANGED
@@ -172,6 +172,65 @@ QUANTIZATION_CONFIG = {
172
  "skip_layers": ["embeddings", "lm_head", "shared", "embed_positions"] # Layers to skip quantization
173
  }
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def get_default_model(model_type: str, use_spaces_optimized: bool = False) -> str:
176
  """Get the default model for a given type, optimized for T4 Spaces"""
177
  # Always use T4-optimized models when on T4 Medium
 
172
  "skip_layers": ["embeddings", "lm_head", "shared", "embed_positions"] # Layers to skip quantization
173
  }
174
 
175
+ # ========== MODEL-SPECIFIC TOKEN LIMITS ==========
176
+ # Maximum context window sizes for different models
177
+ MODEL_TOKEN_LIMITS = {
178
+ # Phi-3 models
179
+ "microsoft/Phi-3-mini-4k-instruct": 4096,
180
+ "microsoft/Phi-3-mini-4k-instruct-gguf": 4096,
181
+ "microsoft/Phi-3-mini-4k-instruct-GGUF": 4096,
182
+ "microsoft/Phi-3-mini-128k-instruct": 131072,
183
+ "microsoft/Phi-3-mini-128k-instruct-gguf": 131072,
184
+ "microsoft/Phi-3-small-8k-instruct": 8192,
185
+ "microsoft/Phi-3-small-128k-instruct": 131072,
186
+
187
+ # OpenVINO models
188
+ "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov": 4096,
189
+ "OpenVINO/Phi-3-mini-128k-instruct-int4-ov": 131072,
190
+
191
+ # DialoGPT and BART models
192
+ "microsoft/DialoGPT-small": 1024,
193
+ "microsoft/DialoGPT-medium": 1024,
194
+ "facebook/bart-base": 1024,
195
+ "facebook/bart-large-cnn": 1024,
196
+ "sshleifer/distilbart-cnn-6-6": 1024,
197
+
198
+ # Default fallback
199
+ "default": 4096
200
+ }
201
+
202
+ def get_model_token_limit(model_name: str) -> int:
203
+ """
204
+ Get the maximum token limit for a specific model.
205
+
206
+ Args:
207
+ model_name: Name of the model
208
+
209
+ Returns:
210
+ Maximum token limit for the model
211
+ """
212
+ # Check exact match first
213
+ if model_name in MODEL_TOKEN_LIMITS:
214
+ return MODEL_TOKEN_LIMITS[model_name]
215
+
216
+ # Check for partial matches (e.g., for GGUF files with full paths)
217
+ model_name_lower = model_name.lower()
218
+ for key, limit in MODEL_TOKEN_LIMITS.items():
219
+ if key.lower() in model_name_lower or model_name_lower in key.lower():
220
+ return limit
221
+
222
+ # Check for common patterns
223
+ if "128k" in model_name_lower:
224
+ return 131072
225
+ elif "8k" in model_name_lower:
226
+ return 8192
227
+ elif "4k" in model_name_lower:
228
+ return 4096
229
+
230
+ # Return default
231
+ return MODEL_TOKEN_LIMITS["default"]
232
+
233
+
234
  def get_default_model(model_type: str, use_spaces_optimized: bool = False) -> str:
235
  """Get the default model for a given type, optimized for T4 Spaces"""
236
  # Always use T4-optimized models when on T4 Medium
services/ai-service/src/ai_med_extract/utils/unified_model_manager.py CHANGED
@@ -25,7 +25,7 @@ import torch
25
  from .model_config import (
26
  get_default_model, get_fallback_model, get_t4_model_kwargs,
27
  get_t4_generation_config, is_model_supported_on_t4, detect_model_type,
28
- IS_T4_MEDIUM
29
  )
30
 
31
  # Configure logging
@@ -76,6 +76,96 @@ class ModelError(Exception):
76
  self.timestamp = time.time()
77
  super().__init__(f"Model {model_name} failed ({error_type}): {details}")
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Global unified model manager instance
80
  unified_model_manager = None
81
 
@@ -229,6 +319,24 @@ class TransformersModel(BaseModel):
229
  if self._model is None:
230
  raise ModelError(self.name, "not_loaded", "Model not loaded")
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  try:
233
  # Get T4-optimized generation config
234
  gen_config = get_t4_generation_config(self.model_type)
@@ -275,6 +383,16 @@ class TransformersModel(BaseModel):
275
  return generated_text
276
 
277
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
278
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
279
 
280
  class GGUFModel(BaseModel):
@@ -330,6 +448,24 @@ class GGUFModel(BaseModel):
330
  if self._model is None:
331
  raise ModelError(self.name, "not_loaded", "Model not loaded")
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  try:
334
  # Get T4-optimized generation config
335
  gen_config = get_t4_generation_config("gguf")
@@ -347,6 +483,16 @@ class GGUFModel(BaseModel):
347
  return result['choices'][0]['text'] if result and 'choices' in result else ""
348
 
349
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
350
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
351
 
352
  class OpenVINOModel(BaseModel):
@@ -390,6 +536,24 @@ class OpenVINOModel(BaseModel):
390
  if self._model is None or self._tokenizer is None:
391
  raise ModelError(self.name, "not_loaded", "Model not loaded")
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  try:
394
  inputs = self._tokenizer(prompt, return_tensors="pt")
395
  if torch.cuda.is_available():
@@ -413,6 +577,16 @@ class OpenVINOModel(BaseModel):
413
  return generated_text
414
 
415
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
416
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
417
 
418
  class FallbackModel(BaseModel):
 
25
  from .model_config import (
26
  get_default_model, get_fallback_model, get_t4_model_kwargs,
27
  get_t4_generation_config, is_model_supported_on_t4, detect_model_type,
28
+ get_model_token_limit, IS_T4_MEDIUM
29
  )
30
 
31
  # Configure logging
 
76
  self.timestamp = time.time()
77
  super().__init__(f"Model {model_name} failed ({error_type}): {details}")
78
 
79
+ def count_tokens(text: str, model_name: str = None) -> int:
80
+ """
81
+ Estimate token count for a given text.
82
+ Uses a simple heuristic: ~4 characters per token for English text.
83
+ This is a conservative estimate that works reasonably well for medical text.
84
+
85
+ Args:
86
+ text: Text to count tokens for
87
+ model_name: Optional model name for model-specific counting
88
+
89
+ Returns:
90
+ Estimated token count
91
+ """
92
+ if not text:
93
+ return 0
94
+
95
+ # Simple heuristic: ~4 characters per token
96
+ # This is conservative and works well for medical/clinical text
97
+ estimated_tokens = len(text) // 4
98
+
99
+ # Add some overhead for special tokens and formatting
100
+ estimated_tokens = int(estimated_tokens * 1.1)
101
+
102
+ return estimated_tokens
103
+
104
+ def check_token_limits(text: str, model_name: str, reserve_for_output: int = 8192) -> dict:
105
+ """
106
+ Check if text exceeds model's token limit.
107
+
108
+ Args:
109
+ text: Input text to check
110
+ model_name: Name of the model
111
+ reserve_for_output: Tokens to reserve for model output
112
+
113
+ Returns:
114
+ Dictionary with check results:
115
+ - within_limit: bool
116
+ - estimated_tokens: int
117
+ - max_tokens: int
118
+ - available_for_input: int
119
+ - usage_percentage: float
120
+ """
121
+ max_tokens = get_model_token_limit(model_name)
122
+ estimated_tokens = count_tokens(text, model_name)
123
+ available_for_input = max_tokens - reserve_for_output
124
+
125
+ return {
126
+ "within_limit": estimated_tokens <= available_for_input,
127
+ "estimated_tokens": estimated_tokens,
128
+ "max_tokens": max_tokens,
129
+ "available_for_input": available_for_input,
130
+ "reserve_for_output": reserve_for_output,
131
+ "usage_percentage": (estimated_tokens / available_for_input * 100) if available_for_input > 0 else 0
132
+ }
133
+
134
+ def is_token_limit_error(error: Exception) -> bool:
135
+ """
136
+ Detect if an error is related to token limits being exceeded.
137
+
138
+ Args:
139
+ error: Exception to check
140
+
141
+ Returns:
142
+ True if error is token-limit related
143
+ """
144
+ error_str = str(error).lower()
145
+ error_patterns = [
146
+ "input is too long",
147
+ "maximum context length",
148
+ "exceeds the maximum",
149
+ "context_length",
150
+ "too many tokens",
151
+ "input too long",
152
+ "sequence length",
153
+ "max_position_embeddings",
154
+ "position_ids",
155
+ "token limit" # Added for direct token limit messages
156
+ ]
157
+
158
+ # Check error message
159
+ for pattern in error_patterns:
160
+ if pattern in error_str:
161
+ return True
162
+
163
+ # Check for IndexError which can indicate token overflow
164
+ if isinstance(error, IndexError) and ("position" in error_str or "index" in error_str):
165
+ return True
166
+
167
+ return False
168
+
169
  # Global unified model manager instance
170
  unified_model_manager = None
171
 
 
319
  if self._model is None:
320
  raise ModelError(self.name, "not_loaded", "Model not loaded")
321
 
322
+ # Check token limits before generation
323
+ token_check = check_token_limits(prompt, self.name, config.max_tokens)
324
+ logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
325
+
326
+ if not token_check["within_limit"]:
327
+ error_msg = (
328
+ f"Input exceeds token limit for model {self.name}. "
329
+ f"Estimated tokens: {token_check['estimated_tokens']}, "
330
+ f"Available for input: {token_check['available_for_input']} "
331
+ f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
332
+ f"Please reduce the input size or use a model with larger context window."
333
+ )
334
+ logger.error(error_msg)
335
+ raise ModelError(self.name, "token_limit_exceeded", error_msg)
336
+
337
+ if token_check["usage_percentage"] > 80:
338
+ logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
339
+
340
  try:
341
  # Get T4-optimized generation config
342
  gen_config = get_t4_generation_config(self.model_type)
 
383
  return generated_text
384
 
385
  except Exception as e:
386
+ # Check if this is a token limit error
387
+ if is_token_limit_error(e):
388
+ error_msg = (
389
+ f"Token limit exceeded for model {self.name}. "
390
+ f"Input length: ~{token_check['estimated_tokens']} tokens, "
391
+ f"Model limit: {token_check['max_tokens']} tokens. "
392
+ f"Original error: {str(e)}"
393
+ )
394
+ logger.error(error_msg)
395
+ raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
396
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
397
 
398
  class GGUFModel(BaseModel):
 
448
  if self._model is None:
449
  raise ModelError(self.name, "not_loaded", "Model not loaded")
450
 
451
+ # Check token limits before generation
452
+ token_check = check_token_limits(prompt, self.name, config.max_tokens)
453
+ logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
454
+
455
+ if not token_check["within_limit"]:
456
+ error_msg = (
457
+ f"Input exceeds token limit for model {self.name}. "
458
+ f"Estimated tokens: {token_check['estimated_tokens']}, "
459
+ f"Available for input: {token_check['available_for_input']} "
460
+ f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
461
+ f"Please reduce the input size or use a model with larger context window."
462
+ )
463
+ logger.error(error_msg)
464
+ raise ModelError(self.name, "token_limit_exceeded", error_msg)
465
+
466
+ if token_check["usage_percentage"] > 80:
467
+ logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
468
+
469
  try:
470
  # Get T4-optimized generation config
471
  gen_config = get_t4_generation_config("gguf")
 
483
  return result['choices'][0]['text'] if result and 'choices' in result else ""
484
 
485
  except Exception as e:
486
+ # Check if this is a token limit error
487
+ if is_token_limit_error(e):
488
+ error_msg = (
489
+ f"Token limit exceeded for model {self.name}. "
490
+ f"Input length: ~{token_check['estimated_tokens']} tokens, "
491
+ f"Model limit: {token_check['max_tokens']} tokens. "
492
+ f"Original error: {str(e)}"
493
+ )
494
+ logger.error(error_msg)
495
+ raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
496
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
497
 
498
  class OpenVINOModel(BaseModel):
 
536
  if self._model is None or self._tokenizer is None:
537
  raise ModelError(self.name, "not_loaded", "Model not loaded")
538
 
539
+ # Check token limits before generation
540
+ token_check = check_token_limits(prompt, self.name, config.max_tokens)
541
+ logger.info(f"Token check for {self.name}: {token_check['estimated_tokens']}/{token_check['available_for_input']} tokens ({token_check['usage_percentage']:.1f}%)")
542
+
543
+ if not token_check["within_limit"]:
544
+ error_msg = (
545
+ f"Input exceeds token limit for model {self.name}. "
546
+ f"Estimated tokens: {token_check['estimated_tokens']}, "
547
+ f"Available for input: {token_check['available_for_input']} "
548
+ f"(max: {token_check['max_tokens']}, reserved for output: {token_check['reserve_for_output']}). "
549
+ f"Please reduce the input size or use a model with larger context window."
550
+ )
551
+ logger.error(error_msg)
552
+ raise ModelError(self.name, "token_limit_exceeded", error_msg)
553
+
554
+ if token_check["usage_percentage"] > 80:
555
+ logger.warning(f"Approaching token limit for {self.name}: {token_check['usage_percentage']:.1f}% of available tokens")
556
+
557
  try:
558
  inputs = self._tokenizer(prompt, return_tensors="pt")
559
  if torch.cuda.is_available():
 
577
  return generated_text
578
 
579
  except Exception as e:
580
+ # Check if this is a token limit error
581
+ if is_token_limit_error(e):
582
+ error_msg = (
583
+ f"Token limit exceeded for model {self.name}. "
584
+ f"Input length: ~{token_check['estimated_tokens']} tokens, "
585
+ f"Model limit: {token_check['max_tokens']} tokens. "
586
+ f"Original error: {str(e)}"
587
+ )
588
+ logger.error(error_msg)
589
+ raise ModelError(self.name, "token_limit_exceeded", error_msg, e)
590
  raise ModelError(self.name, "generation_failed", f"Generation failed: {str(e)}", e)
591
 
592
  class FallbackModel(BaseModel):
services/ai-service/test_token_limits.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple test to verify token limit detection works correctly.
3
+ """
4
+
5
+ import sys
6
+ import os
7
+
8
+ # Set UTF-8 encoding for Windows console
9
+ if sys.platform == 'win32':
10
+ os.system('chcp 65001 > nul')
11
+
12
+ sys.path.insert(0, 'src')
13
+
14
+ from ai_med_extract.utils.model_config import get_model_token_limit
15
+ from ai_med_extract.utils.unified_model_manager import count_tokens, check_token_limits, is_token_limit_error
16
+
17
+ def test_model_token_limits():
18
+ """Test that model token limits are configured correctly"""
19
+ print("Testing model token limits...")
20
+
21
+ assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct") == 4096
22
+ assert get_model_token_limit("microsoft/Phi-3-mini-128k-instruct") == 131072
23
+ assert get_model_token_limit("microsoft/Phi-3-small-8k-instruct") == 8192
24
+ assert get_model_token_limit("microsoft/Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-q4.gguf") == 4096
25
+ assert get_model_token_limit("some-model-128k") == 131072
26
+ assert get_model_token_limit("unknown-model") == 4096
27
+
28
+ print("[PASS] Model token limits working correctly\n")
29
+
30
+ def test_token_counting():
31
+ """Test token counting estimation"""
32
+ print("Testing token counting...")
33
+
34
+ assert count_tokens("") == 0
35
+ small_text = "This is a test of the token counting system. It should estimate tokens based on character count."
36
+ tokens = count_tokens(small_text)
37
+ assert 20 < tokens < 35, f"Expected ~27 tokens, got {tokens}"
38
+
39
+ large_text = "Patient visit data. " * 1000
40
+ tokens = count_tokens(large_text)
41
+ assert 5000 < tokens < 6000, f"Expected ~5,500 tokens, got {tokens}"
42
+
43
+ print(f"[PASS] Token counting working correctly")
44
+ print(f" Small text ({len(small_text)} chars) = {count_tokens(small_text)} tokens")
45
+ print(f" Large text ({len(large_text)} chars) = {count_tokens(large_text)} tokens\n")
46
+
47
+ def test_token_limit_checking():
48
+ """Test token limit validation"""
49
+ print("Testing token limit checking...")
50
+
51
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
52
+
53
+ # Small input
54
+ small_text = "Short patient summary. " * 10
55
+ result = check_token_limits(small_text, model_name, reserve_for_output=2048)
56
+ assert result["within_limit"] == True
57
+ print(f"[PASS] Small input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
58
+
59
+ # Large input
60
+ large_text = "Patient visit data. " * 2000
61
+ result = check_token_limits(large_text, model_name, reserve_for_output=2048)
62
+ assert result["within_limit"] == False
63
+ print(f"[PASS] Large input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%) - EXCEEDS LIMIT")
64
+
65
+ # Medium input - adjust to actually be ~80-90% of limit
66
+ medium_text = "Patient visit data. " * 350 # ~7000 chars = ~1925 tokens (~94% of 2048)
67
+ result = check_token_limits(medium_text, model_name, reserve_for_output=2048)
68
+ print(f"[INFO] Medium input: {result['estimated_tokens']}/{result['available_for_input']} tokens ({result['usage_percentage']:.1f}%)")
69
+ assert result["within_limit"] == True
70
+ assert result["usage_percentage"] > 80, f"Expected >80%, got {result['usage_percentage']:.1f}%"
71
+ print(f"[PASS] Medium input - APPROACHING LIMIT\n")
72
+
73
+ def test_error_detection():
74
+ """Test token limit error pattern detection"""
75
+ print("Testing error pattern detection...")
76
+
77
+ test_cases = [
78
+ (Exception("input is too long"), True),
79
+ (Exception("maximum context length exceeded"), True),
80
+ (Exception("Token limit exceeded"), True),
81
+ (IndexError("position index out of range"), True),
82
+ (Exception("some other error"), False),
83
+ ]
84
+
85
+ for error, expected in test_cases:
86
+ result = is_token_limit_error(error)
87
+ assert result == expected, f"Failed for: {error}"
88
+ status = "[PASS]" if result else "[SKIP]"
89
+ print(f" {status} '{str(error)[:40]}...' -> token_limit={result}")
90
+
91
+ print("[PASS] Error pattern detection working correctly\n")
92
+
93
+ if __name__ == "__main__":
94
+ print("="*60)
95
+ print("Token Limit Detection - Verification Tests")
96
+ print("="*60 + "\n")
97
+
98
+ try:
99
+ test_model_token_limits()
100
+ test_token_counting()
101
+ test_token_limit_checking()
102
+ test_error_detection()
103
+
104
+ print("="*60)
105
+ print("[SUCCESS] ALL TESTS PASSED")
106
+ print("="*60)
107
+ print("\nToken limit detection is working correctly!")
108
+ print("\nNext steps:")
109
+ print("1. Test with real patient data containing many visits")
110
+ print("2. Verify error messages appear in API responses")
111
+ print("3. Check logs for token diagnostics")
112
+
113
+ except AssertionError as e:
114
+ print(f"\n[FAILED] TEST FAILED: {e}")
115
+ sys.exit(1)
116
+ except Exception as e:
117
+ print(f"\n[ERROR] {e}")
118
+ import traceback
119
+ traceback.print_exc()
120
+ sys.exit(1)