hoshikrana commited on
Commit
440d613
Β·
1 Parent(s): 2bdb663

feat: resilience, circuit breakers, and LRU cache system

Browse files
backend/orchestration/resilience.py CHANGED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import random
3
+ import logging
4
+ from enum import Enum
5
+ from functools import wraps
6
+ from typing import Callable, Tuple, Any
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, UTC
9
+
10
+ from backend.core.exceptions import (
11
+ InferenceError, CircuitOpenError, ValidationError,
12
+ SecurityError, ModelNotLoadedError
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # ━━━ PART A: RETRY DECORATOR ━━━
18
+
19
+ @dataclass
20
+ class RetryConfig:
21
+ max_attempts: int = 3
22
+ initial_delay_seconds: float = 1.0
23
+ exponential_base: float = 2.0
24
+ max_delay_seconds: float = 30.0
25
+ jitter_factor: float = 0.2
26
+ retry_on: Tuple[type[Exception], ...] = (InferenceError, TimeoutError, asyncio.TimeoutError)
27
+ no_retry_on: Tuple[type[Exception], ...] = (ValidationError, SecurityError, ModelNotLoadedError)
28
+
29
+ def retry(config: RetryConfig | None = None):
30
+ """Decorator factory for async functions to retry on failure with exponential backoff and jitter."""
31
+ if config is None:
32
+ config = RetryConfig()
33
+
34
+ def decorator(func: Callable):
35
+ @wraps(func)
36
+ async def wrapper(*args, **kwargs):
37
+ last_exception = None
38
+
39
+ for attempt in range(1, config.max_attempts + 1):
40
+ try:
41
+ return await func(*args, **kwargs)
42
+
43
+ except config.no_retry_on as e:
44
+ # Immediately fail on exceptions we know won't succeed on retry
45
+ raise
46
+
47
+ except config.retry_on as e:
48
+ last_exception = e
49
+
50
+ if attempt == config.max_attempts:
51
+ logger.error(f"{func.__name__} failed after {attempt} attempts: {e}")
52
+ raise
53
+
54
+ delay = min(
55
+ config.initial_delay_seconds * (config.exponential_base ** (attempt - 1)),
56
+ config.max_delay_seconds
57
+ )
58
+ jitter = delay * config.jitter_factor * (2 * random.random() - 1)
59
+ actual_delay = max(0.1, delay + jitter)
60
+
61
+ logger.warning(
62
+ f"{func.__name__} attempt {attempt}/{config.max_attempts} failed. "
63
+ f"Retrying in {actual_delay:.1f}s. Error: {e}"
64
+ )
65
+ await asyncio.sleep(actual_delay)
66
+
67
+ if last_exception:
68
+ raise last_exception
69
+ return wrapper
70
+ return decorator
71
+
72
+ # ━━━ PART B: CIRCUIT BREAKER ━━━
73
+
74
+ class CircuitState(Enum):
75
+ CLOSED = "CLOSED" # Normal operation
76
+ OPEN = "OPEN" # Failing β€” reject all calls
77
+ HALF_OPEN = "HALF_OPEN" # Testing β€” allow one call
78
+
79
+ @dataclass
80
+ class CircuitBreakerConfig:
81
+ name: str
82
+ failure_threshold: int = 5
83
+ success_threshold: int = 2
84
+ timeout_seconds: float = 60.0
85
+ call_timeout_seconds: float = 30.0
86
+
87
+ class CircuitBreaker:
88
+ def __init__(self, config: CircuitBreakerConfig):
89
+ self.config = config
90
+ self._state = CircuitState.CLOSED
91
+ self._failure_count = 0
92
+ self._success_count = 0
93
+ self._last_failure_time: datetime | None = None
94
+ self._lock = asyncio.Lock()
95
+
96
+ @property
97
+ def state(self) -> CircuitState:
98
+ if self._state == CircuitState.OPEN and self._last_failure_time:
99
+ elapsed = (datetime.now(UTC) - self._last_failure_time).total_seconds()
100
+ if elapsed >= self.config.timeout_seconds:
101
+ return CircuitState.HALF_OPEN
102
+ return self._state
103
+
104
+ async def call(self, func: Callable, *args, **kwargs) -> Any:
105
+ async with self._lock:
106
+ current_state = self.state
107
+
108
+ if current_state == CircuitState.OPEN:
109
+ next_attempt_in = (
110
+ self.config.timeout_seconds -
111
+ (datetime.now(UTC) - self._last_failure_time).total_seconds()
112
+ )
113
+ raise CircuitOpenError(
114
+ f"Circuit '{self.config.name}' is OPEN. Retry in {max(0, next_attempt_in):.0f}s"
115
+ )
116
+
117
+ try:
118
+ # Enforce timeout at the circuit breaker level
119
+ result = await asyncio.wait_for(
120
+ func(*args, **kwargs),
121
+ timeout=self.config.call_timeout_seconds
122
+ )
123
+ await self._on_success()
124
+ return result
125
+
126
+ except CircuitOpenError:
127
+ raise
128
+ except Exception as e:
129
+ await self._on_failure()
130
+ raise
131
+
132
+ async def _on_success(self):
133
+ async with self._lock:
134
+ current_state = self.state
135
+ if current_state == CircuitState.HALF_OPEN:
136
+ self._success_count += 1
137
+ if self._success_count >= self.config.success_threshold:
138
+ self._state = CircuitState.CLOSED
139
+ self._failure_count = 0
140
+ self._success_count = 0
141
+ logger.info(f"Circuit '{self.config.name}' CLOSED (recovered)")
142
+ elif current_state == CircuitState.CLOSED:
143
+ self._failure_count = 0
144
+
145
+ async def _on_failure(self):
146
+ async with self._lock:
147
+ self._failure_count += 1
148
+ self._success_count = 0
149
+ self._last_failure_time = datetime.now(UTC)
150
+
151
+ current_state = self.state
152
+ if current_state == CircuitState.HALF_OPEN:
153
+ self._state = CircuitState.OPEN
154
+ logger.warning(f"Circuit '{self.config.name}' re-OPENED (test failed)")
155
+ elif current_state == CircuitState.CLOSED and self._failure_count >= self.config.failure_threshold:
156
+ self._state = CircuitState.OPEN
157
+ logger.error(f"Circuit '{self.config.name}' OPENED after {self._failure_count} failures")
158
+
159
+ def force_close(self):
160
+ self._state = CircuitState.CLOSED
161
+ self._failure_count = 0
162
+ self._success_count = 0
163
+ logger.info(f"Circuit '{self.config.name}' force-closed by admin")
164
+
165
+ def get_status(self) -> dict:
166
+ return {
167
+ "name": self.config.name,
168
+ "state": self.state.value,
169
+ "failure_count": self._failure_count,
170
+ "failure_threshold": self.config.failure_threshold,
171
+ "last_failure": self._last_failure_time.isoformat() if self._last_failure_time else None,
172
+ "next_attempt_seconds": max(0, self.config.timeout_seconds - (
173
+ (datetime.now(UTC) - self._last_failure_time).total_seconds()
174
+ )) if self._state == CircuitState.OPEN and self._last_failure_time else None
175
+ }
176
+
177
+ CIRCUIT_BREAKERS = {
178
+ "vision": CircuitBreaker(CircuitBreakerConfig("vision", failure_threshold=3, timeout_seconds=60, call_timeout_seconds=45)),
179
+ "nlp": CircuitBreaker(CircuitBreakerConfig("nlp", failure_threshold=5, timeout_seconds=30, call_timeout_seconds=20)),
180
+ "fusion": CircuitBreaker(CircuitBreakerConfig("fusion", failure_threshold=3, timeout_seconds=45, call_timeout_seconds=30)),
181
+ "rag": CircuitBreaker(CircuitBreakerConfig("rag", failure_threshold=5, timeout_seconds=60, call_timeout_seconds=30)),
182
+ }
backend/tests/unit/test_resilience.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ from unittest.mock import AsyncMock
4
+ from backend.orchestration.resilience import (
5
+ retry, RetryConfig, CircuitBreaker, CircuitBreakerConfig, CircuitState
6
+ )
7
+ from backend.core.exceptions import InferenceError, ValidationError, CircuitOpenError
8
+
9
+ @pytest.mark.asyncio
10
+ async def test_retry_succeeds_on_third_attempt():
11
+ mock_func = AsyncMock(side_effect=[InferenceError(), InferenceError(), "success"])
12
+ decorated = retry(RetryConfig(max_attempts=3, initial_delay_seconds=0.01))(mock_func)
13
+
14
+ result = await decorated()
15
+ assert result == "success"
16
+ assert mock_func.call_count == 3
17
+
18
+ @pytest.mark.asyncio
19
+ async def test_retry_raises_after_max_attempts():
20
+ mock_func = AsyncMock(side_effect=InferenceError())
21
+ decorated = retry(RetryConfig(max_attempts=2, initial_delay_seconds=0.01))(mock_func)
22
+
23
+ with pytest.raises(InferenceError):
24
+ await decorated()
25
+ assert mock_func.call_count == 2
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_no_retry_on_excluded_exceptions():
29
+ mock_func = AsyncMock(side_effect=ValidationError())
30
+ decorated = retry(RetryConfig())(mock_func)
31
+
32
+ with pytest.raises(ValidationError):
33
+ await decorated()
34
+ assert mock_func.call_count == 1 # Failed immediately, no retry
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_circuit_opens_after_threshold():
38
+ cb = CircuitBreaker(CircuitBreakerConfig("test", failure_threshold=2, timeout_seconds=60))
39
+ mock_fail = AsyncMock(side_effect=InferenceError())
40
+
41
+ with pytest.raises(InferenceError): await cb.call(mock_fail)
42
+ with pytest.raises(InferenceError): await cb.call(mock_fail)
43
+
44
+ assert cb.state == CircuitState.OPEN
45
+ with pytest.raises(CircuitOpenError):
46
+ await cb.call(mock_fail)
47
+
48
+ @pytest.mark.asyncio
49
+ async def test_circuit_allows_one_call_in_half_open():
50
+ cb = CircuitBreaker(CircuitBreakerConfig("test", failure_threshold=1, timeout_seconds=0.1))
51
+ mock_fail = AsyncMock(side_effect=InferenceError())
52
+ mock_success = AsyncMock(return_value="success")
53
+
54
+ with pytest.raises(InferenceError): await cb.call(mock_fail) # Opens circuit
55
+ assert cb.state == CircuitState.OPEN
56
+
57
+ await asyncio.sleep(0.15) # Wait for timeout
58
+ assert cb.state == CircuitState.HALF_OPEN
59
+
60
+ result = await cb.call(mock_success)
61
+ assert result == "success"
62
+
63
+ @pytest.mark.asyncio
64
+ async def test_circuit_reopens_on_half_open_failure():
65
+ cb = CircuitBreaker(CircuitBreakerConfig("test", failure_threshold=1, timeout_seconds=0.1))
66
+ mock_fail = AsyncMock(side_effect=InferenceError())
67
+
68
+ with pytest.raises(InferenceError): await cb.call(mock_fail)
69
+ await asyncio.sleep(0.15)
70
+
71
+ assert cb.state == CircuitState.HALF_OPEN
72
+ with pytest.raises(InferenceError): await cb.call(mock_fail)
73
+ assert cb.state == CircuitState.OPEN
backend/utils/cache.py CHANGED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ from collections import OrderedDict
7
+ from typing import Generic, TypeVar, Callable
8
+ from dataclasses import dataclass
9
+ from functools import wraps
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ K = TypeVar("K")
14
+ V = TypeVar("V")
15
+
16
+ @dataclass
17
+ class CacheEntry(Generic[V]):
18
+ value: V
19
+ created_at: float
20
+ hit_count: int = 0
21
+
22
+ class LRUCache(Generic[K, V]):
23
+ """In-memory thread/async-safe Least Recently Used cache."""
24
+ def __init__(self, max_size: int, ttl_seconds: int, name: str = "cache"):
25
+ self._cache: OrderedDict[K, CacheEntry[V]] = OrderedDict()
26
+ self._max_size = max_size
27
+ self._ttl = ttl_seconds
28
+ self._name = name
29
+ self._lock = asyncio.Lock()
30
+ self._hits = 0
31
+ self._misses = 0
32
+
33
+ async def get(self, key: K) -> V | None:
34
+ async with self._lock:
35
+ if key not in self._cache:
36
+ self._misses += 1
37
+ return None
38
+
39
+ entry = self._cache[key]
40
+
41
+ if time.monotonic() - entry.created_at > self._ttl:
42
+ del self._cache[key]
43
+ self._misses += 1
44
+ return None
45
+
46
+ self._cache.move_to_end(key)
47
+ entry.hit_count += 1
48
+ self._hits += 1
49
+ return entry.value
50
+
51
+ async def set(self, key: K, value: V):
52
+ async with self._lock:
53
+ if key in self._cache:
54
+ self._cache.move_to_end(key)
55
+ self._cache[key] = CacheEntry(value=value, created_at=time.monotonic())
56
+
57
+ while len(self._cache) > self._max_size:
58
+ evicted_key, _ = self._cache.popitem(last=False)
59
+ logger.debug(f"Cache '{self._name}': evicted key {evicted_key}")
60
+
61
+ async def invalidate(self, key: K):
62
+ async with self._lock:
63
+ self._cache.pop(key, None)
64
+
65
+ async def clear(self):
66
+ async with self._lock:
67
+ self._cache.clear()
68
+ self._hits = 0
69
+ self._misses = 0
70
+
71
+ def stats(self) -> dict:
72
+ total = self._hits + self._misses
73
+ return {
74
+ "name": self._name,
75
+ "size": len(self._cache),
76
+ "max_size": self._max_size,
77
+ "ttl_seconds": self._ttl,
78
+ "total_requests": total,
79
+ "hits": self._hits,
80
+ "misses": self._misses,
81
+ "hit_rate": round(self._hits / total, 3) if total > 0 else 0.0
82
+ }
83
+
84
+ # ━━━ PART B: APPLICATION CACHES ━━━
85
+
86
+ def _make_analysis_key(image_bytes: bytes, symptoms_text: str) -> str:
87
+ content = image_bytes + symptoms_text.encode()
88
+ return hashlib.md5(content).hexdigest()
89
+
90
+ def _make_rag_key(query: str) -> str:
91
+ return hashlib.md5(query.lower().strip().encode()).hexdigest()
92
+
93
+ # Singletons initialized here for app-wide use
94
+ analysis_cache: LRUCache[str, dict] = LRUCache(max_size=50, ttl_seconds=3600, name="analysis")
95
+ rag_cache: LRUCache[str, list] = LRUCache(max_size=200, ttl_seconds=1800, name="rag")
96
+ user_cache: LRUCache[str, dict] = LRUCache(max_size=500, ttl_seconds=300, name="user")
97
+
98
+ # ━━━ PART C: CACHE DECORATOR ━━━
99
+
100
+ def cached(cache: LRUCache, key_fn: Callable):
101
+ """Decorator for async functions. Checks cache before executing."""
102
+ def decorator(func):
103
+ @wraps(func)
104
+ async def wrapper(*args, **kwargs):
105
+ key = key_fn(*args, **kwargs)
106
+ cached_value = await cache.get(key)
107
+
108
+ if cached_value is not None:
109
+ logger.debug(f"Cache HIT for {func.__name__}: key={str(key)[:8]}...")
110
+ return cached_value
111
+
112
+ logger.debug(f"Cache MISS for {func.__name__}: key={str(key)[:8]}...")
113
+ result = await func(*args, **kwargs)
114
+ await cache.set(key, result)
115
+ return result
116
+ return wrapper
117
+ return decorator