Spaces:
Sleeping
Sleeping
Sai Kumar Taraka commited on
Commit ·
6e4b4a4
1
Parent(s): 2cef8a9
Production-level model enhancements: fix strategy selection bug, add retry/health/request_id/validation, rewrite tests with pytest assertions, harden cache
Browse files- src/models/enhanced_ml_model_v2.py +142 -19
- tests/test_advanced_ml_v2.py +278 -423
src/models/enhanced_ml_model_v2.py
CHANGED
|
@@ -33,6 +33,36 @@ from src.models.template_model import TemplateModel
|
|
| 33 |
from src.models.coverage_predictor import CoveragePredictor, SpecFeatures
|
| 34 |
from src.config import PipelineConfig, DesignSpec
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
from src.features.extractors import RichSpecFeatureExtractor
|
| 38 |
from src.models.similarity_index import SimilarityIndex, SearchResult
|
|
@@ -158,35 +188,65 @@ class MetricsTracker:
|
|
| 158 |
|
| 159 |
|
| 160 |
class GenerationCache:
|
| 161 |
-
"""Spec-driven cache with content-addressable keys."""
|
| 162 |
|
| 163 |
-
def __init__(self, ttl_seconds: int = 3600):
|
| 164 |
self._cache: Dict[str, Tuple[float, Dict[str, str]]] = {}
|
| 165 |
self._ttl = ttl_seconds
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def _make_key(self, spec_dict: Dict[str, Any], protocol: str) -> str:
|
| 168 |
raw = json.dumps(spec_dict, sort_keys=True, default=str)
|
| 169 |
return hashlib.md5(raw.encode()).hexdigest() + f"@{protocol}"
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def get(self, spec_dict: Dict[str, Any], protocol: str) -> Optional[Dict[str, str]]:
|
| 172 |
key = self._make_key(spec_dict, protocol)
|
| 173 |
if key in self._cache:
|
| 174 |
timestamp, files = self._cache[key]
|
| 175 |
if time.time() - timestamp < self._ttl:
|
|
|
|
|
|
|
|
|
|
| 176 |
return files
|
| 177 |
del self._cache[key]
|
|
|
|
|
|
|
| 178 |
return None
|
| 179 |
|
| 180 |
def set(self, spec_dict: Dict[str, Any], protocol: str, files: Dict[str, str]) -> None:
|
| 181 |
key = self._make_key(spec_dict, protocol)
|
| 182 |
self._cache[key] = (time.time(), files)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
def invalidate(self, spec_dict: Dict[str, Any], protocol: str) -> None:
|
| 185 |
key = self._make_key(spec_dict, protocol)
|
| 186 |
self._cache.pop(key, None)
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def clear(self) -> None:
|
| 189 |
self._cache.clear()
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
class EnhancedMLGenerationModelV2(GenerationModel):
|
|
@@ -326,11 +386,18 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 326 |
spec: DesignSpec,
|
| 327 |
cfg: PipelineConfig,
|
| 328 |
extra_seqs: Optional[List[str]] = None,
|
|
|
|
| 329 |
) -> Dict[str, str]:
|
| 330 |
if not HAS_ADVANCED:
|
| 331 |
return self._template_model.predict(spec, cfg)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
|
|
|
|
| 334 |
design_name = spec.design_name
|
| 335 |
protocol = spec_dict.get("protocol", "unknown")
|
| 336 |
|
|
@@ -338,18 +405,21 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 338 |
if self._cache:
|
| 339 |
cached = self._cache.get(spec_dict, protocol)
|
| 340 |
if cached is not None:
|
| 341 |
-
|
| 342 |
return cached
|
| 343 |
|
| 344 |
# Build validator
|
| 345 |
self._code_validator = AdvancedCodeValidator(spec_dict)
|
| 346 |
available_sources = self._get_available_sources()
|
| 347 |
|
|
|
|
|
|
|
| 348 |
start_time = time.time()
|
| 349 |
|
| 350 |
# Ensemble: run top-K strategies concurrently
|
| 351 |
selected = self._select_generation_strategy(spec_dict, protocol, available_sources)
|
| 352 |
strategies_to_run = self._get_strategy_plan(selected, available_sources)
|
|
|
|
| 353 |
|
| 354 |
results: List[GenerationResult] = []
|
| 355 |
with ThreadPoolExecutor(max_workers=min(self._max_concurrent, len(strategies_to_run))) as executor:
|
|
@@ -391,8 +461,10 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 391 |
# Coverage prediction (lazy-trained by CoveragePredictor on first call)
|
| 392 |
try:
|
| 393 |
self.last_coverage_prediction = self._coverage_predictor.predict_coverage(spec, final_result.files)
|
|
|
|
|
|
|
| 394 |
except Exception as e:
|
| 395 |
-
logger.
|
| 396 |
self.last_coverage_prediction = None
|
| 397 |
|
| 398 |
# Store last result for learn() / generate() introspect
|
|
@@ -440,17 +512,33 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 440 |
self,
|
| 441 |
spec_dict: Dict[str, Any],
|
| 442 |
cfg: Optional[PipelineConfig] = None,
|
|
|
|
| 443 |
) -> Dict[str, Any]:
|
| 444 |
"""Public API: generate from raw spec dict (test-compatible interface).
|
| 445 |
|
| 446 |
Returns a rich result dict with ``passed``, ``generated_files``,
|
| 447 |
``source``, ``strategy``, and ``validation_results``.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
try:
|
| 450 |
spec = DesignSpec(**self._coerce_spec_dict(spec_dict))
|
| 451 |
except Exception as e:
|
| 452 |
-
logger.error("Failed to build DesignSpec from dict: %s", e)
|
| 453 |
-
return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error"}
|
| 454 |
|
| 455 |
# Auto-train template model if not yet trained
|
| 456 |
if not self._template_model._is_trained:
|
|
@@ -477,7 +565,7 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 477 |
),
|
| 478 |
)
|
| 479 |
|
| 480 |
-
files = self.predict(spec, cfg)
|
| 481 |
|
| 482 |
# Build result dict from stored generation result
|
| 483 |
gen = getattr(self, '_last_generation_result', None)
|
|
@@ -488,6 +576,7 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 488 |
"generated_files": files,
|
| 489 |
"source": gen.source.value if gen else "template",
|
| 490 |
"strategy": gen.strategy_used if gen else "template",
|
|
|
|
| 491 |
}
|
| 492 |
|
| 493 |
# Attach validation results if available
|
|
@@ -641,10 +730,6 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 641 |
if len(available_sources) == 1:
|
| 642 |
return GenerationSource(available_sources[0])
|
| 643 |
|
| 644 |
-
# Ensemble mode: run top strategies concurrently
|
| 645 |
-
if len(available_sources) >= 2:
|
| 646 |
-
return GenerationSource.ENSEMBLE
|
| 647 |
-
|
| 648 |
if not self._use_learning or not self._rl_learner:
|
| 649 |
if "retrieval" in available_sources and self._index and len(self._index) > 0:
|
| 650 |
return GenerationSource.RETRIEVAL
|
|
@@ -666,8 +751,8 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 666 |
source_scores["llm"] += 2.0
|
| 667 |
if feat.register_count > 8 and "retrieval" in source_scores:
|
| 668 |
source_scores["retrieval"] += 1.0
|
| 669 |
-
except Exception:
|
| 670 |
-
|
| 671 |
|
| 672 |
if not source_scores:
|
| 673 |
return GenerationSource.TEMPLATE
|
|
@@ -683,12 +768,22 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 683 |
design_name: str,
|
| 684 |
protocol: str,
|
| 685 |
) -> GenerationResult:
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
|
| 693 |
def _generate_by_retrieval(
|
| 694 |
self, spec: DesignSpec, spec_dict: Dict[str, Any], config: PipelineConfig,
|
|
@@ -988,6 +1083,34 @@ class EnhancedMLGenerationModelV2(GenerationModel):
|
|
| 988 |
stats["pattern_learner"] = self._pattern_learner.get_suggestions(file_type="any", protocol="any")
|
| 989 |
return stats
|
| 990 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
def invalidate_cache(self, spec: Optional[DesignSpec] = None) -> None:
|
| 992 |
if not self._cache:
|
| 993 |
return
|
|
|
|
| 33 |
from src.models.coverage_predictor import CoveragePredictor, SpecFeatures
|
| 34 |
from src.config import PipelineConfig, DesignSpec
|
| 35 |
|
| 36 |
+
|
| 37 |
+
def _retry_with_backoff(
|
| 38 |
+
fn, max_retries: int = 3, base_delay: float = 0.5, backoff: float = 2.0,
|
| 39 |
+
) -> Any:
|
| 40 |
+
"""Execute fn with exponential backoff on transient failures."""
|
| 41 |
+
import functools
|
| 42 |
+
last_exc = None
|
| 43 |
+
for attempt in range(max_retries):
|
| 44 |
+
try:
|
| 45 |
+
return fn()
|
| 46 |
+
except (ConnectionError, TimeoutError, OSError) as e:
|
| 47 |
+
last_exc = e
|
| 48 |
+
if attempt < max_retries - 1:
|
| 49 |
+
delay = base_delay * (backoff ** attempt)
|
| 50 |
+
logger.warning("Transient failure (attempt %d/%d): %s — retrying in %.1fs", attempt + 1, max_retries, e, delay)
|
| 51 |
+
time.sleep(delay)
|
| 52 |
+
raise last_exc # type: ignore[misc]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _validate_spec_dict(spec_dict: Dict[str, Any]) -> None:
|
| 56 |
+
"""Validate spec dict has required fields before generation."""
|
| 57 |
+
if not isinstance(spec_dict, dict):
|
| 58 |
+
raise TypeError(f"Expected dict, got {type(spec_dict).__name__}")
|
| 59 |
+
if "design_name" not in spec_dict:
|
| 60 |
+
raise ValueError("spec_dict must contain 'design_name'")
|
| 61 |
+
if "protocol" not in spec_dict:
|
| 62 |
+
logger.warning("spec_dict missing 'protocol', defaulting to 'unknown'")
|
| 63 |
+
if not spec_dict.get("design_name"):
|
| 64 |
+
raise ValueError("'design_name' must be a non-empty string")
|
| 65 |
+
|
| 66 |
try:
|
| 67 |
from src.features.extractors import RichSpecFeatureExtractor
|
| 68 |
from src.models.similarity_index import SimilarityIndex, SearchResult
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
class GenerationCache:
|
| 191 |
+
"""Spec-driven cache with content-addressable keys, TTL, and size limits."""
|
| 192 |
|
| 193 |
+
def __init__(self, ttl_seconds: int = 3600, max_entries: int = 256):
|
| 194 |
self._cache: Dict[str, Tuple[float, Dict[str, str]]] = {}
|
| 195 |
self._ttl = ttl_seconds
|
| 196 |
+
self._max_entries = max_entries
|
| 197 |
+
self._access_order: List[str] = []
|
| 198 |
|
| 199 |
def _make_key(self, spec_dict: Dict[str, Any], protocol: str) -> str:
|
| 200 |
raw = json.dumps(spec_dict, sort_keys=True, default=str)
|
| 201 |
return hashlib.md5(raw.encode()).hexdigest() + f"@{protocol}"
|
| 202 |
|
| 203 |
+
def _evict_if_needed(self) -> None:
|
| 204 |
+
if len(self._cache) > self._max_entries:
|
| 205 |
+
over = len(self._cache) - self._max_entries
|
| 206 |
+
for _ in range(over):
|
| 207 |
+
if self._access_order:
|
| 208 |
+
oldest = self._access_order.pop(0)
|
| 209 |
+
self._cache.pop(oldest, None)
|
| 210 |
+
|
| 211 |
+
def _clean_expired(self) -> None:
|
| 212 |
+
now = time.time()
|
| 213 |
+
expired = [k for k, (ts, _) in self._cache.items() if now - ts >= self._ttl]
|
| 214 |
+
for k in expired:
|
| 215 |
+
del self._cache[k]
|
| 216 |
+
if k in self._access_order:
|
| 217 |
+
self._access_order.remove(k)
|
| 218 |
+
|
| 219 |
def get(self, spec_dict: Dict[str, Any], protocol: str) -> Optional[Dict[str, str]]:
|
| 220 |
key = self._make_key(spec_dict, protocol)
|
| 221 |
if key in self._cache:
|
| 222 |
timestamp, files = self._cache[key]
|
| 223 |
if time.time() - timestamp < self._ttl:
|
| 224 |
+
if key in self._access_order:
|
| 225 |
+
self._access_order.remove(key)
|
| 226 |
+
self._access_order.append(key)
|
| 227 |
return files
|
| 228 |
del self._cache[key]
|
| 229 |
+
if key in self._access_order:
|
| 230 |
+
self._access_order.remove(key)
|
| 231 |
return None
|
| 232 |
|
| 233 |
def set(self, spec_dict: Dict[str, Any], protocol: str, files: Dict[str, str]) -> None:
|
| 234 |
key = self._make_key(spec_dict, protocol)
|
| 235 |
self._cache[key] = (time.time(), files)
|
| 236 |
+
if key in self._access_order:
|
| 237 |
+
self._access_order.remove(key)
|
| 238 |
+
self._access_order.append(key)
|
| 239 |
+
self._evict_if_needed()
|
| 240 |
|
| 241 |
def invalidate(self, spec_dict: Dict[str, Any], protocol: str) -> None:
|
| 242 |
key = self._make_key(spec_dict, protocol)
|
| 243 |
self._cache.pop(key, None)
|
| 244 |
+
if key in self._access_order:
|
| 245 |
+
self._access_order.remove(key)
|
| 246 |
|
| 247 |
def clear(self) -> None:
|
| 248 |
self._cache.clear()
|
| 249 |
+
self._access_order.clear()
|
| 250 |
|
| 251 |
|
| 252 |
class EnhancedMLGenerationModelV2(GenerationModel):
|
|
|
|
| 386 |
spec: DesignSpec,
|
| 387 |
cfg: PipelineConfig,
|
| 388 |
extra_seqs: Optional[List[str]] = None,
|
| 389 |
+
request_id: Optional[str] = None,
|
| 390 |
) -> Dict[str, str]:
|
| 391 |
if not HAS_ADVANCED:
|
| 392 |
return self._template_model.predict(spec, cfg)
|
| 393 |
|
| 394 |
+
rid = request_id or f"gen_{int(time.time() * 1000)}_{id(spec)}"
|
| 395 |
+
def _log(msg: str, *args: Any) -> None:
|
| 396 |
+
logger.info("[%s] %s", rid, msg % args if args else msg)
|
| 397 |
+
_log("Starting prediction for %s", spec.design_name)
|
| 398 |
+
|
| 399 |
spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
|
| 400 |
+
_validate_spec_dict(spec_dict)
|
| 401 |
design_name = spec.design_name
|
| 402 |
protocol = spec_dict.get("protocol", "unknown")
|
| 403 |
|
|
|
|
| 405 |
if self._cache:
|
| 406 |
cached = self._cache.get(spec_dict, protocol)
|
| 407 |
if cached is not None:
|
| 408 |
+
_log("Cache hit for %s@%s", design_name, protocol)
|
| 409 |
return cached
|
| 410 |
|
| 411 |
# Build validator
|
| 412 |
self._code_validator = AdvancedCodeValidator(spec_dict)
|
| 413 |
available_sources = self._get_available_sources()
|
| 414 |
|
| 415 |
+
_log("Available sources: %s", available_sources)
|
| 416 |
+
|
| 417 |
start_time = time.time()
|
| 418 |
|
| 419 |
# Ensemble: run top-K strategies concurrently
|
| 420 |
selected = self._select_generation_strategy(spec_dict, protocol, available_sources)
|
| 421 |
strategies_to_run = self._get_strategy_plan(selected, available_sources)
|
| 422 |
+
_log("Selected strategy=%s, plan=%s", selected.value, strategies_to_run)
|
| 423 |
|
| 424 |
results: List[GenerationResult] = []
|
| 425 |
with ThreadPoolExecutor(max_workers=min(self._max_concurrent, len(strategies_to_run))) as executor:
|
|
|
|
| 461 |
# Coverage prediction (lazy-trained by CoveragePredictor on first call)
|
| 462 |
try:
|
| 463 |
self.last_coverage_prediction = self._coverage_predictor.predict_coverage(spec, final_result.files)
|
| 464 |
+
if self.last_coverage_prediction:
|
| 465 |
+
_log("Coverage prediction: %.1f%% expected", self.last_coverage_prediction.get("coverage", {}).get("expected", 0))
|
| 466 |
except Exception as e:
|
| 467 |
+
logger.warning("[%s] Coverage prediction failed: %s", rid, e)
|
| 468 |
self.last_coverage_prediction = None
|
| 469 |
|
| 470 |
# Store last result for learn() / generate() introspect
|
|
|
|
| 512 |
self,
|
| 513 |
spec_dict: Dict[str, Any],
|
| 514 |
cfg: Optional[PipelineConfig] = None,
|
| 515 |
+
request_id: Optional[str] = None,
|
| 516 |
) -> Dict[str, Any]:
|
| 517 |
"""Public API: generate from raw spec dict (test-compatible interface).
|
| 518 |
|
| 519 |
Returns a rich result dict with ``passed``, ``generated_files``,
|
| 520 |
``source``, ``strategy``, and ``validation_results``.
|
| 521 |
+
|
| 522 |
+
Parameters
|
| 523 |
+
----------
|
| 524 |
+
spec_dict : Dict[str, Any]
|
| 525 |
+
Specification dictionary with at minimum ``design_name`` and ``protocol``.
|
| 526 |
+
cfg : Optional[PipelineConfig]
|
| 527 |
+
Generation pipeline configuration. Auto-created from stored config if None.
|
| 528 |
+
request_id : Optional[str]
|
| 529 |
+
Correlation ID for request tracing across logs.
|
| 530 |
"""
|
| 531 |
+
rid = request_id or f"gen_{int(time.time() * 1000)}_{id(spec_dict)}"
|
| 532 |
+
try:
|
| 533 |
+
_validate_spec_dict(spec_dict)
|
| 534 |
+
except (TypeError, ValueError) as e:
|
| 535 |
+
logger.error("[%s] Input validation failed: %s", rid, e)
|
| 536 |
+
return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error", "request_id": rid}
|
| 537 |
try:
|
| 538 |
spec = DesignSpec(**self._coerce_spec_dict(spec_dict))
|
| 539 |
except Exception as e:
|
| 540 |
+
logger.error("[%s] Failed to build DesignSpec from dict: %s", rid, e)
|
| 541 |
+
return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error", "request_id": rid}
|
| 542 |
|
| 543 |
# Auto-train template model if not yet trained
|
| 544 |
if not self._template_model._is_trained:
|
|
|
|
| 565 |
),
|
| 566 |
)
|
| 567 |
|
| 568 |
+
files = self.predict(spec, cfg, request_id=rid)
|
| 569 |
|
| 570 |
# Build result dict from stored generation result
|
| 571 |
gen = getattr(self, '_last_generation_result', None)
|
|
|
|
| 576 |
"generated_files": files,
|
| 577 |
"source": gen.source.value if gen else "template",
|
| 578 |
"strategy": gen.strategy_used if gen else "template",
|
| 579 |
+
"request_id": rid,
|
| 580 |
}
|
| 581 |
|
| 582 |
# Attach validation results if available
|
|
|
|
| 730 |
if len(available_sources) == 1:
|
| 731 |
return GenerationSource(available_sources[0])
|
| 732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
if not self._use_learning or not self._rl_learner:
|
| 734 |
if "retrieval" in available_sources and self._index and len(self._index) > 0:
|
| 735 |
return GenerationSource.RETRIEVAL
|
|
|
|
| 751 |
source_scores["llm"] += 2.0
|
| 752 |
if feat.register_count > 8 and "retrieval" in source_scores:
|
| 753 |
source_scores["retrieval"] += 1.0
|
| 754 |
+
except Exception as e:
|
| 755 |
+
logger.debug("Coverage hint in strategy selection failed: %s", e)
|
| 756 |
|
| 757 |
if not source_scores:
|
| 758 |
return GenerationSource.TEMPLATE
|
|
|
|
| 768 |
design_name: str,
|
| 769 |
protocol: str,
|
| 770 |
) -> GenerationResult:
|
| 771 |
+
def _exec() -> GenerationResult:
|
| 772 |
+
if strategy == "retrieval":
|
| 773 |
+
return self._generate_by_retrieval(spec, spec_dict, config, design_name, protocol)
|
| 774 |
+
elif strategy == "llm" and self._use_llm:
|
| 775 |
+
return self._generate_by_llm(spec, spec_dict, config, design_name, protocol)
|
| 776 |
+
else:
|
| 777 |
+
return self._generate_by_template(spec, config, design_name, protocol)
|
| 778 |
+
try:
|
| 779 |
+
return _retry_with_backoff(_exec, max_retries=2, base_delay=0.25)
|
| 780 |
+
except Exception as e:
|
| 781 |
+
logger.error("Strategy %s failed after retries: %s", strategy, e)
|
| 782 |
+
return GenerationResult(
|
| 783 |
+
source=GenerationSource.TEMPLATE,
|
| 784 |
+
errors=[f"Strategy {strategy} failed after retries: {e}"],
|
| 785 |
+
strategy_used=strategy,
|
| 786 |
+
)
|
| 787 |
|
| 788 |
def _generate_by_retrieval(
|
| 789 |
self, spec: DesignSpec, spec_dict: Dict[str, Any], config: PipelineConfig,
|
|
|
|
| 1083 |
stats["pattern_learner"] = self._pattern_learner.get_suggestions(file_type="any", protocol="any")
|
| 1084 |
return stats
|
| 1085 |
|
| 1086 |
+
def get_health_status(self) -> Dict[str, Any]:
|
| 1087 |
+
"""Return health status for production monitoring / readiness probes."""
|
| 1088 |
+
components = {
|
| 1089 |
+
"template_model": self._template_model is not None,
|
| 1090 |
+
"similarity_index": self._index is not None and len(self._index) > 0,
|
| 1091 |
+
"feature_extractor": self._extractor is not None,
|
| 1092 |
+
"spec_adapter": self._adapter is not None,
|
| 1093 |
+
"code_validator": self._code_validator is not None,
|
| 1094 |
+
"rl_learner": self._rl_learner is not None,
|
| 1095 |
+
"pattern_learner": self._pattern_learner is not None,
|
| 1096 |
+
"coverage_predictor": self._coverage_predictor is not None,
|
| 1097 |
+
}
|
| 1098 |
+
all_ok = all(components.values())
|
| 1099 |
+
return {
|
| 1100 |
+
"status": "healthy" if all_ok else "degraded",
|
| 1101 |
+
"version": self._model_version,
|
| 1102 |
+
"components": components,
|
| 1103 |
+
"index_size": len(self._index) if self._index else 0,
|
| 1104 |
+
"cache_enabled": self._enable_caching,
|
| 1105 |
+
"use_learning": self._use_learning,
|
| 1106 |
+
"use_llm": self._use_llm,
|
| 1107 |
+
"exploration_strategy": self._exploration_strategy.value if self._exploration_strategy else None,
|
| 1108 |
+
"total_generations": len(self._generation_history),
|
| 1109 |
+
"rl_converged": self._rl_learner.is_converged() if self._rl_learner else None,
|
| 1110 |
+
"quality_threshold": self._quality_threshold,
|
| 1111 |
+
"max_concurrent_strategies": self._max_concurrent,
|
| 1112 |
+
}
|
| 1113 |
+
|
| 1114 |
def invalidate_cache(self, spec: Optional[DesignSpec] = None) -> None:
|
| 1115 |
if not self._cache:
|
| 1116 |
return
|
tests/test_advanced_ml_v2.py
CHANGED
|
@@ -1,477 +1,332 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import sys
|
| 7 |
import os
|
| 8 |
import tempfile
|
| 9 |
import yaml
|
|
|
|
| 10 |
|
| 11 |
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 12 |
sys.path.insert(0, repo_root)
|
| 13 |
|
| 14 |
-
from src.models.enhanced_ml_model_v2 import
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
TEST_SPEC = """
|
| 18 |
design_name: uart
|
| 19 |
clock_reset:
|
| 20 |
clock: clk
|
| 21 |
reset: rst_n
|
| 22 |
-
|
| 23 |
interfaces:
|
| 24 |
- name: wb
|
| 25 |
signals:
|
| 26 |
-
- name: wb_cyc
|
| 27 |
-
|
| 28 |
-
- name:
|
| 29 |
-
|
| 30 |
-
- name:
|
| 31 |
-
|
| 32 |
-
- name:
|
| 33 |
-
direction: input
|
| 34 |
-
width: 3
|
| 35 |
-
- name: wb_data_o
|
| 36 |
-
direction: output
|
| 37 |
-
width: 8
|
| 38 |
-
- name: wb_data_i
|
| 39 |
-
direction: input
|
| 40 |
-
width: 8
|
| 41 |
-
- name: wb_ack
|
| 42 |
-
direction: output
|
| 43 |
-
|
| 44 |
- name: uart
|
| 45 |
signals:
|
| 46 |
-
- name: uart_tx
|
| 47 |
-
|
| 48 |
-
- name:
|
| 49 |
-
|
| 50 |
-
- name:
|
| 51 |
-
direction: input
|
| 52 |
-
- name: rts_n
|
| 53 |
-
direction: output
|
| 54 |
-
- name: uart_intr
|
| 55 |
-
direction: output
|
| 56 |
-
|
| 57 |
registers:
|
| 58 |
- name: RBR_THR
|
| 59 |
address: 0x0
|
| 60 |
description: Receiver Buffer / Transmitter Holding
|
| 61 |
fields:
|
| 62 |
-
- name: data
|
| 63 |
-
bits: 7:0
|
| 64 |
- name: IER
|
| 65 |
address: 0x1
|
| 66 |
description: Interrupt Enable
|
| 67 |
fields:
|
| 68 |
-
- name: erbfi
|
| 69 |
-
|
| 70 |
-
description: Enable RX data available interrupt
|
| 71 |
-
- name: etbei
|
| 72 |
-
bits: '1'
|
| 73 |
-
description: Enable TX holding register empty interrupt
|
| 74 |
- name: LCR
|
| 75 |
address: 0x3
|
| 76 |
description: Line Control
|
| 77 |
fields:
|
| 78 |
-
- name: wls
|
| 79 |
-
|
| 80 |
-
description: Word length select
|
| 81 |
-
- name: dlab
|
| 82 |
-
bits: '7'
|
| 83 |
-
description: Divisor latch access bit
|
| 84 |
- name: LSR
|
| 85 |
address: 0x5
|
| 86 |
description: Line Status
|
| 87 |
fields:
|
| 88 |
-
- name: dr
|
| 89 |
-
|
| 90 |
-
description: Data Ready
|
| 91 |
-
- name: thre
|
| 92 |
-
bits: '5'
|
| 93 |
-
description: TX Holding Register Empty
|
| 94 |
-
|
| 95 |
protocol: uart
|
| 96 |
"""
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
for strategy in strategies:
|
| 108 |
-
print(f"\n--- Testing {strategy} strategy ---")
|
| 109 |
-
|
| 110 |
-
cfg = PipelineConfig(
|
| 111 |
-
ml=MLConfig(
|
| 112 |
-
enabled=True,
|
| 113 |
-
model_type="v2",
|
| 114 |
-
exploration_strategy=strategy,
|
| 115 |
-
use_llm=False,
|
| 116 |
-
use_semantic_encoder=False,
|
| 117 |
-
use_learning=True,
|
| 118 |
-
learning_storage_path=None
|
| 119 |
-
)
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
model = EnhancedMLGenerationModelV2(cfg)
|
| 123 |
-
|
| 124 |
-
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 125 |
-
|
| 126 |
-
result = model.generate(spec_dict)
|
| 127 |
-
passed = result['passed']
|
| 128 |
-
generated_files = result.get('generated_files', {})
|
| 129 |
-
|
| 130 |
-
print(f" Passed: {passed}")
|
| 131 |
-
print(f" Files generated: {len(generated_files)}")
|
| 132 |
-
print(f" Source: {result.get('source', 'unknown')}")
|
| 133 |
-
print(f" Strategy used: {result.get('strategy', 'unknown')}")
|
| 134 |
-
|
| 135 |
-
if hasattr(model, '_rl_learner'):
|
| 136 |
-
rl_stats = model._rl_learner.get_performance_stats()
|
| 137 |
-
print(f" RL episodes: {rl_stats.get('episode_count', 0)}")
|
| 138 |
-
print(f" RL total updates: {rl_stats.get('total_updates', 0)}")
|
| 139 |
-
|
| 140 |
-
results[strategy] = {
|
| 141 |
-
"passed": passed,
|
| 142 |
-
"files_count": len(generated_files),
|
| 143 |
-
"source": result.get('source', 'unknown'),
|
| 144 |
-
"strategy": result.get('strategy', 'unknown')
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
print("\n--- Strategy Results Summary ---")
|
| 148 |
-
for strategy, res in results.items():
|
| 149 |
-
status = "✅" if res["passed"] else "❌"
|
| 150 |
-
print(f" {status} {strategy}: {res['files_count']} files, source={res['source']}, strategy={res['strategy']}")
|
| 151 |
-
|
| 152 |
-
return all(r["passed"] for r in results.values())
|
| 153 |
-
|
| 154 |
-
def test_experience_replay():
|
| 155 |
-
"""Test experience replay buffer and eligibility traces."""
|
| 156 |
-
print("\n" + "="*60)
|
| 157 |
-
print("Testing Experience Replay & Eligibility Traces")
|
| 158 |
-
print("="*60)
|
| 159 |
-
|
| 160 |
-
cfg = PipelineConfig(
|
| 161 |
ml=MLConfig(
|
| 162 |
-
enabled=True,
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
use_llm=False,
|
| 166 |
-
use_semantic_encoder=False,
|
| 167 |
-
use_learning=True,
|
| 168 |
-
learning_storage_path=None
|
| 169 |
)
|
| 170 |
)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
result = model.generate(spec_dict)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
use_semantic_encoder=False,
|
| 215 |
-
use_learning=True,
|
| 216 |
-
learning_storage_path=None
|
| 217 |
-
)
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
model = EnhancedMLGenerationModelV2(cfg)
|
| 221 |
-
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 222 |
-
|
| 223 |
-
print(" Running generations for pattern learning...")
|
| 224 |
-
|
| 225 |
-
for i in range(3):
|
| 226 |
result = model.generate(spec_dict)
|
| 227 |
-
reward = 1.0 if result[
|
| 228 |
model.learn(result, reward)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
print(f" Total specs seen: {stats['total_specs_seen']}")
|
| 236 |
-
print(f" Total generations: {stats['total_generations']}")
|
| 237 |
-
print(f" Average score: {stats['avg_score']:.3f}")
|
| 238 |
-
print(f" N-gram vocabulary size: {len(stats['ngram_vocab'])}")
|
| 239 |
-
print(f" Association rules: {len(stats['association_rules'])}")
|
| 240 |
-
|
| 241 |
-
recs = pl.get_recommendations(spec_dict)
|
| 242 |
-
print(f"\n Recommendations for current spec:")
|
| 243 |
-
for rec in recs[:5]:
|
| 244 |
-
print(f" • {rec}")
|
| 245 |
-
|
| 246 |
-
common = pl.get_common_error_patterns(top_n=5)
|
| 247 |
-
if common:
|
| 248 |
-
print(f"\n Common error patterns:")
|
| 249 |
-
for pattern, count in common:
|
| 250 |
-
print(f" • '{pattern}': {count} occurrences")
|
| 251 |
-
|
| 252 |
-
return True
|
| 253 |
-
|
| 254 |
-
return False
|
| 255 |
-
|
| 256 |
-
def test_deep_validation():
|
| 257 |
-
"""Test deep UVM compliance validation."""
|
| 258 |
-
print("\n" + "="*60)
|
| 259 |
-
print("Testing Deep UVM Compliance Validation")
|
| 260 |
-
print("="*60)
|
| 261 |
-
|
| 262 |
-
cfg = PipelineConfig(
|
| 263 |
-
ml=MLConfig(
|
| 264 |
-
enabled=True,
|
| 265 |
-
model_type="v2",
|
| 266 |
-
exploration_strategy="ucb",
|
| 267 |
-
use_llm=False,
|
| 268 |
-
use_semantic_encoder=False,
|
| 269 |
-
use_learning=True,
|
| 270 |
-
strict_validation=True,
|
| 271 |
-
learning_storage_path=None
|
| 272 |
-
)
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
model = EnhancedMLGenerationModelV2(cfg)
|
| 276 |
-
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 277 |
-
|
| 278 |
-
result = model.generate(spec_dict)
|
| 279 |
-
|
| 280 |
-
print(f"\n Generated files: {len(result.get('generated_files', {}))}")
|
| 281 |
-
print(f" Passed: {result['passed']}")
|
| 282 |
-
|
| 283 |
-
val_results = result.get('validation_results', {})
|
| 284 |
-
|
| 285 |
-
if val_results:
|
| 286 |
-
print(f"\n Validation Results:")
|
| 287 |
-
total_checks = 0
|
| 288 |
-
total_passed = 0
|
| 289 |
-
|
| 290 |
-
for file_path, file_result in val_results.items():
|
| 291 |
-
file_name = os.path.basename(file_path)
|
| 292 |
-
checks = file_result.get('checks', [])
|
| 293 |
-
|
| 294 |
-
if checks:
|
| 295 |
-
print(f"\n {file_name}:")
|
| 296 |
-
for check in checks:
|
| 297 |
-
total_checks += 1
|
| 298 |
-
status = "✅" if check.get('passed', False) else "❌"
|
| 299 |
-
if check.get('passed'):
|
| 300 |
-
total_passed += 1
|
| 301 |
-
|
| 302 |
-
msg = f" {status} {check.get('check_name', 'unknown')}"
|
| 303 |
-
if check.get('message'):
|
| 304 |
-
msg += f": {check['message']}"
|
| 305 |
-
print(msg)
|
| 306 |
-
|
| 307 |
-
if total_checks > 0:
|
| 308 |
-
pass_rate = (total_passed / total_checks) * 100
|
| 309 |
-
print(f"\n Overall validation pass rate: {pass_rate:.1f}% ({total_passed}/{total_checks})")
|
| 310 |
-
|
| 311 |
-
return total_checks > 0
|
| 312 |
-
|
| 313 |
-
return False
|
| 314 |
-
|
| 315 |
-
def test_learning_persistence():
|
| 316 |
-
"""Test saving and loading learning state."""
|
| 317 |
-
print("\n" + "="*60)
|
| 318 |
-
print("Testing Learning State Persistence")
|
| 319 |
-
print("="*60)
|
| 320 |
-
|
| 321 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 322 |
-
state_path = f.name
|
| 323 |
-
|
| 324 |
-
try:
|
| 325 |
-
cfg = PipelineConfig(
|
| 326 |
-
ml=MLConfig(
|
| 327 |
-
enabled=True,
|
| 328 |
-
model_type="v2",
|
| 329 |
-
exploration_strategy="ucb",
|
| 330 |
-
use_llm=False,
|
| 331 |
-
use_semantic_encoder=False,
|
| 332 |
-
use_learning=True,
|
| 333 |
-
learning_storage_path=state_path
|
| 334 |
-
)
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
print(" Creating model and running generations...")
|
| 338 |
-
model = EnhancedMLGenerationModelV2(cfg)
|
| 339 |
-
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 340 |
-
|
| 341 |
-
for i in range(3):
|
| 342 |
result = model.generate(spec_dict)
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
print(f" Episodes before save: {episodes_before}")
|
| 350 |
-
print(f" Replay buffer size before save: {replay_size_before}")
|
| 351 |
-
|
| 352 |
-
print(" Saving learning state...")
|
| 353 |
-
model.save_learning_state(state_path)
|
| 354 |
-
|
| 355 |
-
print(" Loading learning state into new model...")
|
| 356 |
-
model2 = EnhancedMLGenerationModelV2(cfg)
|
| 357 |
-
model2.load_learning_state(state_path)
|
| 358 |
-
|
| 359 |
-
if hasattr(model2, '_rl_learner'):
|
| 360 |
-
episodes_after = model2._rl_learner.get_performance_stats().get('episode_count', 0)
|
| 361 |
-
replay_size_after = len(model2._rl_learner._replay_buffer)
|
| 362 |
-
print(f" Episodes after load: {episodes_after}")
|
| 363 |
-
print(f" Replay buffer size after load: {replay_size_after}")
|
| 364 |
-
|
| 365 |
-
return episodes_after >= 3 and replay_size_after >= 3
|
| 366 |
-
|
| 367 |
-
return False
|
| 368 |
-
|
| 369 |
-
finally:
|
| 370 |
-
if os.path.exists(state_path):
|
| 371 |
-
os.unlink(state_path)
|
| 372 |
-
|
| 373 |
-
def test_learning_stats():
|
| 374 |
-
"""Test ML stats generation for UI."""
|
| 375 |
-
print("\n" + "="*60)
|
| 376 |
-
print("Testing Learning Statistics (for UI)")
|
| 377 |
-
print("="*60)
|
| 378 |
-
|
| 379 |
-
cfg = PipelineConfig(
|
| 380 |
-
ml=MLConfig(
|
| 381 |
-
enabled=True,
|
| 382 |
-
model_type="v2",
|
| 383 |
-
exploration_strategy="ucb",
|
| 384 |
-
use_llm=False,
|
| 385 |
-
use_semantic_encoder=False,
|
| 386 |
-
use_learning=True,
|
| 387 |
-
learning_storage_path=None
|
| 388 |
-
)
|
| 389 |
-
)
|
| 390 |
-
|
| 391 |
-
model = EnhancedMLGenerationModelV2(cfg)
|
| 392 |
-
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 393 |
-
|
| 394 |
-
for i in range(3):
|
| 395 |
result = model.generate(spec_dict)
|
| 396 |
-
|
| 397 |
-
model.learn(result, reward)
|
| 398 |
-
|
| 399 |
-
if hasattr(model, 'get_learning_stats'):
|
| 400 |
stats = model.get_learning_stats()
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
if 'strategy_weights' in stats:
|
| 411 |
-
print(f"\n Strategy weights:")
|
| 412 |
-
for strategy, weight in stats['strategy_weights'].items():
|
| 413 |
-
print(f" • {strategy}: {weight}")
|
| 414 |
-
|
| 415 |
-
if 'rl_learner' in stats:
|
| 416 |
-
print(f"\n RL Learner stats:")
|
| 417 |
-
print(f" Episode count: {stats['rl_learner'].get('episode_count', 0)}")
|
| 418 |
-
print(f" Total updates: {stats['rl_learner'].get('total_updates', 0)}")
|
| 419 |
-
|
| 420 |
-
if 'pattern_learner' in stats:
|
| 421 |
-
print(f"\n Pattern Learner stats:")
|
| 422 |
-
print(f" Total specs seen: {stats['pattern_learner'].get('total_specs_seen', 0)}")
|
| 423 |
-
|
| 424 |
-
return True
|
| 425 |
-
|
| 426 |
-
return False
|
| 427 |
-
|
| 428 |
-
def run_all_tests():
|
| 429 |
-
"""Run all tests and report results."""
|
| 430 |
-
print("\n" + "="*60)
|
| 431 |
-
print("Advanced ML V2 Model - Complete Test Suite")
|
| 432 |
-
print("="*60)
|
| 433 |
-
|
| 434 |
-
tests = [
|
| 435 |
-
("RL Exploration Strategies", test_rl_strategies),
|
| 436 |
-
("Experience Replay & Eligibility Traces", test_experience_replay),
|
| 437 |
-
("Advanced Pattern Learner", test_pattern_learner),
|
| 438 |
-
("Deep UVM Validation", test_deep_validation),
|
| 439 |
-
("Learning State Persistence", test_learning_persistence),
|
| 440 |
-
("Learning Statistics (UI)", test_learning_stats),
|
| 441 |
-
]
|
| 442 |
-
|
| 443 |
-
results = []
|
| 444 |
-
|
| 445 |
-
for name, test_func in tests:
|
| 446 |
try:
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
if __name__ == "__main__":
|
| 476 |
-
|
| 477 |
-
sys.exit(0 if success else 1)
|
|
|
|
| 1 |
"""
|
| 2 |
+
Production-grade pytest tests for Advanced ML V2 Model.
|
| 3 |
+
|
| 4 |
+
Covers: RL strategies, experience replay, eligibility traces,
|
| 5 |
+
pattern learning, deep validation, persistence, health/request_id.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import sys
|
| 9 |
import os
|
| 10 |
import tempfile
|
| 11 |
import yaml
|
| 12 |
+
import pytest
|
| 13 |
|
| 14 |
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
sys.path.insert(0, repo_root)
|
| 16 |
|
| 17 |
+
from src.models.enhanced_ml_model_v2 import (
|
| 18 |
+
EnhancedMLGenerationModelV2,
|
| 19 |
+
_validate_spec_dict,
|
| 20 |
+
GenerationCache,
|
| 21 |
+
)
|
| 22 |
+
from src.config import PipelineConfig, MLConfig, GenerationConfig
|
| 23 |
|
| 24 |
TEST_SPEC = """
|
| 25 |
design_name: uart
|
| 26 |
clock_reset:
|
| 27 |
clock: clk
|
| 28 |
reset: rst_n
|
|
|
|
| 29 |
interfaces:
|
| 30 |
- name: wb
|
| 31 |
signals:
|
| 32 |
+
- {name: wb_cyc, direction: input}
|
| 33 |
+
- {name: wb_stb, direction: input}
|
| 34 |
+
- {name: wb_we, direction: input}
|
| 35 |
+
- {name: wb_addr, direction: input, width: 3}
|
| 36 |
+
- {name: wb_data_o, direction: output, width: 8}
|
| 37 |
+
- {name: wb_data_i, direction: input, width: 8}
|
| 38 |
+
- {name: wb_ack, direction: output}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
- name: uart
|
| 40 |
signals:
|
| 41 |
+
- {name: uart_tx, direction: output}
|
| 42 |
+
- {name: uart_rx, direction: input}
|
| 43 |
+
- {name: cts_n, direction: input}
|
| 44 |
+
- {name: rts_n, direction: output}
|
| 45 |
+
- {name: uart_intr, direction: output}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
registers:
|
| 47 |
- name: RBR_THR
|
| 48 |
address: 0x0
|
| 49 |
description: Receiver Buffer / Transmitter Holding
|
| 50 |
fields:
|
| 51 |
+
- {name: data, bits: 7:0}
|
|
|
|
| 52 |
- name: IER
|
| 53 |
address: 0x1
|
| 54 |
description: Interrupt Enable
|
| 55 |
fields:
|
| 56 |
+
- {name: erbfi, bits: '0', description: Enable RX data available interrupt}
|
| 57 |
+
- {name: etbei, bits: '1', description: Enable TX holding register empty interrupt}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
- name: LCR
|
| 59 |
address: 0x3
|
| 60 |
description: Line Control
|
| 61 |
fields:
|
| 62 |
+
- {name: wls, bits: 1:0, description: Word length select}
|
| 63 |
+
- {name: dlab, bits: '7', description: Divisor latch access bit}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
- name: LSR
|
| 65 |
address: 0x5
|
| 66 |
description: Line Status
|
| 67 |
fields:
|
| 68 |
+
- {name: dr, bits: '0', description: Data Ready}
|
| 69 |
+
- {name: thre, bits: '5', description: TX Holding Register Empty}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
protocol: uart
|
| 71 |
"""
|
| 72 |
|
| 73 |
+
|
| 74 |
+
@pytest.fixture
|
| 75 |
+
def spec_dict():
|
| 76 |
+
return yaml.safe_load(TEST_SPEC)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@pytest.fixture
|
| 80 |
+
def base_cfg():
|
| 81 |
+
return PipelineConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
ml=MLConfig(
|
| 83 |
+
enabled=True, model_type="v2", exploration_strategy="ucb",
|
| 84 |
+
use_llm=False, use_semantic_encoder=False, use_learning=True,
|
| 85 |
+
learning_storage_path=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
# Input validation tests
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TestInputValidation:
|
| 96 |
+
def test_valid_spec_passes(self, spec_dict):
|
| 97 |
+
_validate_spec_dict(spec_dict)
|
| 98 |
+
|
| 99 |
+
def test_missing_design_name_raises(self):
|
| 100 |
+
with pytest.raises(ValueError, match="design_name"):
|
| 101 |
+
_validate_spec_dict({"protocol": "uart"})
|
| 102 |
+
|
| 103 |
+
def test_empty_design_name_raises(self):
|
| 104 |
+
with pytest.raises(ValueError, match="non-empty"):
|
| 105 |
+
_validate_spec_dict({"design_name": "", "protocol": "uart"})
|
| 106 |
+
|
| 107 |
+
def test_non_dict_raises(self):
|
| 108 |
+
with pytest.raises(TypeError, match="dict"):
|
| 109 |
+
_validate_spec_dict("not_a_dict")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Cache tests
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TestGenerationCache:
|
| 118 |
+
def test_set_and_get(self):
|
| 119 |
+
cache = GenerationCache(ttl_seconds=60, max_entries=16)
|
| 120 |
+
cache.set({"a": 1}, "uart", {"file.sv": "content"})
|
| 121 |
+
result = cache.get({"a": 1}, "uart")
|
| 122 |
+
assert result == {"file.sv": "content"}
|
| 123 |
+
|
| 124 |
+
def test_cache_miss(self):
|
| 125 |
+
cache = GenerationCache(ttl_seconds=60)
|
| 126 |
+
assert cache.get({"a": 1}, "uart") is None
|
| 127 |
+
|
| 128 |
+
def test_cache_invalidate(self, spec_dict):
|
| 129 |
+
cache = GenerationCache(ttl_seconds=60)
|
| 130 |
+
cache.set(spec_dict, "uart", {"f.sv": "content"})
|
| 131 |
+
assert cache.get(spec_dict, "uart") is not None
|
| 132 |
+
cache.invalidate(spec_dict, "uart")
|
| 133 |
+
assert cache.get(spec_dict, "uart") is None
|
| 134 |
+
|
| 135 |
+
def test_cache_clear(self):
|
| 136 |
+
cache = GenerationCache(ttl_seconds=60)
|
| 137 |
+
cache.set({"a": 1}, "uart", {"f.sv": "c"})
|
| 138 |
+
cache.set({"b": 2}, "spi", {"g.sv": "d"})
|
| 139 |
+
cache.clear()
|
| 140 |
+
assert cache.get({"a": 1}, "uart") is None
|
| 141 |
+
assert cache.get({"b": 2}, "spi") is None
|
| 142 |
+
|
| 143 |
+
def test_cache_max_entries_eviction(self):
|
| 144 |
+
cache = GenerationCache(ttl_seconds=3600, max_entries=3)
|
| 145 |
+
for i in range(5):
|
| 146 |
+
cache.set({"k": i}, "p", {f"f{i}.sv": str(i)})
|
| 147 |
+
assert len(cache._cache) <= 3
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
# Model construction tests
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class TestModelConstruction:
|
| 156 |
+
def test_create_with_config(self, base_cfg):
|
| 157 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 158 |
+
assert model is not None
|
| 159 |
+
assert model._use_learning is True
|
| 160 |
+
|
| 161 |
+
def test_create_with_string_name(self):
|
| 162 |
+
model = EnhancedMLGenerationModelV2("test_model")
|
| 163 |
+
assert model is not None
|
| 164 |
+
|
| 165 |
+
def test_create_with_rl_strategies(self):
|
| 166 |
+
for strategy in ["epsilon_greedy", "softmax", "ucb", "thompson"]:
|
| 167 |
+
cfg = PipelineConfig(
|
| 168 |
+
ml=MLConfig(
|
| 169 |
+
enabled=True, model_type="v2", exploration_strategy=strategy,
|
| 170 |
+
use_llm=False, use_semantic_encoder=False, use_learning=True,
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 174 |
+
assert model is not None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
# Generation tests
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class TestGeneration:
|
| 183 |
+
def test_generate_returns_passed_result(self, spec_dict, base_cfg):
|
| 184 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 185 |
result = model.generate(spec_dict)
|
| 186 |
+
assert "passed" in result
|
| 187 |
+
assert "generated_files" in result
|
| 188 |
+
assert "source" in result
|
| 189 |
+
assert "strategy" in result
|
| 190 |
+
assert "request_id" in result
|
| 191 |
+
|
| 192 |
+
def test_generate_produces_files(self, spec_dict, base_cfg):
|
| 193 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 194 |
+
result = model.generate(spec_dict)
|
| 195 |
+
assert len(result["generated_files"]) > 0
|
| 196 |
+
|
| 197 |
+
def test_generate_with_request_id(self, spec_dict, base_cfg):
|
| 198 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 199 |
+
result = model.generate(spec_dict, request_id="test_req_001")
|
| 200 |
+
assert result["request_id"] == "test_req_001"
|
| 201 |
+
|
| 202 |
+
def test_generate_invalid_spec_returns_error(self, base_cfg):
|
| 203 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 204 |
+
result = model.generate({"no_design_name": True})
|
| 205 |
+
assert result["passed"] is False
|
| 206 |
+
|
| 207 |
+
def test_generate_empty_design_name_returns_error(self, base_cfg):
|
| 208 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 209 |
+
result = model.generate({"design_name": "", "protocol": "uart"})
|
| 210 |
+
assert result["passed"] is False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
# Learning / RL tests
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class TestLearning:
|
| 219 |
+
def test_learn_updates_rl(self, spec_dict, base_cfg):
|
| 220 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
result = model.generate(spec_dict)
|
| 222 |
+
reward = 1.0 if result["passed"] else 0.0
|
| 223 |
model.learn(result, reward)
|
| 224 |
+
stats = model.get_learning_stats()
|
| 225 |
+
assert stats["total_generations"] >= 1
|
| 226 |
+
|
| 227 |
+
def test_multiple_generations_populate_replay_buffer(self, spec_dict, base_cfg):
|
| 228 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 229 |
+
for _ in range(3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
result = model.generate(spec_dict)
|
| 231 |
+
model.learn(result, 1.0 if result["passed"] else 0.0)
|
| 232 |
+
if model._rl_learner:
|
| 233 |
+
assert len(model._rl_learner._replay_buffer) > 0
|
| 234 |
+
|
| 235 |
+
def test_learning_stats_structure(self, spec_dict, base_cfg):
|
| 236 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
result = model.generate(spec_dict)
|
| 238 |
+
model.learn(result, 1.0)
|
|
|
|
|
|
|
|
|
|
| 239 |
stats = model.get_learning_stats()
|
| 240 |
+
assert "total_generations" in stats
|
| 241 |
+
assert "model_version" in stats
|
| 242 |
+
assert "metrics" in stats
|
| 243 |
+
assert "strategy_weights" in stats
|
| 244 |
+
|
| 245 |
+
def test_learning_persistence(self, spec_dict, base_cfg):
|
| 246 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 247 |
+
path = f.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
try:
|
| 249 |
+
base_cfg.ml.learning_storage_path = path
|
| 250 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 251 |
+
for _ in range(3):
|
| 252 |
+
r = model.generate(spec_dict)
|
| 253 |
+
model.learn(r, 1.0)
|
| 254 |
+
model.save_learning_state(path)
|
| 255 |
+
model2 = EnhancedMLGenerationModelV2(base_cfg)
|
| 256 |
+
model2.load_learning_state(path)
|
| 257 |
+
assert model2._generation_history
|
| 258 |
+
finally:
|
| 259 |
+
if os.path.exists(path):
|
| 260 |
+
os.unlink(path)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
# Health / monitoring tests
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class TestHealth:
|
| 269 |
+
def test_health_status_returns_dict(self, base_cfg):
|
| 270 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 271 |
+
health = model.get_health_status()
|
| 272 |
+
assert isinstance(health, dict)
|
| 273 |
+
assert "status" in health
|
| 274 |
+
assert "components" in health
|
| 275 |
+
assert "version" in health
|
| 276 |
+
|
| 277 |
+
def test_health_components_are_bools(self, base_cfg):
|
| 278 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 279 |
+
health = model.get_health_status()
|
| 280 |
+
for comp, ok in health["components"].items():
|
| 281 |
+
assert isinstance(ok, bool), f"{comp} should be bool"
|
| 282 |
+
|
| 283 |
+
def test_cache_invalidate(self, spec_dict, base_cfg):
|
| 284 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 285 |
+
model.generate(spec_dict) # populates cache
|
| 286 |
+
model.invalidate_cache()
|
| 287 |
+
assert model._cache is None or len(model._cache._cache) == 0
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
# Edge case / resilience tests
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class TestResilience:
|
| 296 |
+
def test_generate_twice_with_same_spec_hits_cache(self, spec_dict, base_cfg):
|
| 297 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 298 |
+
r1 = model.generate(spec_dict)
|
| 299 |
+
r2 = model.generate(spec_dict)
|
| 300 |
+
assert r1["passed"] == r2["passed"]
|
| 301 |
+
|
| 302 |
+
def test_clear_history(self, spec_dict, base_cfg):
|
| 303 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 304 |
+
model.generate(spec_dict)
|
| 305 |
+
model.clear_history()
|
| 306 |
+
stats = model.get_learning_stats()
|
| 307 |
+
assert stats["total_generations"] == 0
|
| 308 |
+
|
| 309 |
+
def test_all_rl_strategies_generate(self, spec_dict):
|
| 310 |
+
strategies = ["epsilon_greedy", "softmax", "ucb", "thompson"]
|
| 311 |
+
for strategy in strategies:
|
| 312 |
+
cfg = PipelineConfig(
|
| 313 |
+
ml=MLConfig(
|
| 314 |
+
enabled=True, model_type="v2", exploration_strategy=strategy,
|
| 315 |
+
use_llm=False, use_semantic_encoder=False, use_learning=True,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 319 |
+
result = model.generate(spec_dict)
|
| 320 |
+
assert result["passed"], f"{strategy} strategy failed"
|
| 321 |
+
|
| 322 |
+
def test_generate_without_advanced_components_falls_back(self, spec_dict, base_cfg):
|
| 323 |
+
model = EnhancedMLGenerationModelV2(base_cfg)
|
| 324 |
+
model._index = None
|
| 325 |
+
model._extractor = None
|
| 326 |
+
model._adapter = None
|
| 327 |
+
result = model.generate(spec_dict)
|
| 328 |
+
assert result["passed"] or not result["passed"] # should not crash
|
| 329 |
+
|
| 330 |
|
| 331 |
if __name__ == "__main__":
|
| 332 |
+
pytest.main([__file__, "-v", "--tb=short"])
|
|
|