| """ |
| Test script for Advanced RAG features |
| Demonstrates new capabilities: multiple texts/images indexing and advanced RAG chat |
| """ |
|
|
| import requests |
| import json |
| from typing import List, Optional |
|
|
|
|
| class AdvancedRAGTester: |
| """Test client for Advanced RAG API""" |
|
|
| def __init__(self, base_url: str = "http://localhost:8000"): |
| self.base_url = base_url |
|
|
| def test_multiple_index(self, doc_id: str, texts: List[str], image_paths: Optional[List[str]] = None): |
| """ |
| Test indexing with multiple texts and images |
| |
| Args: |
| doc_id: Document ID |
| texts: List of texts (max 10) |
| image_paths: List of image file paths (max 10) |
| """ |
| print(f"\n{'='*60}") |
| print(f"TEST: Indexing document '{doc_id}' with multiple texts/images") |
| print(f"{'='*60}") |
|
|
| |
| data = {'id': doc_id} |
|
|
| |
| if texts: |
| if len(texts) > 10: |
| print("WARNING: Maximum 10 texts allowed. Taking first 10.") |
| texts = texts[:10] |
| data['texts'] = texts |
| print(f"✓ Texts: {len(texts)} items") |
|
|
| |
| files = [] |
| if image_paths: |
| if len(image_paths) > 10: |
| print("WARNING: Maximum 10 images allowed. Taking first 10.") |
| image_paths = image_paths[:10] |
|
|
| for img_path in image_paths: |
| try: |
| files.append(('images', open(img_path, 'rb'))) |
| except FileNotFoundError: |
| print(f"WARNING: Image not found: {img_path}") |
|
|
| print(f"✓ Images: {len(files)} files") |
|
|
| |
| try: |
| response = requests.post(f"{self.base_url}/index", data=data, files=files) |
| response.raise_for_status() |
|
|
| result = response.json() |
| print(f"\n✓ SUCCESS") |
| print(f" - Document ID: {result['id']}") |
| print(f" - Message: {result['message']}") |
| return result |
|
|
| except requests.exceptions.RequestException as e: |
| print(f"\n✗ ERROR: {e}") |
| if hasattr(e.response, 'text'): |
| print(f" Response: {e.response.text}") |
| return None |
|
|
| finally: |
| |
| for _, file_obj in files: |
| file_obj.close() |
|
|
| def test_advanced_rag_chat( |
| self, |
| message: str, |
| hf_token: Optional[str] = None, |
| use_advanced_rag: bool = True, |
| use_reranking: bool = True, |
| use_compression: bool = True, |
| top_k: int = 3, |
| score_threshold: float = 0.5 |
| ): |
| """ |
| Test advanced RAG chat |
| |
| Args: |
| message: User question |
| hf_token: Hugging Face token (optional) |
| use_advanced_rag: Use advanced RAG pipeline |
| use_reranking: Enable reranking |
| use_compression: Enable context compression |
| top_k: Number of documents to retrieve |
| score_threshold: Minimum relevance score |
| """ |
| print(f"\n{'='*60}") |
| print(f"TEST: Advanced RAG Chat") |
| print(f"{'='*60}") |
| print(f"Question: {message}") |
| print(f"Advanced RAG: {use_advanced_rag}") |
| print(f"Reranking: {use_reranking}") |
| print(f"Compression: {use_compression}") |
|
|
| payload = { |
| 'message': message, |
| 'use_rag': True, |
| 'use_advanced_rag': use_advanced_rag, |
| 'use_reranking': use_reranking, |
| 'use_compression': use_compression, |
| 'top_k': top_k, |
| 'score_threshold': score_threshold, |
| } |
|
|
| if hf_token: |
| payload['hf_token'] = hf_token |
|
|
| try: |
| response = requests.post(f"{self.base_url}/chat", json=payload) |
| response.raise_for_status() |
|
|
| result = response.json() |
|
|
| print(f"\n✓ SUCCESS") |
| print(f"\n--- Answer ---") |
| print(result['response']) |
|
|
| print(f"\n--- Retrieved Context ({len(result['context_used'])} documents) ---") |
| for i, ctx in enumerate(result['context_used'], 1): |
| print(f"{i}. [{ctx['id']}] Confidence: {ctx['confidence']:.2%}") |
| text_preview = ctx['metadata'].get('text', '')[:100] |
| print(f" Text: {text_preview}...") |
|
|
| if result.get('rag_stats'): |
| print(f"\n--- RAG Pipeline Statistics ---") |
| stats = result['rag_stats'] |
| print(f" Original query: {stats.get('original_query')}") |
| print(f" Expanded queries: {stats.get('expanded_queries')}") |
| print(f" Initial results: {stats.get('initial_results')}") |
| print(f" After reranking: {stats.get('after_rerank')}") |
| print(f" After compression: {stats.get('after_compression')}") |
|
|
| return result |
|
|
| except requests.exceptions.RequestException as e: |
| print(f"\n✗ ERROR: {e}") |
| if hasattr(e.response, 'text'): |
| print(f" Response: {e.response.text}") |
| return None |
|
|
| def compare_basic_vs_advanced_rag(self, message: str, hf_token: Optional[str] = None): |
| """Compare basic RAG vs advanced RAG side by side""" |
| print(f"\n{'='*60}") |
| print(f"COMPARISON: Basic RAG vs Advanced RAG") |
| print(f"{'='*60}") |
| print(f"Question: {message}\n") |
|
|
| |
| print("\n--- BASIC RAG ---") |
| basic_result = self.test_advanced_rag_chat( |
| message=message, |
| hf_token=hf_token, |
| use_advanced_rag=False |
| ) |
|
|
| |
| print("\n--- ADVANCED RAG ---") |
| advanced_result = self.test_advanced_rag_chat( |
| message=message, |
| hf_token=hf_token, |
| use_advanced_rag=True |
| ) |
|
|
| |
| print(f"\n{'='*60}") |
| print("COMPARISON SUMMARY") |
| print(f"{'='*60}") |
|
|
| if basic_result and advanced_result: |
| print(f"Basic RAG:") |
| print(f" - Retrieved docs: {len(basic_result['context_used'])}") |
|
|
| print(f"\nAdvanced RAG:") |
| print(f" - Retrieved docs: {len(advanced_result['context_used'])}") |
| if advanced_result.get('rag_stats'): |
| stats = advanced_result['rag_stats'] |
| print(f" - Query expansion: {len(stats.get('expanded_queries', []))} variants") |
| print(f" - Initial retrieval: {stats.get('initial_results', 0)} docs") |
| print(f" - After reranking: {stats.get('after_rerank', 0)} docs") |
|
|
|
|
| def main(): |
| """Run tests""" |
| tester = AdvancedRAGTester() |
|
|
| print("="*60) |
| print("ADVANCED RAG FEATURE TESTS") |
| print("="*60) |
|
|
| |
| print("\n\n### TEST 1: Index Multiple Texts ###") |
| tester.test_multiple_index( |
| doc_id="event_music_festival_2025", |
| texts=[ |
| "Festival âm nhạc quốc tế Hà Nội 2025", |
| "Thời gian: 15-17 tháng 11 năm 2025", |
| "Địa điểm: Công viên Thống Nhất, Hà Nội", |
| "Line-up: Sơn Tùng MTP, Đen Vâu, Hoàng Thùy Linh, Mỹ Tâm", |
| "Giá vé: Early bird 500.000đ, VIP 2.000.000đ", |
| "Dự kiến 50.000 khán giả tham dự", |
| "3 sân khấu chính, 5 food court, khu vực cắm trại" |
| ] |
| ) |
|
|
| |
| print("\n\n### TEST 2: Index Another Document ###") |
| tester.test_multiple_index( |
| doc_id="safety_guidelines", |
| texts=[ |
| "Vũ khí và đồ vật nguy hiểm bị cấm mang vào sự kiện", |
| "Dao, kiếm, súng và các loại vũ khí nguy hiểm nghiêm cấm", |
| "An ninh sẽ kiểm tra tất cả túi xách và đồ mang theo", |
| "Vi phạm sẽ bị tịch thu và có thể bị trục xuất khỏi sự kiện" |
| ] |
| ) |
|
|
| |
| print("\n\n### TEST 3: Basic RAG Chat (No LLM) ###") |
| tester.test_advanced_rag_chat( |
| message="Festival Hà Nội diễn ra khi nào?", |
| use_advanced_rag=False |
| ) |
|
|
| |
| print("\n\n### TEST 4: Advanced RAG Chat (No LLM) ###") |
| tester.test_advanced_rag_chat( |
| message="Festival Hà Nội diễn ra khi nào và có những nghệ sĩ nào?", |
| use_advanced_rag=True, |
| use_reranking=True, |
| use_compression=True |
| ) |
|
|
| |
| print("\n\n### TEST 5: Comparison Test ###") |
| tester.compare_basic_vs_advanced_rag( |
| message="Dao có được mang vào sự kiện không?" |
| ) |
|
|
| print("\n\n" + "="*60) |
| print("ALL TESTS COMPLETED") |
| print("="*60) |
| print("\nNOTE: To test with actual LLM responses, add your Hugging Face token:") |
| print(" tester.test_advanced_rag_chat(message='...', hf_token='hf_xxxxx')") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|