Spaces:
Paused
Paused
File size: 4,361 Bytes
c20c8b5 202f345 c20c8b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | #!/usr/bin/env python3
"""
Test script to verify the device parameter fix for accelerate models.
This script tests the get_summarizer_pipeline function with different scenarios.
"""
import os
import sys
import logging
from pathlib import Path
# Add the ai_med_extract module to Python path
current_dir = Path(__file__).parent
ai_med_extract_path = current_dir / "services" / "ai-service" / "src"
sys.path.insert(0, str(ai_med_extract_path))
# Also add the parent directory for proper module resolution
sys.path.insert(0, str(ai_med_extract_path.parent))
# Set environment variables
os.environ.setdefault('HF_SPACES', 'true')
os.environ.setdefault('PYTHONUNBUFFERED', '1')
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def test_device_parameter_fix():
"""Test that the device parameter fix works correctly."""
print("Testing device parameter fix...")
try:
# Import the fixed function
from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline
# Test with a small model that's likely to work
test_model_name = "sshleifer/distilbart-cnn-6-6"
test_model_type = "summarization"
print(f"Testing with model: {test_model_name}")
# This should not raise the accelerate device error
try:
pipeline = get_summarizer_pipeline(test_model_type, test_model_name)
print("β Pipeline created successfully without device parameter conflicts")
# Test a simple summarization
test_text = "This is a test document that needs to be summarized. It contains multiple sentences to test the summarization functionality."
result = pipeline(test_text, max_length=50, min_length=10, do_sample=False)
print("β Pipeline works correctly for summarization")
print(f" Result: {result[0]['summary_text']}")
return True
except Exception as e:
if "accelerate" in str(e).lower() and "device" in str(e).lower():
print(f"β Device parameter conflict still exists: {e}")
return False
else:
print(f"β Other error (not device-related): {e}")
return True # This might be a different issue
except ImportError as e:
print(f"β Import error: {e}")
return False
except Exception as e:
print(f"β Unexpected error: {e}")
return False
def test_fallback_behavior():
"""Test that the fallback behavior works when device parameters fail."""
print("\nTesting fallback behavior...")
try:
from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline
# Clear any existing cache
if hasattr(get_summarizer_pipeline, "cache"):
get_summarizer_pipeline.cache.clear()
# Test with a model that might trigger accelerate issues
test_model_name = "google/flan-t5-large" # This model often uses accelerate
test_model_type = "summarization"
print(f"Testing fallback with model: {test_model_name}")
try:
pipeline = get_summarizer_pipeline(test_model_type, test_model_name)
print("β Fallback pipeline created successfully")
return True
except Exception as e:
print(f"β Fallback test failed: {e}")
return False
except Exception as e:
print(f"β Fallback test error: {e}")
return False
def main():
"""Run all tests."""
print("Running device parameter fix tests...")
print("=" * 50)
tests = [
test_device_parameter_fix,
test_fallback_behavior
]
passed = 0
total = len(tests)
for test in tests:
if test():
passed += 1
print()
print("=" * 50)
print(f"Tests passed: {passed}/{total}")
if passed == total:
print("β All tests passed! The device parameter fix should work.")
else:
print("β Some tests failed. The fix may need additional work.")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
|