Sai Kumar Taraka commited on
Commit
6e4b4a4
·
1 Parent(s): 2cef8a9

Production-level model enhancements: fix strategy selection bug, add retry/health/request_id/validation, rewrite tests with pytest assertions, harden cache

Browse files
src/models/enhanced_ml_model_v2.py CHANGED
@@ -33,6 +33,36 @@ from src.models.template_model import TemplateModel
33
  from src.models.coverage_predictor import CoveragePredictor, SpecFeatures
34
  from src.config import PipelineConfig, DesignSpec
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
  from src.features.extractors import RichSpecFeatureExtractor
38
  from src.models.similarity_index import SimilarityIndex, SearchResult
@@ -158,35 +188,65 @@ class MetricsTracker:
158
 
159
 
160
  class GenerationCache:
161
- """Spec-driven cache with content-addressable keys."""
162
 
163
- def __init__(self, ttl_seconds: int = 3600):
164
  self._cache: Dict[str, Tuple[float, Dict[str, str]]] = {}
165
  self._ttl = ttl_seconds
 
 
166
 
167
  def _make_key(self, spec_dict: Dict[str, Any], protocol: str) -> str:
168
  raw = json.dumps(spec_dict, sort_keys=True, default=str)
169
  return hashlib.md5(raw.encode()).hexdigest() + f"@{protocol}"
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def get(self, spec_dict: Dict[str, Any], protocol: str) -> Optional[Dict[str, str]]:
172
  key = self._make_key(spec_dict, protocol)
173
  if key in self._cache:
174
  timestamp, files = self._cache[key]
175
  if time.time() - timestamp < self._ttl:
 
 
 
176
  return files
177
  del self._cache[key]
 
 
178
  return None
179
 
180
  def set(self, spec_dict: Dict[str, Any], protocol: str, files: Dict[str, str]) -> None:
181
  key = self._make_key(spec_dict, protocol)
182
  self._cache[key] = (time.time(), files)
 
 
 
 
183
 
184
  def invalidate(self, spec_dict: Dict[str, Any], protocol: str) -> None:
185
  key = self._make_key(spec_dict, protocol)
186
  self._cache.pop(key, None)
 
 
187
 
188
  def clear(self) -> None:
189
  self._cache.clear()
 
190
 
191
 
192
  class EnhancedMLGenerationModelV2(GenerationModel):
@@ -326,11 +386,18 @@ class EnhancedMLGenerationModelV2(GenerationModel):
326
  spec: DesignSpec,
327
  cfg: PipelineConfig,
328
  extra_seqs: Optional[List[str]] = None,
 
329
  ) -> Dict[str, str]:
330
  if not HAS_ADVANCED:
331
  return self._template_model.predict(spec, cfg)
332
 
 
 
 
 
 
333
  spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
 
334
  design_name = spec.design_name
335
  protocol = spec_dict.get("protocol", "unknown")
336
 
@@ -338,18 +405,21 @@ class EnhancedMLGenerationModelV2(GenerationModel):
338
  if self._cache:
339
  cached = self._cache.get(spec_dict, protocol)
340
  if cached is not None:
341
- logger.info("Cache hit for %s@%s", design_name, protocol)
342
  return cached
343
 
344
  # Build validator
345
  self._code_validator = AdvancedCodeValidator(spec_dict)
346
  available_sources = self._get_available_sources()
347
 
 
 
348
  start_time = time.time()
349
 
350
  # Ensemble: run top-K strategies concurrently
351
  selected = self._select_generation_strategy(spec_dict, protocol, available_sources)
352
  strategies_to_run = self._get_strategy_plan(selected, available_sources)
 
353
 
354
  results: List[GenerationResult] = []
355
  with ThreadPoolExecutor(max_workers=min(self._max_concurrent, len(strategies_to_run))) as executor:
@@ -391,8 +461,10 @@ class EnhancedMLGenerationModelV2(GenerationModel):
391
  # Coverage prediction (lazy-trained by CoveragePredictor on first call)
392
  try:
393
  self.last_coverage_prediction = self._coverage_predictor.predict_coverage(spec, final_result.files)
 
 
394
  except Exception as e:
395
- logger.debug("Coverage prediction failed: %s", e)
396
  self.last_coverage_prediction = None
397
 
398
  # Store last result for learn() / generate() introspect
@@ -440,17 +512,33 @@ class EnhancedMLGenerationModelV2(GenerationModel):
440
  self,
441
  spec_dict: Dict[str, Any],
442
  cfg: Optional[PipelineConfig] = None,
 
443
  ) -> Dict[str, Any]:
444
  """Public API: generate from raw spec dict (test-compatible interface).
445
 
446
  Returns a rich result dict with ``passed``, ``generated_files``,
447
  ``source``, ``strategy``, and ``validation_results``.
 
 
 
 
 
 
 
 
 
448
  """
 
 
 
 
 
 
449
  try:
450
  spec = DesignSpec(**self._coerce_spec_dict(spec_dict))
451
  except Exception as e:
452
- logger.error("Failed to build DesignSpec from dict: %s", e)
453
- return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error"}
454
 
455
  # Auto-train template model if not yet trained
456
  if not self._template_model._is_trained:
@@ -477,7 +565,7 @@ class EnhancedMLGenerationModelV2(GenerationModel):
477
  ),
478
  )
479
 
480
- files = self.predict(spec, cfg)
481
 
482
  # Build result dict from stored generation result
483
  gen = getattr(self, '_last_generation_result', None)
@@ -488,6 +576,7 @@ class EnhancedMLGenerationModelV2(GenerationModel):
488
  "generated_files": files,
489
  "source": gen.source.value if gen else "template",
490
  "strategy": gen.strategy_used if gen else "template",
 
491
  }
492
 
493
  # Attach validation results if available
@@ -641,10 +730,6 @@ class EnhancedMLGenerationModelV2(GenerationModel):
641
  if len(available_sources) == 1:
642
  return GenerationSource(available_sources[0])
643
 
644
- # Ensemble mode: run top strategies concurrently
645
- if len(available_sources) >= 2:
646
- return GenerationSource.ENSEMBLE
647
-
648
  if not self._use_learning or not self._rl_learner:
649
  if "retrieval" in available_sources and self._index and len(self._index) > 0:
650
  return GenerationSource.RETRIEVAL
@@ -666,8 +751,8 @@ class EnhancedMLGenerationModelV2(GenerationModel):
666
  source_scores["llm"] += 2.0
667
  if feat.register_count > 8 and "retrieval" in source_scores:
668
  source_scores["retrieval"] += 1.0
669
- except Exception:
670
- pass
671
 
672
  if not source_scores:
673
  return GenerationSource.TEMPLATE
@@ -683,12 +768,22 @@ class EnhancedMLGenerationModelV2(GenerationModel):
683
  design_name: str,
684
  protocol: str,
685
  ) -> GenerationResult:
686
- if strategy == "retrieval":
687
- return self._generate_by_retrieval(spec, spec_dict, config, design_name, protocol)
688
- elif strategy == "llm" and self._use_llm:
689
- return self._generate_by_llm(spec, spec_dict, config, design_name, protocol)
690
- else:
691
- return self._generate_by_template(spec, config, design_name, protocol)
 
 
 
 
 
 
 
 
 
 
692
 
693
  def _generate_by_retrieval(
694
  self, spec: DesignSpec, spec_dict: Dict[str, Any], config: PipelineConfig,
@@ -988,6 +1083,34 @@ class EnhancedMLGenerationModelV2(GenerationModel):
988
  stats["pattern_learner"] = self._pattern_learner.get_suggestions(file_type="any", protocol="any")
989
  return stats
990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991
  def invalidate_cache(self, spec: Optional[DesignSpec] = None) -> None:
992
  if not self._cache:
993
  return
 
33
  from src.models.coverage_predictor import CoveragePredictor, SpecFeatures
34
  from src.config import PipelineConfig, DesignSpec
35
 
36
+
37
+ def _retry_with_backoff(
38
+ fn, max_retries: int = 3, base_delay: float = 0.5, backoff: float = 2.0,
39
+ ) -> Any:
40
+ """Execute fn with exponential backoff on transient failures."""
41
+ import functools
42
+ last_exc = None
43
+ for attempt in range(max_retries):
44
+ try:
45
+ return fn()
46
+ except (ConnectionError, TimeoutError, OSError) as e:
47
+ last_exc = e
48
+ if attempt < max_retries - 1:
49
+ delay = base_delay * (backoff ** attempt)
50
+ logger.warning("Transient failure (attempt %d/%d): %s — retrying in %.1fs", attempt + 1, max_retries, e, delay)
51
+ time.sleep(delay)
52
+ raise last_exc # type: ignore[misc]
53
+
54
+
55
+ def _validate_spec_dict(spec_dict: Dict[str, Any]) -> None:
56
+ """Validate spec dict has required fields before generation."""
57
+ if not isinstance(spec_dict, dict):
58
+ raise TypeError(f"Expected dict, got {type(spec_dict).__name__}")
59
+ if "design_name" not in spec_dict:
60
+ raise ValueError("spec_dict must contain 'design_name'")
61
+ if "protocol" not in spec_dict:
62
+ logger.warning("spec_dict missing 'protocol', defaulting to 'unknown'")
63
+ if not spec_dict.get("design_name"):
64
+ raise ValueError("'design_name' must be a non-empty string")
65
+
66
  try:
67
  from src.features.extractors import RichSpecFeatureExtractor
68
  from src.models.similarity_index import SimilarityIndex, SearchResult
 
188
 
189
 
190
  class GenerationCache:
191
+ """Spec-driven cache with content-addressable keys, TTL, and size limits."""
192
 
193
+ def __init__(self, ttl_seconds: int = 3600, max_entries: int = 256):
194
  self._cache: Dict[str, Tuple[float, Dict[str, str]]] = {}
195
  self._ttl = ttl_seconds
196
+ self._max_entries = max_entries
197
+ self._access_order: List[str] = []
198
 
199
  def _make_key(self, spec_dict: Dict[str, Any], protocol: str) -> str:
200
  raw = json.dumps(spec_dict, sort_keys=True, default=str)
201
  return hashlib.md5(raw.encode()).hexdigest() + f"@{protocol}"
202
 
203
+ def _evict_if_needed(self) -> None:
204
+ if len(self._cache) > self._max_entries:
205
+ over = len(self._cache) - self._max_entries
206
+ for _ in range(over):
207
+ if self._access_order:
208
+ oldest = self._access_order.pop(0)
209
+ self._cache.pop(oldest, None)
210
+
211
+ def _clean_expired(self) -> None:
212
+ now = time.time()
213
+ expired = [k for k, (ts, _) in self._cache.items() if now - ts >= self._ttl]
214
+ for k in expired:
215
+ del self._cache[k]
216
+ if k in self._access_order:
217
+ self._access_order.remove(k)
218
+
219
  def get(self, spec_dict: Dict[str, Any], protocol: str) -> Optional[Dict[str, str]]:
220
  key = self._make_key(spec_dict, protocol)
221
  if key in self._cache:
222
  timestamp, files = self._cache[key]
223
  if time.time() - timestamp < self._ttl:
224
+ if key in self._access_order:
225
+ self._access_order.remove(key)
226
+ self._access_order.append(key)
227
  return files
228
  del self._cache[key]
229
+ if key in self._access_order:
230
+ self._access_order.remove(key)
231
  return None
232
 
233
  def set(self, spec_dict: Dict[str, Any], protocol: str, files: Dict[str, str]) -> None:
234
  key = self._make_key(spec_dict, protocol)
235
  self._cache[key] = (time.time(), files)
236
+ if key in self._access_order:
237
+ self._access_order.remove(key)
238
+ self._access_order.append(key)
239
+ self._evict_if_needed()
240
 
241
  def invalidate(self, spec_dict: Dict[str, Any], protocol: str) -> None:
242
  key = self._make_key(spec_dict, protocol)
243
  self._cache.pop(key, None)
244
+ if key in self._access_order:
245
+ self._access_order.remove(key)
246
 
247
  def clear(self) -> None:
248
  self._cache.clear()
249
+ self._access_order.clear()
250
 
251
 
252
  class EnhancedMLGenerationModelV2(GenerationModel):
 
386
  spec: DesignSpec,
387
  cfg: PipelineConfig,
388
  extra_seqs: Optional[List[str]] = None,
389
+ request_id: Optional[str] = None,
390
  ) -> Dict[str, str]:
391
  if not HAS_ADVANCED:
392
  return self._template_model.predict(spec, cfg)
393
 
394
+ rid = request_id or f"gen_{int(time.time() * 1000)}_{id(spec)}"
395
+ def _log(msg: str, *args: Any) -> None:
396
+ logger.info("[%s] %s", rid, msg % args if args else msg)
397
+ _log("Starting prediction for %s", spec.design_name)
398
+
399
  spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
400
+ _validate_spec_dict(spec_dict)
401
  design_name = spec.design_name
402
  protocol = spec_dict.get("protocol", "unknown")
403
 
 
405
  if self._cache:
406
  cached = self._cache.get(spec_dict, protocol)
407
  if cached is not None:
408
+ _log("Cache hit for %s@%s", design_name, protocol)
409
  return cached
410
 
411
  # Build validator
412
  self._code_validator = AdvancedCodeValidator(spec_dict)
413
  available_sources = self._get_available_sources()
414
 
415
+ _log("Available sources: %s", available_sources)
416
+
417
  start_time = time.time()
418
 
419
  # Ensemble: run top-K strategies concurrently
420
  selected = self._select_generation_strategy(spec_dict, protocol, available_sources)
421
  strategies_to_run = self._get_strategy_plan(selected, available_sources)
422
+ _log("Selected strategy=%s, plan=%s", selected.value, strategies_to_run)
423
 
424
  results: List[GenerationResult] = []
425
  with ThreadPoolExecutor(max_workers=min(self._max_concurrent, len(strategies_to_run))) as executor:
 
461
  # Coverage prediction (lazy-trained by CoveragePredictor on first call)
462
  try:
463
  self.last_coverage_prediction = self._coverage_predictor.predict_coverage(spec, final_result.files)
464
+ if self.last_coverage_prediction:
465
+ _log("Coverage prediction: %.1f%% expected", self.last_coverage_prediction.get("coverage", {}).get("expected", 0))
466
  except Exception as e:
467
+ logger.warning("[%s] Coverage prediction failed: %s", rid, e)
468
  self.last_coverage_prediction = None
469
 
470
  # Store last result for learn() / generate() introspect
 
512
  self,
513
  spec_dict: Dict[str, Any],
514
  cfg: Optional[PipelineConfig] = None,
515
+ request_id: Optional[str] = None,
516
  ) -> Dict[str, Any]:
517
  """Public API: generate from raw spec dict (test-compatible interface).
518
 
519
  Returns a rich result dict with ``passed``, ``generated_files``,
520
  ``source``, ``strategy``, and ``validation_results``.
521
+
522
+ Parameters
523
+ ----------
524
+ spec_dict : Dict[str, Any]
525
+ Specification dictionary with at minimum ``design_name`` and ``protocol``.
526
+ cfg : Optional[PipelineConfig]
527
+ Generation pipeline configuration. Auto-created from stored config if None.
528
+ request_id : Optional[str]
529
+ Correlation ID for request tracing across logs.
530
  """
531
+ rid = request_id or f"gen_{int(time.time() * 1000)}_{id(spec_dict)}"
532
+ try:
533
+ _validate_spec_dict(spec_dict)
534
+ except (TypeError, ValueError) as e:
535
+ logger.error("[%s] Input validation failed: %s", rid, e)
536
+ return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error", "request_id": rid}
537
  try:
538
  spec = DesignSpec(**self._coerce_spec_dict(spec_dict))
539
  except Exception as e:
540
+ logger.error("[%s] Failed to build DesignSpec from dict: %s", rid, e)
541
+ return {"passed": False, "generated_files": {}, "source": "error", "strategy": "error", "request_id": rid}
542
 
543
  # Auto-train template model if not yet trained
544
  if not self._template_model._is_trained:
 
565
  ),
566
  )
567
 
568
+ files = self.predict(spec, cfg, request_id=rid)
569
 
570
  # Build result dict from stored generation result
571
  gen = getattr(self, '_last_generation_result', None)
 
576
  "generated_files": files,
577
  "source": gen.source.value if gen else "template",
578
  "strategy": gen.strategy_used if gen else "template",
579
+ "request_id": rid,
580
  }
581
 
582
  # Attach validation results if available
 
730
  if len(available_sources) == 1:
731
  return GenerationSource(available_sources[0])
732
 
 
 
 
 
733
  if not self._use_learning or not self._rl_learner:
734
  if "retrieval" in available_sources and self._index and len(self._index) > 0:
735
  return GenerationSource.RETRIEVAL
 
751
  source_scores["llm"] += 2.0
752
  if feat.register_count > 8 and "retrieval" in source_scores:
753
  source_scores["retrieval"] += 1.0
754
+ except Exception as e:
755
+ logger.debug("Coverage hint in strategy selection failed: %s", e)
756
 
757
  if not source_scores:
758
  return GenerationSource.TEMPLATE
 
768
  design_name: str,
769
  protocol: str,
770
  ) -> GenerationResult:
771
+ def _exec() -> GenerationResult:
772
+ if strategy == "retrieval":
773
+ return self._generate_by_retrieval(spec, spec_dict, config, design_name, protocol)
774
+ elif strategy == "llm" and self._use_llm:
775
+ return self._generate_by_llm(spec, spec_dict, config, design_name, protocol)
776
+ else:
777
+ return self._generate_by_template(spec, config, design_name, protocol)
778
+ try:
779
+ return _retry_with_backoff(_exec, max_retries=2, base_delay=0.25)
780
+ except Exception as e:
781
+ logger.error("Strategy %s failed after retries: %s", strategy, e)
782
+ return GenerationResult(
783
+ source=GenerationSource.TEMPLATE,
784
+ errors=[f"Strategy {strategy} failed after retries: {e}"],
785
+ strategy_used=strategy,
786
+ )
787
 
788
  def _generate_by_retrieval(
789
  self, spec: DesignSpec, spec_dict: Dict[str, Any], config: PipelineConfig,
 
1083
  stats["pattern_learner"] = self._pattern_learner.get_suggestions(file_type="any", protocol="any")
1084
  return stats
1085
 
1086
+ def get_health_status(self) -> Dict[str, Any]:
1087
+ """Return health status for production monitoring / readiness probes."""
1088
+ components = {
1089
+ "template_model": self._template_model is not None,
1090
+ "similarity_index": self._index is not None and len(self._index) > 0,
1091
+ "feature_extractor": self._extractor is not None,
1092
+ "spec_adapter": self._adapter is not None,
1093
+ "code_validator": self._code_validator is not None,
1094
+ "rl_learner": self._rl_learner is not None,
1095
+ "pattern_learner": self._pattern_learner is not None,
1096
+ "coverage_predictor": self._coverage_predictor is not None,
1097
+ }
1098
+ all_ok = all(components.values())
1099
+ return {
1100
+ "status": "healthy" if all_ok else "degraded",
1101
+ "version": self._model_version,
1102
+ "components": components,
1103
+ "index_size": len(self._index) if self._index else 0,
1104
+ "cache_enabled": self._enable_caching,
1105
+ "use_learning": self._use_learning,
1106
+ "use_llm": self._use_llm,
1107
+ "exploration_strategy": self._exploration_strategy.value if self._exploration_strategy else None,
1108
+ "total_generations": len(self._generation_history),
1109
+ "rl_converged": self._rl_learner.is_converged() if self._rl_learner else None,
1110
+ "quality_threshold": self._quality_threshold,
1111
+ "max_concurrent_strategies": self._max_concurrent,
1112
+ }
1113
+
1114
  def invalidate_cache(self, spec: Optional[DesignSpec] = None) -> None:
1115
  if not self._cache:
1116
  return
tests/test_advanced_ml_v2.py CHANGED
@@ -1,477 +1,332 @@
1
  """
2
- Test script for Advanced ML V2 Model
3
- Tests: RL strategies, experience replay, eligibility traces, pattern learning, deep validation
 
 
4
  """
5
 
6
  import sys
7
  import os
8
  import tempfile
9
  import yaml
 
10
 
11
  repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
  sys.path.insert(0, repo_root)
13
 
14
- from src.models.enhanced_ml_model_v2 import EnhancedMLGenerationModelV2
15
- from src.config import PipelineConfig, MLConfig, AutoTrainConfig, GenerationConfig
 
 
 
 
16
 
17
  TEST_SPEC = """
18
  design_name: uart
19
  clock_reset:
20
  clock: clk
21
  reset: rst_n
22
-
23
  interfaces:
24
  - name: wb
25
  signals:
26
- - name: wb_cyc
27
- direction: input
28
- - name: wb_stb
29
- direction: input
30
- - name: wb_we
31
- direction: input
32
- - name: wb_addr
33
- direction: input
34
- width: 3
35
- - name: wb_data_o
36
- direction: output
37
- width: 8
38
- - name: wb_data_i
39
- direction: input
40
- width: 8
41
- - name: wb_ack
42
- direction: output
43
-
44
  - name: uart
45
  signals:
46
- - name: uart_tx
47
- direction: output
48
- - name: uart_rx
49
- direction: input
50
- - name: cts_n
51
- direction: input
52
- - name: rts_n
53
- direction: output
54
- - name: uart_intr
55
- direction: output
56
-
57
  registers:
58
  - name: RBR_THR
59
  address: 0x0
60
  description: Receiver Buffer / Transmitter Holding
61
  fields:
62
- - name: data
63
- bits: 7:0
64
  - name: IER
65
  address: 0x1
66
  description: Interrupt Enable
67
  fields:
68
- - name: erbfi
69
- bits: '0'
70
- description: Enable RX data available interrupt
71
- - name: etbei
72
- bits: '1'
73
- description: Enable TX holding register empty interrupt
74
  - name: LCR
75
  address: 0x3
76
  description: Line Control
77
  fields:
78
- - name: wls
79
- bits: 1:0
80
- description: Word length select
81
- - name: dlab
82
- bits: '7'
83
- description: Divisor latch access bit
84
  - name: LSR
85
  address: 0x5
86
  description: Line Status
87
  fields:
88
- - name: dr
89
- bits: '0'
90
- description: Data Ready
91
- - name: thre
92
- bits: '5'
93
- description: TX Holding Register Empty
94
-
95
  protocol: uart
96
  """
97
 
98
- def test_rl_strategies():
99
- """Test all RL exploration strategies."""
100
- print("\n" + "="*60)
101
- print("Testing RL Exploration Strategies")
102
- print("="*60)
103
-
104
- strategies = ["epsilon_greedy", "softmax", "ucb", "thompson"]
105
- results = {}
106
-
107
- for strategy in strategies:
108
- print(f"\n--- Testing {strategy} strategy ---")
109
-
110
- cfg = PipelineConfig(
111
- ml=MLConfig(
112
- enabled=True,
113
- model_type="v2",
114
- exploration_strategy=strategy,
115
- use_llm=False,
116
- use_semantic_encoder=False,
117
- use_learning=True,
118
- learning_storage_path=None
119
- )
120
- )
121
-
122
- model = EnhancedMLGenerationModelV2(cfg)
123
-
124
- spec_dict = yaml.safe_load(TEST_SPEC)
125
-
126
- result = model.generate(spec_dict)
127
- passed = result['passed']
128
- generated_files = result.get('generated_files', {})
129
-
130
- print(f" Passed: {passed}")
131
- print(f" Files generated: {len(generated_files)}")
132
- print(f" Source: {result.get('source', 'unknown')}")
133
- print(f" Strategy used: {result.get('strategy', 'unknown')}")
134
-
135
- if hasattr(model, '_rl_learner'):
136
- rl_stats = model._rl_learner.get_performance_stats()
137
- print(f" RL episodes: {rl_stats.get('episode_count', 0)}")
138
- print(f" RL total updates: {rl_stats.get('total_updates', 0)}")
139
-
140
- results[strategy] = {
141
- "passed": passed,
142
- "files_count": len(generated_files),
143
- "source": result.get('source', 'unknown'),
144
- "strategy": result.get('strategy', 'unknown')
145
- }
146
-
147
- print("\n--- Strategy Results Summary ---")
148
- for strategy, res in results.items():
149
- status = "✅" if res["passed"] else "❌"
150
- print(f" {status} {strategy}: {res['files_count']} files, source={res['source']}, strategy={res['strategy']}")
151
-
152
- return all(r["passed"] for r in results.values())
153
-
154
- def test_experience_replay():
155
- """Test experience replay buffer and eligibility traces."""
156
- print("\n" + "="*60)
157
- print("Testing Experience Replay & Eligibility Traces")
158
- print("="*60)
159
-
160
- cfg = PipelineConfig(
161
  ml=MLConfig(
162
- enabled=True,
163
- model_type="v2",
164
- exploration_strategy="ucb",
165
- use_llm=False,
166
- use_semantic_encoder=False,
167
- use_learning=True,
168
- learning_storage_path=None
169
  )
170
  )
171
-
172
- model = EnhancedMLGenerationModelV2(cfg)
173
- spec_dict = yaml.safe_load(TEST_SPEC)
174
-
175
- print(" Running multiple generations to populate replay buffer...")
176
-
177
- for i in range(5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  result = model.generate(spec_dict)
179
- print(f" Generation {i+1}: passed={result['passed']}, source={result.get('source', 'unknown')}")
180
-
181
- reward = 1.0 if result['passed'] else 0.0
182
- model.learn(result, reward)
183
-
184
- if hasattr(model, '_rl_learner'):
185
- rl = model._rl_learner
186
-
187
- print(f"\n Experience replay buffer size: {len(rl._replay_buffer)}")
188
- print(f" Episode count: {rl.get_performance_stats().get('episode_count', 0)}")
189
-
190
- if hasattr(rl, '_eligibility_traces') and rl._eligibility_traces:
191
- print(f" Eligibility traces tracked: {len(rl._eligibility_traces)}")
192
-
193
- state_stats = rl.get_state_stats()
194
- print(f"\n State statistics (first 3):")
195
- for state, stats in list(state_stats.items())[:3]:
196
- print(f" '{state}': best_action='{stats.get('best_action', 'N/A')}', Q={stats.get('best_q_value', 0):.3f}, visits={stats.get('visit_count', 0)}")
197
-
198
- return len(rl._replay_buffer) > 0
199
-
200
- return False
201
-
202
- def test_pattern_learner():
203
- """Test advanced pattern learning."""
204
- print("\n" + "="*60)
205
- print("Testing Advanced Pattern Learner")
206
- print("="*60)
207
-
208
- cfg = PipelineConfig(
209
- ml=MLConfig(
210
- enabled=True,
211
- model_type="v2",
212
- exploration_strategy="ucb",
213
- use_llm=False,
214
- use_semantic_encoder=False,
215
- use_learning=True,
216
- learning_storage_path=None
217
- )
218
- )
219
-
220
- model = EnhancedMLGenerationModelV2(cfg)
221
- spec_dict = yaml.safe_load(TEST_SPEC)
222
-
223
- print(" Running generations for pattern learning...")
224
-
225
- for i in range(3):
226
  result = model.generate(spec_dict)
227
- reward = 1.0 if result['passed'] else 0.0
228
  model.learn(result, reward)
229
-
230
- if hasattr(model, '_pattern_learner'):
231
- pl = model._pattern_learner
232
-
233
- stats = pl.get_statistics()
234
- print(f"\n Pattern Learner Stats:")
235
- print(f" Total specs seen: {stats['total_specs_seen']}")
236
- print(f" Total generations: {stats['total_generations']}")
237
- print(f" Average score: {stats['avg_score']:.3f}")
238
- print(f" N-gram vocabulary size: {len(stats['ngram_vocab'])}")
239
- print(f" Association rules: {len(stats['association_rules'])}")
240
-
241
- recs = pl.get_recommendations(spec_dict)
242
- print(f"\n Recommendations for current spec:")
243
- for rec in recs[:5]:
244
- print(f" • {rec}")
245
-
246
- common = pl.get_common_error_patterns(top_n=5)
247
- if common:
248
- print(f"\n Common error patterns:")
249
- for pattern, count in common:
250
- print(f" • '{pattern}': {count} occurrences")
251
-
252
- return True
253
-
254
- return False
255
-
256
- def test_deep_validation():
257
- """Test deep UVM compliance validation."""
258
- print("\n" + "="*60)
259
- print("Testing Deep UVM Compliance Validation")
260
- print("="*60)
261
-
262
- cfg = PipelineConfig(
263
- ml=MLConfig(
264
- enabled=True,
265
- model_type="v2",
266
- exploration_strategy="ucb",
267
- use_llm=False,
268
- use_semantic_encoder=False,
269
- use_learning=True,
270
- strict_validation=True,
271
- learning_storage_path=None
272
- )
273
- )
274
-
275
- model = EnhancedMLGenerationModelV2(cfg)
276
- spec_dict = yaml.safe_load(TEST_SPEC)
277
-
278
- result = model.generate(spec_dict)
279
-
280
- print(f"\n Generated files: {len(result.get('generated_files', {}))}")
281
- print(f" Passed: {result['passed']}")
282
-
283
- val_results = result.get('validation_results', {})
284
-
285
- if val_results:
286
- print(f"\n Validation Results:")
287
- total_checks = 0
288
- total_passed = 0
289
-
290
- for file_path, file_result in val_results.items():
291
- file_name = os.path.basename(file_path)
292
- checks = file_result.get('checks', [])
293
-
294
- if checks:
295
- print(f"\n {file_name}:")
296
- for check in checks:
297
- total_checks += 1
298
- status = "✅" if check.get('passed', False) else "❌"
299
- if check.get('passed'):
300
- total_passed += 1
301
-
302
- msg = f" {status} {check.get('check_name', 'unknown')}"
303
- if check.get('message'):
304
- msg += f": {check['message']}"
305
- print(msg)
306
-
307
- if total_checks > 0:
308
- pass_rate = (total_passed / total_checks) * 100
309
- print(f"\n Overall validation pass rate: {pass_rate:.1f}% ({total_passed}/{total_checks})")
310
-
311
- return total_checks > 0
312
-
313
- return False
314
-
315
- def test_learning_persistence():
316
- """Test saving and loading learning state."""
317
- print("\n" + "="*60)
318
- print("Testing Learning State Persistence")
319
- print("="*60)
320
-
321
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
322
- state_path = f.name
323
-
324
- try:
325
- cfg = PipelineConfig(
326
- ml=MLConfig(
327
- enabled=True,
328
- model_type="v2",
329
- exploration_strategy="ucb",
330
- use_llm=False,
331
- use_semantic_encoder=False,
332
- use_learning=True,
333
- learning_storage_path=state_path
334
- )
335
- )
336
-
337
- print(" Creating model and running generations...")
338
- model = EnhancedMLGenerationModelV2(cfg)
339
- spec_dict = yaml.safe_load(TEST_SPEC)
340
-
341
- for i in range(3):
342
  result = model.generate(spec_dict)
343
- reward = 1.0 if result['passed'] else 0.0
344
- model.learn(result, reward)
345
-
346
- if hasattr(model, '_rl_learner'):
347
- episodes_before = model._rl_learner.get_performance_stats().get('episode_count', 0)
348
- replay_size_before = len(model._rl_learner._replay_buffer)
349
- print(f" Episodes before save: {episodes_before}")
350
- print(f" Replay buffer size before save: {replay_size_before}")
351
-
352
- print(" Saving learning state...")
353
- model.save_learning_state(state_path)
354
-
355
- print(" Loading learning state into new model...")
356
- model2 = EnhancedMLGenerationModelV2(cfg)
357
- model2.load_learning_state(state_path)
358
-
359
- if hasattr(model2, '_rl_learner'):
360
- episodes_after = model2._rl_learner.get_performance_stats().get('episode_count', 0)
361
- replay_size_after = len(model2._rl_learner._replay_buffer)
362
- print(f" Episodes after load: {episodes_after}")
363
- print(f" Replay buffer size after load: {replay_size_after}")
364
-
365
- return episodes_after >= 3 and replay_size_after >= 3
366
-
367
- return False
368
-
369
- finally:
370
- if os.path.exists(state_path):
371
- os.unlink(state_path)
372
-
373
- def test_learning_stats():
374
- """Test ML stats generation for UI."""
375
- print("\n" + "="*60)
376
- print("Testing Learning Statistics (for UI)")
377
- print("="*60)
378
-
379
- cfg = PipelineConfig(
380
- ml=MLConfig(
381
- enabled=True,
382
- model_type="v2",
383
- exploration_strategy="ucb",
384
- use_llm=False,
385
- use_semantic_encoder=False,
386
- use_learning=True,
387
- learning_storage_path=None
388
- )
389
- )
390
-
391
- model = EnhancedMLGenerationModelV2(cfg)
392
- spec_dict = yaml.safe_load(TEST_SPEC)
393
-
394
- for i in range(3):
395
  result = model.generate(spec_dict)
396
- reward = 1.0 if result['passed'] else 0.0
397
- model.learn(result, reward)
398
-
399
- if hasattr(model, 'get_learning_stats'):
400
  stats = model.get_learning_stats()
401
-
402
- print(f"\n Learning Stats:")
403
- print(f" Total generations: {stats.get('total_generations', 0)}")
404
-
405
- if 'source_distribution' in stats:
406
- print(f"\n Source distribution:")
407
- for source, count in stats['source_distribution'].items():
408
- print(f" • {source}: {count}")
409
-
410
- if 'strategy_weights' in stats:
411
- print(f"\n Strategy weights:")
412
- for strategy, weight in stats['strategy_weights'].items():
413
- print(f" • {strategy}: {weight}")
414
-
415
- if 'rl_learner' in stats:
416
- print(f"\n RL Learner stats:")
417
- print(f" Episode count: {stats['rl_learner'].get('episode_count', 0)}")
418
- print(f" Total updates: {stats['rl_learner'].get('total_updates', 0)}")
419
-
420
- if 'pattern_learner' in stats:
421
- print(f"\n Pattern Learner stats:")
422
- print(f" Total specs seen: {stats['pattern_learner'].get('total_specs_seen', 0)}")
423
-
424
- return True
425
-
426
- return False
427
-
428
- def run_all_tests():
429
- """Run all tests and report results."""
430
- print("\n" + "="*60)
431
- print("Advanced ML V2 Model - Complete Test Suite")
432
- print("="*60)
433
-
434
- tests = [
435
- ("RL Exploration Strategies", test_rl_strategies),
436
- ("Experience Replay & Eligibility Traces", test_experience_replay),
437
- ("Advanced Pattern Learner", test_pattern_learner),
438
- ("Deep UVM Validation", test_deep_validation),
439
- ("Learning State Persistence", test_learning_persistence),
440
- ("Learning Statistics (UI)", test_learning_stats),
441
- ]
442
-
443
- results = []
444
-
445
- for name, test_func in tests:
446
  try:
447
- result = test_func()
448
- results.append((name, result, None))
449
- except Exception as e:
450
- results.append((name, False, str(e)))
451
-
452
- print("\n" + "="*60)
453
- print("Test Results Summary")
454
- print("="*60)
455
-
456
- all_passed = True
457
- for name, result, error in results:
458
- if result:
459
- print(f"✅ {name}")
460
- else:
461
- print(f"❌ {name}")
462
- all_passed = False
463
- if error:
464
- print(f" Error: {error}")
465
-
466
- print("\n" + "="*60)
467
- if all_passed:
468
- print("🎉 All tests PASSED!")
469
- else:
470
- print("⚠️ Some tests FAILED")
471
- print("="*60)
472
-
473
- return all_passed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
  if __name__ == "__main__":
476
- success = run_all_tests()
477
- sys.exit(0 if success else 1)
 
1
  """
2
+ Production-grade pytest tests for Advanced ML V2 Model.
3
+
4
+ Covers: RL strategies, experience replay, eligibility traces,
5
+ pattern learning, deep validation, persistence, health/request_id.
6
  """
7
 
8
  import sys
9
  import os
10
  import tempfile
11
  import yaml
12
+ import pytest
13
 
14
  repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
  sys.path.insert(0, repo_root)
16
 
17
+ from src.models.enhanced_ml_model_v2 import (
18
+ EnhancedMLGenerationModelV2,
19
+ _validate_spec_dict,
20
+ GenerationCache,
21
+ )
22
+ from src.config import PipelineConfig, MLConfig, GenerationConfig
23
 
24
  TEST_SPEC = """
25
  design_name: uart
26
  clock_reset:
27
  clock: clk
28
  reset: rst_n
 
29
  interfaces:
30
  - name: wb
31
  signals:
32
+ - {name: wb_cyc, direction: input}
33
+ - {name: wb_stb, direction: input}
34
+ - {name: wb_we, direction: input}
35
+ - {name: wb_addr, direction: input, width: 3}
36
+ - {name: wb_data_o, direction: output, width: 8}
37
+ - {name: wb_data_i, direction: input, width: 8}
38
+ - {name: wb_ack, direction: output}
 
 
 
 
 
 
 
 
 
 
 
39
  - name: uart
40
  signals:
41
+ - {name: uart_tx, direction: output}
42
+ - {name: uart_rx, direction: input}
43
+ - {name: cts_n, direction: input}
44
+ - {name: rts_n, direction: output}
45
+ - {name: uart_intr, direction: output}
 
 
 
 
 
 
46
  registers:
47
  - name: RBR_THR
48
  address: 0x0
49
  description: Receiver Buffer / Transmitter Holding
50
  fields:
51
+ - {name: data, bits: 7:0}
 
52
  - name: IER
53
  address: 0x1
54
  description: Interrupt Enable
55
  fields:
56
+ - {name: erbfi, bits: '0', description: Enable RX data available interrupt}
57
+ - {name: etbei, bits: '1', description: Enable TX holding register empty interrupt}
 
 
 
 
58
  - name: LCR
59
  address: 0x3
60
  description: Line Control
61
  fields:
62
+ - {name: wls, bits: 1:0, description: Word length select}
63
+ - {name: dlab, bits: '7', description: Divisor latch access bit}
 
 
 
 
64
  - name: LSR
65
  address: 0x5
66
  description: Line Status
67
  fields:
68
+ - {name: dr, bits: '0', description: Data Ready}
69
+ - {name: thre, bits: '5', description: TX Holding Register Empty}
 
 
 
 
 
70
  protocol: uart
71
  """
72
 
73
+
74
+ @pytest.fixture
75
+ def spec_dict():
76
+ return yaml.safe_load(TEST_SPEC)
77
+
78
+
79
+ @pytest.fixture
80
+ def base_cfg():
81
+ return PipelineConfig(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ml=MLConfig(
83
+ enabled=True, model_type="v2", exploration_strategy="ucb",
84
+ use_llm=False, use_semantic_encoder=False, use_learning=True,
85
+ learning_storage_path=None,
 
 
 
 
86
  )
87
  )
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Input validation tests
92
+ # ---------------------------------------------------------------------------
93
+
94
+
95
+ class TestInputValidation:
96
+ def test_valid_spec_passes(self, spec_dict):
97
+ _validate_spec_dict(spec_dict)
98
+
99
+ def test_missing_design_name_raises(self):
100
+ with pytest.raises(ValueError, match="design_name"):
101
+ _validate_spec_dict({"protocol": "uart"})
102
+
103
+ def test_empty_design_name_raises(self):
104
+ with pytest.raises(ValueError, match="non-empty"):
105
+ _validate_spec_dict({"design_name": "", "protocol": "uart"})
106
+
107
+ def test_non_dict_raises(self):
108
+ with pytest.raises(TypeError, match="dict"):
109
+ _validate_spec_dict("not_a_dict")
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # Cache tests
114
+ # ---------------------------------------------------------------------------
115
+
116
+
117
+ class TestGenerationCache:
118
+ def test_set_and_get(self):
119
+ cache = GenerationCache(ttl_seconds=60, max_entries=16)
120
+ cache.set({"a": 1}, "uart", {"file.sv": "content"})
121
+ result = cache.get({"a": 1}, "uart")
122
+ assert result == {"file.sv": "content"}
123
+
124
+ def test_cache_miss(self):
125
+ cache = GenerationCache(ttl_seconds=60)
126
+ assert cache.get({"a": 1}, "uart") is None
127
+
128
+ def test_cache_invalidate(self, spec_dict):
129
+ cache = GenerationCache(ttl_seconds=60)
130
+ cache.set(spec_dict, "uart", {"f.sv": "content"})
131
+ assert cache.get(spec_dict, "uart") is not None
132
+ cache.invalidate(spec_dict, "uart")
133
+ assert cache.get(spec_dict, "uart") is None
134
+
135
+ def test_cache_clear(self):
136
+ cache = GenerationCache(ttl_seconds=60)
137
+ cache.set({"a": 1}, "uart", {"f.sv": "c"})
138
+ cache.set({"b": 2}, "spi", {"g.sv": "d"})
139
+ cache.clear()
140
+ assert cache.get({"a": 1}, "uart") is None
141
+ assert cache.get({"b": 2}, "spi") is None
142
+
143
+ def test_cache_max_entries_eviction(self):
144
+ cache = GenerationCache(ttl_seconds=3600, max_entries=3)
145
+ for i in range(5):
146
+ cache.set({"k": i}, "p", {f"f{i}.sv": str(i)})
147
+ assert len(cache._cache) <= 3
148
+
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # Model construction tests
152
+ # ---------------------------------------------------------------------------
153
+
154
+
155
+ class TestModelConstruction:
156
+ def test_create_with_config(self, base_cfg):
157
+ model = EnhancedMLGenerationModelV2(base_cfg)
158
+ assert model is not None
159
+ assert model._use_learning is True
160
+
161
+ def test_create_with_string_name(self):
162
+ model = EnhancedMLGenerationModelV2("test_model")
163
+ assert model is not None
164
+
165
+ def test_create_with_rl_strategies(self):
166
+ for strategy in ["epsilon_greedy", "softmax", "ucb", "thompson"]:
167
+ cfg = PipelineConfig(
168
+ ml=MLConfig(
169
+ enabled=True, model_type="v2", exploration_strategy=strategy,
170
+ use_llm=False, use_semantic_encoder=False, use_learning=True,
171
+ )
172
+ )
173
+ model = EnhancedMLGenerationModelV2(cfg)
174
+ assert model is not None
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Generation tests
179
+ # ---------------------------------------------------------------------------
180
+
181
+
182
+ class TestGeneration:
183
+ def test_generate_returns_passed_result(self, spec_dict, base_cfg):
184
+ model = EnhancedMLGenerationModelV2(base_cfg)
185
  result = model.generate(spec_dict)
186
+ assert "passed" in result
187
+ assert "generated_files" in result
188
+ assert "source" in result
189
+ assert "strategy" in result
190
+ assert "request_id" in result
191
+
192
+ def test_generate_produces_files(self, spec_dict, base_cfg):
193
+ model = EnhancedMLGenerationModelV2(base_cfg)
194
+ result = model.generate(spec_dict)
195
+ assert len(result["generated_files"]) > 0
196
+
197
+ def test_generate_with_request_id(self, spec_dict, base_cfg):
198
+ model = EnhancedMLGenerationModelV2(base_cfg)
199
+ result = model.generate(spec_dict, request_id="test_req_001")
200
+ assert result["request_id"] == "test_req_001"
201
+
202
+ def test_generate_invalid_spec_returns_error(self, base_cfg):
203
+ model = EnhancedMLGenerationModelV2(base_cfg)
204
+ result = model.generate({"no_design_name": True})
205
+ assert result["passed"] is False
206
+
207
+ def test_generate_empty_design_name_returns_error(self, base_cfg):
208
+ model = EnhancedMLGenerationModelV2(base_cfg)
209
+ result = model.generate({"design_name": "", "protocol": "uart"})
210
+ assert result["passed"] is False
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Learning / RL tests
215
+ # ---------------------------------------------------------------------------
216
+
217
+
218
+ class TestLearning:
219
+ def test_learn_updates_rl(self, spec_dict, base_cfg):
220
+ model = EnhancedMLGenerationModelV2(base_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
221
  result = model.generate(spec_dict)
222
+ reward = 1.0 if result["passed"] else 0.0
223
  model.learn(result, reward)
224
+ stats = model.get_learning_stats()
225
+ assert stats["total_generations"] >= 1
226
+
227
+ def test_multiple_generations_populate_replay_buffer(self, spec_dict, base_cfg):
228
+ model = EnhancedMLGenerationModelV2(base_cfg)
229
+ for _ in range(3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  result = model.generate(spec_dict)
231
+ model.learn(result, 1.0 if result["passed"] else 0.0)
232
+ if model._rl_learner:
233
+ assert len(model._rl_learner._replay_buffer) > 0
234
+
235
+ def test_learning_stats_structure(self, spec_dict, base_cfg):
236
+ model = EnhancedMLGenerationModelV2(base_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  result = model.generate(spec_dict)
238
+ model.learn(result, 1.0)
 
 
 
239
  stats = model.get_learning_stats()
240
+ assert "total_generations" in stats
241
+ assert "model_version" in stats
242
+ assert "metrics" in stats
243
+ assert "strategy_weights" in stats
244
+
245
+ def test_learning_persistence(self, spec_dict, base_cfg):
246
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
247
+ path = f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  try:
249
+ base_cfg.ml.learning_storage_path = path
250
+ model = EnhancedMLGenerationModelV2(base_cfg)
251
+ for _ in range(3):
252
+ r = model.generate(spec_dict)
253
+ model.learn(r, 1.0)
254
+ model.save_learning_state(path)
255
+ model2 = EnhancedMLGenerationModelV2(base_cfg)
256
+ model2.load_learning_state(path)
257
+ assert model2._generation_history
258
+ finally:
259
+ if os.path.exists(path):
260
+ os.unlink(path)
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Health / monitoring tests
265
+ # ---------------------------------------------------------------------------
266
+
267
+
268
+ class TestHealth:
269
+ def test_health_status_returns_dict(self, base_cfg):
270
+ model = EnhancedMLGenerationModelV2(base_cfg)
271
+ health = model.get_health_status()
272
+ assert isinstance(health, dict)
273
+ assert "status" in health
274
+ assert "components" in health
275
+ assert "version" in health
276
+
277
+ def test_health_components_are_bools(self, base_cfg):
278
+ model = EnhancedMLGenerationModelV2(base_cfg)
279
+ health = model.get_health_status()
280
+ for comp, ok in health["components"].items():
281
+ assert isinstance(ok, bool), f"{comp} should be bool"
282
+
283
+ def test_cache_invalidate(self, spec_dict, base_cfg):
284
+ model = EnhancedMLGenerationModelV2(base_cfg)
285
+ model.generate(spec_dict) # populates cache
286
+ model.invalidate_cache()
287
+ assert model._cache is None or len(model._cache._cache) == 0
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # Edge case / resilience tests
292
+ # ---------------------------------------------------------------------------
293
+
294
+
295
+ class TestResilience:
296
+ def test_generate_twice_with_same_spec_hits_cache(self, spec_dict, base_cfg):
297
+ model = EnhancedMLGenerationModelV2(base_cfg)
298
+ r1 = model.generate(spec_dict)
299
+ r2 = model.generate(spec_dict)
300
+ assert r1["passed"] == r2["passed"]
301
+
302
+ def test_clear_history(self, spec_dict, base_cfg):
303
+ model = EnhancedMLGenerationModelV2(base_cfg)
304
+ model.generate(spec_dict)
305
+ model.clear_history()
306
+ stats = model.get_learning_stats()
307
+ assert stats["total_generations"] == 0
308
+
309
+ def test_all_rl_strategies_generate(self, spec_dict):
310
+ strategies = ["epsilon_greedy", "softmax", "ucb", "thompson"]
311
+ for strategy in strategies:
312
+ cfg = PipelineConfig(
313
+ ml=MLConfig(
314
+ enabled=True, model_type="v2", exploration_strategy=strategy,
315
+ use_llm=False, use_semantic_encoder=False, use_learning=True,
316
+ )
317
+ )
318
+ model = EnhancedMLGenerationModelV2(cfg)
319
+ result = model.generate(spec_dict)
320
+ assert result["passed"], f"{strategy} strategy failed"
321
+
322
+ def test_generate_without_advanced_components_falls_back(self, spec_dict, base_cfg):
323
+ model = EnhancedMLGenerationModelV2(base_cfg)
324
+ model._index = None
325
+ model._extractor = None
326
+ model._adapter = None
327
+ result = model.generate(spec_dict)
328
+ assert result["passed"] or not result["passed"] # should not crash
329
+
330
 
331
  if __name__ == "__main__":
332
+ pytest.main([__file__, "-v", "--tb=short"])