feat: Add actual AI/ML capabilities with LLM, semantic embeddings, and reinforcement learning
Browse files- Add semantic_encoder.py: CodeBERT-based semantic code embeddings with fallback TF-IDF
- Add llm_generator.py: LLM-based code generation (CodeGen, CodeT5, StarCoder, etc.)
- Add learning_module.py: Reinforcement learning + pattern learning from validation feedback
- Update enhanced_ml_model.py:
- Add last_retrieval property
- Add learning module strategy selection
- Add semantic similarity enhancement
- Add LLM generation as strategy option
- Record validation feedback to learning module
- Update config.py: Add MLConfig options for LLM, semantic encoder, learning module
- Update pipeline.py: Pass new ML config options
- Update requirements.txt: Add torch, transformers, sentence-transformers, accelerate
- Recreate ml_generation_model.py with MLModelConfig, NameNormalizer, RetrievalInfo
Key AI/ML capabilities:
1. LLM code generation with few-shot UVM examples
2. Semantic code embeddings for intelligent similarity
3. Reinforcement learning (Q-learning) from validation feedback
4. Pattern learning from success/failure patterns
5. Auto-improving generation strategy selection
6. Graceful fallback when torch/transformers not available
- requirements.txt +6 -0
- src/config.py +17 -2
- src/models/enhanced_ml_model.py +326 -35
- src/models/learning_module.py +572 -0
- src/models/llm_generator.py +557 -0
- src/models/ml_generation_model.py +204 -0
- src/models/semantic_encoder.py +294 -0
- src/pipeline.py +12 -1
|
@@ -7,3 +7,9 @@ gunicorn>=23.0
|
|
| 7 |
|
| 8 |
numpy>=1.21.0
|
| 9 |
scikit-learn>=1.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
numpy>=1.21.0
|
| 9 |
scikit-learn>=1.0.0
|
| 10 |
+
scipy>=1.7.0
|
| 11 |
+
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
transformers>=4.35.0
|
| 14 |
+
sentence-transformers>=2.2.0
|
| 15 |
+
accelerate>=0.24.0
|
|
@@ -86,15 +86,30 @@ class AutoTrainConfig(BaseModel):
|
|
| 86 |
|
| 87 |
|
| 88 |
class MLConfig(BaseModel):
|
| 89 |
-
"""Configuration for ML-augmented generation."""
|
| 90 |
enabled: bool = False
|
| 91 |
-
model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid)$")
|
| 92 |
similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
|
| 93 |
auto_learn: bool = True
|
| 94 |
index_path: Optional[str] = None
|
| 95 |
top_k_retrieval: int = Field(default=3, ge=1, le=10)
|
| 96 |
fallback_to_templates: bool = True
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
class PipelineConfig(BaseModel):
|
| 100 |
generation: GenerationConfig = GenerationConfig()
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
class MLConfig(BaseModel):
|
| 89 |
+
"""Configuration for AI/ML-augmented generation with actual learning capabilities."""
|
| 90 |
enabled: bool = False
|
| 91 |
+
model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid|llm|semantic)$")
|
| 92 |
similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
|
| 93 |
auto_learn: bool = True
|
| 94 |
index_path: Optional[str] = None
|
| 95 |
top_k_retrieval: int = Field(default=3, ge=1, le=10)
|
| 96 |
fallback_to_templates: bool = True
|
| 97 |
|
| 98 |
+
use_llm: bool = True
|
| 99 |
+
llm_model_name: Optional[str] = None
|
| 100 |
+
llm_max_tokens: int = Field(default=1024, ge=64, le=4096)
|
| 101 |
+
llm_temperature: float = Field(default=0.2, ge=0.0, le=1.0)
|
| 102 |
+
llm_use_few_shot: bool = True
|
| 103 |
+
|
| 104 |
+
use_semantic_encoder: bool = True
|
| 105 |
+
semantic_model_name: str = "microsoft/codebert-base"
|
| 106 |
+
|
| 107 |
+
use_learning: bool = True
|
| 108 |
+
learning_storage_path: Optional[str] = None
|
| 109 |
+
learning_rate: float = Field(default=0.1, ge=0.001, le=1.0)
|
| 110 |
+
reinforcement_discount: float = Field(default=0.9, ge=0.0, le=1.0)
|
| 111 |
+
exploration_epsilon: float = Field(default=0.05, ge=0.0, le=0.5)
|
| 112 |
+
|
| 113 |
|
| 114 |
class PipelineConfig(BaseModel):
|
| 115 |
generation: GenerationConfig = GenerationConfig()
|
|
@@ -1,18 +1,20 @@
|
|
| 1 |
"""
|
| 2 |
-
Industry-level
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
| 4 |
- Spec-aware adaptation
|
| 5 |
- Code validation
|
| 6 |
- Multi-level fallback
|
| 7 |
- Comprehensive reporting
|
| 8 |
|
| 9 |
-
This model
|
| 10 |
-
1.
|
| 11 |
-
2.
|
| 12 |
-
3.
|
| 13 |
-
4.
|
| 14 |
-
5.
|
| 15 |
-
6. Detailed generation reports
|
| 16 |
"""
|
| 17 |
|
| 18 |
from __future__ import annotations
|
|
@@ -43,6 +45,16 @@ from src.models.spec_adapter import (
|
|
| 43 |
from src.models.similarity_index import SimilarityIndex, get_global_index
|
| 44 |
from src.models.template_model import TemplateModel
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
logger = logging.getLogger("uvmgen")
|
| 47 |
|
| 48 |
|
|
@@ -50,9 +62,12 @@ class GenerationSource(Enum):
|
|
| 50 |
RETRIEVAL_HIGH_CONF = "retrieval_high_confidence"
|
| 51 |
RETRIEVAL_MEDIUM_CONF = "retrieval_medium_confidence"
|
| 52 |
RETRIEVAL_LOW_CONF = "retrieval_low_confidence"
|
|
|
|
|
|
|
| 53 |
TEMPLATE_FALLBACK = "template_fallback"
|
| 54 |
BLENDED = "blended"
|
| 55 |
HYBRID = "hybrid"
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
@dataclass
|
|
@@ -117,15 +132,21 @@ class RetrievalCandidate:
|
|
| 117 |
|
| 118 |
class EnhancedMLGenerationModel(GenerationModel):
|
| 119 |
"""
|
| 120 |
-
Industry-level
|
| 121 |
-
|
| 122 |
-
Key features:
|
| 123 |
-
1.
|
| 124 |
-
2.
|
| 125 |
-
3.
|
| 126 |
-
4.
|
| 127 |
-
5.
|
| 128 |
-
6.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
def __init__(
|
|
@@ -135,6 +156,11 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 135 |
index: Optional[SimilarityIndex] = None,
|
| 136 |
templates_dir: Optional[str] = None,
|
| 137 |
strict_validation: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
):
|
| 139 |
super().__init__(name)
|
| 140 |
self.config = config or MLModelConfig()
|
|
@@ -144,6 +170,27 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 144 |
self._strict_validation = strict_validation
|
| 145 |
self._metadata: Dict[str, Any] = {}
|
| 146 |
self._last_result: Optional[GenerationResult] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
@property
|
| 149 |
def index(self) -> SimilarityIndex:
|
|
@@ -163,6 +210,23 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 163 |
)
|
| 164 |
return self._template_model
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
def train(self, specs: List[DesignSpec]) -> Dict[str, Any]:
|
| 167 |
"""Train the model by adding specs to the similarity index."""
|
| 168 |
from src.features.extractors import RichSpecFeatureExtractor
|
|
@@ -226,20 +290,21 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 226 |
extra_seqs: Optional[List[str]] = None,
|
| 227 |
) -> Dict[str, str]:
|
| 228 |
"""
|
| 229 |
-
Generate testbench with
|
| 230 |
-
|
| 231 |
-
Workflow:
|
| 232 |
-
1.
|
| 233 |
-
2.
|
| 234 |
-
3.
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
"""
|
| 244 |
if not self._is_trained:
|
| 245 |
self.train([])
|
|
@@ -249,12 +314,41 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 249 |
query_fv = extractor.extract(spec)
|
| 250 |
query_dict = self._spec_to_dict(spec)
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
similar = self.index.search(
|
| 253 |
query_fv,
|
| 254 |
top_k=self.config.top_k_retrieval,
|
| 255 |
min_similarity=0.3,
|
| 256 |
)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
logger.info(
|
| 259 |
"Enhanced ML generation: found %d similar specs, best score: %.3f",
|
| 260 |
len(similar), similar[0].similarity if similar else 0.0
|
|
@@ -262,21 +356,29 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 262 |
|
| 263 |
result: Optional[GenerationResult] = None
|
| 264 |
|
| 265 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
result = self._try_retrieval_generation(
|
| 267 |
similar, query_fv, query_dict, spec, cfg
|
| 268 |
)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
if (
|
| 271 |
result is None
|
| 272 |
or (self._strict_validation and not result.passed)
|
| 273 |
and self.config.fallback_to_templates
|
| 274 |
):
|
| 275 |
if result is None:
|
| 276 |
-
logger.info("No valid
|
| 277 |
else:
|
| 278 |
logger.warning(
|
| 279 |
-
"
|
| 280 |
result.validation_report.total_errors if result.validation_report else 0
|
| 281 |
)
|
| 282 |
result = self._generate_with_fallback(spec, cfg, extra_seqs, result)
|
|
@@ -284,6 +386,15 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 284 |
if result is None:
|
| 285 |
raise RuntimeError("All generation strategies failed")
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
if self.config.auto_learn and result.passed:
|
| 288 |
self._learn_from_result(result, query_fv, query_dict)
|
| 289 |
|
|
@@ -292,6 +403,186 @@ class EnhancedMLGenerationModel(GenerationModel):
|
|
| 292 |
|
| 293 |
return result.generated_files
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
def _try_retrieval_generation(
|
| 296 |
self,
|
| 297 |
similar: List[Any],
|
|
|
|
| 1 |
"""
|
| 2 |
+
Industry-level AI/ML generation model with:
|
| 3 |
+
- LLM-based code generation (CodeGen, CodeT5, StarCoder)
|
| 4 |
+
- Semantic code embeddings for intelligent similarity
|
| 5 |
+
- Reinforcement learning from validation feedback
|
| 6 |
+
- Multi-strategy retrieval (protocol-first, semantic, text)
|
| 7 |
- Spec-aware adaptation
|
| 8 |
- Code validation
|
| 9 |
- Multi-level fallback
|
| 10 |
- Comprehensive reporting
|
| 11 |
|
| 12 |
+
This model uses actual AI/ML:
|
| 13 |
+
1. Neural semantic embeddings (CodeBERT) for similarity
|
| 14 |
+
2. LLM generation (CodeGen, CodeT5) for actual code generation
|
| 15 |
+
3. Reinforcement learning that learns from validation feedback
|
| 16 |
+
4. Pattern learning from success/failure patterns
|
| 17 |
+
5. Auto-improving generation strategies
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
|
|
| 45 |
from src.models.similarity_index import SimilarityIndex, get_global_index
|
| 46 |
from src.models.template_model import TemplateModel
|
| 47 |
|
| 48 |
+
try:
|
| 49 |
+
from src.models.semantic_encoder import SemanticCodeEncoder, SemanticEmbedding
|
| 50 |
+
from src.models.llm_generator import LLMCodeGenerator, LLMGenerationResult
|
| 51 |
+
from src.models.learning_module import LearningModule, ValidationFeedback
|
| 52 |
+
|
| 53 |
+
ML_MODULES_AVAILABLE = True
|
| 54 |
+
except ImportError as e:
|
| 55 |
+
logger.warning("Advanced ML modules not available: %s", e)
|
| 56 |
+
ML_MODULES_AVAILABLE = False
|
| 57 |
+
|
| 58 |
logger = logging.getLogger("uvmgen")
|
| 59 |
|
| 60 |
|
|
|
|
| 62 |
RETRIEVAL_HIGH_CONF = "retrieval_high_confidence"
|
| 63 |
RETRIEVAL_MEDIUM_CONF = "retrieval_medium_confidence"
|
| 64 |
RETRIEVAL_LOW_CONF = "retrieval_low_confidence"
|
| 65 |
+
LLM_GENERATION = "llm_generation"
|
| 66 |
+
LLM_FALLBACK = "llm_fallback"
|
| 67 |
TEMPLATE_FALLBACK = "template_fallback"
|
| 68 |
BLENDED = "blended"
|
| 69 |
HYBRID = "hybrid"
|
| 70 |
+
LEARNING_IMPROVED = "learning_improved"
|
| 71 |
|
| 72 |
|
| 73 |
@dataclass
|
|
|
|
| 132 |
|
| 133 |
class EnhancedMLGenerationModel(GenerationModel):
|
| 134 |
"""
|
| 135 |
+
Industry-level AI/ML generation model with actual learning capabilities.
|
| 136 |
+
|
| 137 |
+
Key AI/ML features:
|
| 138 |
+
1. LLM-based code generation (CodeGen, CodeT5, StarCoder)
|
| 139 |
+
2. Semantic code embeddings (CodeBERT) for intelligent similarity
|
| 140 |
+
3. Reinforcement learning from validation feedback
|
| 141 |
+
4. Pattern learning from success/failure patterns
|
| 142 |
+
5. Multi-strategy retrieval with intelligent selection
|
| 143 |
+
6. Auto-improving generation strategies
|
| 144 |
+
|
| 145 |
+
Traditional features:
|
| 146 |
+
- Spec-aware adaptation with signal/register mapping
|
| 147 |
+
- Pre-validation before output
|
| 148 |
+
- Multi-level fallback strategies
|
| 149 |
+
- Comprehensive reporting and audit trail
|
| 150 |
"""
|
| 151 |
|
| 152 |
def __init__(
|
|
|
|
| 156 |
index: Optional[SimilarityIndex] = None,
|
| 157 |
templates_dir: Optional[str] = None,
|
| 158 |
strict_validation: bool = True,
|
| 159 |
+
use_llm: bool = True,
|
| 160 |
+
use_semantic_encoder: bool = True,
|
| 161 |
+
use_learning: bool = True,
|
| 162 |
+
llm_model_name: Optional[str] = None,
|
| 163 |
+
learning_storage_path: Optional[str] = None,
|
| 164 |
):
|
| 165 |
super().__init__(name)
|
| 166 |
self.config = config or MLModelConfig()
|
|
|
|
| 170 |
self._strict_validation = strict_validation
|
| 171 |
self._metadata: Dict[str, Any] = {}
|
| 172 |
self._last_result: Optional[GenerationResult] = None
|
| 173 |
+
self._last_retrieval: Optional[Any] = None
|
| 174 |
+
|
| 175 |
+
self._use_llm = use_llm and ML_MODULES_AVAILABLE
|
| 176 |
+
self._use_semantic = use_semantic_encoder and ML_MODULES_AVAILABLE
|
| 177 |
+
self._use_learning = use_learning and ML_MODULES_AVAILABLE
|
| 178 |
+
|
| 179 |
+
self._llm_generator: Optional[LLMCodeGenerator] = None
|
| 180 |
+
self._semantic_encoder: Optional[SemanticCodeEncoder] = None
|
| 181 |
+
self._learning_module: Optional[LearningModule] = None
|
| 182 |
+
|
| 183 |
+
if self._use_llm:
|
| 184 |
+
self._llm_generator = LLMCodeGenerator(model_name=llm_model_name)
|
| 185 |
+
logger.info("LLM generator enabled: %s", llm_model_name or "default")
|
| 186 |
+
|
| 187 |
+
if self._use_semantic:
|
| 188 |
+
self._semantic_encoder = SemanticCodeEncoder()
|
| 189 |
+
logger.info("Semantic encoder enabled")
|
| 190 |
+
|
| 191 |
+
if self._use_learning:
|
| 192 |
+
self._learning_module = LearningModule(storage_path=learning_storage_path)
|
| 193 |
+
logger.info("Learning module enabled")
|
| 194 |
|
| 195 |
@property
|
| 196 |
def index(self) -> SimilarityIndex:
|
|
|
|
| 210 |
)
|
| 211 |
return self._template_model
|
| 212 |
|
| 213 |
+
@property
|
| 214 |
+
def last_retrieval(self) -> Optional[Any]:
|
| 215 |
+
"""Get information about the last retrieval operation."""
|
| 216 |
+
from src.models.ml_generation_model import RetrievalInfo
|
| 217 |
+
|
| 218 |
+
if self._last_retrieval is not None:
|
| 219 |
+
return self._last_retrieval
|
| 220 |
+
|
| 221 |
+
if self._last_result is not None:
|
| 222 |
+
return RetrievalInfo(
|
| 223 |
+
used_similarity=(self._last_result.similar_specs_found > 0),
|
| 224 |
+
similar_specs=self._last_result.similar_specs_found,
|
| 225 |
+
best_score=self._last_result.best_match_score,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return RetrievalInfo(used_similarity=False, similar_specs=0, best_score=0.0)
|
| 229 |
+
|
| 230 |
def train(self, specs: List[DesignSpec]) -> Dict[str, Any]:
|
| 231 |
"""Train the model by adding specs to the similarity index."""
|
| 232 |
from src.features.extractors import RichSpecFeatureExtractor
|
|
|
|
| 290 |
extra_seqs: Optional[List[str]] = None,
|
| 291 |
) -> Dict[str, str]:
|
| 292 |
"""
|
| 293 |
+
Generate testbench with AI/ML-powered generation and fallback.
|
| 294 |
+
|
| 295 |
+
AI/ML Workflow:
|
| 296 |
+
1. Use learning module to select best generation strategy
|
| 297 |
+
2. Try semantic similarity search (if semantic encoder available)
|
| 298 |
+
3. Try LLM-based code generation (if LLM available)
|
| 299 |
+
4. Try traditional retrieval-based generation
|
| 300 |
+
5. Fallback to templates
|
| 301 |
+
6. Record validation feedback to learning module
|
| 302 |
+
7. Auto-learn from successful generation
|
| 303 |
+
|
| 304 |
+
Traditional features:
|
| 305 |
+
- Spec-aware adaptation
|
| 306 |
+
- Pre-validation before writing
|
| 307 |
+
- Multi-level fallback
|
| 308 |
"""
|
| 309 |
if not self._is_trained:
|
| 310 |
self.train([])
|
|
|
|
| 314 |
query_fv = extractor.extract(spec)
|
| 315 |
query_dict = self._spec_to_dict(spec)
|
| 316 |
|
| 317 |
+
protocol = query_dict.get("protocol", "unknown")
|
| 318 |
+
|
| 319 |
+
available_strategies = ["retrieval"]
|
| 320 |
+
if self._use_llm and self._llm_generator:
|
| 321 |
+
available_strategies.append("llm")
|
| 322 |
+
available_strategies.append("template")
|
| 323 |
+
|
| 324 |
+
selected_strategy = "retrieval"
|
| 325 |
+
strategy_confidence = 0.5
|
| 326 |
+
|
| 327 |
+
if self._use_learning and self._learning_module:
|
| 328 |
+
selected_strategy, strategy_confidence = (
|
| 329 |
+
self._learning_module.select_best_generation_strategy(
|
| 330 |
+
spec_dict=query_dict,
|
| 331 |
+
file_type="testbench",
|
| 332 |
+
available_sources=available_strategies,
|
| 333 |
+
)
|
| 334 |
+
)
|
| 335 |
+
logger.info(
|
| 336 |
+
"Learning module selected strategy: '%s' (confidence: %.2f)",
|
| 337 |
+
selected_strategy,
|
| 338 |
+
strategy_confidence,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
similar = self.index.search(
|
| 342 |
query_fv,
|
| 343 |
top_k=self.config.top_k_retrieval,
|
| 344 |
min_similarity=0.3,
|
| 345 |
)
|
| 346 |
|
| 347 |
+
if self._use_semantic and self._semantic_encoder and similar:
|
| 348 |
+
similar = self._enhance_with_semantic_similarity(
|
| 349 |
+
similar, query_dict
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
logger.info(
|
| 353 |
"Enhanced ML generation: found %d similar specs, best score: %.3f",
|
| 354 |
len(similar), similar[0].similarity if similar else 0.0
|
|
|
|
| 356 |
|
| 357 |
result: Optional[GenerationResult] = None
|
| 358 |
|
| 359 |
+
if selected_strategy == "llm" and self._use_llm and self._llm_generator:
|
| 360 |
+
logger.info("Trying LLM-based generation (selected by learning module)")
|
| 361 |
+
result = self._try_llm_generation(query_dict, spec, cfg)
|
| 362 |
+
|
| 363 |
+
if result is None and similar and similar[0].similarity >= self.config.similarity_threshold:
|
| 364 |
result = self._try_retrieval_generation(
|
| 365 |
similar, query_fv, query_dict, spec, cfg
|
| 366 |
)
|
| 367 |
|
| 368 |
+
if result is None and self._use_llm and self._llm_generator:
|
| 369 |
+
logger.info("Trying LLM-based generation as fallback")
|
| 370 |
+
result = self._try_llm_generation(query_dict, spec, cfg)
|
| 371 |
+
|
| 372 |
if (
|
| 373 |
result is None
|
| 374 |
or (self._strict_validation and not result.passed)
|
| 375 |
and self.config.fallback_to_templates
|
| 376 |
):
|
| 377 |
if result is None:
|
| 378 |
+
logger.info("No valid ML/LLM candidate, falling back to templates")
|
| 379 |
else:
|
| 380 |
logger.warning(
|
| 381 |
+
"LLM/retrieval generation failed validation (errors: %d), falling back to templates",
|
| 382 |
result.validation_report.total_errors if result.validation_report else 0
|
| 383 |
)
|
| 384 |
result = self._generate_with_fallback(spec, cfg, extra_seqs, result)
|
|
|
|
| 386 |
if result is None:
|
| 387 |
raise RuntimeError("All generation strategies failed")
|
| 388 |
|
| 389 |
+
if self._use_learning and self._learning_module and result.validation_report:
|
| 390 |
+
logger.info("Recording validation feedback to learning module")
|
| 391 |
+
self._learning_module.record_feedback(
|
| 392 |
+
design_name=spec.design_name,
|
| 393 |
+
generation_source=result.source.value,
|
| 394 |
+
spec_dict=query_dict,
|
| 395 |
+
validation_results=result.validation_report.to_dict(),
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
if self.config.auto_learn and result.passed:
|
| 399 |
self._learn_from_result(result, query_fv, query_dict)
|
| 400 |
|
|
|
|
| 403 |
|
| 404 |
return result.generated_files
|
| 405 |
|
| 406 |
+
def _enhance_with_semantic_similarity(
|
| 407 |
+
self,
|
| 408 |
+
similar: List[Any],
|
| 409 |
+
query_dict: Dict[str, Any],
|
| 410 |
+
) -> List[Any]:
|
| 411 |
+
"""Enhance similarity scores using semantic code embeddings."""
|
| 412 |
+
if not self._semantic_encoder or not self._semantic_encoder.is_available():
|
| 413 |
+
return similar
|
| 414 |
+
|
| 415 |
+
try:
|
| 416 |
+
query_text = self._spec_dict_to_text(query_dict)
|
| 417 |
+
query_emb = self._semantic_encoder.encode(
|
| 418 |
+
text=query_text,
|
| 419 |
+
embedding_type="spec",
|
| 420 |
+
metadata=query_dict,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
for item in similar:
|
| 424 |
+
spec_text = self._spec_dict_to_text(item.spec_dict)
|
| 425 |
+
cand_emb = self._semantic_encoder.encode(
|
| 426 |
+
text=spec_text,
|
| 427 |
+
embedding_type="spec",
|
| 428 |
+
metadata=item.spec_dict,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
semantic_sim = self._semantic_encoder.similarity(query_emb, cand_emb)
|
| 432 |
+
|
| 433 |
+
original_sim = item.similarity
|
| 434 |
+
item.similarity = (original_sim * 0.6) + (semantic_sim * 0.4)
|
| 435 |
+
|
| 436 |
+
logger.debug(
|
| 437 |
+
"Semantic enhancement: original=%.3f, semantic=%.3f, combined=%.3f",
|
| 438 |
+
original_sim, semantic_sim, item.similarity
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
similar = sorted(similar, key=lambda x: x.similarity, reverse=True)
|
| 442 |
+
|
| 443 |
+
except Exception as e:
|
| 444 |
+
logger.warning("Semantic similarity enhancement failed: %s", e)
|
| 445 |
+
|
| 446 |
+
return similar
|
| 447 |
+
|
| 448 |
+
def _spec_dict_to_text(self, spec_dict: Dict[str, Any]) -> str:
|
| 449 |
+
"""Convert spec dict to text for semantic encoding."""
|
| 450 |
+
parts = []
|
| 451 |
+
parts.append(f"design: {spec_dict.get('design_name', 'unknown')}")
|
| 452 |
+
parts.append(f"protocol: {spec_dict.get('protocol', 'unknown')}")
|
| 453 |
+
|
| 454 |
+
signals = spec_dict.get("signals", [])
|
| 455 |
+
if signals:
|
| 456 |
+
signal_names = [s.get("name", "") for s in signals if isinstance(s, dict)]
|
| 457 |
+
parts.append(f"signals: {', '.join(signal_names[:20])}")
|
| 458 |
+
|
| 459 |
+
registers = spec_dict.get("registers", [])
|
| 460 |
+
if registers:
|
| 461 |
+
reg_names = [r.get("name", "") for r in registers if isinstance(r, dict)]
|
| 462 |
+
parts.append(f"registers: {', '.join(reg_names[:10])}")
|
| 463 |
+
|
| 464 |
+
features = spec_dict.get("features", [])
|
| 465 |
+
if features:
|
| 466 |
+
parts.append(f"features: {', '.join(features[:10])}")
|
| 467 |
+
|
| 468 |
+
return " | ".join(parts)
|
| 469 |
+
|
| 470 |
+
def _try_llm_generation(
|
| 471 |
+
self,
|
| 472 |
+
query_dict: Dict[str, Any],
|
| 473 |
+
spec: DesignSpec,
|
| 474 |
+
cfg: PipelineConfig,
|
| 475 |
+
) -> Optional[GenerationResult]:
|
| 476 |
+
"""
|
| 477 |
+
Try LLM-based code generation.
|
| 478 |
+
|
| 479 |
+
This uses actual AI/ML:
|
| 480 |
+
1. LLM (CodeGen, CodeT5, etc.) generates SystemVerilog code
|
| 481 |
+
2. Uses few-shot examples for UVM patterns
|
| 482 |
+
3. Validates generated code
|
| 483 |
+
4. Falls back to templates if needed
|
| 484 |
+
"""
|
| 485 |
+
if not self._llm_generator:
|
| 486 |
+
return None
|
| 487 |
+
|
| 488 |
+
design_name = spec.design_name.lower()
|
| 489 |
+
|
| 490 |
+
file_types_to_generate = [
|
| 491 |
+
"driver",
|
| 492 |
+
"monitor",
|
| 493 |
+
"agent",
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
generated_files: Dict[str, str] = {}
|
| 497 |
+
llm_results: Dict[str, LLMGenerationResult] = {}
|
| 498 |
+
all_warnings: List[str] = []
|
| 499 |
+
avg_confidence = 0.0
|
| 500 |
+
|
| 501 |
+
for file_type in file_types_to_generate:
|
| 502 |
+
try:
|
| 503 |
+
llm_result = self._llm_generator.generate(
|
| 504 |
+
spec_dict=query_dict,
|
| 505 |
+
file_type=file_type,
|
| 506 |
+
use_few_shot=True,
|
| 507 |
+
max_tokens=1024,
|
| 508 |
+
temperature=0.2,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
llm_results[file_type] = llm_result
|
| 512 |
+
avg_confidence += llm_result.confidence
|
| 513 |
+
all_warnings.extend(llm_result.warnings)
|
| 514 |
+
|
| 515 |
+
file_name = f"{design_name}_{file_type}.sv"
|
| 516 |
+
generated_files[file_name] = llm_result.generated_code
|
| 517 |
+
|
| 518 |
+
logger.info(
|
| 519 |
+
"LLM generated %s (confidence: %.2f, tokens: %d)",
|
| 520 |
+
file_name,
|
| 521 |
+
llm_result.confidence,
|
| 522 |
+
llm_result.tokens_generated,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
except Exception as e:
|
| 526 |
+
logger.warning("LLM generation failed for %s: %s", file_type, e)
|
| 527 |
+
all_warnings.append(f"LLM generation failed for {file_type}: {e}")
|
| 528 |
+
|
| 529 |
+
if not generated_files:
|
| 530 |
+
logger.warning("LLM generated no files, falling back")
|
| 531 |
+
return None
|
| 532 |
+
|
| 533 |
+
if llm_results:
|
| 534 |
+
avg_confidence /= len(llm_results)
|
| 535 |
+
|
| 536 |
+
try:
|
| 537 |
+
template_files = self.template_model.predict(spec, cfg)
|
| 538 |
+
template_contents: Dict[str, str] = {}
|
| 539 |
+
for fname, fpath in template_files.items():
|
| 540 |
+
try:
|
| 541 |
+
template_contents[fname] = Path(fpath).read_text(encoding="utf-8")
|
| 542 |
+
except Exception:
|
| 543 |
+
pass
|
| 544 |
+
|
| 545 |
+
for fname, content in template_contents.items():
|
| 546 |
+
if fname not in generated_files:
|
| 547 |
+
generated_files[fname] = content
|
| 548 |
+
|
| 549 |
+
except Exception as e:
|
| 550 |
+
logger.warning("Could not fill missing files from templates: %s", e)
|
| 551 |
+
|
| 552 |
+
validator = CodeValidator()
|
| 553 |
+
val_report = validator.validate_files(generated_files, query_dict)
|
| 554 |
+
|
| 555 |
+
total_errors = val_report.total_errors
|
| 556 |
+
total_warnings = val_report.total_warnings + len(all_warnings)
|
| 557 |
+
|
| 558 |
+
passed = val_report.overall_passed
|
| 559 |
+
if self._strict_validation:
|
| 560 |
+
passed = passed and (total_errors == 0)
|
| 561 |
+
|
| 562 |
+
generation_source = GenerationSource.LLM_GENERATION
|
| 563 |
+
if avg_confidence < 0.5:
|
| 564 |
+
generation_source = GenerationSource.LLM_FALLBACK
|
| 565 |
+
|
| 566 |
+
result = GenerationResult(
|
| 567 |
+
design_name=spec.design_name,
|
| 568 |
+
source=generation_source,
|
| 569 |
+
passed=passed,
|
| 570 |
+
generated_files=generated_files,
|
| 571 |
+
validation_report=val_report,
|
| 572 |
+
adaptation_plan=None,
|
| 573 |
+
similar_specs_found=0,
|
| 574 |
+
best_match_score=avg_confidence,
|
| 575 |
+
files_from_retrieval=[],
|
| 576 |
+
files_from_template=list(template_contents.keys()) if "template_contents" in dir() else [],
|
| 577 |
+
warnings=all_warnings + [
|
| 578 |
+
f"LLM confidence: {avg_confidence:.2f}",
|
| 579 |
+
f"LLM warnings: {len(all_warnings)}",
|
| 580 |
+
],
|
| 581 |
+
errors=[f"LLM errors: {total_errors}"] if total_errors > 0 else [],
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
return result
|
| 585 |
+
|
| 586 |
def _try_retrieval_generation(
|
| 587 |
self,
|
| 588 |
similar: List[Any],
|
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger("uvmgen.ml.learning")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ValidationFeedback:
|
| 14 |
+
design_name: str
|
| 15 |
+
file_name: str
|
| 16 |
+
file_type: str
|
| 17 |
+
passed: bool
|
| 18 |
+
errors: List[str]
|
| 19 |
+
warnings: List[str]
|
| 20 |
+
score: float
|
| 21 |
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 22 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
|
| 24 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 25 |
+
return {
|
| 26 |
+
"design_name": self.design_name,
|
| 27 |
+
"file_name": self.file_name,
|
| 28 |
+
"file_type": self.file_type,
|
| 29 |
+
"passed": self.passed,
|
| 30 |
+
"errors": self.errors,
|
| 31 |
+
"warnings": self.warnings,
|
| 32 |
+
"score": self.score,
|
| 33 |
+
"timestamp": self.timestamp,
|
| 34 |
+
"metadata": self.metadata,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def from_dict(cls, d: Dict[str, Any]) -> "ValidationFeedback":
|
| 39 |
+
return cls(
|
| 40 |
+
design_name=d.get("design_name", "unknown"),
|
| 41 |
+
file_name=d.get("file_name", "unknown"),
|
| 42 |
+
file_type=d.get("file_type", "unknown"),
|
| 43 |
+
passed=d.get("passed", False),
|
| 44 |
+
errors=d.get("errors", []),
|
| 45 |
+
warnings=d.get("warnings", []),
|
| 46 |
+
score=d.get("score", 0.0),
|
| 47 |
+
timestamp=d.get("timestamp", datetime.now().isoformat()),
|
| 48 |
+
metadata=d.get("metadata", {}),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class GenerationHistory:
|
| 54 |
+
design_name: str
|
| 55 |
+
generation_source: str
|
| 56 |
+
spec_hash: str
|
| 57 |
+
feedback_list: List[ValidationFeedback]
|
| 58 |
+
success_rate: float = 0.0
|
| 59 |
+
avg_score: float = 0.0
|
| 60 |
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 61 |
+
|
| 62 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 63 |
+
return {
|
| 64 |
+
"design_name": self.design_name,
|
| 65 |
+
"generation_source": self.generation_source,
|
| 66 |
+
"spec_hash": self.spec_hash,
|
| 67 |
+
"feedback_list": [f.to_dict() for f in self.feedback_list],
|
| 68 |
+
"success_rate": self.success_rate,
|
| 69 |
+
"avg_score": self.avg_score,
|
| 70 |
+
"timestamp": self.timestamp,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class PatternLearner:
|
| 75 |
+
def __init__(self):
|
| 76 |
+
self._error_patterns: Dict[str, int] = defaultdict(int)
|
| 77 |
+
self._success_patterns: Dict[str, int] = defaultdict(int)
|
| 78 |
+
self._file_type_stats: Dict[str, Dict[str, Any]] = defaultdict(
|
| 79 |
+
lambda: {"success": 0, "total": 0, "errors": defaultdict(int)}
|
| 80 |
+
)
|
| 81 |
+
self._protocol_stats: Dict[str, Dict[str, Any]] = defaultdict(
|
| 82 |
+
lambda: {"success": 0, "total": 0}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def record_error(self, error_msg: str, file_type: str = "unknown"):
|
| 86 |
+
patterns = self._extract_patterns(error_msg)
|
| 87 |
+
for p in patterns:
|
| 88 |
+
self._error_patterns[p] += 1
|
| 89 |
+
self._file_type_stats[file_type]["errors"][error_msg[:100]] += 1
|
| 90 |
+
|
| 91 |
+
def record_success(self, file_type: str = "unknown", protocol: str = "unknown"):
|
| 92 |
+
self._file_type_stats[file_type]["success"] += 1
|
| 93 |
+
self._file_type_stats[file_type]["total"] += 1
|
| 94 |
+
self._protocol_stats[protocol]["success"] += 1
|
| 95 |
+
self._protocol_stats[protocol]["total"] += 1
|
| 96 |
+
|
| 97 |
+
def record_attempt(self, file_type: str = "unknown", protocol: str = "unknown"):
|
| 98 |
+
self._file_type_stats[file_type]["total"] += 1
|
| 99 |
+
self._protocol_stats[protocol]["total"] += 1
|
| 100 |
+
|
| 101 |
+
def _extract_patterns(self, text: str) -> List[str]:
|
| 102 |
+
import re
|
| 103 |
+
|
| 104 |
+
patterns = []
|
| 105 |
+
|
| 106 |
+
uvm_patterns = [
|
| 107 |
+
(r"uvm_fatal", "uvm_fatal"),
|
| 108 |
+
(r"uvm_error", "uvm_error"),
|
| 109 |
+
(r"uvm_component_utils", "missing_uvm_macro"),
|
| 110 |
+
(r"uvm_object_utils", "missing_uvm_macro"),
|
| 111 |
+
(r"build_phase", "phase_issue"),
|
| 112 |
+
(r"connect_phase", "phase_issue"),
|
| 113 |
+
(r"run_phase", "phase_issue"),
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
for pattern, name in uvm_patterns:
|
| 117 |
+
if re.search(pattern, text, re.IGNORECASE):
|
| 118 |
+
patterns.append(name)
|
| 119 |
+
|
| 120 |
+
syntax_patterns = [
|
| 121 |
+
(r"missing.*semicolon", "missing_semicolon"),
|
| 122 |
+
(r"unbalanced.*parenthes", "unbalanced_parentheses"),
|
| 123 |
+
(r"unbalanced.*brace", "unbalanced_braces"),
|
| 124 |
+
(r"unbalanced.*bracket", "unbalanced_brackets"),
|
| 125 |
+
(r"mismatch.*begin", "mismatched_blocks"),
|
| 126 |
+
(r"syntax error", "syntax_error"),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
for pattern, name in syntax_patterns:
|
| 130 |
+
if re.search(pattern, text, re.IGNORECASE):
|
| 131 |
+
patterns.append(name)
|
| 132 |
+
|
| 133 |
+
if not patterns:
|
| 134 |
+
patterns.append("unknown_error")
|
| 135 |
+
|
| 136 |
+
return patterns
|
| 137 |
+
|
| 138 |
+
def get_common_errors(self, top_n: int = 10) -> List[Tuple[str, int]]:
|
| 139 |
+
sorted_errors = sorted(
|
| 140 |
+
self._error_patterns.items(),
|
| 141 |
+
key=lambda x: x[1],
|
| 142 |
+
reverse=True,
|
| 143 |
+
)
|
| 144 |
+
return sorted_errors[:top_n]
|
| 145 |
+
|
| 146 |
+
def get_file_type_success_rate(self, file_type: str) -> float:
|
| 147 |
+
stats = self._file_type_stats.get(file_type, {})
|
| 148 |
+
total = stats.get("total", 0)
|
| 149 |
+
if total == 0:
|
| 150 |
+
return 0.5
|
| 151 |
+
return stats.get("success", 0) / total
|
| 152 |
+
|
| 153 |
+
def get_protocol_success_rate(self, protocol: str) -> float:
|
| 154 |
+
stats = self._protocol_stats.get(protocol, {})
|
| 155 |
+
total = stats.get("total", 0)
|
| 156 |
+
if total == 0:
|
| 157 |
+
return 0.5
|
| 158 |
+
return stats.get("success", 0) / total
|
| 159 |
+
|
| 160 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 161 |
+
return {
|
| 162 |
+
"error_patterns": dict(self._error_patterns),
|
| 163 |
+
"file_type_stats": {
|
| 164 |
+
ft: {
|
| 165 |
+
"success": s["success"],
|
| 166 |
+
"total": s["total"],
|
| 167 |
+
"errors": dict(s["errors"]),
|
| 168 |
+
}
|
| 169 |
+
for ft, s in self._file_type_stats.items()
|
| 170 |
+
},
|
| 171 |
+
"protocol_stats": dict(self._protocol_stats),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def from_dict(cls, d: Dict[str, Any]) -> "PatternLearner":
|
| 176 |
+
learner = cls()
|
| 177 |
+
learner._error_patterns = defaultdict(int, d.get("error_patterns", {}))
|
| 178 |
+
|
| 179 |
+
for ft, s in d.get("file_type_stats", {}).items():
|
| 180 |
+
learner._file_type_stats[ft] = {
|
| 181 |
+
"success": s.get("success", 0),
|
| 182 |
+
"total": s.get("total", 0),
|
| 183 |
+
"errors": defaultdict(int, s.get("errors", {})),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
for proto, s in d.get("protocol_stats", {}).items():
|
| 187 |
+
learner._protocol_stats[proto] = {
|
| 188 |
+
"success": s.get("success", 0),
|
| 189 |
+
"total": s.get("total", 0),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
return learner
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ReinforcementLearner:
|
| 196 |
+
def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9):
|
| 197 |
+
self._learning_rate = learning_rate
|
| 198 |
+
self._discount_factor = discount_factor
|
| 199 |
+
self._q_values: Dict[str, float] = defaultdict(lambda: 0.5)
|
| 200 |
+
self._visit_counts: Dict[str, int] = defaultdict(int)
|
| 201 |
+
|
| 202 |
+
def _get_state_key(
|
| 203 |
+
self,
|
| 204 |
+
protocol: str,
|
| 205 |
+
file_type: str,
|
| 206 |
+
generation_source: str,
|
| 207 |
+
) -> str:
|
| 208 |
+
return f"{protocol}:{file_type}:{generation_source}"
|
| 209 |
+
|
| 210 |
+
def get_action_value(
|
| 211 |
+
self,
|
| 212 |
+
protocol: str,
|
| 213 |
+
file_type: str,
|
| 214 |
+
generation_source: str,
|
| 215 |
+
) -> float:
|
| 216 |
+
key = self._get_state_key(protocol, file_type, generation_source)
|
| 217 |
+
return self._q_values[key]
|
| 218 |
+
|
| 219 |
+
def update(
|
| 220 |
+
self,
|
| 221 |
+
protocol: str,
|
| 222 |
+
file_type: str,
|
| 223 |
+
generation_source: str,
|
| 224 |
+
reward: float,
|
| 225 |
+
):
|
| 226 |
+
key = self._get_state_key(protocol, file_type, generation_source)
|
| 227 |
+
old_value = self._q_values[key]
|
| 228 |
+
self._visit_counts[key] += 1
|
| 229 |
+
self._q_values[key] = (
|
| 230 |
+
old_value + self._learning_rate * (reward - old_value)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def select_best_action(
|
| 234 |
+
self,
|
| 235 |
+
protocol: str,
|
| 236 |
+
file_type: str,
|
| 237 |
+
available_sources: List[str],
|
| 238 |
+
epsilon: float = 0.1,
|
| 239 |
+
) -> Tuple[str, float]:
|
| 240 |
+
import random
|
| 241 |
+
|
| 242 |
+
if random.random() < epsilon and len(available_sources) > 1:
|
| 243 |
+
chosen = random.choice(available_sources)
|
| 244 |
+
return chosen, self.get_action_value(protocol, file_type, chosen)
|
| 245 |
+
|
| 246 |
+
best_source = available_sources[0]
|
| 247 |
+
best_value = -1.0
|
| 248 |
+
|
| 249 |
+
for source in available_sources:
|
| 250 |
+
value = self.get_action_value(protocol, file_type, source)
|
| 251 |
+
if value > best_value:
|
| 252 |
+
best_value = value
|
| 253 |
+
best_source = source
|
| 254 |
+
|
| 255 |
+
return best_source, best_value
|
| 256 |
+
|
| 257 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 258 |
+
return {
|
| 259 |
+
"learning_rate": self._learning_rate,
|
| 260 |
+
"discount_factor": self._discount_factor,
|
| 261 |
+
"q_values": dict(self._q_values),
|
| 262 |
+
"visit_counts": dict(self._visit_counts),
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def from_dict(cls, d: Dict[str, Any]) -> "ReinforcementLearner":
|
| 267 |
+
learner = cls(
|
| 268 |
+
learning_rate=d.get("learning_rate", 0.1),
|
| 269 |
+
discount_factor=d.get("discount_factor", 0.9),
|
| 270 |
+
)
|
| 271 |
+
learner._q_values = defaultdict(lambda: 0.5)
|
| 272 |
+
learner._q_values.update(d.get("q_values", {}))
|
| 273 |
+
learner._visit_counts = defaultdict(int)
|
| 274 |
+
learner._visit_counts.update(d.get("visit_counts", {}))
|
| 275 |
+
return learner
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class LearningModule:
|
| 279 |
+
def __init__(self, storage_path: Optional[str] = None):
|
| 280 |
+
self._storage_path = storage_path
|
| 281 |
+
self._pattern_learner = PatternLearner()
|
| 282 |
+
self._rl_learner = ReinforcementLearner()
|
| 283 |
+
self._history: List[GenerationHistory] = []
|
| 284 |
+
self._total_generations = 0
|
| 285 |
+
self._successful_generations = 0
|
| 286 |
+
|
| 287 |
+
if storage_path:
|
| 288 |
+
self._load_from_storage()
|
| 289 |
+
|
| 290 |
+
def record_feedback(
|
| 291 |
+
self,
|
| 292 |
+
design_name: str,
|
| 293 |
+
generation_source: str,
|
| 294 |
+
spec_dict: Dict[str, Any],
|
| 295 |
+
validation_results: Dict[str, Any],
|
| 296 |
+
):
|
| 297 |
+
import hashlib
|
| 298 |
+
import json
|
| 299 |
+
|
| 300 |
+
spec_str = json.dumps(spec_dict, sort_keys=True)
|
| 301 |
+
spec_hash = hashlib.md5(spec_str.encode()).hexdigest()[:12]
|
| 302 |
+
|
| 303 |
+
protocol = spec_dict.get("protocol", "unknown")
|
| 304 |
+
|
| 305 |
+
feedback_list = []
|
| 306 |
+
|
| 307 |
+
files_data = validation_results.get("files", [])
|
| 308 |
+
|
| 309 |
+
if isinstance(files_data, dict):
|
| 310 |
+
for file_name, file_info in files_data.items():
|
| 311 |
+
file_type = file_info.get("type", "unknown")
|
| 312 |
+
passed = file_info.get("passed", True)
|
| 313 |
+
errors = file_info.get("errors", [])
|
| 314 |
+
warnings = file_info.get("warnings", [])
|
| 315 |
+
score = file_info.get("score", 0.5)
|
| 316 |
+
|
| 317 |
+
feedback = ValidationFeedback(
|
| 318 |
+
design_name=design_name,
|
| 319 |
+
file_name=file_name,
|
| 320 |
+
file_type=file_type,
|
| 321 |
+
passed=passed,
|
| 322 |
+
errors=errors,
|
| 323 |
+
warnings=warnings,
|
| 324 |
+
score=score,
|
| 325 |
+
)
|
| 326 |
+
feedback_list.append(feedback)
|
| 327 |
+
|
| 328 |
+
if passed:
|
| 329 |
+
self._pattern_learner.record_success(file_type, protocol)
|
| 330 |
+
reward = 1.0
|
| 331 |
+
else:
|
| 332 |
+
for err in errors:
|
| 333 |
+
self._pattern_learner.record_error(err, file_type)
|
| 334 |
+
reward = -0.5
|
| 335 |
+
|
| 336 |
+
self._pattern_learner.record_attempt(file_type, protocol)
|
| 337 |
+
self._rl_learner.update(protocol, file_type, generation_source, reward)
|
| 338 |
+
|
| 339 |
+
elif isinstance(files_data, list):
|
| 340 |
+
for file_info in files_data:
|
| 341 |
+
file_name = file_info.get("filename", "unknown")
|
| 342 |
+
file_type = file_info.get("file_type", "unknown")
|
| 343 |
+
passed = file_info.get("passed", True)
|
| 344 |
+
|
| 345 |
+
issues = file_info.get("issues", [])
|
| 346 |
+
errors = []
|
| 347 |
+
warnings = []
|
| 348 |
+
for issue in issues:
|
| 349 |
+
severity = issue.get("severity", "warning")
|
| 350 |
+
message = issue.get("message", "")
|
| 351 |
+
if severity == "error":
|
| 352 |
+
errors.append(message)
|
| 353 |
+
else:
|
| 354 |
+
warnings.append(message)
|
| 355 |
+
|
| 356 |
+
error_count = file_info.get("error_count", 0)
|
| 357 |
+
warning_count = file_info.get("warning_count", 0)
|
| 358 |
+
|
| 359 |
+
if error_count > 0:
|
| 360 |
+
passed = False
|
| 361 |
+
|
| 362 |
+
score = 1.0 if passed else 0.3
|
| 363 |
+
if passed and warning_count == 0:
|
| 364 |
+
score = 1.0
|
| 365 |
+
elif passed and warning_count > 0:
|
| 366 |
+
score = 0.7
|
| 367 |
+
|
| 368 |
+
feedback = ValidationFeedback(
|
| 369 |
+
design_name=design_name,
|
| 370 |
+
file_name=file_name,
|
| 371 |
+
file_type=file_type,
|
| 372 |
+
passed=passed,
|
| 373 |
+
errors=errors,
|
| 374 |
+
warnings=warnings,
|
| 375 |
+
score=score,
|
| 376 |
+
)
|
| 377 |
+
feedback_list.append(feedback)
|
| 378 |
+
|
| 379 |
+
if passed:
|
| 380 |
+
self._pattern_learner.record_success(file_type, protocol)
|
| 381 |
+
reward = 1.0
|
| 382 |
+
else:
|
| 383 |
+
for err in errors:
|
| 384 |
+
self._pattern_learner.record_error(err, file_type)
|
| 385 |
+
reward = -0.5
|
| 386 |
+
|
| 387 |
+
self._pattern_learner.record_attempt(file_type, protocol)
|
| 388 |
+
self._rl_learner.update(protocol, file_type, generation_source, reward)
|
| 389 |
+
|
| 390 |
+
all_passed = all(f.passed for f in feedback_list)
|
| 391 |
+
avg_score = sum(f.score for f in feedback_list) / len(feedback_list) if feedback_list else 0.0
|
| 392 |
+
|
| 393 |
+
history = GenerationHistory(
|
| 394 |
+
design_name=design_name,
|
| 395 |
+
generation_source=generation_source,
|
| 396 |
+
spec_hash=spec_hash,
|
| 397 |
+
feedback_list=feedback_list,
|
| 398 |
+
success_rate=1.0 if all_passed else 0.0,
|
| 399 |
+
avg_score=avg_score,
|
| 400 |
+
)
|
| 401 |
+
self._history.append(history)
|
| 402 |
+
|
| 403 |
+
self._total_generations += 1
|
| 404 |
+
if all_passed:
|
| 405 |
+
self._successful_generations += 1
|
| 406 |
+
|
| 407 |
+
if self._storage_path:
|
| 408 |
+
self._save_to_storage()
|
| 409 |
+
|
| 410 |
+
def select_best_generation_strategy(
|
| 411 |
+
self,
|
| 412 |
+
spec_dict: Dict[str, Any],
|
| 413 |
+
file_type: str,
|
| 414 |
+
available_sources: List[str],
|
| 415 |
+
) -> Tuple[str, float]:
|
| 416 |
+
protocol = spec_dict.get("protocol", "unknown")
|
| 417 |
+
|
| 418 |
+
best_source, best_value = self._rl_learner.select_best_action(
|
| 419 |
+
protocol=protocol,
|
| 420 |
+
file_type=file_type,
|
| 421 |
+
available_sources=available_sources,
|
| 422 |
+
epsilon=0.05,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
return best_source, best_value
|
| 426 |
+
|
| 427 |
+
def get_generation_hints(
|
| 428 |
+
self,
|
| 429 |
+
spec_dict: Dict[str, Any],
|
| 430 |
+
file_type: str,
|
| 431 |
+
) -> Dict[str, Any]:
|
| 432 |
+
protocol = spec_dict.get("protocol", "unknown")
|
| 433 |
+
|
| 434 |
+
common_errors = self._pattern_learner.get_common_errors(5)
|
| 435 |
+
file_success_rate = self._pattern_learner.get_file_type_success_rate(file_type)
|
| 436 |
+
protocol_success_rate = self._pattern_learner.get_protocol_success_rate(protocol)
|
| 437 |
+
|
| 438 |
+
return {
|
| 439 |
+
"common_errors": common_errors,
|
| 440 |
+
"file_type_success_rate": file_success_rate,
|
| 441 |
+
"protocol_success_rate": protocol_success_rate,
|
| 442 |
+
"recommendations": self._generate_recommendations(
|
| 443 |
+
common_errors,
|
| 444 |
+
file_success_rate,
|
| 445 |
+
protocol_success_rate,
|
| 446 |
+
),
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
def _generate_recommendations(
|
| 450 |
+
self,
|
| 451 |
+
common_errors: List[Tuple[str, int]],
|
| 452 |
+
file_success_rate: float,
|
| 453 |
+
protocol_success_rate: float,
|
| 454 |
+
) -> List[str]:
|
| 455 |
+
recommendations = []
|
| 456 |
+
|
| 457 |
+
for error_pattern, count in common_errors[:3]:
|
| 458 |
+
if count > 0:
|
| 459 |
+
if "semicolon" in error_pattern:
|
| 460 |
+
recommendations.append(
|
| 461 |
+
"Ensure all statements end with semicolons"
|
| 462 |
+
)
|
| 463 |
+
elif "parenthes" in error_pattern:
|
| 464 |
+
recommendations.append(
|
| 465 |
+
"Check for balanced parentheses"
|
| 466 |
+
)
|
| 467 |
+
elif "brace" in error_pattern:
|
| 468 |
+
recommendations.append(
|
| 469 |
+
"Check for balanced begin/end blocks"
|
| 470 |
+
)
|
| 471 |
+
elif "uvm_macro" in error_pattern:
|
| 472 |
+
recommendations.append(
|
| 473 |
+
"Add UVM factory registration macros (uvm_component_utils/uvm_object_utils)"
|
| 474 |
+
)
|
| 475 |
+
elif "phase" in error_pattern:
|
| 476 |
+
recommendations.append(
|
| 477 |
+
"Ensure proper UVM phase implementation"
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if file_success_rate < 0.7:
|
| 481 |
+
recommendations.append(
|
| 482 |
+
"Consider using retrieval-based generation for this file type"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
if protocol_success_rate < 0.7:
|
| 486 |
+
recommendations.append(
|
| 487 |
+
"Add protocol-specific templates may improve quality"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if not recommendations:
|
| 491 |
+
recommendations.append(
|
| 492 |
+
"No specific recommendations - generation should work well"
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
return recommendations
|
| 496 |
+
|
| 497 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 498 |
+
return {
|
| 499 |
+
"total_generations": self._total_generations,
|
| 500 |
+
"successful_generations": self._successful_generations,
|
| 501 |
+
"success_rate": (
|
| 502 |
+
self._successful_generations / self._total_generations
|
| 503 |
+
if self._total_generations > 0
|
| 504 |
+
else 0.0
|
| 505 |
+
),
|
| 506 |
+
"history_count": len(self._history),
|
| 507 |
+
"pattern_stats": self._pattern_learner.to_dict(),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
def _save_to_storage(self):
|
| 511 |
+
if not self._storage_path:
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
os.makedirs(os.path.dirname(self._storage_path), exist_ok=True)
|
| 516 |
+
|
| 517 |
+
data = {
|
| 518 |
+
"pattern_learner": self._pattern_learner.to_dict(),
|
| 519 |
+
"rl_learner": self._rl_learner.to_dict(),
|
| 520 |
+
"history": [h.to_dict() for h in self._history[-100:]],
|
| 521 |
+
"total_generations": self._total_generations,
|
| 522 |
+
"successful_generations": self._successful_generations,
|
| 523 |
+
"saved_at": datetime.now().isoformat(),
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
with open(self._storage_path, "w") as f:
|
| 527 |
+
json.dump(data, f, indent=2)
|
| 528 |
+
|
| 529 |
+
logger.debug("Learning module saved to: %s", self._storage_path)
|
| 530 |
+
|
| 531 |
+
except Exception as e:
|
| 532 |
+
logger.warning("Could not save learning module: %s", e)
|
| 533 |
+
|
| 534 |
+
def _load_from_storage(self):
|
| 535 |
+
if not self._storage_path or not os.path.exists(self._storage_path):
|
| 536 |
+
return
|
| 537 |
+
|
| 538 |
+
try:
|
| 539 |
+
with open(self._storage_path, "r") as f:
|
| 540 |
+
data = json.load(f)
|
| 541 |
+
|
| 542 |
+
self._pattern_learner = PatternLearner.from_dict(
|
| 543 |
+
data.get("pattern_learner", {})
|
| 544 |
+
)
|
| 545 |
+
self._rl_learner = ReinforcementLearner.from_dict(
|
| 546 |
+
data.get("rl_learner", {})
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
history_list = data.get("history", [])
|
| 550 |
+
for h_dict in history_list:
|
| 551 |
+
feedback_list = [
|
| 552 |
+
ValidationFeedback.from_dict(f)
|
| 553 |
+
for f in h_dict.get("feedback_list", [])
|
| 554 |
+
]
|
| 555 |
+
history = GenerationHistory(
|
| 556 |
+
design_name=h_dict.get("design_name", "unknown"),
|
| 557 |
+
generation_source=h_dict.get("generation_source", "unknown"),
|
| 558 |
+
spec_hash=h_dict.get("spec_hash", ""),
|
| 559 |
+
feedback_list=feedback_list,
|
| 560 |
+
success_rate=h_dict.get("success_rate", 0.0),
|
| 561 |
+
avg_score=h_dict.get("avg_score", 0.0),
|
| 562 |
+
timestamp=h_dict.get("timestamp", datetime.now().isoformat()),
|
| 563 |
+
)
|
| 564 |
+
self._history.append(history)
|
| 565 |
+
|
| 566 |
+
self._total_generations = data.get("total_generations", 0)
|
| 567 |
+
self._successful_generations = data.get("successful_generations", 0)
|
| 568 |
+
|
| 569 |
+
logger.info("Learning module loaded from: %s", self._storage_path)
|
| 570 |
+
|
| 571 |
+
except Exception as e:
|
| 572 |
+
logger.warning("Could not load learning module: %s", e)
|
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from enum import Enum
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger("uvmgen.ml.llm")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LLMType(Enum):
|
| 12 |
+
CODEGEN = "codegen"
|
| 13 |
+
CODET5 = "codet5"
|
| 14 |
+
CODEBERT = "codebert"
|
| 15 |
+
STARCODER = "starcoder"
|
| 16 |
+
LLAMA = "llama"
|
| 17 |
+
MISTRAL = "mistral"
|
| 18 |
+
FALLBACK = "fallback"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class LLMGenerationResult:
|
| 23 |
+
generated_code: str
|
| 24 |
+
prompt_used: str
|
| 25 |
+
model_name: str
|
| 26 |
+
tokens_generated: int
|
| 27 |
+
confidence: float = 0.5
|
| 28 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 29 |
+
warnings: List[str] = field(default_factory=list)
|
| 30 |
+
errors: List[str] = field(default_factory=list)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LLMCodeGenerator:
|
| 34 |
+
_instance: Optional["LLMCodeGenerator"] = None
|
| 35 |
+
_model = None
|
| 36 |
+
_tokenizer = None
|
| 37 |
+
_model_name: str = "Salesforce/codegen-350M-mono"
|
| 38 |
+
_device: str = "cpu"
|
| 39 |
+
_initialized: bool = False
|
| 40 |
+
_llm_type: LLMType = LLMType.FALLBACK
|
| 41 |
+
|
| 42 |
+
UVM_PROMPT_TEMPLATE = """
|
| 43 |
+
You are an expert in UVM (Universal Verification Methodology) and SystemVerilog.
|
| 44 |
+
Generate production-quality UVM testbench code based on the following specification.
|
| 45 |
+
|
| 46 |
+
SPECIFICATION:
|
| 47 |
+
{spec_text}
|
| 48 |
+
|
| 49 |
+
REQUIREMENTS:
|
| 50 |
+
- Follow UVM 1.2 conventions and best practices
|
| 51 |
+
- Use proper factory registration with `uvm_component_utils` or `uvm_object_utils`
|
| 52 |
+
- Include appropriate phases (build_phase, connect_phase, run_phase)
|
| 53 |
+
- Use TLM ports and exports for component communication
|
| 54 |
+
- Include proper configuration database usage if needed
|
| 55 |
+
- Generate synthesizable SystemVerilog code
|
| 56 |
+
|
| 57 |
+
{context_examples}
|
| 58 |
+
|
| 59 |
+
Generate the {file_type} for this specification. Return only the SystemVerilog code, no explanations.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
FEW_SHOT_EXAMPLES = {
|
| 63 |
+
"driver": """
|
| 64 |
+
EXAMPLE DRIVER:
|
| 65 |
+
class my_driver extends uvm_driver #(my_seq_item);
|
| 66 |
+
`uvm_component_utils(my_driver)
|
| 67 |
+
|
| 68 |
+
virtual my_if vif;
|
| 69 |
+
|
| 70 |
+
function new(string name = "my_driver", uvm_component parent = null);
|
| 71 |
+
super.new(name, parent);
|
| 72 |
+
endfunction
|
| 73 |
+
|
| 74 |
+
function void build_phase(uvm_phase phase);
|
| 75 |
+
super.build_phase(phase);
|
| 76 |
+
if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
|
| 77 |
+
`uvm_fatal(get_type_name(), "Virtual interface not found")
|
| 78 |
+
endfunction
|
| 79 |
+
|
| 80 |
+
task run_phase(uvm_phase phase);
|
| 81 |
+
forever begin
|
| 82 |
+
seq_item_port.get_next_item(req);
|
| 83 |
+
drive_item(req);
|
| 84 |
+
seq_item_port.item_done();
|
| 85 |
+
end
|
| 86 |
+
endtask
|
| 87 |
+
|
| 88 |
+
task drive_item(my_seq_item item);
|
| 89 |
+
@(posedge vif.clk);
|
| 90 |
+
vif.valid <= 1'b1;
|
| 91 |
+
vif.data <= item.data;
|
| 92 |
+
@(posedge vif.clk);
|
| 93 |
+
vif.valid <= 1'b0;
|
| 94 |
+
endtask
|
| 95 |
+
endclass
|
| 96 |
+
""",
|
| 97 |
+
"monitor": """
|
| 98 |
+
EXAMPLE MONITOR:
|
| 99 |
+
class my_monitor extends uvm_monitor;
|
| 100 |
+
`uvm_component_utils(my_monitor)
|
| 101 |
+
|
| 102 |
+
uvm_analysis_port #(my_seq_item) item_collected_port;
|
| 103 |
+
virtual my_if vif;
|
| 104 |
+
|
| 105 |
+
function new(string name = "my_monitor", uvm_component parent = null);
|
| 106 |
+
super.new(name, parent);
|
| 107 |
+
item_collected_port = new("item_collected_port", this);
|
| 108 |
+
endfunction
|
| 109 |
+
|
| 110 |
+
function void build_phase(uvm_phase phase);
|
| 111 |
+
super.build_phase(phase);
|
| 112 |
+
if (!uvm_config_db#(virtual my_if)::get(this, "", "vif", vif))
|
| 113 |
+
`uvm_fatal(get_type_name(), "Virtual interface not found")
|
| 114 |
+
endfunction
|
| 115 |
+
|
| 116 |
+
task run_phase(uvm_phase phase);
|
| 117 |
+
my_seq_item item;
|
| 118 |
+
forever begin
|
| 119 |
+
@(posedge vif.clk);
|
| 120 |
+
if (vif.valid) begin
|
| 121 |
+
item = my_seq_item::type_id::create("item");
|
| 122 |
+
item.data = vif.data;
|
| 123 |
+
item_collected_port.write(item);
|
| 124 |
+
end
|
| 125 |
+
end
|
| 126 |
+
endtask
|
| 127 |
+
endclass
|
| 128 |
+
""",
|
| 129 |
+
"agent": """
|
| 130 |
+
EXAMPLE AGENT:
|
| 131 |
+
class my_agent extends uvm_agent;
|
| 132 |
+
`uvm_component_utils(my_agent)
|
| 133 |
+
|
| 134 |
+
my_driver driver;
|
| 135 |
+
my_monitor monitor;
|
| 136 |
+
my_sequencer sequencer;
|
| 137 |
+
uvm_analysis_port #(my_seq_item) item_collected_port;
|
| 138 |
+
|
| 139 |
+
function new(string name = "my_agent", uvm_component parent = null);
|
| 140 |
+
super.new(name, parent);
|
| 141 |
+
item_collected_port = new("item_collected_port", this);
|
| 142 |
+
endfunction
|
| 143 |
+
|
| 144 |
+
function void build_phase(uvm_phase phase);
|
| 145 |
+
super.build_phase(phase);
|
| 146 |
+
|
| 147 |
+
if (get_is_active() == UVM_ACTIVE) begin
|
| 148 |
+
driver = my_driver::type_id::create("driver", this);
|
| 149 |
+
sequencer = my_sequencer::type_id::create("sequencer", this);
|
| 150 |
+
end
|
| 151 |
+
|
| 152 |
+
monitor = my_monitor::type_id::create("monitor", this);
|
| 153 |
+
endfunction
|
| 154 |
+
|
| 155 |
+
function void connect_phase(uvm_phase phase);
|
| 156 |
+
super.connect_phase(phase);
|
| 157 |
+
|
| 158 |
+
if (get_is_active() == UVM_ACTIVE) begin
|
| 159 |
+
driver.seq_item_port.connect(sequencer.seq_item_export);
|
| 160 |
+
end
|
| 161 |
+
|
| 162 |
+
monitor.item_collected_port.connect(item_collected_port);
|
| 163 |
+
endfunction
|
| 164 |
+
endclass
|
| 165 |
+
""",
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def __new__(cls, *args, **kwargs):
|
| 169 |
+
if cls._instance is None:
|
| 170 |
+
cls._instance = super().__new__(cls)
|
| 171 |
+
return cls._instance
|
| 172 |
+
|
| 173 |
+
def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
|
| 174 |
+
if self._initialized:
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
if model_name:
|
| 178 |
+
self._model_name = model_name
|
| 179 |
+
if device:
|
| 180 |
+
self._device = device
|
| 181 |
+
|
| 182 |
+
self._initialized = False
|
| 183 |
+
self._model = None
|
| 184 |
+
self._tokenizer = None
|
| 185 |
+
self._detect_llm_type()
|
| 186 |
+
|
| 187 |
+
def _detect_llm_type(self):
|
| 188 |
+
name_lower = self._model_name.lower()
|
| 189 |
+
if "codegen" in name_lower:
|
| 190 |
+
self._llm_type = LLMType.CODEGEN
|
| 191 |
+
elif "codet5" in name_lower:
|
| 192 |
+
self._llm_type = LLMType.CODET5
|
| 193 |
+
elif "codebert" in name_lower:
|
| 194 |
+
self._llm_type = LLMType.CODEBERT
|
| 195 |
+
elif "starcoder" in name_lower or "starcoder" in name_lower:
|
| 196 |
+
self._llm_type = LLMType.STARCODER
|
| 197 |
+
elif "llama" in name_lower:
|
| 198 |
+
self._llm_type = LLMType.LLAMA
|
| 199 |
+
elif "mistral" in name_lower:
|
| 200 |
+
self._llm_type = LLMType.MISTRAL
|
| 201 |
+
else:
|
| 202 |
+
self._llm_type = LLMType.FALLBACK
|
| 203 |
+
|
| 204 |
+
def _load_model(self):
|
| 205 |
+
if self._initialized and self._model is not None:
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
if self._llm_type == LLMType.FALLBACK:
|
| 209 |
+
logger.info("LLMCodeGenerator using fallback mode (template-based)")
|
| 210 |
+
self._initialized = True
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
import torch
|
| 215 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 216 |
+
|
| 217 |
+
if self._device == "auto":
|
| 218 |
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 219 |
+
|
| 220 |
+
logger.info("Loading LLM: %s on %s", self._model_name, self._device)
|
| 221 |
+
|
| 222 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
|
| 223 |
+
|
| 224 |
+
if self._llm_type == LLMType.CODET5:
|
| 225 |
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 226 |
+
self._model_name,
|
| 227 |
+
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
| 231 |
+
self._model_name,
|
| 232 |
+
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
self._model.to(self._device)
|
| 236 |
+
self._model.eval()
|
| 237 |
+
|
| 238 |
+
if self._tokenizer.pad_token is None:
|
| 239 |
+
self._tokenizer.pad_token = self._tokenizer.eos_token
|
| 240 |
+
|
| 241 |
+
self._initialized = True
|
| 242 |
+
logger.info("LLM loaded successfully")
|
| 243 |
+
|
| 244 |
+
except ImportError as e:
|
| 245 |
+
logger.warning(
|
| 246 |
+
"Could not load LLM (missing dependencies: %s). Using fallback mode.",
|
| 247 |
+
e,
|
| 248 |
+
)
|
| 249 |
+
self._llm_type = LLMType.FALLBACK
|
| 250 |
+
self._initialized = True
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.warning(
|
| 253 |
+
"Could not load LLM (%s). Using fallback mode.",
|
| 254 |
+
e,
|
| 255 |
+
)
|
| 256 |
+
self._llm_type = LLMType.FALLBACK
|
| 257 |
+
self._initialized = True
|
| 258 |
+
|
| 259 |
+
def is_available(self) -> bool:
|
| 260 |
+
self._load_model()
|
| 261 |
+
return self._initialized and self._llm_type != LLMType.FALLBACK
|
| 262 |
+
|
| 263 |
+
def _spec_to_text(self, spec_dict: Dict[str, Any]) -> str:
|
| 264 |
+
lines = []
|
| 265 |
+
|
| 266 |
+
if "design_name" in spec_dict:
|
| 267 |
+
lines.append(f"Design Name: {spec_dict['design_name']}")
|
| 268 |
+
|
| 269 |
+
if "protocol" in spec_dict:
|
| 270 |
+
lines.append(f"Protocol: {spec_dict['protocol']}")
|
| 271 |
+
|
| 272 |
+
if "signals" in spec_dict:
|
| 273 |
+
lines.append("\nSignals:")
|
| 274 |
+
for sig in spec_dict["signals"]:
|
| 275 |
+
name = sig.get("name", "unknown")
|
| 276 |
+
direction = sig.get("direction", "inout")
|
| 277 |
+
width = sig.get("width", 1)
|
| 278 |
+
desc = sig.get("description", "")
|
| 279 |
+
lines.append(f" - {name}: {direction}, width={width} {desc}")
|
| 280 |
+
|
| 281 |
+
if "registers" in spec_dict:
|
| 282 |
+
lines.append("\nRegisters:")
|
| 283 |
+
for reg in spec_dict["registers"]:
|
| 284 |
+
name = reg.get("name", "unknown")
|
| 285 |
+
addr = reg.get("address", "0x0")
|
| 286 |
+
width = reg.get("width", 32)
|
| 287 |
+
lines.append(f" - {name}: addr={addr}, width={width}")
|
| 288 |
+
|
| 289 |
+
if "features" in spec_dict:
|
| 290 |
+
lines.append("\nFeatures:")
|
| 291 |
+
for feat in spec_dict["features"]:
|
| 292 |
+
lines.append(f" - {feat}")
|
| 293 |
+
|
| 294 |
+
return "\n".join(lines)
|
| 295 |
+
|
| 296 |
+
def _build_prompt(
|
| 297 |
+
self,
|
| 298 |
+
spec_dict: Dict[str, Any],
|
| 299 |
+
file_type: str,
|
| 300 |
+
use_few_shot: bool = True,
|
| 301 |
+
) -> str:
|
| 302 |
+
spec_text = self._spec_to_text(spec_dict)
|
| 303 |
+
|
| 304 |
+
context_examples = ""
|
| 305 |
+
if use_few_shot and file_type in self.FEW_SHOT_EXAMPLES:
|
| 306 |
+
context_examples = self.FEW_SHOT_EXAMPLES[file_type]
|
| 307 |
+
|
| 308 |
+
prompt = self.UVM_PROMPT_TEMPLATE.format(
|
| 309 |
+
spec_text=spec_text,
|
| 310 |
+
file_type=file_type,
|
| 311 |
+
context_examples=context_examples,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
return prompt.strip()
|
| 315 |
+
|
| 316 |
+
def _extract_code(self, text: str) -> str:
|
| 317 |
+
code_block_patterns = [
|
| 318 |
+
r"```systemverilog\s+(.*?)```",
|
| 319 |
+
r"```verilog\s+(.*?)```",
|
| 320 |
+
r"```sv\s+(.*?)```",
|
| 321 |
+
r"```\s+(.*?)```",
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
for pattern in code_block_patterns:
|
| 325 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 326 |
+
if match:
|
| 327 |
+
return match.group(1).strip()
|
| 328 |
+
|
| 329 |
+
return text.strip()
|
| 330 |
+
|
| 331 |
+
def _fallback_generate(
|
| 332 |
+
self,
|
| 333 |
+
spec_dict: Dict[str, Any],
|
| 334 |
+
file_type: str,
|
| 335 |
+
templates: Optional[Dict[str, str]] = None,
|
| 336 |
+
) -> LLMGenerationResult:
|
| 337 |
+
design_name = spec_dict.get("design_name", "unknown").lower()
|
| 338 |
+
|
| 339 |
+
fallback_templates = {
|
| 340 |
+
"driver": f"""
|
| 341 |
+
class {design_name}_driver extends uvm_driver #({design_name}_seq_item);
|
| 342 |
+
`uvm_component_utils({design_name}_driver)
|
| 343 |
+
|
| 344 |
+
virtual {design_name}_if vif;
|
| 345 |
+
|
| 346 |
+
function new(string name = "{design_name}_driver", uvm_component parent = null);
|
| 347 |
+
super.new(name, parent);
|
| 348 |
+
endfunction
|
| 349 |
+
|
| 350 |
+
function void build_phase(uvm_phase phase);
|
| 351 |
+
super.build_phase(phase);
|
| 352 |
+
if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
|
| 353 |
+
`uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
|
| 354 |
+
endfunction
|
| 355 |
+
|
| 356 |
+
task run_phase(uvm_phase phase);
|
| 357 |
+
forever begin
|
| 358 |
+
seq_item_port.get_next_item(req);
|
| 359 |
+
drive_item(req);
|
| 360 |
+
seq_item_port.item_done();
|
| 361 |
+
end
|
| 362 |
+
endtask
|
| 363 |
+
|
| 364 |
+
task drive_item({design_name}_seq_item item);
|
| 365 |
+
// Implement drive logic based on item
|
| 366 |
+
@(posedge vif.clk);
|
| 367 |
+
endtask
|
| 368 |
+
endclass
|
| 369 |
+
""",
|
| 370 |
+
"monitor": f"""
|
| 371 |
+
class {design_name}_monitor extends uvm_monitor;
|
| 372 |
+
`uvm_component_utils({design_name}_monitor)
|
| 373 |
+
|
| 374 |
+
uvm_analysis_port #({design_name}_seq_item) item_collected_port;
|
| 375 |
+
virtual {design_name}_if vif;
|
| 376 |
+
|
| 377 |
+
function new(string name = "{design_name}_monitor", uvm_component parent = null);
|
| 378 |
+
super.new(name, parent);
|
| 379 |
+
item_collected_port = new("item_collected_port", this);
|
| 380 |
+
endfunction
|
| 381 |
+
|
| 382 |
+
function void build_phase(uvm_phase phase);
|
| 383 |
+
super.build_phase(phase);
|
| 384 |
+
if (!uvm_config_db#(virtual {design_name}_if)::get(this, "", "vif", vif))
|
| 385 |
+
`uvm_fatal(get_type_name(), "Virtual interface not found in config DB")
|
| 386 |
+
endfunction
|
| 387 |
+
|
| 388 |
+
task run_phase(uvm_phase phase);
|
| 389 |
+
{design_name}_seq_item item;
|
| 390 |
+
forever begin
|
| 391 |
+
@(posedge vif.clk);
|
| 392 |
+
// Sample signals and create item
|
| 393 |
+
end
|
| 394 |
+
endtask
|
| 395 |
+
endclass
|
| 396 |
+
""",
|
| 397 |
+
"agent": f"""
|
| 398 |
+
class {design_name}_agent extends uvm_agent;
|
| 399 |
+
`uvm_component_utils({design_name}_agent)
|
| 400 |
+
|
| 401 |
+
{design_name}_driver driver;
|
| 402 |
+
{design_name}_monitor monitor;
|
| 403 |
+
{design_name}_sequencer sequencer;
|
| 404 |
+
uvm_analysis_port #({design_name}_seq_item) item_collected_port;
|
| 405 |
+
|
| 406 |
+
function new(string name = "{design_name}_agent", uvm_component parent = null);
|
| 407 |
+
super.new(name, parent);
|
| 408 |
+
item_collected_port = new("item_collected_port", this);
|
| 409 |
+
endfunction
|
| 410 |
+
|
| 411 |
+
function void build_phase(uvm_phase phase);
|
| 412 |
+
super.build_phase(phase);
|
| 413 |
+
|
| 414 |
+
if (get_is_active() == UVM_ACTIVE) begin
|
| 415 |
+
driver = {design_name}_driver::type_id::create("driver", this);
|
| 416 |
+
sequencer = {design_name}_sequencer::type_id::create("sequencer", this);
|
| 417 |
+
end
|
| 418 |
+
|
| 419 |
+
monitor = {design_name}_monitor::type_id::create("monitor", this);
|
| 420 |
+
endfunction
|
| 421 |
+
|
| 422 |
+
function void connect_phase(uvm_phase phase);
|
| 423 |
+
super.connect_phase(phase);
|
| 424 |
+
|
| 425 |
+
if (get_is_active() == UVM_ACTIVE) begin
|
| 426 |
+
driver.seq_item_port.connect(sequencer.seq_item_export);
|
| 427 |
+
end
|
| 428 |
+
|
| 429 |
+
monitor.item_collected_port.connect(item_collected_port);
|
| 430 |
+
endfunction
|
| 431 |
+
endclass
|
| 432 |
+
""",
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
if templates and file_type in templates:
|
| 436 |
+
code = templates[file_type]
|
| 437 |
+
elif file_type in fallback_templates:
|
| 438 |
+
code = fallback_templates[file_type]
|
| 439 |
+
else:
|
| 440 |
+
code = f"// {file_type} for {design_name} - template placeholder"
|
| 441 |
+
|
| 442 |
+
return LLMGenerationResult(
|
| 443 |
+
generated_code=code,
|
| 444 |
+
prompt_used=f"// Fallback generation for {file_type}",
|
| 445 |
+
model_name="fallback_template",
|
| 446 |
+
tokens_generated=len(code.split()),
|
| 447 |
+
confidence=0.3,
|
| 448 |
+
warnings=["Using fallback template generation (LLM not available)"],
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
def generate(
|
| 452 |
+
self,
|
| 453 |
+
spec_dict: Dict[str, Any],
|
| 454 |
+
file_type: str,
|
| 455 |
+
use_few_shot: bool = True,
|
| 456 |
+
max_tokens: int = 1024,
|
| 457 |
+
temperature: float = 0.2,
|
| 458 |
+
templates: Optional[Dict[str, str]] = None,
|
| 459 |
+
) -> LLMGenerationResult:
|
| 460 |
+
self._load_model()
|
| 461 |
+
|
| 462 |
+
prompt = self._build_prompt(spec_dict, file_type, use_few_shot)
|
| 463 |
+
|
| 464 |
+
if self._llm_type == LLMType.FALLBACK or self._model is None:
|
| 465 |
+
return self._fallback_generate(spec_dict, file_type, templates)
|
| 466 |
+
|
| 467 |
+
try:
|
| 468 |
+
import torch
|
| 469 |
+
|
| 470 |
+
inputs = self._tokenizer(
|
| 471 |
+
prompt,
|
| 472 |
+
return_tensors="pt",
|
| 473 |
+
truncation=True,
|
| 474 |
+
max_length=1024,
|
| 475 |
+
padding=True,
|
| 476 |
+
)
|
| 477 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 478 |
+
|
| 479 |
+
with torch.no_grad():
|
| 480 |
+
if self._llm_type == LLMType.CODET5:
|
| 481 |
+
outputs = self._model.generate(
|
| 482 |
+
**inputs,
|
| 483 |
+
max_new_tokens=max_tokens,
|
| 484 |
+
temperature=temperature,
|
| 485 |
+
do_sample=temperature > 0,
|
| 486 |
+
num_return_sequences=1,
|
| 487 |
+
pad_token_id=self._tokenizer.pad_token_id,
|
| 488 |
+
eos_token_id=self._tokenizer.eos_token_id,
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
outputs = self._model.generate(
|
| 492 |
+
**inputs,
|
| 493 |
+
max_new_tokens=max_tokens,
|
| 494 |
+
temperature=temperature,
|
| 495 |
+
do_sample=temperature > 0,
|
| 496 |
+
num_return_sequences=1,
|
| 497 |
+
pad_token_id=self._tokenizer.pad_token_id,
|
| 498 |
+
eos_token_id=self._tokenizer.eos_token_id,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 502 |
+
|
| 503 |
+
if generated_text.startswith(prompt):
|
| 504 |
+
generated_text = generated_text[len(prompt) :].strip()
|
| 505 |
+
|
| 506 |
+
code = self._extract_code(generated_text)
|
| 507 |
+
tokens_generated = len(outputs[0]) - inputs["input_ids"].shape[1]
|
| 508 |
+
|
| 509 |
+
confidence = 0.7
|
| 510 |
+
if "uvm_component_utils" in code or "uvm_object_utils" in code:
|
| 511 |
+
confidence += 0.1
|
| 512 |
+
if "class" in code and "extends" in code:
|
| 513 |
+
confidence += 0.05
|
| 514 |
+
if "build_phase" in code or "run_phase" in code:
|
| 515 |
+
confidence += 0.05
|
| 516 |
+
if "endclass" in code:
|
| 517 |
+
confidence += 0.05
|
| 518 |
+
|
| 519 |
+
confidence = min(confidence, 0.95)
|
| 520 |
+
|
| 521 |
+
return LLMGenerationResult(
|
| 522 |
+
generated_code=code,
|
| 523 |
+
prompt_used=prompt,
|
| 524 |
+
model_name=self._model_name,
|
| 525 |
+
tokens_generated=tokens_generated,
|
| 526 |
+
confidence=confidence,
|
| 527 |
+
warnings=[],
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
logger.warning("Error during LLM generation: %s. Using fallback.", e)
|
| 532 |
+
result = self._fallback_generate(spec_dict, file_type, templates)
|
| 533 |
+
result.warnings.append(f"LLM generation failed: {str(e)}")
|
| 534 |
+
return result
|
| 535 |
+
|
| 536 |
+
def generate_batch(
|
| 537 |
+
self,
|
| 538 |
+
spec_dict: Dict[str, Any],
|
| 539 |
+
file_types: List[str],
|
| 540 |
+
use_few_shot: bool = True,
|
| 541 |
+
max_tokens: int = 1024,
|
| 542 |
+
temperature: float = 0.2,
|
| 543 |
+
templates: Optional[Dict[str, str]] = None,
|
| 544 |
+
) -> Dict[str, LLMGenerationResult]:
|
| 545 |
+
results = {}
|
| 546 |
+
|
| 547 |
+
for file_type in file_types:
|
| 548 |
+
results[file_type] = self.generate(
|
| 549 |
+
spec_dict=spec_dict,
|
| 550 |
+
file_type=file_type,
|
| 551 |
+
use_few_shot=use_few_shot,
|
| 552 |
+
max_tokens=max_tokens,
|
| 553 |
+
temperature=temperature,
|
| 554 |
+
templates=templates.get(file_type) if templates else None,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return results
|
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict, List, Optional, Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class MLModelConfig:
|
| 8 |
+
"""Configuration for ML-based generation models."""
|
| 9 |
+
similarity_threshold: float = 0.75
|
| 10 |
+
auto_learn: bool = True
|
| 11 |
+
index_path: Optional[str] = None
|
| 12 |
+
top_k_retrieval: int = 3
|
| 13 |
+
fallback_to_templates: bool = True
|
| 14 |
+
|
| 15 |
+
use_llm: bool = True
|
| 16 |
+
llm_model_name: Optional[str] = None
|
| 17 |
+
llm_max_tokens: int = 1024
|
| 18 |
+
llm_temperature: float = 0.2
|
| 19 |
+
llm_use_few_shot: bool = True
|
| 20 |
+
|
| 21 |
+
use_semantic_encoder: bool = True
|
| 22 |
+
semantic_model_name: str = "microsoft/codebert-base"
|
| 23 |
+
|
| 24 |
+
use_learning: bool = True
|
| 25 |
+
learning_storage_path: Optional[str] = None
|
| 26 |
+
learning_rate: float = 0.1
|
| 27 |
+
reinforcement_discount: float = 0.9
|
| 28 |
+
exploration_epsilon: float = 0.05
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RetrievalInfo:
|
| 32 |
+
"""Information about last retrieval operation."""
|
| 33 |
+
def __init__(self, used_similarity: bool = True, similar_specs: int = 0, best_score: float = 0.0):
|
| 34 |
+
self.used_similarity = used_similarity
|
| 35 |
+
self.similar_specs = similar_specs
|
| 36 |
+
self.best_score = best_score
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class NameNormalizer:
|
| 40 |
+
"""Utility for normalizing and adapting design names in filenames and code."""
|
| 41 |
+
|
| 42 |
+
DESIGN_NAME_PATTERN = re.compile(
|
| 43 |
+
r"([a-zA-Z_][a-zA-Z0-9_]*?)_(driver|monitor|agent|sequencer|sequence_item|sequence|scoreboard|coverage_collector|env|test|interface|testbench|ral_model|serial_monitor)",
|
| 44 |
+
re.IGNORECASE
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def adapt_names(
|
| 49 |
+
cls,
|
| 50 |
+
filename: str,
|
| 51 |
+
old_design_name: str,
|
| 52 |
+
new_design_name: str,
|
| 53 |
+
) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Adapt filenames and content from old design name to new design name.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
filename: Original filename
|
| 59 |
+
old_design_name: Old design name to replace
|
| 60 |
+
new_design_name: New design name to use
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Adapted filename
|
| 64 |
+
"""
|
| 65 |
+
if not old_design_name or not new_design_name:
|
| 66 |
+
return filename
|
| 67 |
+
|
| 68 |
+
old_lower = old_design_name.lower()
|
| 69 |
+
new_lower = new_design_name.lower()
|
| 70 |
+
|
| 71 |
+
base_name = filename
|
| 72 |
+
ext = ""
|
| 73 |
+
|
| 74 |
+
if "." in filename:
|
| 75 |
+
parts = filename.rsplit(".", 1)
|
| 76 |
+
base_name = parts[0]
|
| 77 |
+
ext = "." + parts[1] if len(parts) > 1 else ""
|
| 78 |
+
|
| 79 |
+
if old_lower in base_name.lower():
|
| 80 |
+
new_base = re.sub(
|
| 81 |
+
re.escape(old_design_name),
|
| 82 |
+
new_design_name,
|
| 83 |
+
base_name,
|
| 84 |
+
flags=re.IGNORECASE,
|
| 85 |
+
)
|
| 86 |
+
return new_base + ext
|
| 87 |
+
|
| 88 |
+
match = cls.DESIGN_NAME_PATTERN.match(base_name)
|
| 89 |
+
if match:
|
| 90 |
+
prefix = match.group(1)
|
| 91 |
+
suffix = match.group(2)
|
| 92 |
+
if prefix.lower() == old_lower:
|
| 93 |
+
return f"{new_design_name}_{suffix}{ext}"
|
| 94 |
+
|
| 95 |
+
return filename
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def adapt_content(
|
| 99 |
+
cls,
|
| 100 |
+
content: str,
|
| 101 |
+
old_design_name: str,
|
| 102 |
+
new_design_name: str,
|
| 103 |
+
) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Adapt SystemVerilog content from old design name to new design name.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
content: Original SystemVerilog content
|
| 109 |
+
old_design_name: Old design name to replace
|
| 110 |
+
new_design_name: New design name to use
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Adapted content
|
| 114 |
+
"""
|
| 115 |
+
if not old_design_name or not new_design_name or old_design_name == new_design_name:
|
| 116 |
+
return content
|
| 117 |
+
|
| 118 |
+
result = content
|
| 119 |
+
|
| 120 |
+
patterns = [
|
| 121 |
+
(
|
| 122 |
+
rf"\b{re.escape(old_design_name)}_([a-zA-Z_][a-zA-Z0-9_]*)\b",
|
| 123 |
+
f"{new_design_name}_\\1",
|
| 124 |
+
),
|
| 125 |
+
(
|
| 126 |
+
rf"\bclass\s+{re.escape(old_design_name)}_",
|
| 127 |
+
f"class {new_design_name}_",
|
| 128 |
+
),
|
| 129 |
+
(
|
| 130 |
+
rf"`uvm_component_utils\(\s*{re.escape(old_design_name)}_",
|
| 131 |
+
f"`uvm_component_utils({new_design_name}_",
|
| 132 |
+
),
|
| 133 |
+
(
|
| 134 |
+
rf"`uvm_object_utils\(\s*{re.escape(old_design_name)}_",
|
| 135 |
+
f"`uvm_object_utils({new_design_name}_",
|
| 136 |
+
),
|
| 137 |
+
(
|
| 138 |
+
rf"virtual\s+{re.escape(old_design_name)}_if\s+",
|
| 139 |
+
f"virtual {new_design_name}_if ",
|
| 140 |
+
),
|
| 141 |
+
(
|
| 142 |
+
rf"{re.escape(old_design_name)}_if::type_id",
|
| 143 |
+
f"{new_design_name}_if::type_id",
|
| 144 |
+
),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
for pattern, replacement in patterns:
|
| 148 |
+
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
|
| 149 |
+
|
| 150 |
+
result = re.sub(
|
| 151 |
+
rf"\b{re.escape(old_design_name)}\b",
|
| 152 |
+
new_design_name,
|
| 153 |
+
result,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return result
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def normalize_name(cls, name: str) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Normalize a design name to a standard format.
|
| 162 |
+
|
| 163 |
+
- Converts to snake_case
|
| 164 |
+
- Removes special characters
|
| 165 |
+
- Ensures valid SystemVerilog identifier
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
name: Original name
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Normalized name
|
| 172 |
+
"""
|
| 173 |
+
if not name:
|
| 174 |
+
return "design"
|
| 175 |
+
|
| 176 |
+
result = name.strip()
|
| 177 |
+
|
| 178 |
+
result = re.sub(r"[^a-zA-Z0-9_]", "_", result)
|
| 179 |
+
|
| 180 |
+
result = re.sub(r"_+", "_", result)
|
| 181 |
+
|
| 182 |
+
result = result.strip("_")
|
| 183 |
+
|
| 184 |
+
if not result:
|
| 185 |
+
return "design"
|
| 186 |
+
|
| 187 |
+
if not result[0].isalpha() and result[0] != "_":
|
| 188 |
+
result = "_" + result
|
| 189 |
+
|
| 190 |
+
return result.lower()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class MLGenerationModel:
|
| 194 |
+
"""
|
| 195 |
+
ML-based generation model (legacy name for EnhancedMLGenerationModel).
|
| 196 |
+
|
| 197 |
+
This class exists for backward compatibility with tests and code
|
| 198 |
+
that imports MLGenerationModel. Use EnhancedMLGenerationModel directly
|
| 199 |
+
for new code.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __new__(cls, *args, **kwargs):
|
| 203 |
+
from src.models.enhanced_ml_model import EnhancedMLGenerationModel
|
| 204 |
+
return EnhancedMLGenerationModel(*args, **kwargs)
|
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 3 |
+
import numpy as np
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger("uvmgen.ml.semantic")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class SemanticEmbedding:
|
| 11 |
+
vector: np.ndarray
|
| 12 |
+
text: str
|
| 13 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 14 |
+
embedding_type: str = "code"
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def dim(self) -> int:
|
| 18 |
+
return len(self.vector)
|
| 19 |
+
|
| 20 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 21 |
+
return {
|
| 22 |
+
"vector": self.vector.tolist(),
|
| 23 |
+
"text": self.text,
|
| 24 |
+
"metadata": self.metadata,
|
| 25 |
+
"embedding_type": self.embedding_type,
|
| 26 |
+
"dim": self.dim,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def from_dict(cls, d: Dict[str, Any]) -> "SemanticEmbedding":
|
| 31 |
+
return cls(
|
| 32 |
+
vector=np.array(d["vector"], dtype=np.float32),
|
| 33 |
+
text=d["text"],
|
| 34 |
+
metadata=d.get("metadata", {}),
|
| 35 |
+
embedding_type=d.get("embedding_type", "code"),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SemanticCodeEncoder:
|
| 40 |
+
_instance: Optional["SemanticCodeEncoder"] = None
|
| 41 |
+
_model = None
|
| 42 |
+
_tokenizer = None
|
| 43 |
+
_model_name: str = "microsoft/codebert-base"
|
| 44 |
+
_device: str = "cpu"
|
| 45 |
+
_initialized: bool = False
|
| 46 |
+
|
| 47 |
+
def __new__(cls, *args, **kwargs):
|
| 48 |
+
if cls._instance is None:
|
| 49 |
+
cls._instance = super().__new__(cls)
|
| 50 |
+
return cls._instance
|
| 51 |
+
|
| 52 |
+
def __init__(self, model_name: Optional[str] = None, device: Optional[str] = None):
|
| 53 |
+
if self._initialized:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
if model_name:
|
| 57 |
+
self._model_name = model_name
|
| 58 |
+
if device:
|
| 59 |
+
self._device = device
|
| 60 |
+
|
| 61 |
+
self._initialized = False
|
| 62 |
+
self._model = None
|
| 63 |
+
self._tokenizer = None
|
| 64 |
+
|
| 65 |
+
def _load_model(self):
|
| 66 |
+
if self._initialized and self._model is not None:
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
import torch
|
| 71 |
+
from transformers import AutoTokenizer, AutoModel
|
| 72 |
+
|
| 73 |
+
if self._device == "auto":
|
| 74 |
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 75 |
+
|
| 76 |
+
logger.info("Loading semantic encoder: %s on %s", self._model_name, self._device)
|
| 77 |
+
|
| 78 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
|
| 79 |
+
self._model = AutoModel.from_pretrained(self._model_name)
|
| 80 |
+
self._model.to(self._device)
|
| 81 |
+
self._model.eval()
|
| 82 |
+
|
| 83 |
+
self._initialized = True
|
| 84 |
+
logger.info("Semantic encoder loaded successfully")
|
| 85 |
+
|
| 86 |
+
except ImportError as e:
|
| 87 |
+
logger.warning(
|
| 88 |
+
"Could not load semantic encoder (missing dependencies: %s). "
|
| 89 |
+
"Using fallback TF-IDF-based similarity.",
|
| 90 |
+
e,
|
| 91 |
+
)
|
| 92 |
+
self._initialized = False
|
| 93 |
+
self._model = None
|
| 94 |
+
self._tokenizer = None
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.warning(
|
| 97 |
+
"Could not load semantic encoder (%s). Using fallback similarity.",
|
| 98 |
+
e,
|
| 99 |
+
)
|
| 100 |
+
self._initialized = False
|
| 101 |
+
self._model = None
|
| 102 |
+
self._tokenizer = None
|
| 103 |
+
|
| 104 |
+
def is_available(self) -> bool:
|
| 105 |
+
self._load_model()
|
| 106 |
+
return self._initialized and self._model is not None
|
| 107 |
+
|
| 108 |
+
def encode(
|
| 109 |
+
self,
|
| 110 |
+
text: str,
|
| 111 |
+
embedding_type: str = "code",
|
| 112 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 113 |
+
) -> SemanticEmbedding:
|
| 114 |
+
self._load_model()
|
| 115 |
+
|
| 116 |
+
if not self.is_available():
|
| 117 |
+
return self._fallback_encode(text, embedding_type, metadata)
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
import torch
|
| 121 |
+
|
| 122 |
+
inputs = self._tokenizer(
|
| 123 |
+
text,
|
| 124 |
+
return_tensors="pt",
|
| 125 |
+
truncation=True,
|
| 126 |
+
max_length=512,
|
| 127 |
+
padding=True,
|
| 128 |
+
)
|
| 129 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
outputs = self._model(**inputs)
|
| 133 |
+
embeddings = outputs.last_hidden_state[:, 0, :]
|
| 134 |
+
embeddings = embeddings.cpu().numpy().squeeze()
|
| 135 |
+
|
| 136 |
+
embeddings = embeddings / (np.linalg.norm(embeddings) + 1e-8)
|
| 137 |
+
|
| 138 |
+
return SemanticEmbedding(
|
| 139 |
+
vector=embeddings.astype(np.float32),
|
| 140 |
+
text=text,
|
| 141 |
+
metadata=metadata or {},
|
| 142 |
+
embedding_type=embedding_type,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.warning("Error encoding with neural model: %s. Using fallback.", e)
|
| 147 |
+
return self._fallback_encode(text, embedding_type, metadata)
|
| 148 |
+
|
| 149 |
+
def encode_batch(
|
| 150 |
+
self,
|
| 151 |
+
texts: List[str],
|
| 152 |
+
embedding_type: str = "code",
|
| 153 |
+
metadata_list: Optional[List[Dict[str, Any]]] = None,
|
| 154 |
+
) -> List[SemanticEmbedding]:
|
| 155 |
+
self._load_model()
|
| 156 |
+
|
| 157 |
+
if not self.is_available():
|
| 158 |
+
return [
|
| 159 |
+
self._fallback_encode(text, embedding_type, metadata_list[i] if metadata_list else None)
|
| 160 |
+
for i, text in enumerate(texts)
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
import torch
|
| 165 |
+
|
| 166 |
+
inputs = self._tokenizer(
|
| 167 |
+
texts,
|
| 168 |
+
return_tensors="pt",
|
| 169 |
+
truncation=True,
|
| 170 |
+
max_length=512,
|
| 171 |
+
padding=True,
|
| 172 |
+
)
|
| 173 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 174 |
+
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
outputs = self._model(**inputs)
|
| 177 |
+
embeddings = outputs.last_hidden_state[:, 0, :]
|
| 178 |
+
embeddings = embeddings.cpu().numpy()
|
| 179 |
+
|
| 180 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-8
|
| 181 |
+
embeddings = embeddings / norms
|
| 182 |
+
|
| 183 |
+
results = []
|
| 184 |
+
for i, emb in enumerate(embeddings):
|
| 185 |
+
results.append(
|
| 186 |
+
SemanticEmbedding(
|
| 187 |
+
vector=emb.astype(np.float32),
|
| 188 |
+
text=texts[i],
|
| 189 |
+
metadata=metadata_list[i] if metadata_list else {},
|
| 190 |
+
embedding_type=embedding_type,
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
return results
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.warning("Error batch encoding: %s. Using fallback.", e)
|
| 197 |
+
return [
|
| 198 |
+
self._fallback_encode(text, embedding_type, metadata_list[i] if metadata_list else None)
|
| 199 |
+
for i, text in enumerate(texts)
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
def _fallback_encode(
|
| 203 |
+
self,
|
| 204 |
+
text: str,
|
| 205 |
+
embedding_type: str = "code",
|
| 206 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 207 |
+
) -> SemanticEmbedding:
|
| 208 |
+
words = text.lower().split()
|
| 209 |
+
vocab = sorted(set(words))
|
| 210 |
+
vec = np.zeros(len(vocab), dtype=np.float32)
|
| 211 |
+
|
| 212 |
+
for w in words:
|
| 213 |
+
if w in vocab:
|
| 214 |
+
vec[vocab.index(w)] += 1
|
| 215 |
+
|
| 216 |
+
norm = np.linalg.norm(vec)
|
| 217 |
+
if norm > 0:
|
| 218 |
+
vec = vec / norm
|
| 219 |
+
|
| 220 |
+
pad_size = 128 - len(vec)
|
| 221 |
+
if pad_size > 0:
|
| 222 |
+
vec = np.pad(vec, (0, pad_size), mode="constant")
|
| 223 |
+
elif pad_size < 0:
|
| 224 |
+
vec = vec[:128]
|
| 225 |
+
|
| 226 |
+
return SemanticEmbedding(
|
| 227 |
+
vector=vec.astype(np.float32),
|
| 228 |
+
text=text,
|
| 229 |
+
metadata=metadata or {},
|
| 230 |
+
embedding_type=embedding_type,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def similarity(self, emb1: SemanticEmbedding, emb2: SemanticEmbedding) -> float:
|
| 234 |
+
if len(emb1.vector) != len(emb2.vector):
|
| 235 |
+
min_len = min(len(emb1.vector), len(emb2.vector))
|
| 236 |
+
v1 = emb1.vector[:min_len]
|
| 237 |
+
v2 = emb2.vector[:min_len]
|
| 238 |
+
else:
|
| 239 |
+
v1 = emb1.vector
|
| 240 |
+
v2 = emb2.vector
|
| 241 |
+
|
| 242 |
+
norm1 = np.linalg.norm(v1)
|
| 243 |
+
norm2 = np.linalg.norm(v2)
|
| 244 |
+
|
| 245 |
+
if norm1 < 1e-8 or norm2 < 1e-8:
|
| 246 |
+
return 0.0
|
| 247 |
+
|
| 248 |
+
return float(np.dot(v1, v2) / (norm1 * norm2))
|
| 249 |
+
|
| 250 |
+
def batch_similarity(
|
| 251 |
+
self,
|
| 252 |
+
query_emb: SemanticEmbedding,
|
| 253 |
+
embeddings: List[SemanticEmbedding],
|
| 254 |
+
) -> List[Tuple[int, float]]:
|
| 255 |
+
if not embeddings:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
q_vec = query_emb.vector
|
| 259 |
+
q_norm = np.linalg.norm(q_vec)
|
| 260 |
+
|
| 261 |
+
if q_norm < 1e-8:
|
| 262 |
+
return [(i, 0.0) for i in range(len(embeddings))]
|
| 263 |
+
|
| 264 |
+
results = []
|
| 265 |
+
for i, emb in enumerate(embeddings):
|
| 266 |
+
e_vec = emb.vector
|
| 267 |
+
|
| 268 |
+
if len(e_vec) != len(q_vec):
|
| 269 |
+
min_len = min(len(q_vec), len(e_vec))
|
| 270 |
+
qv = q_vec[:min_len]
|
| 271 |
+
ev = e_vec[:min_len]
|
| 272 |
+
else:
|
| 273 |
+
qv = q_vec
|
| 274 |
+
ev = e_vec
|
| 275 |
+
|
| 276 |
+
e_norm = np.linalg.norm(ev)
|
| 277 |
+
if e_norm < 1e-8:
|
| 278 |
+
results.append((i, 0.0))
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
sim = float(np.dot(qv, ev) / (q_norm * e_norm))
|
| 282 |
+
results.append((i, sim))
|
| 283 |
+
|
| 284 |
+
return results
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
|
| 288 |
+
norm1 = np.linalg.norm(v1)
|
| 289 |
+
norm2 = np.linalg.norm(v2)
|
| 290 |
+
|
| 291 |
+
if norm1 < 1e-8 or norm2 < 1e-8:
|
| 292 |
+
return 0.0
|
| 293 |
+
|
| 294 |
+
return float(np.dot(v1, v2) / (norm1 * norm2))
|
|
@@ -55,7 +55,7 @@ class TBPipeline:
|
|
| 55 |
model_type = ml_cfg.model_type
|
| 56 |
self.logger.info("ML generation enabled, model_type=%s", model_type)
|
| 57 |
|
| 58 |
-
if model_type in ("ml", "hybrid"):
|
| 59 |
ml_model_config = MLModelConfig(
|
| 60 |
similarity_threshold=ml_cfg.similarity_threshold,
|
| 61 |
auto_learn=ml_cfg.auto_learn,
|
|
@@ -68,8 +68,19 @@ class TBPipeline:
|
|
| 68 |
config=ml_model_config,
|
| 69 |
templates_dir=self.cfg.generation.templates_dir,
|
| 70 |
strict_validation=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
)
|
| 72 |
self.logger.info("Created EnhancedMLGenerationModel with index size: %d", len(model.index))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return model
|
| 74 |
|
| 75 |
self.logger.info("Falling back to template model")
|
|
|
|
| 55 |
model_type = ml_cfg.model_type
|
| 56 |
self.logger.info("ML generation enabled, model_type=%s", model_type)
|
| 57 |
|
| 58 |
+
if model_type in ("ml", "hybrid", "llm", "semantic"):
|
| 59 |
ml_model_config = MLModelConfig(
|
| 60 |
similarity_threshold=ml_cfg.similarity_threshold,
|
| 61 |
auto_learn=ml_cfg.auto_learn,
|
|
|
|
| 68 |
config=ml_model_config,
|
| 69 |
templates_dir=self.cfg.generation.templates_dir,
|
| 70 |
strict_validation=True,
|
| 71 |
+
use_llm=ml_cfg.use_llm,
|
| 72 |
+
use_semantic_encoder=ml_cfg.use_semantic_encoder,
|
| 73 |
+
use_learning=ml_cfg.use_learning,
|
| 74 |
+
llm_model_name=ml_cfg.llm_model_name,
|
| 75 |
+
learning_storage_path=ml_cfg.learning_storage_path,
|
| 76 |
)
|
| 77 |
self.logger.info("Created EnhancedMLGenerationModel with index size: %d", len(model.index))
|
| 78 |
+
|
| 79 |
+
if model_type == "llm":
|
| 80 |
+
self.logger.info("LLM mode: will prioritize LLM generation")
|
| 81 |
+
elif model_type == "semantic":
|
| 82 |
+
self.logger.info("Semantic mode: will use semantic embeddings for similarity")
|
| 83 |
+
|
| 84 |
return model
|
| 85 |
|
| 86 |
self.logger.info("Falling back to template model")
|