sachinchandrankallar commited on
Commit
299444a
·
1 Parent(s): 84a9aa6

Refactor PyTorch compatibility handling by centralizing the RMSNorm patch into a dedicated utility function. This ensures consistent application across modules and improves maintainability. Update logging to reflect the new approach.

Browse files
services/ai-service/src/ai_med_extract/api/routes_fastapi.py CHANGED
@@ -15,35 +15,9 @@ from ..core_logger import log_with_memory, log_exception_with_memory
15
  logger = logging.getLogger(__name__)
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import torch
18
- # Patch torch.rms_norm for compatibility with models like Phi-3 that expect this function
19
- if not hasattr(torch, 'rms_norm'):
20
- def rms_norm(input_tensor, normalized_shape=None, weight=None, eps=1e-6):
21
- """Simple RMS normalization implementation compatible with various call signatures"""
22
- # Handle different input formats
23
- if normalized_shape is None:
24
- # If no shape specified, normalize over last dimension
25
- dim = -1
26
- keepdim = True
27
- else:
28
- # If shape is specified, normalize over those dimensions
29
- if isinstance(normalized_shape, int):
30
- dim = normalized_shape
31
- keepdim = True
32
- else:
33
- # Multiple dimensions - normalize over all of them
34
- dim = tuple(range(-len(normalized_shape), 0))
35
- keepdim = True
36
-
37
- # Calculate RMS (root mean square)
38
- variance = input_tensor.pow(2).mean(dim=dim, keepdim=keepdim)
39
- # Normalize
40
- output = input_tensor * torch.rsqrt(variance + eps)
41
- # Apply weight if provided
42
- if weight is not None:
43
- output = output * weight
44
- return output
45
- torch.rms_norm = rms_norm
46
- logger.info("Patched torch.rms_norm for compatibility with Phi-3 and similar models")
47
 
48
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as transformers_pipeline
49
  import requests
 
15
  logger = logging.getLogger(__name__)
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import torch
18
+ # Ensure PyTorch compatibility patches are applied early
19
+ from ..utils.torch_compat import ensure_torch_compatibility
20
+ ensure_torch_compatibility()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as transformers_pipeline
23
  import requests
services/ai-service/src/ai_med_extract/app.py CHANGED
@@ -7,6 +7,9 @@ from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from starlette.middleware.base import BaseHTTPMiddleware
9
  import torch
 
 
 
10
  from contextlib import asynccontextmanager
11
  from datetime import datetime
12
  import redis.asyncio as redis
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from starlette.middleware.base import BaseHTTPMiddleware
9
  import torch
10
+ # Ensure PyTorch compatibility patches are applied early
11
+ from .utils.torch_compat import ensure_torch_compatibility
12
+ ensure_torch_compatibility()
13
  from contextlib import asynccontextmanager
14
  from datetime import datetime
15
  import redis.asyncio as redis
services/ai-service/src/ai_med_extract/utils/fallback_pipeline.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fallback Pipeline Implementation
3
+ Provides a transformers-compatible pipeline wrapper for fallback scenarios
4
+ """
5
+ import logging
6
+ import torch
7
+ from typing import Dict, Any, Optional, Union
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class FallbackPipeline:
13
+ """
14
+ A transformers-compatible pipeline wrapper for fallback scenarios.
15
+
16
+ This class provides a consistent interface when OpenVINO loading fails
17
+ and we need to fall back to standard transformers models.
18
+ """
19
+
20
+ def __init__(self, model, tokenizer):
21
+ """
22
+ Initialize the fallback pipeline.
23
+
24
+ Args:
25
+ model: The transformers model instance
26
+ tokenizer: The tokenizer instance
27
+ """
28
+ self.model = model
29
+ self.tokenizer = tokenizer
30
+ self.device = next(model.parameters()).device if hasattr(model, 'parameters') else None
31
+ self.cache_settings = {}
32
+
33
+ if hasattr(model, 'config'):
34
+ self.cache_settings['max_length'] = getattr(
35
+ model.config, 'max_position_embeddings', 2048
36
+ )
37
+
38
+ def _has_dynamic_cache(self) -> bool:
39
+ """Check if model has dynamic cache support."""
40
+ if not hasattr(self.model, 'config'):
41
+ return False
42
+
43
+ model_config = self.model.config
44
+ return (
45
+ hasattr(model_config, 'sliding_window') or
46
+ hasattr(model_config, 'sliding_window_size') or
47
+ (hasattr(model_config, 'architectures') and
48
+ model_config.architectures and
49
+ any(arch.lower() in ('mistral', 'llama', 'phi')
50
+ for arch in model_config.architectures))
51
+ )
52
+
53
+ def _cleanup_legacy_cache_attrs(self):
54
+ """Remove legacy cache attributes that may cause issues."""
55
+ if hasattr(self.model, 'config'):
56
+ model_config = self.model.config
57
+ for legacy_attr in ['get_max_length', 'max_cache_length']:
58
+ if hasattr(model_config, legacy_attr):
59
+ try:
60
+ delattr(model_config, legacy_attr)
61
+ except Exception:
62
+ pass
63
+
64
+ def _get_safe_generation_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
65
+ """
66
+ Extract and sanitize generation kwargs.
67
+
68
+ Args:
69
+ kwargs: Original generation kwargs
70
+
71
+ Returns:
72
+ Sanitized kwargs safe for model generation
73
+ """
74
+ # Remove unsupported kwargs
75
+ kwargs.pop('loss_type', None)
76
+
77
+ # Determine cache usage
78
+ use_cache = not self._has_dynamic_cache()
79
+ self._cleanup_legacy_cache_attrs()
80
+
81
+ # Allowed generation parameters
82
+ allowed = {
83
+ 'max_new_tokens', 'do_sample', 'temperature', 'top_k', 'top_p',
84
+ 'num_return_sequences', 'pad_token_id', 'eos_token_id', 'num_beams',
85
+ 'early_stopping', 'repetition_penalty', 'use_cache',
86
+ 'output_attentions', 'output_hidden_states', 'return_dict_in_generate'
87
+ }
88
+
89
+ safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
90
+ safe_kwargs['use_cache'] = use_cache
91
+
92
+ # Convert max_length to max_new_tokens if needed
93
+ if 'max_length' in kwargs and 'max_new_tokens' not in safe_kwargs:
94
+ try:
95
+ input_len = self._get_input_length(kwargs.get('inputs'))
96
+ max_len_val = kwargs.get('max_length', 2048)
97
+ computed_new = max(1, int(max_len_val) - int(input_len))
98
+ safe_kwargs['max_new_tokens'] = min(
99
+ computed_new,
100
+ self.cache_settings.get('max_length', 2048)
101
+ )
102
+ except Exception:
103
+ safe_kwargs['max_new_tokens'] = 256
104
+
105
+ return safe_kwargs
106
+
107
+ def _get_input_length(self, inputs: Any) -> int:
108
+ """Extract input length from various input formats."""
109
+ if isinstance(inputs, dict) and 'input_ids' in inputs:
110
+ input_ids = inputs['input_ids']
111
+ try:
112
+ return input_ids.shape[-1]
113
+ except Exception:
114
+ try:
115
+ return len(input_ids[0])
116
+ except Exception:
117
+ return 0
118
+ return 0
119
+
120
+ def _extract_prompt_text(self, inputs: Union[str, Dict[str, Any]]) -> str:
121
+ """
122
+ Extract prompt text from various input formats.
123
+
124
+ Args:
125
+ inputs: Can be a string prompt or a dict with tokenized inputs
126
+
127
+ Returns:
128
+ Extracted prompt text
129
+ """
130
+ if isinstance(inputs, dict):
131
+ # Try to decode tokenized inputs
132
+ if 'input_ids' in inputs and self.tokenizer is not None:
133
+ try:
134
+ input_ids = inputs['input_ids']
135
+ # Handle both tensor and list formats
136
+ if hasattr(input_ids, 'tolist'):
137
+ decoded = self.tokenizer.decode(
138
+ input_ids[0] if len(input_ids.shape) > 1 else input_ids,
139
+ skip_special_tokens=True
140
+ )
141
+ else:
142
+ decoded = self.tokenizer.decode(
143
+ input_ids[0] if isinstance(input_ids, list) else input_ids,
144
+ skip_special_tokens=True
145
+ )
146
+ return decoded
147
+ except Exception:
148
+ pass
149
+
150
+ # Fallback to text/prompt keys
151
+ return inputs.get('text') or inputs.get('prompt') or ""
152
+
153
+ return str(inputs) if inputs is not None else ""
154
+
155
+ def _move_to_device(self, tokenized: Dict[str, Any]) -> Dict[str, Any]:
156
+ """Move tokenized inputs to the appropriate device."""
157
+ if self.device is not None:
158
+ try:
159
+ return {
160
+ k: v.to(self.device) if hasattr(v, 'to') and torch.is_tensor(v) else v
161
+ for k, v in tokenized.items()
162
+ }
163
+ except Exception:
164
+ pass
165
+ return tokenized
166
+
167
+ def generate(self, inputs: Union[str, Dict[str, Any]], **kwargs) -> str:
168
+ """
169
+ Generate text from inputs.
170
+
171
+ Args:
172
+ inputs: Input prompt (string or tokenized dict)
173
+ **kwargs: Generation parameters
174
+
175
+ Returns:
176
+ Generated text
177
+ """
178
+ # Get safe kwargs
179
+ safe_kwargs = self._get_safe_generation_kwargs(kwargs.copy())
180
+
181
+ # Extract prompt text
182
+ prompt_text = self._extract_prompt_text(inputs)
183
+
184
+ try:
185
+ # Tokenize and generate if we have a tokenizer and string input
186
+ if self.tokenizer is not None and isinstance(prompt_text, str) and prompt_text:
187
+ tokenized = self.tokenizer([prompt_text], return_tensors='pt')
188
+ tokenized = self._move_to_device(tokenized)
189
+
190
+ # Generate
191
+ with torch.no_grad():
192
+ outputs = self.model.generate(**tokenized, **safe_kwargs)
193
+
194
+ # Decode output
195
+ if hasattr(self.tokenizer, 'decode'):
196
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
197
+ return str(outputs)
198
+
199
+ # Try direct generation with provided inputs
200
+ if isinstance(inputs, dict):
201
+ inputs = self._move_to_device(inputs)
202
+
203
+ with torch.no_grad():
204
+ outputs = self.model.generate(**inputs, **safe_kwargs)
205
+
206
+ # Decode if possible
207
+ if hasattr(self.tokenizer, 'decode'):
208
+ try:
209
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
210
+ except Exception:
211
+ return str(outputs)
212
+
213
+ return str(outputs)
214
+
215
+ except TypeError as te:
216
+ logger.warning(f"Generation signature mismatch: {te}")
217
+ # Last resort: try unified model manager
218
+ try:
219
+ from .unified_model_manager import unified_model_manager as _umm
220
+ return _umm.generate_text(
221
+ getattr(self.model, 'name', str(self.model)),
222
+ prompt_text,
223
+ model_type="text-generation"
224
+ )
225
+ except Exception as e:
226
+ logger.warning(f"Fallback generation failed: {e}")
227
+ raise
228
+
services/ai-service/src/ai_med_extract/utils/model_loader_spaces.py CHANGED
@@ -107,38 +107,9 @@ def get_openvino_pipeline(model_name: str, device: str = None):
107
 
108
  logging.info(f"Loading OpenVINO model {model_name} on device: {device}")
109
 
110
- # Check for torch.rms_norm compatibility issue and patch if needed
111
- # Some models (like Phi-3) may reference torch.rms_norm which doesn't exist in older PyTorch versions
112
- if not hasattr(torch, 'rms_norm'):
113
- # Add a simple RMS normalization function to torch if missing
114
- # This is a workaround for models that expect torch.rms_norm to exist
115
- def rms_norm(input_tensor, normalized_shape=None, weight=None, eps=1e-6):
116
- """Simple RMS normalization implementation compatible with various call signatures"""
117
- # Handle different input formats
118
- if normalized_shape is None:
119
- # If no shape specified, normalize over last dimension
120
- dim = -1
121
- keepdim = True
122
- else:
123
- # If shape is specified, normalize over those dimensions
124
- if isinstance(normalized_shape, int):
125
- dim = normalized_shape
126
- keepdim = True
127
- else:
128
- # Multiple dimensions - normalize over all of them
129
- dim = tuple(range(-len(normalized_shape), 0))
130
- keepdim = True
131
-
132
- # Calculate RMS (root mean square)
133
- variance = input_tensor.pow(2).mean(dim=dim, keepdim=keepdim)
134
- # Normalize
135
- output = input_tensor * torch.rsqrt(variance + eps)
136
- # Apply weight if provided
137
- if weight is not None:
138
- output = output * weight
139
- return output
140
- torch.rms_norm = rms_norm
141
- logging.info("Patched torch.rms_norm for compatibility")
142
 
143
  try:
144
  # If model_name is a directory, try to load IR from there; else, download and export
@@ -225,112 +196,8 @@ def get_openvino_pipeline(model_name: str, device: str = None):
225
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
226
  )
227
 
228
- # Use the FallbackPipeline class defined below
229
- class FallbackPipeline:
230
- def __init__(self, model, tokenizer):
231
- self.model = model
232
- self.tokenizer = tokenizer
233
- self.device = next(model.parameters()).device if hasattr(model, 'parameters') else None
234
- self.cache_settings = {}
235
- if hasattr(model, 'config'):
236
- self.cache_settings['max_length'] = getattr(model.config, 'max_position_embeddings', 2048)
237
-
238
- def generate(self, inputs, **kwargs):
239
- import logging as _logging
240
- if 'loss_type' in kwargs:
241
- kwargs.pop('loss_type', None)
242
-
243
- use_cache_value = False
244
- if hasattr(self.model, 'config'):
245
- model_config = self.model.config
246
- has_dynamic_cache = (
247
- hasattr(model_config, 'sliding_window') or
248
- hasattr(model_config, 'sliding_window_size') or
249
- (hasattr(model_config, 'architectures') and
250
- model_config.architectures and
251
- any('mistral' in arch.lower() or 'llama' in arch.lower() or 'phi' in arch.lower()
252
- for arch in model_config.architectures))
253
- )
254
- if has_dynamic_cache:
255
- use_cache_value = False
256
- else:
257
- use_cache_value = True
258
-
259
- for legacy_cache_attr in ['get_max_length', 'max_cache_length']:
260
- if hasattr(model_config, legacy_cache_attr):
261
- delattr(model_config, legacy_cache_attr)
262
-
263
- kwargs['use_cache'] = use_cache_value
264
- allowed = {
265
- 'max_new_tokens', 'do_sample', 'temperature', 'top_k', 'top_p', 'num_return_sequences',
266
- 'pad_token_id', 'eos_token_id', 'num_beams', 'early_stopping', 'repetition_penalty',
267
- 'use_cache', 'output_attentions', 'output_hidden_states', 'return_dict_in_generate'
268
- }
269
- safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
270
- if 'max_length' in kwargs and 'max_new_tokens' not in safe_kwargs:
271
- try:
272
- input_len = 0
273
- if isinstance(inputs, dict) and 'input_ids' in inputs:
274
- input_ids = inputs['input_ids']
275
- try:
276
- input_len = input_ids.shape[-1]
277
- except Exception:
278
- try:
279
- input_len = len(input_ids[0])
280
- except Exception:
281
- input_len = 0
282
- max_len_val = kwargs.get('max_length')
283
- computed_new = max(1, int(max_len_val) - int(input_len))
284
- safe_kwargs['max_new_tokens'] = min(computed_new, self.cache_settings.get('max_length', 2048))
285
- except Exception:
286
- safe_kwargs['max_new_tokens'] = 256
287
-
288
- prompt_text = None
289
- if isinstance(inputs, dict):
290
- if 'input_ids' in inputs and self.tokenizer is not None:
291
- try:
292
- input_ids = inputs['input_ids']
293
- if hasattr(input_ids, 'tolist'):
294
- decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
295
- else:
296
- decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
297
- prompt_text = decoded
298
- except Exception:
299
- prompt_text = None
300
- else:
301
- prompt_text = inputs.get('text') or inputs.get('prompt')
302
- else:
303
- prompt_text = inputs
304
- if prompt_text is None:
305
- prompt_text = ""
306
- try:
307
- if self.tokenizer is not None and isinstance(prompt_text, str):
308
- tokenized = self.tokenizer([prompt_text], return_tensors='pt')
309
- try:
310
- if self.device is not None and hasattr(tokenized['input_ids'], 'to'):
311
- tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
312
- except Exception:
313
- _pass = None
314
- outputs = self.model.generate(**tokenized, **safe_kwargs)
315
- if hasattr(self.tokenizer, 'decode'):
316
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
317
- return outputs
318
- else:
319
- outputs = self.model.generate(**inputs, **safe_kwargs)
320
- if hasattr(self.tokenizer, 'decode'):
321
- try:
322
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
323
- except Exception:
324
- return str(outputs)
325
- return outputs
326
- except TypeError as te:
327
- _logging.warning(f"Fallback generate signature mismatch: {te}")
328
- try:
329
- from .unified_model_manager import unified_model_manager as _umm
330
- return _umm.generate_text(self.model.name if hasattr(self.model, 'name') else str(self.model), prompt_text)
331
- except Exception as e:
332
- _logging.warning(f"Fallback final generation failed: {e}")
333
- raise
334
 
335
  logging.info(f"Successfully loaded transformers model as fallback: {original_model_name}")
336
  return FallbackPipeline(model, tokenizer)
@@ -385,150 +252,17 @@ def get_openvino_pipeline(model_name: str, device: str = None):
385
  original_model_name,
386
  torch_dtype=torch.float32,
387
  device_map="auto" if torch.cuda.is_available() and not is_hf_spaces else None,
388
- trust_remote_code=True
 
389
  )
390
  tokenizer = AutoTokenizer.from_pretrained(
391
  original_model_name,
392
- trust_remote_code=True
 
393
  )
394
 
395
- # Create a compatible pipeline
396
- class FallbackPipeline:
397
- def __init__(self, model, tokenizer):
398
- self.model = model
399
- self.tokenizer = tokenizer
400
- # determine device for tensors
401
- self.device = next(model.parameters()).device if hasattr(model, 'parameters') else None
402
- # Modern cache configuration for transformer models
403
- self.cache_settings = {}
404
- if hasattr(model, 'config'):
405
- # Don't set use_cache at model config level - handle it per generation
406
- # This prevents issues with dynamic cache systems
407
- # Store max length for reference but don't enforce it
408
- self.cache_settings['max_length'] = getattr(model.config, 'max_position_embeddings', 2048)
409
-
410
- def generate(self, inputs, **kwargs):
411
- """Robust generate wrapper that accepts either a prompt string or a tokenized inputs dict.
412
- It sanitizes unsupported kwargs (e.g., loss_type) before delegating to the underlying model.
413
- """
414
- import logging as _logging
415
- # Sanitize unsupported kwargs forwarded from callers
416
- if 'loss_type' in kwargs:
417
- kwargs.pop('loss_type', None)
418
-
419
- # Modern cache handling for transformers models with dynamic cache support
420
- # For single independent generations, explicitly disable cache to prevent stale cache issues
421
- use_cache_value = False # Default to False for single generations
422
-
423
- # Check if model has dynamic cache support
424
- if hasattr(self.model, 'config'):
425
- model_config = self.model.config
426
- # Check for dynamic cache indicators
427
- has_dynamic_cache = (
428
- hasattr(model_config, 'sliding_window') or
429
- hasattr(model_config, 'sliding_window_size') or
430
- (hasattr(model_config, 'architectures') and
431
- model_config.architectures and
432
- any('mistral' in arch.lower() or 'llama' in arch.lower() or 'phi' in arch.lower()
433
- for arch in model_config.architectures))
434
- )
435
-
436
- if has_dynamic_cache:
437
- use_cache_value = False # Disable cache for dynamic cache models in single generations
438
- else:
439
- # For standard models without dynamic cache, we can use cache
440
- use_cache_value = True
441
-
442
- # Ensure we're not passing legacy cache attributes
443
- for legacy_cache_attr in ['get_max_length', 'max_cache_length']:
444
- if hasattr(model_config, legacy_cache_attr):
445
- delattr(model_config, legacy_cache_attr)
446
-
447
- # Set use_cache in kwargs for generation
448
- kwargs['use_cache'] = use_cache_value
449
- # Known-safe generation args (prefer max_new_tokens for causal models)
450
- allowed = {
451
- 'max_new_tokens', 'do_sample', 'temperature', 'top_k', 'top_p', 'num_return_sequences',
452
- 'pad_token_id', 'eos_token_id', 'num_beams', 'early_stopping', 'repetition_penalty',
453
- 'use_cache', 'output_attentions', 'output_hidden_states', 'return_dict_in_generate'
454
- }
455
- safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
456
- # If callers provided max_length (common in some code paths), convert it to max_new_tokens
457
- # for causal models to avoid internal calls that rely on deprecated cache APIs.
458
- if 'max_length' in kwargs and 'max_new_tokens' not in safe_kwargs:
459
- try:
460
- input_len = 0
461
- if isinstance(inputs, dict) and 'input_ids' in inputs:
462
- input_ids = inputs['input_ids']
463
- # support tensor-like or list-like input_ids
464
- try:
465
- input_len = input_ids.shape[-1]
466
- except Exception:
467
- try:
468
- input_len = len(input_ids[0])
469
- except Exception:
470
- input_len = 0
471
- max_len_val = kwargs.get('max_length')
472
- computed_new = max(1, int(max_len_val) - int(input_len))
473
- safe_kwargs['max_new_tokens'] = min(computed_new, self.cache_settings.get('max_length', 2048))
474
- except Exception:
475
- # If anything goes wrong, default to a conservative value
476
- safe_kwargs['max_new_tokens'] = 256
477
- # Accept prompt string or tokenized dict
478
- prompt_text = None
479
- if isinstance(inputs, dict):
480
- # If tokenized tensors provided, try to decode to text when tokenizer exists
481
- if 'input_ids' in inputs and self.tokenizer is not None:
482
- try:
483
- input_ids = inputs['input_ids']
484
- # handle tensors or lists
485
- if hasattr(input_ids, 'tolist'):
486
- decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
487
- else:
488
- decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
489
- prompt_text = decoded
490
- except Exception:
491
- prompt_text = None
492
- else:
493
- prompt_text = inputs.get('text') or inputs.get('prompt')
494
- else:
495
- prompt_text = inputs
496
- if prompt_text is None:
497
- prompt_text = ""
498
- try:
499
- # If tokenizer available, tokenize prompt and generate
500
- if self.tokenizer is not None and isinstance(prompt_text, str):
501
- tokenized = self.tokenizer([prompt_text], return_tensors='pt')
502
- # move tensors to device if needed
503
- try:
504
- if self.device is not None and hasattr(tokenized['input_ids'], 'to'):
505
- tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
506
- except Exception:
507
- _pass = None
508
- outputs = self.model.generate(**tokenized, **safe_kwargs)
509
- # decode if tokenizer has decode
510
- if hasattr(self.tokenizer, 'decode'):
511
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
512
- return outputs
513
- else:
514
- # Try calling model.generate with provided inputs
515
- outputs = self.model.generate(**inputs, **safe_kwargs)
516
- # If tokenizer exists and outputs is tensor-like, decode
517
- if hasattr(self.tokenizer, 'decode'):
518
- try:
519
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
520
- except Exception:
521
- return str(outputs)
522
- return outputs
523
- except TypeError as te:
524
- _logging.warning(f"Fallback generate signature mismatch: {te}")
525
- # As a last resort, try to call unified_model_manager if available
526
- try:
527
- from .unified_model_manager import unified_model_manager as _umm
528
- return _umm.generate_text(self.model.name if hasattr(self.model, 'name') else str(self.model), prompt_text)
529
- except Exception as e:
530
- _logging.warning(f"Fallback final generation failed: {e}")
531
- raise
532
 
533
  logging.info(f"Successfully loaded fallback transformers model: {original_model_name}")
534
  return FallbackPipeline(model, tokenizer)
 
107
 
108
  logging.info(f"Loading OpenVINO model {model_name} on device: {device}")
109
 
110
+ # Ensure torch compatibility patches are applied
111
+ from .torch_compat import ensure_torch_compatibility
112
+ ensure_torch_compatibility()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  try:
115
  # If model_name is a directory, try to load IR from there; else, download and export
 
196
  cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
197
  )
198
 
199
+ # Use optimized FallbackPipeline from dedicated module
200
+ from .fallback_pipeline import FallbackPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  logging.info(f"Successfully loaded transformers model as fallback: {original_model_name}")
203
  return FallbackPipeline(model, tokenizer)
 
252
  original_model_name,
253
  torch_dtype=torch.float32,
254
  device_map="auto" if torch.cuda.is_available() and not is_hf_spaces else None,
255
+ trust_remote_code=True,
256
+ cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
257
  )
258
  tokenizer = AutoTokenizer.from_pretrained(
259
  original_model_name,
260
+ trust_remote_code=True,
261
+ cache_dir=os.environ.get('HF_HOME', '/tmp/huggingface')
262
  )
263
 
264
+ # Use optimized FallbackPipeline from dedicated module
265
+ from .fallback_pipeline import FallbackPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  logging.info(f"Successfully loaded fallback transformers model: {original_model_name}")
268
  return FallbackPipeline(model, tokenizer)
services/ai-service/src/ai_med_extract/utils/torch_compat.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Compatibility Utilities
3
+ Provides compatibility patches and optimizations for PyTorch operations
4
+ """
5
+ import logging
6
+ import torch
7
+ from typing import Optional, Union, Tuple
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Track if patches have been applied
12
+ _RMS_NORM_PATCHED = False
13
+
14
+
15
+ def _create_rms_norm_implementation():
16
+ """
17
+ Create an optimized RMS normalization function.
18
+
19
+ RMS normalization formula: output = input * rsqrt(mean(input^2) + eps) * weight
20
+
21
+ Optimizations:
22
+ - Uses in-place operations where possible
23
+ - Efficient tensor operations
24
+ - Handles various input shapes and signatures
25
+ """
26
+ def rms_norm(
27
+ input_tensor: torch.Tensor,
28
+ normalized_shape: Optional[Union[int, Tuple[int, ...]]] = None,
29
+ weight: Optional[torch.Tensor] = None,
30
+ eps: float = 1e-6
31
+ ) -> torch.Tensor:
32
+ """
33
+ RMS normalization implementation compatible with PyTorch's expected signature.
34
+
35
+ Args:
36
+ input_tensor: Input tensor to normalize
37
+ normalized_shape: Shape of dimensions to normalize over (None = last dim)
38
+ weight: Optional weight tensor to apply after normalization
39
+ eps: Small epsilon value for numerical stability
40
+
41
+ Returns:
42
+ Normalized tensor
43
+ """
44
+ # Determine normalization dimensions
45
+ if normalized_shape is None:
46
+ # Default: normalize over last dimension
47
+ dim = -1
48
+ keepdim = True
49
+ elif isinstance(normalized_shape, int):
50
+ # Single dimension specified
51
+ dim = normalized_shape
52
+ keepdim = True
53
+ else:
54
+ # Multiple dimensions specified (tuple/list)
55
+ if isinstance(normalized_shape, (list, tuple)):
56
+ # Normalize over trailing dimensions matching the shape
57
+ dim = tuple(range(-len(normalized_shape), 0))
58
+ else:
59
+ dim = normalized_shape
60
+ keepdim = True
61
+
62
+ # Compute RMS: sqrt(mean(x^2))
63
+ # Use pow(2) instead of **2 for better performance in some cases
64
+ variance = input_tensor.pow(2).mean(dim=dim, keepdim=keepdim)
65
+
66
+ # Normalize: x * rsqrt(variance + eps)
67
+ # Using rsqrt is more efficient than 1/sqrt
68
+ output = input_tensor * torch.rsqrt(variance + eps)
69
+
70
+ # Apply weight if provided
71
+ if weight is not None:
72
+ output = output * weight
73
+
74
+ return output
75
+
76
+ return rms_norm
77
+
78
+
79
+ def patch_torch_rms_norm() -> bool:
80
+ """
81
+ Patch torch.rms_norm if it doesn't exist.
82
+
83
+ This is needed for compatibility with models like Phi-3 that expect
84
+ torch.rms_norm to be available, but it may not exist in older PyTorch versions.
85
+
86
+ Returns:
87
+ True if patch was applied, False if already exists
88
+ """
89
+ global _RMS_NORM_PATCHED
90
+
91
+ if _RMS_NORM_PATCHED:
92
+ return False
93
+
94
+ if hasattr(torch, 'rms_norm'):
95
+ # Already exists, no need to patch
96
+ _RMS_NORM_PATCHED = True
97
+ return False
98
+
99
+ try:
100
+ rms_norm_func = _create_rms_norm_implementation()
101
+ torch.rms_norm = rms_norm_func
102
+ _RMS_NORM_PATCHED = True
103
+ logger.info("Patched torch.rms_norm for compatibility with Phi-3 and similar models")
104
+ return True
105
+ except Exception as e:
106
+ logger.warning(f"Failed to patch torch.rms_norm: {e}")
107
+ return False
108
+
109
+
110
+ def ensure_torch_compatibility():
111
+ """
112
+ Ensure all PyTorch compatibility patches are applied.
113
+ Call this at module initialization time.
114
+ """
115
+ patch_torch_rms_norm()
116
+
117
+
118
+ # Auto-apply patch on import
119
+ ensure_torch_compatibility()
120
+
services/ai-service/src/ai_med_extract/utils/unified_model_manager.py CHANGED
@@ -16,44 +16,14 @@ from enum import Enum
16
  from collections import OrderedDict
17
  import psutil
18
  import torch
19
- # Patch torch.rms_norm for compatibility with models like Phi-3 that expect this function
20
- if not hasattr(torch, 'rms_norm'):
21
- def rms_norm(input_tensor, normalized_shape=None, weight=None, eps=1e-6):
22
- """Simple RMS normalization implementation compatible with various call signatures"""
23
- # Handle different input formats
24
- if normalized_shape is None:
25
- # If no shape specified, normalize over last dimension
26
- dim = -1
27
- keepdim = True
28
- else:
29
- # If shape is specified, normalize over those dimensions
30
- if isinstance(normalized_shape, int):
31
- dim = normalized_shape
32
- keepdim = True
33
- else:
34
- # Multiple dimensions - normalize over all of them
35
- dim = tuple(range(-len(normalized_shape), 0))
36
- keepdim = True
37
-
38
- # Calculate RMS (root mean square)
39
- variance = input_tensor.pow(2).mean(dim=dim, keepdim=keepdim)
40
- # Normalize
41
- output = input_tensor * torch.rsqrt(variance + eps)
42
- # Apply weight if provided
43
- if weight is not None:
44
- output = output * weight
45
- return output
46
- torch.rms_norm = rms_norm
47
- _rms_norm_patched = True
48
- else:
49
- _rms_norm_patched = False
50
 
51
  from concurrent.futures import ThreadPoolExecutor, as_completed
52
 
53
  # Configure logging
54
  logger = logging.getLogger(__name__)
55
- if _rms_norm_patched:
56
- logger.info("Patched torch.rms_norm for compatibility with Phi-3 and similar models")
57
 
58
  class ModelType(Enum):
59
  """Supported model types"""
 
16
  from collections import OrderedDict
17
  import psutil
18
  import torch
19
+ # Ensure PyTorch compatibility patches are applied early
20
+ from .torch_compat import ensure_torch_compatibility
21
+ ensure_torch_compatibility()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  from concurrent.futures import ThreadPoolExecutor, as_completed
24
 
25
  # Configure logging
26
  logger = logging.getLogger(__name__)
 
 
27
 
28
  class ModelType(Enum):
29
  """Supported model types"""