Sai Kumar Taraka commited on
Commit
9e8e9e2
·
1 Parent(s): a9127d4

feat: Add actual AI/ML capabilities with LLM, semantic embeddings, and reinforcement learning

Browse files

- Add semantic_encoder.py: CodeBERT-based semantic code embeddings with fallback TF-IDF
- Add llm_generator.py: LLM-based code generation (CodeGen, CodeT5, StarCoder, etc.)
- Add learning_module.py: Reinforcement learning + pattern learning from validation feedback
- Update enhanced_ml_model.py:
- Add last_retrieval property
- Add learning module strategy selection
- Add semantic similarity enhancement
- Add LLM generation as strategy option
- Record validation feedback to learning module
- Update config.py: Add MLConfig options for LLM, semantic encoder, learning module
- Update pipeline.py: Pass new ML config options
- Update requirements.txt: Add torch, transformers, sentence-transformers, accelerate
- Recreate ml_generation_model.py with MLModelConfig, NameNormalizer, RetrievalInfo

Key AI/ML capabilities:
1. LLM code generation with few-shot UVM examples
2. Semantic code embeddings for intelligent similarity
3. Reinforcement learning (Q-learning) from validation feedback
4. Pattern learning from success/failure patterns
5. Auto-improving generation strategy selection
6. Graceful fallback when torch/transformers not available

requirements.txt CHANGED
@@ -7,3 +7,9 @@ gunicorn>=23.0
7
 
8
  numpy>=1.21.0
9
  scikit-learn>=1.0.0
 
 
 
 
 
 
 
7
 
8
  numpy>=1.21.0
9
  scikit-learn>=1.0.0
10
+ scipy>=1.7.0
11
+
12
+ torch>=2.0.0
13
+ transformers>=4.35.0
14
+ sentence-transformers>=2.2.0
15
+ accelerate>=0.24.0
src/config.py CHANGED
@@ -86,15 +86,30 @@ class AutoTrainConfig(BaseModel):
86
 
87
 
88
  class MLConfig(BaseModel):
89
- """Configuration for ML-augmented generation."""
90
  enabled: bool = False
91
- model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid)$")
92
  similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
93
  auto_learn: bool = True
94
  index_path: Optional[str] = None
95
  top_k_retrieval: int = Field(default=3, ge=1, le=10)
96
  fallback_to_templates: bool = True
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  class PipelineConfig(BaseModel):
100
  generation: GenerationConfig = GenerationConfig()
 
86
 
87
 
88
  class MLConfig(BaseModel):
89
+ """Configuration for AI/ML-augmented generation with actual learning capabilities."""
90
  enabled: bool = False
91
+ model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid|llm|semantic)$")
92
  similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
93
  auto_learn: bool = True
94
  index_path: Optional[str] = None
95
  top_k_retrieval: int = Field(default=3, ge=1, le=10)
96
  fallback_to_templates: bool = True
97
 
98
+ use_llm: bool = True
99
+ llm_model_name: Optional[str] = None
100
+ llm_max_tokens: int = Field(default=1024, ge=64, le=4096)
101
+ llm_temperature: float = Field(default=0.2, ge=0.0, le=1.0)
102
+ llm_use_few_shot: bool = True
103
+
104
+ use_semantic_encoder: bool = True
105
+ semantic_model_name: str = "microsoft/codebert-base"
106
+
107
+ use_learning: bool = True
108
+ learning_storage_path: Optional[str] = None
109
+ learning_rate: float = Field(default=0.1, ge=0.001, le=1.0)
110
+ reinforcement_discount: float = Field(default=0.9, ge=0.0, le=1.0)
111
+ exploration_epsilon: float = Field(default=0.05, ge=0.0, le=0.5)
112
+
113
 
114
  class PipelineConfig(BaseModel):
115
  generation: GenerationConfig = GenerationConfig()
src/models/enhanced_ml_model.py CHANGED
@@ -1,18 +1,20 @@
1
  """
2
- Industry-level enhanced ML generation model with:
3
- - Multi-strategy retrieval
 
 
 
4
  - Spec-aware adaptation
5
  - Code validation
6
  - Multi-level fallback
7
  - Comprehensive reporting
8
 
9
- This model ensures output quality through:
10
- 1. Protocol-first retrieval
11
- 2. Coverage-aware selection
12
- 3. Full adaptation with signal/register mapping
13
- 4. Pre-validation before writing
14
- 5. Automatic fallback to templates if issues found
15
- 6. Detailed generation reports
16
  """
17
 
18
  from __future__ import annotations
@@ -43,6 +45,16 @@ from src.models.spec_adapter import (
43
  from src.models.similarity_index import SimilarityIndex, get_global_index
44
  from src.models.template_model import TemplateModel
45
 
 
 
 
 
 
 
 
 
 
 
46
  logger = logging.getLogger("uvmgen")
47
 
48
 
@@ -50,9 +62,12 @@ class GenerationSource(Enum):
50
  RETRIEVAL_HIGH_CONF = "retrieval_high_confidence"
51
  RETRIEVAL_MEDIUM_CONF = "retrieval_medium_confidence"
52
  RETRIEVAL_LOW_CONF = "retrieval_low_confidence"
 
 
53
  TEMPLATE_FALLBACK = "template_fallback"
54
  BLENDED = "blended"
55
  HYBRID = "hybrid"
 
56
 
57
 
58
  @dataclass
@@ -117,15 +132,21 @@ class RetrievalCandidate:
117
 
118
  class EnhancedMLGenerationModel(GenerationModel):
119
  """
120
- Industry-level enhanced ML generation model.
121
-
122
- Key features:
123
- 1. Multi-strategy retrieval (protocol-first, then similarity)
124
- 2. Spec-aware adaptation with signal/register mapping
125
- 3. Pre-validation before output
126
- 4. Multi-level fallback strategies
127
- 5. Comprehensive reporting and audit trail
128
- 6. Coverage-aware candidate selection
 
 
 
 
 
 
129
  """
130
 
131
  def __init__(
@@ -135,6 +156,11 @@ class EnhancedMLGenerationModel(GenerationModel):
135
  index: Optional[SimilarityIndex] = None,
136
  templates_dir: Optional[str] = None,
137
  strict_validation: bool = True,
 
 
 
 
 
138
  ):
139
  super().__init__(name)
140
  self.config = config or MLModelConfig()
@@ -144,6 +170,27 @@ class EnhancedMLGenerationModel(GenerationModel):
144
  self._strict_validation = strict_validation
145
  self._metadata: Dict[str, Any] = {}
146
  self._last_result: Optional[GenerationResult] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  @property
149
  def index(self) -> SimilarityIndex:
@@ -163,6 +210,23 @@ class EnhancedMLGenerationModel(GenerationModel):
163
  )
164
  return self._template_model
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def train(self, specs: List[DesignSpec]) -> Dict[str, Any]:
167
  """Train the model by adding specs to the similarity index."""
168
  from src.features.extractors import RichSpecFeatureExtractor
@@ -226,20 +290,21 @@ class EnhancedMLGenerationModel(GenerationModel):
226
  extra_seqs: Optional[List[str]] = None,
227
  ) -> Dict[str, str]:
228
  """
229
- Generate testbench with full validation and fallback.
230
-
231
- Workflow:
232
- 1. Extract rich features
233
- 2. Search for similar specs
234
- 3. For each candidate:
235
- - Create adaptation plan
236
- - Pre-validate
237
- - Score
238
- 4. Select best candidate or fallback
239
- 5. Adapt best candidate
240
- 6. Validate output
241
- 7. If validation fails, fallback to templates
242
- 8. If auto_learn, add to index
 
243
  """
244
  if not self._is_trained:
245
  self.train([])
@@ -249,12 +314,41 @@ class EnhancedMLGenerationModel(GenerationModel):
249
  query_fv = extractor.extract(spec)
250
  query_dict = self._spec_to_dict(spec)
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  similar = self.index.search(
253
  query_fv,
254
  top_k=self.config.top_k_retrieval,
255
  min_similarity=0.3,
256
  )
257
 
 
 
 
 
 
258
  logger.info(
259
  "Enhanced ML generation: found %d similar specs, best score: %.3f",
260
  len(similar), similar[0].similarity if similar else 0.0
@@ -262,21 +356,29 @@ class EnhancedMLGenerationModel(GenerationModel):
262
 
263
  result: Optional[GenerationResult] = None
264
 
265
- if similar and similar[0].similarity >= self.config.similarity_threshold:
 
 
 
 
266
  result = self._try_retrieval_generation(
267
  similar, query_fv, query_dict, spec, cfg
268
  )
269
 
 
 
 
 
270
  if (
271
  result is None
272
  or (self._strict_validation and not result.passed)
273
  and self.config.fallback_to_templates
274
  ):
275
  if result is None:
276
- logger.info("No valid retrieval candidate, falling back to templates")
277
  else:
278
  logger.warning(
279
- "Retrieval-based generation failed validation (errors: %d), falling back to templates",
280
  result.validation_report.total_errors if result.validation_report else 0
281
  )
282
  result = self._generate_with_fallback(spec, cfg, extra_seqs, result)
@@ -284,6 +386,15 @@ class EnhancedMLGenerationModel(GenerationModel):
284
  if result is None:
285
  raise RuntimeError("All generation strategies failed")
286
 
 
 
 
 
 
 
 
 
 
287
  if self.config.auto_learn and result.passed:
288
  self._learn_from_result(result, query_fv, query_dict)
289
 
@@ -292,6 +403,186 @@ class EnhancedMLGenerationModel(GenerationModel):
292
 
293
  return result.generated_files
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def _try_retrieval_generation(
296
  self,
297
  similar: List[Any],
 
1
  """
2
+ Industry-level AI/ML generation model with:
3
+ - LLM-based code generation (CodeGen, CodeT5, StarCoder)
4
+ - Semantic code embeddings for intelligent similarity
5
+ - Reinforcement learning from validation feedback
6
+ - Multi-strategy retrieval (protocol-first, semantic, text)
7
  - Spec-aware adaptation
8
  - Code validation
9
  - Multi-level fallback
10
  - Comprehensive reporting
11
 
12
+ This model uses actual AI/ML:
13
+ 1. Neural semantic embeddings (CodeBERT) for similarity
14
+ 2. LLM generation (CodeGen, CodeT5) for actual code generation
15
+ 3. Reinforcement learning that learns from validation feedback
16
+ 4. Pattern learning from success/failure patterns
17
+ 5. Auto-improving generation strategies
 
18
  """
19
 
20
  from __future__ import annotations
 
45
  from src.models.similarity_index import SimilarityIndex, get_global_index
46
  from src.models.template_model import TemplateModel
47
 
48
+ try:
49
+ from src.models.semantic_encoder import SemanticCodeEncoder, SemanticEmbedding
50
+ from src.models.llm_generator import LLMCodeGenerator, LLMGenerationResult
51
+ from src.models.learning_module import LearningModule, ValidationFeedback
52
+
53
+ ML_MODULES_AVAILABLE = True
54
+ except ImportError as e:
55
+ logger.warning("Advanced ML modules not available: %s", e)
56
+ ML_MODULES_AVAILABLE = False
57
+
58
  logger = logging.getLogger("uvmgen")
59
 
60
 
 
62
  RETRIEVAL_HIGH_CONF = "retrieval_high_confidence"
63
  RETRIEVAL_MEDIUM_CONF = "retrieval_medium_confidence"
64
  RETRIEVAL_LOW_CONF = "retrieval_low_confidence"
65
+ LLM_GENERATION = "llm_generation"
66
+ LLM_FALLBACK = "llm_fallback"
67
  TEMPLATE_FALLBACK = "template_fallback"
68
  BLENDED = "blended"
69
  HYBRID = "hybrid"
70
+ LEARNING_IMPROVED = "learning_improved"
71
 
72
 
73
  @dataclass
 
132
 
133
  class EnhancedMLGenerationModel(GenerationModel):
134
  """
135
+ Industry-level AI/ML generation model with actual learning capabilities.
136
+
137
+ Key AI/ML features:
138
+ 1. LLM-based code generation (CodeGen, CodeT5, StarCoder)
139
+ 2. Semantic code embeddings (CodeBERT) for intelligent similarity
140
+ 3. Reinforcement learning from validation feedback
141
+ 4. Pattern learning from success/failure patterns
142
+ 5. Multi-strategy retrieval with intelligent selection
143
+ 6. Auto-improving generation strategies
144
+
145
+ Traditional features:
146
+ - Spec-aware adaptation with signal/register mapping
147
+ - Pre-validation before output
148
+ - Multi-level fallback strategies
149
+ - Comprehensive reporting and audit trail
150
  """
151
 
152
  def __init__(
 
156
  index: Optional[SimilarityIndex] = None,
157
  templates_dir: Optional[str] = None,
158
  strict_validation: bool = True,
159
+ use_llm: bool = True,
160
+ use_semantic_encoder: bool = True,
161
+ use_learning: bool = True,
162
+ llm_model_name: Optional[str] = None,
163
+ learning_storage_path: Optional[str] = None,
164
  ):
165
  super().__init__(name)
166
  self.config = config or MLModelConfig()
 
170
  self._strict_validation = strict_validation
171
  self._metadata: Dict[str, Any] = {}
172
  self._last_result: Optional[GenerationResult] = None
173
+ self._last_retrieval: Optional[Any] = None
174
+
175
+ self._use_llm = use_llm and ML_MODULES_AVAILABLE
176
+ self._use_semantic = use_semantic_encoder and ML_MODULES_AVAILABLE
177
+ self._use_learning = use_learning and ML_MODULES_AVAILABLE
178
+
179
+ self._llm_generator: Optional[LLMCodeGenerator] = None
180
+ self._semantic_encoder: Optional[SemanticCodeEncoder] = None
181
+ self._learning_module: Optional[LearningModule] = None
182
+
183
+ if self._use_llm:
184
+ self._llm_generator = LLMCodeGenerator(model_name=llm_model_name)
185
+ logger.info("LLM generator enabled: %s", llm_model_name or "default")
186
+
187
+ if self._use_semantic:
188
+ self._semantic_encoder = SemanticCodeEncoder()
189
+ logger.info("Semantic encoder enabled")
190
+
191
+ if self._use_learning:
192
+ self._learning_module = LearningModule(storage_path=learning_storage_path)
193
+ logger.info("Learning module enabled")
194
 
195
  @property
196
  def index(self) -> SimilarityIndex:
 
210
  )
211
  return self._template_model
212
 
213
+ @property
214
+ def last_retrieval(self) -> Optional[Any]:
215
+ """Get information about the last retrieval operation."""
216
+ from src.models.ml_generation_model import RetrievalInfo
217
+
218
+ if self._last_retrieval is not None:
219
+ return self._last_retrieval
220
+
221
+ if self._last_result is not None:
222
+ return RetrievalInfo(
223
+ used_similarity=(self._last_result.similar_specs_found > 0),
224
+ similar_specs=self._last_result.similar_specs_found,
225
+ best_score=self._last_result.best_match_score,
226
+ )
227
+
228
+ return RetrievalInfo(used_similarity=False, similar_specs=0, best_score=0.0)
229
+
230
  def train(self, specs: List[DesignSpec]) -> Dict[str, Any]:
231
  """Train the model by adding specs to the similarity index."""
232
  from src.features.extractors import RichSpecFeatureExtractor
 
290
  extra_seqs: Optional[List[str]] = None,
291
  ) -> Dict[str, str]:
292
  """
293
+ Generate testbench with AI/ML-powered generation and fallback.
294
+
295
+ AI/ML Workflow:
296
+ 1. Use learning module to select best generation strategy
297
+ 2. Try semantic similarity search (if semantic encoder available)
298
+ 3. Try LLM-based code generation (if LLM available)
299
+ 4. Try traditional retrieval-based generation
300
+ 5. Fallback to templates
301
+ 6. Record validation feedback to learning module
302
+ 7. Auto-learn from successful generation
303
+
304
+ Traditional features:
305
+ - Spec-aware adaptation
306
+ - Pre-validation before writing
307
+ - Multi-level fallback
308
  """
309
  if not self._is_trained:
310
  self.train([])
 
314
  query_fv = extractor.extract(spec)
315
  query_dict = self._spec_to_dict(spec)
316
 
317
+ protocol = query_dict.get("protocol", "unknown")
318
+
319
+ available_strategies = ["retrieval"]
320
+ if self._use_llm and self._llm_generator:
321
+ available_strategies.append("llm")
322
+ available_strategies.append("template")
323
+
324
+ selected_strategy = "retrieval"
325
+ strategy_confidence = 0.5
326
+
327
+ if self._use_learning and self._learning_module:
328
+ selected_strategy, strategy_confidence = (
329
+ self._learning_module.select_best_generation_strategy(
330
+ spec_dict=query_dict,
331
+ file_type="testbench",
332
+ available_sources=available_strategies,
333
+ )
334
+ )
335
+ logger.info(
336
+ "Learning module selected strategy: '%s' (confidence: %.2f)",
337
+ selected_strategy,
338
+ strategy_confidence,
339
+ )
340
+
341
  similar = self.index.search(
342
  query_fv,
343
  top_k=self.config.top_k_retrieval,
344
  min_similarity=0.3,
345
  )
346
 
347
+ if self._use_semantic and self._semantic_encoder and similar:
348
+ similar = self._enhance_with_semantic_similarity(
349
+ similar, query_dict
350
+ )
351
+
352
  logger.info(
353
  "Enhanced ML generation: found %d similar specs, best score: %.3f",
354
  len(similar), similar[0].similarity if similar else 0.0
 
356
 
357
  result: Optional[GenerationResult] = None
358
 
359
+ if selected_strategy == "llm" and self._use_llm and self._llm_generator:
360
+ logger.info("Trying LLM-based generation (selected by learning module)")
361
+ result = self._try_llm_generation(query_dict, spec, cfg)
362
+
363
+ if result is None and similar and similar[0].similarity >= self.config.similarity_threshold:
364
  result = self._try_retrieval_generation(
365
  similar, query_fv, query_dict, spec, cfg
366
  )
367
 
368
+ if result is None and self._use_llm and self._llm_generator:
369
+ logger.info("Trying LLM-based generation as fallback")
370
+ result = self._try_llm_generation(query_dict, spec, cfg)
371
+
372
  if (
373
  result is None
374
  or (self._strict_validation and not result.passed)
375
  and self.config.fallback_to_templates
376
  ):
377
  if result is None:
378
+ logger.info("No valid ML/LLM candidate, falling back to templates")
379
  else:
380
  logger.warning(
381
+ "LLM/retrieval generation failed validation (errors: %d), falling back to templates",
382
  result.validation_report.total_errors if result.validation_report else 0
383
  )
384
  result = self._generate_with_fallback(spec, cfg, extra_seqs, result)
 
386
  if result is None:
387
  raise RuntimeError("All generation strategies failed")
388
 
389
+ if self._use_learning and self._learning_module and result.validation_report:
390
+ logger.info("Recording validation feedback to learning module")
391
+ self._learning_module.record_feedback(
392
+ design_name=spec.design_name,
393
+ generation_source=result.source.value,
394
+ spec_dict=query_dict,
395
+ validation_results=result.validation_report.to_dict(),
396
+ )
397
+
398
  if self.config.auto_learn and result.passed:
399
  self._learn_from_result(result, query_fv, query_dict)
400
 
 
403
 
404
  return result.generated_files
405
 
406
+ def _enhance_with_semantic_similarity(
407
+ self,
408
+ similar: List[Any],
409
+ query_dict: Dict[str, Any],
410
+ ) -> List[Any]:
411
+ """Enhance similarity scores using semantic code embeddings."""
412
+ if not self._semantic_encoder or not self._semantic_encoder.is_available():
413
+ return similar
414
+
415
+ try:
416
+ query_text = self._spec_dict_to_text(query_dict)
417
+ query_emb = self._semantic_encoder.encode(
418
+ text=query_text,
419
+ embedding_type="spec",
420
+ metadata=query_dict,
421
+ )
422
+
423
+ for item in similar:
424
+ spec_text = self._spec_dict_to_text(item.spec_dict)
425
+ cand_emb = self._semantic_encoder.encode(
426
+ text=spec_text,
427
+ embedding_type="spec",
428
+ metadata=item.spec_dict,
429
+ )
430
+
431
+ semantic_sim = self._semantic_encoder.similarity(query_emb, cand_emb)
432
+
433
+ original_sim = item.similarity
434
+ item.similarity = (original_sim * 0.6) + (semantic_sim * 0.4)
435
+
436
+ logger.debug(
437
+ "Semantic enhancement: original=%.3f, semantic=%.3f, combined=%.3f",
438
+ original_sim, semantic_sim, item.similarity
439
+ )
440
+
441
+ similar = sorted(similar, key=lambda x: x.similarity, reverse=True)
442
+
443
+ except Exception as e:
444
+ logger.warning("Semantic similarity enhancement failed: %s", e)
445
+
446
+ return similar
447
+
448
+ def _spec_dict_to_text(self, spec_dict: Dict[str, Any]) -> str:
449
+ """Convert spec dict to text for semantic encoding."""
450
+ parts = []
451
+ parts.append(f"design: {spec_dict.get('design_name', 'unknown')}")
452
+ parts.append(f"protocol: {spec_dict.get('protocol', 'unknown')}")
453
+
454
+ signals = spec_dict.get("signals", [])
455
+ if signals:
456
+ signal_names = [s.get("name", "") for s in signals if isinstance(s, dict)]
457
+ parts.append(f"signals: {', '.join(signal_names[:20])}")
458
+
459
+ registers = spec_dict.get("registers", [])
460
+ if registers:
461
+ reg_names = [r.get("name", "") for r in registers if isinstance(r, dict)]
462
+ parts.append(f"registers: {', '.join(reg_names[:10])}")
463
+
464
+ features = spec_dict.get("features", [])
465
+ if features:
466
+ parts.append(f"features: {', '.join(features[:10])}")
467
+
468
+ return " | ".join(parts)
469
+
470
+ def _try_llm_generation(
471
+ self,
472
+ query_dict: Dict[str, Any],
473
+ spec: DesignSpec,
474
+ cfg: PipelineConfig,
475
+ ) -> Optional[GenerationResult]:
476
+ """
477
+ Try LLM-based code generation.
478
+
479
+ This uses actual AI/ML:
480
+ 1. LLM (CodeGen, CodeT5, etc.) generates SystemVerilog code
481
+ 2. Uses few-shot examples for UVM patterns
482
+ 3. Validates generated code
483
+ 4. Falls back to templates if needed
484
+ """
485
+ if not self._llm_generator:
486
+ return None
487
+
488
+ design_name = spec.design_name.lower()
489
+
490
+ file_types_to_generate = [
491
+ "driver",
492
+ "monitor",
493
+ "agent",
494
+ ]
495
+
496
+ generated_files: Dict[str, str] = {}
497
+ llm_results: Dict[str, LLMGenerationResult] = {}
498
+ all_warnings: List[str] = []
499
+ avg_confidence = 0.0
500
+
501
+ for file_type in file_types_to_generate:
502
+ try:
503
+ llm_result = self._llm_generator.generate(
504
+ spec_dict=query_dict,
505
+ file_type=file_type,
506
+ use_few_shot=True,
507
+ max_tokens=1024,
508
+ temperature=0.2,
509
+ )
510
+
511
+ llm_results[file_type] = llm_result
512
+ avg_confidence += llm_result.confidence
513
+ all_warnings.extend(llm_result.warnings)
514
+
515
+ file_name = f"{design_name}_{file_type}.sv"
516
+ generated_files[file_name] = llm_result.generated_code
517
+
518
+ logger.info(
519
+ "LLM generated %s (confidence: %.2f, tokens: %d)",
520
+ file_name,
521
+ llm_result.confidence,
522
+ llm_result.tokens_generated,
523
+ )
524
+
525
+ except Exception as e:
526
+ logger.warning("LLM generation failed for %s: %s", file_type, e)
527
+ all_warnings.append(f"LLM generation failed for {file_type}: {e}")
528
+
529
+ if not generated_files:
530
+ logger.warning("LLM generated no files, falling back")
531
+ return None
532
+
533
+ if llm_results:
534
+ avg_confidence /= len(llm_results)
535
+
536
+ try:
537
+ template_files = self.template_model.predict(spec, cfg)
538
+ template_contents: Dict[str, str] = {}
539
+ for fname, fpath in template_files.items():
540
+ try:
541
+ template_contents[fname] = Path(fpath).read_text(encoding="utf-8")
542
+ except Exception:
543
+ pass
544
+
545
+ for fname, content in template_contents.items():
546
+ if fname not in generated_files:
547
+ generated_files[fname] = content
548
+
549
+ except Exception as e:
550
+ logger.warning("Could not fill missing files from templates: %s", e)
551
+
552
+ validator = CodeValidator()
553
+ val_report = validator.validate_files(generated_files, query_dict)
554
+
555
+ total_errors = val_report.total_errors
556
+ total_warnings = val_report.total_warnings + len(all_warnings)
557
+
558
+ passed = val_report.overall_passed
559
+ if self._strict_validation:
560
+ passed = passed and (total_errors == 0)
561
+
562
+ generation_source = GenerationSource.LLM_GENERATION
563
+ if avg_confidence < 0.5:
564
+ generation_source = GenerationSource.LLM_FALLBACK
565
+
566
+ result = GenerationResult(
567
+ design_name=spec.design_name,
568
+ source=generation_source,
569
+ passed=passed,
570
+ generated_files=generated_files,
571
+ validation_report=val_report,
572
+ adaptation_plan=None,
573
+ similar_specs_found=0,
574
+ best_match_score=avg_confidence,
575
+ files_from_retrieval=[],
576
+ files_from_template=list(template_contents.keys()) if "template_contents" in dir() else [],
577
+ warnings=all_warnings + [
578
+ f"LLM confidence: {avg_confidence:.2f}",
579
+ f"LLM warnings: {len(all_warnings)}",
580
+ ],
581
+ errors=[f"LLM errors: {total_errors}"] if total_errors > 0 else [],
582
+ )
583
+
584
+ return result
585
+
586
  def _try_retrieval_generation(
587
  self,
588
  similar: List[Any],
src/models/learning_module.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+ from dataclasses import dataclass, field
4
+ from collections import defaultdict
5
+ import json
6
+ import os
7
+ from datetime import datetime
8
+
9
+ logger = logging.getLogger("uvmgen.ml.learning")
10
+
11
+
12
+ @dataclass
13
+ class ValidationFeedback:
14
+ design_name: str
15
+ file_name: str
16
+ file_type: str
17
+ passed: bool
18
+ errors: List[str]
19
+ warnings: List[str]
20
+ score: float
21
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
22
+ metadata: Dict[str, Any] = field(default_factory=dict)
23
+
24
+ def to_dict(self) -> Dict[str, Any]:
25
+ return {
26
+ "design_name": self.design_name,
27
+ "file_name": self.file_name,
28
+ "file_type": self.file_type,
29
+ "passed": self.passed,
30
+ "errors": self.errors,
31
+ "warnings": self.warnings,
32
+ "score": self.score,
33
+ "timestamp": self.timestamp,
34
+ "metadata": self.metadata,
35
+ }
36
+
37
+ @classmethod
38
+ def from_dict(cls, d: Dict[str, Any]) -> "ValidationFeedback":
39
+ return cls(
40
+ design_name=d.get("design_name", "unknown"),
41
+ file_name=d.get("file_name", "unknown"),
42
+ file_type=d.get("file_type", "unknown"),
43
+ passed=d.get("passed", False),
44
+ errors=d.get("errors", []),
45
+ warnings=d.get("warnings", []),
46
+ score=d.get("score", 0.0),
47
+ timestamp=d.get("timestamp", datetime.now().isoformat()),
48
+ metadata=d.get("metadata", {}),
49
+ )
50
+
51
+
52
+ @dataclass
53
+ class GenerationHistory:
54
+ design_name: str
55
+ generation_source: str
56
+ spec_hash: str
57
+ feedback_list: List[ValidationFeedback]
58
+ success_rate: float = 0.0
59
+ avg_score: float = 0.0
60
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
61
+
62
+ def to_dict(self) -> Dict[str, Any]:
63
+ return {
64
+ "design_name": self.design_name,
65
+ "generation_source": self.generation_source,
66
+ "spec_hash": self.spec_hash,
67
+ "feedback_list": [f.to_dict() for f in self.feedback_list],
68
+ "success_rate": self.success_rate,
69
+ "avg_score": self.avg_score,
70
+ "timestamp": self.timestamp,
71
+ }
72
+
73
+
74
+ class PatternLearner:
75
+ def __init__(self):
76
+ self._error_patterns: Dict[str, int] = defaultdict(int)
77
+ self._success_patterns: Dict[str, int] = defaultdict(int)
78
+ self._file_type_stats: Dict[str, Dict[str, Any]] = defaultdict(
79
+ lambda: {"success": 0, "total": 0, "errors": defaultdict(int)}
80
+ )
81
+ self._protocol_stats: Dict[str, Dict[str, Any]] = defaultdict(
82
+ lambda: {"success": 0, "total": 0}
83
+ )
84
+
85
+ def record_error(self, error_msg: str, file_type: str = "unknown"):
86
+ patterns = self._extract_patterns(error_msg)
87
+ for p in patterns:
88
+ self._error_patterns[p] += 1
89
+ self._file_type_stats[file_type]["errors"][error_msg[:100]] += 1
90
+
91
+ def record_success(self, file_type: str = "unknown", protocol: str = "unknown"):
92
+ self._file_type_stats[file_type]["success"] += 1
93
+ self._file_type_stats[file_type]["total"] += 1
94
+ self._protocol_stats[protocol]["success"] += 1
95
+ self._protocol_stats[protocol]["total"] += 1
96
+
97
+ def record_attempt(self, file_type: str = "unknown", protocol: str = "unknown"):
98
+ self._file_type_stats[file_type]["total"] += 1
99
+ self._protocol_stats[protocol]["total"] += 1
100
+
101
+ def _extract_patterns(self, text: str) -> List[str]:
102
+ import re
103
+
104
+ patterns = []
105
+
106
+ uvm_patterns = [
107
+ (r"uvm_fatal", "uvm_fatal"),
108
+ (r"uvm_error", "uvm_error"),
109
+ (r"uvm_component_utils", "missing_uvm_macro"),
110
+ (r"uvm_object_utils", "missing_uvm_macro"),
111
+ (r"build_phase", "phase_issue"),
112
+ (r"connect_phase", "phase_issue"),
113
+ (r"run_phase", "phase_issue"),
114
+ ]
115
+
116
+ for pattern, name in uvm_patterns:
117
+ if re.search(pattern, text, re.IGNORECASE):
118
+ patterns.append(name)
119
+
120
+ syntax_patterns = [
121
+ (r"missing.*semicolon", "missing_semicolon"),
122
+ (r"unbalanced.*parenthes", "unbalanced_parentheses"),
123
+ (r"unbalanced.*brace", "unbalanced_braces"),
124
+ (r"unbalanced.*bracket", "unbalanced_brackets"),
125
+ (r"mismatch.*begin", "mismatched_blocks"),
126
+ (r"syntax error", "syntax_error"),
127
+ ]
128
+
129
+ for pattern, name in syntax_patterns:
130
+ if re.search(pattern, text, re.IGNORECASE):
131
+ patterns.append(name)
132
+
133
+ if not patterns:
134
+ patterns.append("unknown_error")
135
+
136
+ return patterns
137
+
138
+ def get_common_errors(self, top_n: int = 10) -> List[Tuple[str, int]]:
139
+ sorted_errors = sorted(
140
+ self._error_patterns.items(),
141
+ key=lambda x: x[1],
142
+ reverse=True,
143
+ )
144
+ return sorted_errors[:top_n]
145
+
146
+ def get_file_type_success_rate(self, file_type: str) -> float:
147
+ stats = self._file_type_stats.get(file_type, {})
148
+ total = stats.get("total", 0)
149
+ if total == 0:
150
+ return 0.5
151
+ return stats.get("success", 0) / total
152
+
153
+ def get_protocol_success_rate(self, protocol: str) -> float:
154
+ stats = self._protocol_stats.get(protocol, {})
155
+ total = stats.get("total", 0)
156
+ if total == 0:
157
+ return 0.5
158
+ return stats.get("success", 0) / total
159
+
160
+ def to_dict(self) -> Dict[str, Any]:
161
+ return {
162
+ "error_patterns": dict(self._error_patterns),
163
+ "file_type_stats": {
164
+ ft: {
165
+ "success": s["success"],
166
+ "total": s["total"],
167
+ "errors": dict(s["errors"]),
168
+ }
169
+ for ft, s in self._file_type_stats.items()
170
+ },
171
+ "protocol_stats": dict(self._protocol_stats),
172
+ }
173
+
174
+ @classmethod
175
+ def from_dict(cls, d: Dict[str, Any]) -> "PatternLearner":
176
+ learner = cls()
177
+ learner._error_patterns = defaultdict(int, d.get("error_patterns", {}))
178
+
179
+ for ft, s in d.get("file_type_stats", {}).items():
180
+ learner._file_type_stats[ft] = {
181
+ "success": s.get("success", 0),
182
+ "total": s.get("total", 0),
183
+ "errors": defaultdict(int, s.get("errors", {})),
184
+ }
185
+
186
+ for proto, s in d.get("protocol_stats", {}).items():
187
+ learner._protocol_stats[proto] = {
188
+ "success": s.get("success", 0),
189
+ "total": s.get("total", 0),
190
+ }
191
+
192
+ return learner
193
+
194
+
195
+ class ReinforcementLearner:
196
+ def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9):
197
+ self._learning_rate = learning_rate
198
+ self._discount_factor = discount_factor
199
+ self._q_values: Dict[str, float] = defaultdict(lambda: 0.5)
200
+ self._visit_counts: Dict[str, int] = defaultdict(int)
201
+
202
+ def _get_state_key(
203
+ self,
204
+ protocol: str,
205
+ file_type: str,
206
+ generation_source: str,
207
+ ) -> str:
208
+ return f"{protocol}:{file_type}:{generation_source}"
209
+
210
+ def get_action_value(
211
+ self,
212
+ protocol: str,
213
+ file_type: str,
214
+ generation_source: str,
215
+ ) -> float:
216
+ key = self._get_state_key(protocol, file_type, generation_source)
217
+ return self._q_values[key]
218
+
219
+ def update(
220
+ self,
221
+ protocol: str,
222
+ file_type: str,
223
+ generation_source: str,
224
+ reward: float,
225
+ ):
226
+ key = self._get_state_key(protocol, file_type, generation_source)
227
+ old_value = self._q_values[key]
228
+ self._visit_counts[key] += 1
229
+ self._q_values[key] = (
230
+ old_value + self._learning_rate * (reward - old_value)
231
+ )
232
+
233
+ def select_best_action(
234
+ self,
235
+ protocol: str,
236
+ file_type: str,
237
+ available_sources: List[str],
238
+ epsilon: float = 0.1,
239
+ ) -> Tuple[str, float]:
240
+ import random
241
+
242
+ if random.random() < epsilon and len(available_sources) > 1:
243
+ chosen = random.choice(available_sources)
244
+ return chosen, self.get_action_value(protocol, file_type, chosen)
245
+
246
+ best_source = available_sources[0]
247
+ best_value = -1.0
248
+
249
+ for source in available_sources:
250
+ value = self.get_action_value(protocol, file_type, source)
251
+ if value > best_value:
252
+ best_value = value
253
+ best_source = source
254
+
255
+ return best_source, best_value
256
+
257
+ def to_dict(self) -> Dict[str, Any]:
258
+ return {
259
+ "learning_rate": self._learning_rate,
260
+ "discount_factor": self._discount_factor,
261
+ "q_values": dict(self._q_values),
262
+ "visit_counts": dict(self._visit_counts),
263
+ }
264
+
265
+ @classmethod
266
+ def from_dict(cls, d: Dict[str, Any]) -> "ReinforcementLearner":
267
+ learner = cls(
268
+ learning_rate=d.get("learning_rate", 0.1),
269
+ discount_factor=d.get("discount_factor", 0.9),
270
+ )
271
+ learner._q_values = defaultdict(lambda: 0.5)
272
+ learner._q_values.update(d.get("q_values", {}))
273
+ learner._visit_counts = defaultdict(int)
274
+ learner._visit_counts.update(d.get("visit_counts", {}))
275
+ return learner
276
+
277
+
278
+ class LearningModule:
279
+ def __init__(self, storage_path: Optional[str] = None):
280
+ self._storage_path = storage_path
281
+ self._pattern_learner = PatternLearner()
282
+ self._rl_learner = ReinforcementLearner()
283
+ self._history: List[GenerationHistory] = []
284
+ self._total_generations = 0
285
+ self._successful_generations = 0
286
+
287
+ if storage_path:
288
+ self._load_from_storage()
289
+
290
+ def record_feedback(
291
+ self,
292
+ design_name: str,
293
+ generation_source: str,
294
+ spec_dict: Dict[str, Any],
295
+ validation_results: Dict[str, Any],
296
+ ):
297
+ import hashlib
298
+ import json
299
+
300
+ spec_str = json.dumps(spec_dict, sort_keys=True)
301
+ spec_hash = hashlib.md5(spec_str.encode()).hexdigest()[:12]
302
+
303
+ protocol = spec_dict.get("protocol", "unknown")
304
+
305
+ feedback_list = []
306
+
307
+ files_data = validation_results.get("files", [])
308
+
309
+ if isinstance(files_data, dict):
310
+ for file_name, file_info in files_data.items():
311
+ file_type = file_info.get("type", "unknown")
312
+ passed = file_info.get("passed", True)
313
+ errors = file_info.get("errors", [])
314
+ warnings = file_info.get("warnings", [])
315
+ score = file_info.get("score", 0.5)
316
+
317
+ feedback = ValidationFeedback(
318
+ design_name=design_name,
319
+ file_name=file_name,
320
+ file_type=file_type,
321
+ passed=passed,
322
+ errors=errors,
323
+ warnings=warnings,
324
+ score=score,
325
+ )
326
+ feedback_list.append(feedback)
327
+
328
+ if passed:
329
+ self._pattern_learner.record_success(file_type, protocol)
330
+ reward = 1.0
331
+ else:
332
+ for err in errors:
333
+ self._pattern_learner.record_error(err, file_type)
334
+ reward = -0.5
335
+
336
+ self._pattern_learner.record_attempt(file_type, protocol)
337
+ self._rl_learner.update(protocol, file_type, generation_source, reward)
338
+
339
+ elif isinstance(files_data, list):
340
+ for file_info in files_data:
341
+ file_name = file_info.get("filename", "unknown")
342
+ file_type = file_info.get("file_type", "unknown")
343
+ passed = file_info.get("passed", True)
344
+
345
+ issues = file_info.get("issues", [])
346
+ errors = []
347
+ warnings = []
348
+ for issue in issues:
349
+ severity = issue.get("severity", "warning")
350
+ message = issue.get("message", "")
351
+ if severity == "error":
352
+ errors.append(message)
353
+ else:
354
+ warnings.append(message)
355
+
356
+ error_count = file_info.get("error_count", 0)
357
+ warning_count = file_info.get("warning_count", 0)
358
+
359
+ if error_count > 0:
360
+ passed = False
361
+
362
+ score = 1.0 if passed else 0.3
363
+ if passed and warning_count == 0:
364
+ score = 1.0
365
+ elif passed and warning_count > 0:
366
+ score = 0.7
367
+
368
+ feedback = ValidationFeedback(
369
+ design_name=design_name,
370
+ file_name=file_name,
371
+ file_type=file_type,
372
+ passed=passed,
373
+ errors=errors,
374
+ warnings=warnings,
375
+ score=score,
376
+ )
377
+ feedback_list.append(feedback)
378
+
379
+ if passed:
380
+ self._pattern_learner.record_success(file_type, protocol)
381
+ reward = 1.0
382
+ else:
383
+ for err in errors:
384
+ self._pattern_learner.record_error(err, file_type)
385
+ reward = -0.5
386
+
387
+ self._pattern_learner.record_attempt(file_type, protocol)
388
+ self._rl_learner.update(protocol, file_type, generation_source, reward)
389
+
390
+ all_passed = all(f.passed for f in feedback_list)
391
+ avg_score = sum(f.score for f in feedback_list) / len(feedback_list) if feedback_list else 0.0
392
+
393
+ history = GenerationHistory(
394
+ design_name=design_name,
395
+ generation_source=generation_source,
396
+ spec_hash=spec_hash,
397
+ feedback_list=feedback_list,
398
+ success_rate=1.0 if all_passed else 0.0,
399
+ avg_score=avg_score,
400
+ )
401
+ self._history.append(history)
402
+
403
+ self._total_generations += 1
404
+ if all_passed:
405
+ self._successful_generations += 1
406
+
407
+ if self._storage_path:
408
+ self._save_to_storage()
409
+
410
+ def select_best_generation_strategy(
411
+ self,
412
+ spec_dict: Dict[str, Any],
413
+ file_type: str,
414
+ available_sources: List[str],
415
+ ) -> Tuple[str, float]:
416
+ protocol = spec_dict.get("protocol", "unknown")
417
+
418
+ best_source, best_value = self._rl_learner.select_best_action(
419
+ protocol=protocol,
420
+ file_type=file_type,
421
+ available_sources=available_sources,
422
+ epsilon=0.05,
423
+ )
424
+
425
+ return best_source, best_value
426
+
427
+ def get_generation_hints(
428
+ self,
429
+ spec_dict: Dict[str, Any],
430
+ file_type: str,
431
+ ) -> Dict[str, Any]:
432
+ protocol = spec_dict.get("protocol", "unknown")
433
+
434
+ common_errors = self._pattern_learner.get_common_errors(5)
435
+ file_success_rate = self._pattern_learner.get_file_type_success_rate(file_type)
436
+ protocol_success_rate = self._pattern_learner.get_protocol_success_rate(protocol)
437
+
438
+ return {
439
+ "common_errors": common_errors,
440
+ "file_type_success_rate": file_success_rate,
441
+ "protocol_success_rate": protocol_success_rate,
442
+ "recommendations": self._generate_recommendations(
443
+ common_errors,
444
+ file_success_rate,
445
+ protocol_success_rate,
446
+ ),
447
+ }
448
+
449
+ def _generate_recommendations(
450
+ self,
451
+ common_errors: List[Tuple[str, int]],
452
+ file_success_rate: float,
453
+ protocol_success_rate: float,
454
+ ) -> List[str]:
455
+ recommendations = []
456
+
457
+ for error_pattern, count in common_errors[:3]:
458
+ if count > 0:
459
+ if "semicolon" in error_pattern:
460
+ recommendations.append(
461
+ "Ensure all statements end with semicolons"
462
+ )
463
+ elif "parenthes" in error_pattern:
464
+ recommendations.append(
465
+ "Check for balanced parentheses"
466
+ )
467
+ elif "brace" in error_pattern:
468
+ recommendations.append(
469
+ "Check for balanced begin/end blocks"
470
+ )
471
+ elif "uvm_macro" in error_pattern:
472
+ recommendations.append(
473
+ "Add UVM factory registration macros (uvm_component_utils/uvm_object_utils)"
474
+ )
475
+ elif "phase" in error_pattern:
476
+ recommendations.append(
477
+ "Ensure proper UVM phase implementation"
478
+ )
479
+
480
+ if file_success_rate < 0.7:
481
+ recommendations.append(
482
+ "Consider using retrieval-based generation for this file type"
483
+ )
484
+
485
+ if protocol_success_rate < 0.7:
486
+ recommendations.append(
487
+ "Add protocol-specific templates may improve quality"
488
+ )
489
+
490
+ if not recommendations:
491
+ recommendations.append(
492
+ "No specific recommendations - generation should work well"
493
+ )
494
+
495
+ return recommendations
496
+
497
+ def get_stats(self) -> Dict[str, Any]:
498
+ return {
499
+ "total_generations": self._total_generations,
500
+ "successful_generations": self._successful_generations,
501
+ "success_rate": (
502
+ self._successful_generations / self._total_generations
503
+ if self._total_generations > 0
504
+ else 0.0
505
+ ),
506
+ "history_count": len(self._history),
507
+ "pattern_stats": self._pattern_learner.to_dict(),
508
+ }
509
+
510
+ def _save_to_storage(self):
511
+ if not self._storage_path:
512
+ return
513
+
514
+ try:
515
+ os.makedirs(os.path.dirname(self._storage_path), exist_ok=True)
516
+
517
+ data = {
518
+ "pattern_learner": self._pattern_learner.to_dict(),
519
+ "rl_learner": self._rl_learner.to_dict(),
520
+ "history": [h.to_dict() for h in self._history[-100:]],
521
+ "total_generations": self._total_generations,
522
+ "successful_generations": self._successful_generations,
523
+ "saved_at": datetime.now().isoformat(),
524
+ }
525
+
526
+ with open(self._storage_path, "w") as f:
527
+ json.dump(data, f, indent=2)
528
+
529
+ logger.debug("Learning module saved to: %s", self._storage_path)
530
+
531
+ except Exception as e:
532
+ logger.warning("Could not save learning module: %s", e)
533
+
534
+ def _load_from_storage(self):
535
+ if not self._storage_path or not os.path.exists(self._storage_path):
536
+ return
537
+
538
+ try:
539
+ with open(self._storage_path, "r") as f:
540
+ data = json.load(f)
541
+
542
+ self._pattern_learner = PatternLearner.from_dict(
543
+ data.get("pattern_learner", {})
544
+ )
545
+ self._rl_learner = ReinforcementLearner.from_dict(
546
+ data.get("rl_learner", {})
547
+ )
548
+
549
+ history_list = data.get("history", [])
550
+ for h_dict in history_list:
551
+ feedback_list = [
552
+ ValidationFeedback.from_dict(f)
553
+ for f in h_dict.get("feedback_list", [])
554
+ ]
555
+ history = GenerationHistory(
556
+ design_name=h_dict.get("design_name", "unknown"),
557
+ generation_source=h_dict.get("generation_source", "unknown"),
558
+ spec_hash=h_dict.get("spec_hash", ""),
559
+ feedback_list=feedback_list,
560
+ success_rate=h_dict.get("success_rate", 0.0),
561
+ avg_score=h_dict.get("avg_score", 0.0),
562
+ timestamp=h_dict.get("timestamp", datetime.now().isoformat()),
563
+ )
564
+ self._history.append(history)
565
+
566
+ self._total_generations = data.get("total_generations", 0)
567
+ self._successful_generations = data.get("successful_generations", 0)
568
+
569
+ logger.info("Learning module loaded from: %s", self._storage_path)
570
+
571
+ except Exception as e:
572
+ logger.warning("Could not load learning module: %s", e)
src/models/llm_generator.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ import json
6
+ import re
7
+
8
+ logger = logging.getLogger("uvmgen.ml.llm")
9
+
10
+
11
+ class LLMType(Enum):
12
+ CODEGEN = "codegen"
13
+ CODET5 = "codet5"
14
+ CODEBERT = "codebert"
15
+ STARCODER = "starcoder"
16
+ LLAMA = "llama"
17
+ MISTRAL = "mistral"
18
+ FALLBACK = "fallback"
19
+
20
+
21
+ @dataclass
22
+ class LLMGenerationResult:
23
+ generated_code: str
24
+ prompt_used: str
25
+ model_name: str
26
+ tokens_generated: int
27
+ confidence: float = 0.5
28
+ metadata: Dict[str, Any] = field(default_factory=dict)
29
+ warnings: List[str] = field(default_factory=list)
30
+ errors: List[str] = field(default_factory=list)
31
+
32
+
33
+ class LLMCodeGenerator:
34
+ _instance: Optional["LLMCodeGenerator"] = None
35
+ _model = None
36
+ _tokenizer = None
37
+ _model_name: str = "Salesforce/codegen-350M-mono"
38
+ _device: str = "cpu"
39
+ _initialized: bool = False
40
+ _llm_type: LLMType = LLMType.FALLBACK
41
+
42
+ UVM_PROMPT_TEMPLATE = """
43
+ You are an expert in UVM (Universal Verification Methodology) and SystemVerilog.
44
+ Generate production-quality UVM testbench code based on the following specification.
45
+
46
+ SPECIFICATION:
47
+ {spec_text}
48
+
49
+ REQUIREMENTS:
50
+ - Follow UVM 1.2 conventions and best practices
51
+ - Use proper factory registration with `uvm_component_utils` or `uvm_object_utils`
52
+ - Include appropriate phases (build_phase, connect_phase, run_phase)
53
+ - Use TLM ports and exports for component communication
54
+ - Include proper configuration database usage if needed
55
+ - Generate synthesizable SystemVerilog code
56
+
57
+ {context_examples}
58
+
59
+ Generate the {file_type} for this specification. Return only the SystemVerilog code, no explanations.
60
+ """
61
+
62
+ FEW_SHOT_EXAMPLES = {
63
+ "driver": """
64
+ EXAMPLE DRIVER:
65
+ class my_driver extends uvm_driver #(my_seq_item);
66
+ `uvm_component_utils(my_driver)
67
+
68
+ virtual my_if vif;
69
+
70
+ function new(string name = "my_driver", uvm_component parent = null);
71
+ super.new(name, parent);
72
+ endfunction
73
+
74
+ function void build_phase(uvm_phase phase);
75
+ super.build_phase(phase);
76
+ if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
77
+ `uvm_fatal(get_type_name(), "Virtual interface not found")
78
+ endfunction
79
+
80
+ task run_phase(uvm_phase phase);
81
+ forever begin
82
+ seq_item_port.get_next_item(req);
83
+ drive_item(req);
84
+ seq_item_port.item_done();
85
+ end
86
+ endtask
87
+
88
+ task drive_item(my_seq_item item);
89
+ @(posedge vif.clk);
90
+ vif.valid <= 1'b1;
91
+ vif.data <= item.data;
92
+ @(posedge vif.clk);
93
+ vif.valid <= 1'b0;
94
+ endtask
95
+ endclass
96
+ """,
97
+ "monitor": """
98
+ EXAMPLE MONITOR:
99
+ class my_monitor extends uvm_monitor;
100
+ `uvm_component_utils(my_monitor)
101
+
102
+ uvm_analysis_port #(my_seq_item) item_collected_port;
103
+ virtual my_if vif;
104
+
105
+ function new(string name = "my_monitor", uvm_component parent = null);
106
+ super.new(name, parent);
107
+ item_collected_port = new("item_collected_port", this);
108
+ endfunction
109
+
110
+ function void build_phase(uvm_phase phase);
111
+ super.build_phase(phase);
112
+ if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
113
+ `uvm_fatal(get_type_name(), "Virtual interface not found")
114
+ endfunction
115
+
116
+ task run_phase(uvm_phase phase);
117
+ my_seq_item item;
118
+ forever begin
119
+ @(posedge vif.clk);
120
+ if (vif.valid) begin
121
+ item = my_seq_item::type_id::create("item");
122
+ item.data = vif.data;
123
+ item_collected_port.write(item);
124
+ end
125
+ end
126
+ endtask
127
+ endclass
128
+ """,
129
+ "agent": """
130
+ EXAMPLE AGENT:
131
+ class my_agent extends uvm_agent;
132
+ `uvm_component_utils(my_agent)
133
+
134
+ my_driver driver;
135
+ my_monitor monitor;
136
+ my_sequencer sequencer;
137
+ uvm_analysis_port #(my_seq_item) item_collected_port;
138
+
139
+ function new(string name = "my_agent", uvm_component parent = null);
140
+ super.new(name, parent);
141
+ item_collected_port = new("item_collected_port", this);
142
+ endfunction
143
+
144
+ function void build_phase(uvm_phase phase);
145
+ super.build_phase(phase);
146
+
147
+ if (get_is_active() == UVM_ACTIVE) begin
148
+ driver = my_driver::type_id::create("driver", this);
149
+ sequencer = my_sequencer::type_id::create("sequencer", this);
150
+ end
151
+
152
+ monitor = my_monitor::type_id::create("monitor", this);
153
+ endfunction
154
+
155
+ function void connect_phase(uvm_phase phase);
156
+ super.connect_phase(phase);
157
+
158
+ if (get_is_active() == UVM_ACTIVE) begin
159
+ driver.seq_item_port.connect(sequencer.seq_item_export);
160
+ end
161
+
162
+ monitor.item_collected_port.connect(item_collected_port);
163
+ endfunction
164
+ endclass
165
+ """,
166
+ }
167
+
168
+ def __new__(cls, *args, **kwargs):
169
+ if cls._instance is None:
170
+ cls._instance = super().__new__(cls)
171
+ return cls._instance
172
+
173
+ def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
174
+ if self._initialized:
175
+ return
176
+
177
+ if model_name:
178
+ self._model_name = model_name
179
+ if device:
180
+ self._device = device
181
+
182
+ self._initialized = False
183
+ self._model = None
184
+ self._tokenizer = None
185
+ self._detect_llm_type()
186
+
187
+ def _detect_llm_type(self):
188
+ name_lower = self._model_name.lower()
189
+ if "codegen" in name_lower:
190
+ self._llm_type = LLMType.CODEGEN
191
+ elif "codet5" in name_lower:
192
+ self._llm_type = LLMType.CODET5
193
+ elif "codebert" in name_lower:
194
+ self._llm_type = LLMType.CODEBERT
195
+ elif "starcoder" in name_lower or "starcoder" in name_lower:
196
+ self._llm_type = LLMType.STARCODER
197
+ elif "llama" in name_lower:
198
+ self._llm_type = LLMType.LLAMA
199
+ elif "mistral" in name_lower:
200
+ self._llm_type = LLMType.MISTRAL
201
+ else:
202
+ self._llm_type = LLMType.FALLBACK
203
+
204
+ def _load_model(self):
205
+ if self._initialized and self._model is not None:
206
+ return
207
+
208
+ if self._llm_type == LLMType.FALLBACK:
209
+ logger.info("LLMCodeGenerator using fallback mode (template-based)")
210
+ self._initialized = True
211
+ return
212
+
213
+ try:
214
+ import torch
215
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
216
+
217
+ if self._device == "auto":
218
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
219
+
220
+ logger.info("Loading LLM: %s on %s", self._model_name, self._device)
221
+
222
+ self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
223
+
224
+ if self._llm_type == LLMType.CODET5:
225
+ self._model = AutoModelForSeq2SeqLM.from_pretrained(
226
+ self._model_name,
227
+ torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
228
+ )
229
+ else:
230
+ self._model = AutoModelForCausalLM.from_pretrained(
231
+ self._model_name,
232
+ torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
233
+ )
234
+
235
+ self._model.to(self._device)
236
+ self._model.eval()
237
+
238
+ if self._tokenizer.pad_token is None:
239
+ self._tokenizer.pad_token = self._tokenizer.eos_token
240
+
241
+ self._initialized = True
242
+ logger.info("LLM loaded successfully")
243
+
244
+ except ImportError as e:
245
+ logger.warning(
246
+ "Could not load LLM (missing dependencies: %s). Using fallback mode.",
247
+ e,
248
+ )
249
+ self._llm_type = LLMType.FALLBACK
250
+ self._initialized = True
251
+ except Exception as e:
252
+ logger.warning(
253
+ "Could not load LLM (%s). Using fallback mode.",
254
+ e,
255
+ )
256
+ self._llm_type = LLMType.FALLBACK
257
+ self._initialized = True
258
+
259
+ def is_available(self) -> bool:
260
+ self._load_model()
261
+ return self._initialized and self._llm_type != LLMType.FALLBACK
262
+
263
+ def _spec_to_text(self, spec_dict: Dict[str, Any]) -> str:
264
+ lines = []
265
+
266
+ if "design_name" in spec_dict:
267
+ lines.append(f"Design Name: {spec_dict['design_name']}")
268
+
269
+ if "protocol" in spec_dict:
270
+ lines.append(f"Protocol: {spec_dict['protocol']}")
271
+
272
+ if "signals" in spec_dict:
273
+ lines.append("\nSignals:")
274
+ for sig in spec_dict["signals"]:
275
+ name = sig.get("name", "unknown")
276
+ direction = sig.get("direction", "inout")
277
+ width = sig.get("width", 1)
278
+ desc = sig.get("description", "")
279
+ lines.append(f" - {name}: {direction}, width={width} {desc}")
280
+
281
+ if "registers" in spec_dict:
282
+ lines.append("\nRegisters:")
283
+ for reg in spec_dict["registers"]:
284
+ name = reg.get("name", "unknown")
285
+ addr = reg.get("address", "0x0")
286
+ width = reg.get("width", 32)
287
+ lines.append(f" - {name}: addr={addr}, width={width}")
288
+
289
+ if "features" in spec_dict:
290
+ lines.append("\nFeatures:")
291
+ for feat in spec_dict["features"]:
292
+ lines.append(f" - {feat}")
293
+
294
+ return "\n".join(lines)
295
+
296
+ def _build_prompt(
297
+ self,
298
+ spec_dict: Dict[str, Any],
299
+ file_type: str,
300
+ use_few_shot: bool = True,
301
+ ) -> str:
302
+ spec_text = self._spec_to_text(spec_dict)
303
+
304
+ context_examples = ""
305
+ if use_few_shot and file_type in self.FEW_SHOT_EXAMPLES:
306
+ context_examples = self.FEW_SHOT_EXAMPLES[file_type]
307
+
308
+ prompt = self.UVM_PROMPT_TEMPLATE.format(
309
+ spec_text=spec_text,
310
+ file_type=file_type,
311
+ context_examples=context_examples,
312
+ )
313
+
314
+ return prompt.strip()
315
+
316
+ def _extract_code(self, text: str) -> str:
317
+ code_block_patterns = [
318
+ r"```systemverilog\s+(.*?)```",
319
+ r"```verilog\s+(.*?)```",
320
+ r"```sv\s+(.*?)```",
321
+ r"```\s+(.*?)```",
322
+ ]
323
+
324
+ for pattern in code_block_patterns:
325
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
326
+ if match:
327
+ return match.group(1).strip()
328
+
329
+ return text.strip()
330
+
331
+ def _fallback_generate(
332
+ self,
333
+ spec_dict: Dict[str, Any],
334
+ file_type: str,
335
+ templates: Optional[Dict[str, str]] = None,
336
+ ) -> LLMGenerationResult:
337
+ design_name = spec_dict.get("design_name", "unknown").lower()
338
+
339
+ fallback_templates = {
340
+ "driver": f"""
341
+ class {design_name}_driver extends uvm_driver #({design_name}_seq_item);
342
+ `uvm_component_utils({design_name}_driver)
343
+
344
+ virtual {design_name}_if vif;
345
+
346
+ function new(string name = "{design_name}_driver", uvm_component parent = null);
347
+ super.new(name, parent);
348
+ endfunction
349
+
350
+ function void build_phase(uvm_phase phase);
351
+ super.build_phase(phase);
352
+ if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
353
+ `uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
354
+ endfunction
355
+
356
+ task run_phase(uvm_phase phase);
357
+ forever begin
358
+ seq_item_port.get_next_item(req);
359
+ drive_item(req);
360
+ seq_item_port.item_done();
361
+ end
362
+ endtask
363
+
364
+ task drive_item({design_name}_seq_item item);
365
+ // Implement drive logic based on item
366
+ @(posedge vif.clk);
367
+ endtask
368
+ endclass
369
+ """,
370
+ "monitor": f"""
371
+ class {design_name}_monitor extends uvm_monitor;
372
+ `uvm_component_utils({design_name}_monitor)
373
+
374
+ uvm_analysis_port #({design_name}_seq_item) item_collected_port;
375
+ virtual {design_name}_if vif;
376
+
377
+ function new(string name = "{design_name}_monitor", uvm_component parent = null);
378
+ super.new(name, parent);
379
+ item_collected_port = new("item_collected_port", this);
380
+ endfunction
381
+
382
+ function void build_phase(uvm_phase phase);
383
+ super.build_phase(phase);
384
+ if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
385
+ `uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
386
+ endfunction
387
+
388
+ task run_phase(uvm_phase phase);
389
+ {design_name}_seq_item item;
390
+ forever begin
391
+ @(posedge vif.clk);
392
+ // Sample signals and create item
393
+ end
394
+ endtask
395
+ endclass
396
+ """,
397
+ "agent": f"""
398
+ class {design_name}_agent extends uvm_agent;
399
+ `uvm_component_utils({design_name}_agent)
400
+
401
+ {design_name}_driver driver;
402
+ {design_name}_monitor monitor;
403
+ {design_name}_sequencer sequencer;
404
+ uvm_analysis_port #({design_name}_seq_item) item_collected_port;
405
+
406
+ function new(string name = "{design_name}_agent", uvm_component parent = null);
407
+ super.new(name, parent);
408
+ item_collected_port = new("item_collected_port", this);
409
+ endfunction
410
+
411
+ function void build_phase(uvm_phase phase);
412
+ super.build_phase(phase);
413
+
414
+ if (get_is_active() == UVM_ACTIVE) begin
415
+ driver = {design_name}_driver::type_id::create("driver", this);
416
+ sequencer = {design_name}_sequencer::type_id::create("sequencer", this);
417
+ end
418
+
419
+ monitor = {design_name}_monitor::type_id::create("monitor", this);
420
+ endfunction
421
+
422
+ function void connect_phase(uvm_phase phase);
423
+ super.connect_phase(phase);
424
+
425
+ if (get_is_active() == UVM_ACTIVE) begin
426
+ driver.seq_item_port.connect(sequencer.seq_item_export);
427
+ end
428
+
429
+ monitor.item_collected_port.connect(item_collected_port);
430
+ endfunction
431
+ endclass
432
+ """,
433
+ }
434
+
435
+ if templates and file_type in templates:
436
+ code = templates[file_type]
437
+ elif file_type in fallback_templates:
438
+ code = fallback_templates[file_type]
439
+ else:
440
+ code = f"// {file_type} for {design_name} - template placeholder"
441
+
442
+ return LLMGenerationResult(
443
+ generated_code=code,
444
+ prompt_used=f"// Fallback generation for {file_type}",
445
+ model_name="fallback_template",
446
+ tokens_generated=len(code.split()),
447
+ confidence=0.3,
448
+ warnings=["Using fallback template generation (LLM not available)"],
449
+ )
450
+
451
+ def generate(
452
+ self,
453
+ spec_dict: Dict[str, Any],
454
+ file_type: str,
455
+ use_few_shot: bool = True,
456
+ max_tokens: int = 1024,
457
+ temperature: float = 0.2,
458
+ templates: Optional[Dict[str, str]] = None,
459
+ ) -> LLMGenerationResult:
460
+ self._load_model()
461
+
462
+ prompt = self._build_prompt(spec_dict, file_type, use_few_shot)
463
+
464
+ if self._llm_type == LLMType.FALLBACK or self._model is None:
465
+ return self._fallback_generate(spec_dict, file_type, templates)
466
+
467
+ try:
468
+ import torch
469
+
470
+ inputs = self._tokenizer(
471
+ prompt,
472
+ return_tensors="pt",
473
+ truncation=True,
474
+ max_length=1024,
475
+ padding=True,
476
+ )
477
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
478
+
479
+ with torch.no_grad():
480
+ if self._llm_type == LLMType.CODET5:
481
+ outputs = self._model.generate(
482
+ **inputs,
483
+ max_new_tokens=max_tokens,
484
+ temperature=temperature,
485
+ do_sample=temperature > 0,
486
+ num_return_sequences=1,
487
+ pad_token_id=self._tokenizer.pad_token_id,
488
+ eos_token_id=self._tokenizer.eos_token_id,
489
+ )
490
+ else:
491
+ outputs = self._model.generate(
492
+ **inputs,
493
+ max_new_tokens=max_tokens,
494
+ temperature=temperature,
495
+ do_sample=temperature > 0,
496
+ num_return_sequences=1,
497
+ pad_token_id=self._tokenizer.pad_token_id,
498
+ eos_token_id=self._tokenizer.eos_token_id,
499
+ )
500
+
501
+ generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
502
+
503
+ if generated_text.startswith(prompt):
504
+ generated_text = generated_text[len(prompt) :].strip()
505
+
506
+ code = self._extract_code(generated_text)
507
+ tokens_generated = len(outputs[0]) - inputs["input_ids"].shape[1]
508
+
509
+ confidence = 0.7
510
+ if "uvm_component_utils" in code or "uvm_object_utils" in code:
511
+ confidence += 0.1
512
+ if "class" in code and "extends" in code:
513
+ confidence += 0.05
514
+ if "build_phase" in code or "run_phase" in code:
515
+ confidence += 0.05
516
+ if "endclass" in code:
517
+ confidence += 0.05
518
+
519
+ confidence = min(confidence, 0.95)
520
+
521
+ return LLMGenerationResult(
522
+ generated_code=code,
523
+ prompt_used=prompt,
524
+ model_name=self._model_name,
525
+ tokens_generated=tokens_generated,
526
+ confidence=confidence,
527
+ warnings=[],
528
+ )
529
+
530
+ except Exception as e:
531
+ logger.warning("Error during LLM generation: %s. Using fallback.", e)
532
+ result = self._fallback_generate(spec_dict, file_type, templates)
533
+ result.warnings.append(f"LLM generation failed: {str(e)}")
534
+ return result
535
+
536
+ def generate_batch(
537
+ self,
538
+ spec_dict: Dict[str, Any],
539
+ file_types: List[str],
540
+ use_few_shot: bool = True,
541
+ max_tokens: int = 1024,
542
+ temperature: float = 0.2,
543
+ templates: Optional[Dict[str, str]] = None,
544
+ ) -> Dict[str, LLMGenerationResult]:
545
+ results = {}
546
+
547
+ for file_type in file_types:
548
+ results[file_type] = self.generate(
549
+ spec_dict=spec_dict,
550
+ file_type=file_type,
551
+ use_few_shot=use_few_shot,
552
+ max_tokens=max_tokens,
553
+ temperature=temperature,
554
+ templates=templates.get(file_type) if templates else None,
555
+ )
556
+
557
+ return results
src/models/ml_generation_model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Any
4
+
5
+
6
+ @dataclass
7
+ class MLModelConfig:
8
+ """Configuration for ML-based generation models."""
9
+ similarity_threshold: float = 0.75
10
+ auto_learn: bool = True
11
+ index_path: Optional[str] = None
12
+ top_k_retrieval: int = 3
13
+ fallback_to_templates: bool = True
14
+
15
+ use_llm: bool = True
16
+ llm_model_name: Optional[str] = None
17
+ llm_max_tokens: int = 1024
18
+ llm_temperature: float = 0.2
19
+ llm_use_few_shot: bool = True
20
+
21
+ use_semantic_encoder: bool = True
22
+ semantic_model_name: str = "microsoft/codebert-base"
23
+
24
+ use_learning: bool = True
25
+ learning_storage_path: Optional[str] = None
26
+ learning_rate: float = 0.1
27
+ reinforcement_discount: float = 0.9
28
+ exploration_epsilon: float = 0.05
29
+
30
+
31
+ class RetrievalInfo:
32
+ """Information about last retrieval operation."""
33
+ def __init__(self, used_similarity: bool = True, similar_specs: int = 0, best_score: float = 0.0):
34
+ self.used_similarity = used_similarity
35
+ self.similar_specs = similar_specs
36
+ self.best_score = best_score
37
+
38
+
39
+ class NameNormalizer:
40
+ """Utility for normalizing and adapting design names in filenames and code."""
41
+
42
+ DESIGN_NAME_PATTERN = re.compile(
43
+ r"([a-zA-Z_][a-zA-Z0-9_]*?)_(driver|monitor|agent|sequencer|sequence_item|sequence|scoreboard|coverage_collector|env|test|interface|testbench|ral_model|serial_monitor)",
44
+ re.IGNORECASE
45
+ )
46
+
47
+ @classmethod
48
+ def adapt_names(
49
+ cls,
50
+ filename: str,
51
+ old_design_name: str,
52
+ new_design_name: str,
53
+ ) -> str:
54
+ """
55
+ Adapt filenames and content from old design name to new design name.
56
+
57
+ Args:
58
+ filename: Original filename
59
+ old_design_name: Old design name to replace
60
+ new_design_name: New design name to use
61
+
62
+ Returns:
63
+ Adapted filename
64
+ """
65
+ if not old_design_name or not new_design_name:
66
+ return filename
67
+
68
+ old_lower = old_design_name.lower()
69
+ new_lower = new_design_name.lower()
70
+
71
+ base_name = filename
72
+ ext = ""
73
+
74
+ if "." in filename:
75
+ parts = filename.rsplit(".", 1)
76
+ base_name = parts[0]
77
+ ext = "." + parts[1] if len(parts) > 1 else ""
78
+
79
+ if old_lower in base_name.lower():
80
+ new_base = re.sub(
81
+ re.escape(old_design_name),
82
+ new_design_name,
83
+ base_name,
84
+ flags=re.IGNORECASE,
85
+ )
86
+ return new_base + ext
87
+
88
+ match = cls.DESIGN_NAME_PATTERN.match(base_name)
89
+ if match:
90
+ prefix = match.group(1)
91
+ suffix = match.group(2)
92
+ if prefix.lower() == old_lower:
93
+ return f"{new_design_name}_{suffix}{ext}"
94
+
95
+ return filename
96
+
97
+ @classmethod
98
+ def adapt_content(
99
+ cls,
100
+ content: str,
101
+ old_design_name: str,
102
+ new_design_name: str,
103
+ ) -> str:
104
+ """
105
+ Adapt SystemVerilog content from old design name to new design name.
106
+
107
+ Args:
108
+ content: Original SystemVerilog content
109
+ old_design_name: Old design name to replace
110
+ new_design_name: New design name to use
111
+
112
+ Returns:
113
+ Adapted content
114
+ """
115
+ if not old_design_name or not new_design_name or old_design_name == new_design_name:
116
+ return content
117
+
118
+ result = content
119
+
120
+ patterns = [
121
+ (
122
+ rf"\b{re.escape(old_design_name)}_([a-zA-Z_][a-zA-Z0-9_]*)\b",
123
+ f"{new_design_name}_\\1",
124
+ ),
125
+ (
126
+ rf"\bclass\s+{re.escape(old_design_name)}_",
127
+ f"class {new_design_name}_",
128
+ ),
129
+ (
130
+ rf"`uvm_component_utils\(\s*{re.escape(old_design_name)}_",
131
+ f"`uvm_component_utils({new_design_name}_",
132
+ ),
133
+ (
134
+ rf"`uvm_object_utils\(\s*{re.escape(old_design_name)}_",
135
+ f"`uvm_object_utils({new_design_name}_",
136
+ ),
137
+ (
138
+ rf"virtual\s+{re.escape(old_design_name)}_if\s+",
139
+ f"virtual {new_design_name}_if ",
140
+ ),
141
+ (
142
+ rf"{re.escape(old_design_name)}_if::type_id",
143
+ f"{new_design_name}_if::type_id",
144
+ ),
145
+ ]
146
+
147
+ for pattern, replacement in patterns:
148
+ result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
149
+
150
+ result = re.sub(
151
+ rf"\b{re.escape(old_design_name)}\b",
152
+ new_design_name,
153
+ result,
154
+ )
155
+
156
+ return result
157
+
158
+ @classmethod
159
+ def normalize_name(cls, name: str) -> str:
160
+ """
161
+ Normalize a design name to a standard format.
162
+
163
+ - Converts to snake_case
164
+ - Removes special characters
165
+ - Ensures valid SystemVerilog identifier
166
+
167
+ Args:
168
+ name: Original name
169
+
170
+ Returns:
171
+ Normalized name
172
+ """
173
+ if not name:
174
+ return "design"
175
+
176
+ result = name.strip()
177
+
178
+ result = re.sub(r"[^a-zA-Z0-9_]", "_", result)
179
+
180
+ result = re.sub(r"_+", "_", result)
181
+
182
+ result = result.strip("_")
183
+
184
+ if not result:
185
+ return "design"
186
+
187
+ if not result[0].isalpha() and result[0] != "_":
188
+ result = "_" + result
189
+
190
+ return result.lower()
191
+
192
+
193
+ class MLGenerationModel:
194
+ """
195
+ ML-based generation model (legacy name for EnhancedMLGenerationModel).
196
+
197
+ This class exists for backward compatibility with tests and code
198
+ that imports MLGenerationModel. Use EnhancedMLGenerationModel directly
199
+ for new code.
200
+ """
201
+
202
+ def __new__(cls, *args, **kwargs):
203
+ from src.models.enhanced_ml_model import EnhancedMLGenerationModel
204
+ return EnhancedMLGenerationModel(*args, **kwargs)
src/models/semantic_encoder.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+ import numpy as np
4
+ from dataclasses import dataclass, field
5
+
6
+ logger = logging.getLogger("uvmgen.ml.semantic")
7
+
8
+
9
+ @dataclass
10
+ class SemanticEmbedding:
11
+ vector: np.ndarray
12
+ text: str
13
+ metadata: Dict[str, Any] = field(default_factory=dict)
14
+ embedding_type: str = "code"
15
+
16
+ @property
17
+ def dim(self) -> int:
18
+ return len(self.vector)
19
+
20
+ def to_dict(self) -> Dict[str, Any]:
21
+ return {
22
+ "vector": self.vector.tolist(),
23
+ "text": self.text,
24
+ "metadata": self.metadata,
25
+ "embedding_type": self.embedding_type,
26
+ "dim": self.dim,
27
+ }
28
+
29
+ @classmethod
30
+ def from_dict(cls, d: Dict[str, Any]) -> "SemanticEmbedding":
31
+ return cls(
32
+ vector=np.array(d["vector"], dtype=np.float32),
33
+ text=d["text"],
34
+ metadata=d.get("metadata", {}),
35
+ embedding_type=d.get("embedding_type", "code"),
36
+ )
37
+
38
+
39
+ class SemanticCodeEncoder:
40
+ _instance: Optional["SemanticCodeEncoder"] = None
41
+ _model = None
42
+ _tokenizer = None
43
+ _model_name: str = "microsoft/codebert-base"
44
+ _device: str = "cpu"
45
+ _initialized: bool = False
46
+
47
+ def __new__(cls, *args, **kwargs):
48
+ if cls._instance is None:
49
+ cls._instance = super().__new__(cls)
50
+ return cls._instance
51
+
52
+ def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
53
+ if self._initialized:
54
+ return
55
+
56
+ if model_name:
57
+ self._model_name = model_name
58
+ if device:
59
+ self._device = device
60
+
61
+ self._initialized = False
62
+ self._model = None
63
+ self._tokenizer = None
64
+
65
+ def _load_model(self):
66
+ if self._initialized and self._model is not None:
67
+ return
68
+
69
+ try:
70
+ import torch
71
+ from transformers import AutoTokenizer, AutoModel
72
+
73
+ if self._device == "auto":
74
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ logger.info("Loading semantic encoder: %s on %s", self._model_name, self._device)
77
+
78
+ self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
79
+ self._model = AutoModel.from_pretrained(self._model_name)
80
+ self._model.to(self._device)
81
+ self._model.eval()
82
+
83
+ self._initialized = True
84
+ logger.info("Semantic encoder loaded successfully")
85
+
86
+ except ImportError as e:
87
+ logger.warning(
88
+ "Could not load semantic encoder (missing dependencies: %s). "
89
+ "Using fallback TF-IDF-based similarity.",
90
+ e,
91
+ )
92
+ self._initialized = False
93
+ self._model = None
94
+ self._tokenizer = None
95
+ except Exception as e:
96
+ logger.warning(
97
+ "Could not load semantic encoder (%s). Using fallback similarity.",
98
+ e,
99
+ )
100
+ self._initialized = False
101
+ self._model = None
102
+ self._tokenizer = None
103
+
104
+ def is_available(self) -> bool:
105
+ self._load_model()
106
+ return self._initialized and self._model is not None
107
+
108
+ def encode(
109
+ self,
110
+ text: str,
111
+ embedding_type: str = "code",
112
+ metadata: Optional[Dict[str, Any]] = None,
113
+ ) -> SemanticEmbedding:
114
+ self._load_model()
115
+
116
+ if not self.is_available():
117
+ return self._fallback_encode(text, embedding_type, metadata)
118
+
119
+ try:
120
+ import torch
121
+
122
+ inputs = self._tokenizer(
123
+ text,
124
+ return_tensors="pt",
125
+ truncation=True,
126
+ max_length=512,
127
+ padding=True,
128
+ )
129
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
130
+
131
+ with torch.no_grad():
132
+ outputs = self._model(**inputs)
133
+ embeddings = outputs.last_hidden_state[:, 0, :]
134
+ embeddings = embeddings.cpu().numpy().squeeze()
135
+
136
+ embeddings = embeddings / (np.linalg.norm(embeddings) + 1e-8)
137
+
138
+ return SemanticEmbedding(
139
+ vector=embeddings.astype(np.float32),
140
+ text=text,
141
+ metadata=metadata or {},
142
+ embedding_type=embedding_type,
143
+ )
144
+
145
+ except Exception as e:
146
+ logger.warning("Error encoding with neural model: %s. Using fallback.", e)
147
+ return self._fallback_encode(text, embedding_type, metadata)
148
+
149
+ def encode_batch(
150
+ self,
151
+ texts: List[str],
152
+ embedding_type: str = "code",
153
+ metadata_list: Optional[List[Dict[str, Any]]] = None,
154
+ ) -> List[SemanticEmbedding]:
155
+ self._load_model()
156
+
157
+ if not self.is_available():
158
+ return [
159
+ self._fallback_encode(text, embedding_type, metadata_list[i] if metadata_list else None)
160
+ for i, text in enumerate(texts)
161
+ ]
162
+
163
+ try:
164
+ import torch
165
+
166
+ inputs = self._tokenizer(
167
+ texts,
168
+ return_tensors="pt",
169
+ truncation=True,
170
+ max_length=512,
171
+ padding=True,
172
+ )
173
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
174
+
175
+ with torch.no_grad():
176
+ outputs = self._model(**inputs)
177
+ embeddings = outputs.last_hidden_state[:, 0, :]
178
+ embeddings = embeddings.cpu().numpy()
179
+
180
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-8
181
+ embeddings = embeddings / norms
182
+
183
+ results = []
184
+ for i, emb in enumerate(embeddings):
185
+ results.append(
186
+ SemanticEmbedding(
187
+ vector=emb.astype(np.float32),
188
+ text=texts[i],
189
+ metadata=metadata_list[i] if metadata_list else {},
190
+ embedding_type=embedding_type,
191
+ )
192
+ )
193
+ return results
194
+
195
+ except Exception as e:
196
+ logger.warning("Error batch encoding: %s. Using fallback.", e)
197
+ return [
198
+ self._fallback_encode(text, embedding_type, metadata_list[i] if metadata_list else None)
199
+ for i, text in enumerate(texts)
200
+ ]
201
+
202
+ def _fallback_encode(
203
+ self,
204
+ text: str,
205
+ embedding_type: str = "code",
206
+ metadata: Optional[Dict[str, Any]] = None,
207
+ ) -> SemanticEmbedding:
208
+ words = text.lower().split()
209
+ vocab = sorted(set(words))
210
+ vec = np.zeros(len(vocab), dtype=np.float32)
211
+
212
+ for w in words:
213
+ if w in vocab:
214
+ vec[vocab.index(w)] += 1
215
+
216
+ norm = np.linalg.norm(vec)
217
+ if norm > 0:
218
+ vec = vec / norm
219
+
220
+ pad_size = 128 - len(vec)
221
+ if pad_size > 0:
222
+ vec = np.pad(vec, (0, pad_size), mode="constant")
223
+ elif pad_size < 0:
224
+ vec = vec[:128]
225
+
226
+ return SemanticEmbedding(
227
+ vector=vec.astype(np.float32),
228
+ text=text,
229
+ metadata=metadata or {},
230
+ embedding_type=embedding_type,
231
+ )
232
+
233
+ def similarity(self, emb1: SemanticEmbedding, emb2: SemanticEmbedding) -> float:
234
+ if len(emb1.vector) != len(emb2.vector):
235
+ min_len = min(len(emb1.vector), len(emb2.vector))
236
+ v1 = emb1.vector[:min_len]
237
+ v2 = emb2.vector[:min_len]
238
+ else:
239
+ v1 = emb1.vector
240
+ v2 = emb2.vector
241
+
242
+ norm1 = np.linalg.norm(v1)
243
+ norm2 = np.linalg.norm(v2)
244
+
245
+ if norm1 < 1e-8 or norm2 < 1e-8:
246
+ return 0.0
247
+
248
+ return float(np.dot(v1, v2) / (norm1 * norm2))
249
+
250
+ def batch_similarity(
251
+ self,
252
+ query_emb: SemanticEmbedding,
253
+ embeddings: List[SemanticEmbedding],
254
+ ) -> List[Tuple[int, float]]:
255
+ if not embeddings:
256
+ return []
257
+
258
+ q_vec = query_emb.vector
259
+ q_norm = np.linalg.norm(q_vec)
260
+
261
+ if q_norm < 1e-8:
262
+ return [(i, 0.0) for i in range(len(embeddings))]
263
+
264
+ results = []
265
+ for i, emb in enumerate(embeddings):
266
+ e_vec = emb.vector
267
+
268
+ if len(e_vec) != len(q_vec):
269
+ min_len = min(len(q_vec), len(e_vec))
270
+ qv = q_vec[:min_len]
271
+ ev = e_vec[:min_len]
272
+ else:
273
+ qv = q_vec
274
+ ev = e_vec
275
+
276
+ e_norm = np.linalg.norm(ev)
277
+ if e_norm < 1e-8:
278
+ results.append((i, 0.0))
279
+ continue
280
+
281
+ sim = float(np.dot(qv, ev) / (q_norm * e_norm))
282
+ results.append((i, sim))
283
+
284
+ return results
285
+
286
+
287
+ def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
288
+ norm1 = np.linalg.norm(v1)
289
+ norm2 = np.linalg.norm(v2)
290
+
291
+ if norm1 < 1e-8 or norm2 < 1e-8:
292
+ return 0.0
293
+
294
+ return float(np.dot(v1, v2) / (norm1 * norm2))
src/pipeline.py CHANGED
@@ -55,7 +55,7 @@ class TBPipeline:
55
  model_type = ml_cfg.model_type
56
  self.logger.info("ML generation enabled, model_type=%s", model_type)
57
 
58
- if model_type in ("ml", "hybrid"):
59
  ml_model_config = MLModelConfig(
60
  similarity_threshold=ml_cfg.similarity_threshold,
61
  auto_learn=ml_cfg.auto_learn,
@@ -68,8 +68,19 @@ class TBPipeline:
68
  config=ml_model_config,
69
  templates_dir=self.cfg.generation.templates_dir,
70
  strict_validation=True,
 
 
 
 
 
71
  )
72
  self.logger.info("Created EnhancedMLGenerationModel with index size: %d", len(model.index))
 
 
 
 
 
 
73
  return model
74
 
75
  self.logger.info("Falling back to template model")
 
55
  model_type = ml_cfg.model_type
56
  self.logger.info("ML generation enabled, model_type=%s", model_type)
57
 
58
+ if model_type in ("ml", "hybrid", "llm", "semantic"):
59
  ml_model_config = MLModelConfig(
60
  similarity_threshold=ml_cfg.similarity_threshold,
61
  auto_learn=ml_cfg.auto_learn,
 
68
  config=ml_model_config,
69
  templates_dir=self.cfg.generation.templates_dir,
70
  strict_validation=True,
71
+ use_llm=ml_cfg.use_llm,
72
+ use_semantic_encoder=ml_cfg.use_semantic_encoder,
73
+ use_learning=ml_cfg.use_learning,
74
+ llm_model_name=ml_cfg.llm_model_name,
75
+ learning_storage_path=ml_cfg.learning_storage_path,
76
  )
77
  self.logger.info("Created EnhancedMLGenerationModel with index size: %d", len(model.index))
78
+
79
+ if model_type == "llm":
80
+ self.logger.info("LLM mode: will prioritize LLM generation")
81
+ elif model_type == "semantic":
82
+ self.logger.info("Semantic mode: will use semantic embeddings for similarity")
83
+
84
  return model
85
 
86
  self.logger.info("Falling back to template model")