Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Test script to verify the device parameter fix for accelerate models. | |
| This script tests the get_summarizer_pipeline function with different scenarios. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| # Add the ai_med_extract module to Python path | |
| current_dir = Path(__file__).parent | |
| ai_med_extract_path = current_dir / "services" / "ai-service" / "src" | |
| sys.path.insert(0, str(ai_med_extract_path)) | |
| # Also add the parent directory for proper module resolution | |
| sys.path.insert(0, str(ai_med_extract_path.parent)) | |
| # Set environment variables | |
| os.environ.setdefault('HF_SPACES', 'true') | |
| os.environ.setdefault('PYTHONUNBUFFERED', '1') | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| def test_device_parameter_fix(): | |
| """Test that the device parameter fix works correctly.""" | |
| print("Testing device parameter fix...") | |
| try: | |
| # Import the fixed function | |
| from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline | |
| # Test with a small model that's likely to work | |
| test_model_name = "sshleifer/distilbart-cnn-6-6" | |
| test_model_type = "summarization" | |
| print(f"Testing with model: {test_model_name}") | |
| # This should not raise the accelerate device error | |
| try: | |
| pipeline = get_summarizer_pipeline(test_model_type, test_model_name) | |
| print("β Pipeline created successfully without device parameter conflicts") | |
| # Test a simple summarization | |
| test_text = "This is a test document that needs to be summarized. It contains multiple sentences to test the summarization functionality." | |
| result = pipeline(test_text, max_length=50, min_length=10, do_sample=False) | |
| print("β Pipeline works correctly for summarization") | |
| print(f" Result: {result[0]['summary_text']}") | |
| return True | |
| except Exception as e: | |
| if "accelerate" in str(e).lower() and "device" in str(e).lower(): | |
| print(f"β Device parameter conflict still exists: {e}") | |
| return False | |
| else: | |
| print(f"β Other error (not device-related): {e}") | |
| return True # This might be a different issue | |
| except ImportError as e: | |
| print(f"β Import error: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"β Unexpected error: {e}") | |
| return False | |
| def test_fallback_behavior(): | |
| """Test that the fallback behavior works when device parameters fail.""" | |
| print("\nTesting fallback behavior...") | |
| try: | |
| from ai_med_extract.api.routes_fastapi import get_summarizer_pipeline | |
| # Clear any existing cache | |
| if hasattr(get_summarizer_pipeline, "cache"): | |
| get_summarizer_pipeline.cache.clear() | |
| # Test with a model that might trigger accelerate issues | |
| test_model_name = "google/flan-t5-large" # This model often uses accelerate | |
| test_model_type = "summarization" | |
| print(f"Testing fallback with model: {test_model_name}") | |
| try: | |
| pipeline = get_summarizer_pipeline(test_model_type, test_model_name) | |
| print("β Fallback pipeline created successfully") | |
| return True | |
| except Exception as e: | |
| print(f"β Fallback test failed: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"β Fallback test error: {e}") | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("Running device parameter fix tests...") | |
| print("=" * 50) | |
| tests = [ | |
| test_device_parameter_fix, | |
| test_fallback_behavior | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test in tests: | |
| if test(): | |
| passed += 1 | |
| print() | |
| print("=" * 50) | |
| print(f"Tests passed: {passed}/{total}") | |
| if passed == total: | |
| print("β All tests passed! The device parameter fix should work.") | |
| else: | |
| print("β Some tests failed. The fix may need additional work.") | |
| return passed == total | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) | |