HNTAI / test_device_fix.py
sachinchandrankallar's picture
updates
202f345
Raw
History Blame
4.36 kB
#!/usr/bin/env python3
"""
Test script to verify the device parameter fix for accelerate models.
This script tests the get_summarizer_pipeline function with different scenarios.
"""
import os
import sys
import logging
from pathlib import Path
# Add the ai_med_extract module to Python path
current_dir = Path(__file__).parent
ai_med_extract_path = current_dir / "services" / "ai-service" / "src"
sys.path.insert(0, str(ai_med_extract_path))
# Also add the parent directory for proper module resolution
sys.path.insert(0, str(ai_med_extract_path.parent))
# Set environment variables
os.environ.setdefault('HF_SPACES', 'true')
os.environ.setdefault('PYTHONUNBUFFERED', '1')
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def test_device_parameter_fix():
"""Test that the device parameter fix works correctly."""
print("Testing device parameter fix...")
try:
# Import the fixed function
from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline
# Test with a small model that's likely to work
test_model_name = "sshleifer/distilbart-cnn-6-6"
test_model_type = "summarization"
print(f"Testing with model: {test_model_name}")
# This should not raise the accelerate device error
try:
pipeline = get_summarizer_pipeline(test_model_type, test_model_name)
print("βœ“ Pipeline created successfully without device parameter conflicts")
# Test a simple summarization
test_text = "This is a test document that needs to be summarized. It contains multiple sentences to test the summarization functionality."
result = pipeline(test_text, max_length=50, min_length=10, do_sample=False)
print("βœ“ Pipeline works correctly for summarization")
print(f" Result: {result[0]['summary_text']}")
return True
except Exception as e:
if "accelerate" in str(e).lower() and "device" in str(e).lower():
print(f"βœ— Device parameter conflict still exists: {e}")
return False
else:
print(f"⚠ Other error (not device-related): {e}")
return True # This might be a different issue
except ImportError as e:
print(f"βœ— Import error: {e}")
return False
except Exception as e:
print(f"βœ— Unexpected error: {e}")
return False
def test_fallback_behavior():
"""Test that the fallback behavior works when device parameters fail."""
print("\nTesting fallback behavior...")
try:
from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline
# Clear any existing cache
if hasattr(get_summarizer_pipeline, "cache"):
get_summarizer_pipeline.cache.clear()
# Test with a model that might trigger accelerate issues
test_model_name = "google/flan-t5-large" # This model often uses accelerate
test_model_type = "summarization"
print(f"Testing fallback with model: {test_model_name}")
try:
pipeline = get_summarizer_pipeline(test_model_type, test_model_name)
print("βœ“ Fallback pipeline created successfully")
return True
except Exception as e:
print(f"⚠ Fallback test failed: {e}")
return False
except Exception as e:
print(f"βœ— Fallback test error: {e}")
return False
def main():
"""Run all tests."""
print("Running device parameter fix tests...")
print("=" * 50)
tests = [
test_device_parameter_fix,
test_fallback_behavior
]
passed = 0
total = len(tests)
for test in tests:
if test():
passed += 1
print()
print("=" * 50)
print(f"Tests passed: {passed}/{total}")
if passed == total:
print("βœ“ All tests passed! The device parameter fix should work.")
else:
print("βœ— Some tests failed. The fix may need additional work.")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)