""" Model Registry & Versioning Module Author: AI Generated Created: 2025-11-24 Purpose: Track and version AI models and configurations """ import pickle import json from datetime import datetime from pathlib import Path from typing import Any, Dict, Optional import hashlib from database import db class ModelRegistry: """ Manage model versions and configurations. Stores models locally and metadata in MongoDB. """ def __init__(self, storage_dir: str = "./model_storage"): self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(exist_ok=True) self.collection = "ModelRegistry" def _generate_version_id(self, model_name: str) -> str: """Generate a unique version ID based on timestamp.""" timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") return f"{model_name}_v{timestamp}" def _calculate_hash(self, file_path: Path) -> str: """Calculate MD5 hash of model file for integrity check.""" md5_hash = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5_hash.update(chunk) return md5_hash.hexdigest() def save_model(self, model: Any, model_name: str, metadata: Dict = None) -> str: """ Save a model with versioning. Args: model: The model object (sklearn, torch, etc.) model_name: Name identifier (e.g., "kmeans_segmentation") metadata: Additional info (params, metrics, etc.) Returns: version_id: Unique version identifier """ version_id = self._generate_version_id(model_name) model_path = self.storage_dir / f"{version_id}.pkl" # Save model file with open(model_path, 'wb') as f: pickle.dump(model, f) # Calculate file hash file_hash = self._calculate_hash(model_path) # Save metadata to MongoDB doc = { "version_id": version_id, "model_name": model_name, "file_path": str(model_path), "file_hash": file_hash, "file_size": model_path.stat().st_size, "created_at": datetime.utcnow(), "metadata": metadata or {}, "status": "active" } db.get_collection(self.collection).insert_one(doc) print(f"✓ Saved model: {version_id}") return version_id def load_model(self, version_id: str = None, model_name: str = None) -> tuple: """ Load a model by version_id or latest version of model_name. Returns: (model, metadata) """ if version_id: doc = db.get_collection(self.collection).find_one({"version_id": version_id}) elif model_name: # Get latest version doc = db.get_collection(self.collection).find_one( {"model_name": model_name, "status": "active"}, sort=[("created_at", -1)] ) else: raise ValueError("Must provide version_id or model_name") if not doc: raise ValueError("Model not found") # Load model file model_path = Path(doc["file_path"]) # Verify integrity current_hash = self._calculate_hash(model_path) if current_hash != doc["file_hash"]: raise ValueError("Model file corrupted (hash mismatch)") with open(model_path, 'rb') as f: model = pickle.load(f) print(f"✓ Loaded model: {doc['version_id']}") return model, doc["metadata"] def save_prompt_template(self, template_name: str, prompt: str, metadata: Dict = None) -> str: """ Save a prompt template with versioning. """ version_id = self._generate_version_id(template_name) # Save to file template_path = self.storage_dir / f"{version_id}.txt" with open(template_path, 'w', encoding='utf-8') as f: f.write(prompt) # Save metadata doc = { "version_id": version_id, "template_name": template_name, "file_path": str(template_path), "created_at": datetime.utcnow(), "metadata": metadata or {}, "status": "active" } db.get_collection("PromptTemplates").insert_one(doc) print(f"✓ Saved prompt template: {version_id}") return version_id def load_prompt_template(self, version_id: str = None, template_name: str = None) -> str: """ Load a prompt template. """ if version_id: doc = db.get_collection("PromptTemplates").find_one({"version_id": version_id}) elif template_name: doc = db.get_collection("PromptTemplates").find_one( {"template_name": template_name, "status": "active"}, sort=[("created_at", -1)] ) else: raise ValueError("Must provide version_id or template_name") if not doc: raise ValueError("Template not found") with open(doc["file_path"], 'r', encoding='utf-8') as f: prompt = f.read() return prompt def list_versions(self, model_name: str) -> list: """ List all versions of a model. """ versions = list(db.get_collection(self.collection).find( {"model_name": model_name}, sort=[("created_at", -1)] )) return [{ "version_id": v["version_id"], "created_at": v["created_at"], "status": v["status"], "metadata": v.get("metadata", {}) } for v in versions] def archive_version(self, version_id: str): """ Archive (deactivate) a model version. """ db.get_collection(self.collection).update_one( {"version_id": version_id}, {"$set": {"status": "archived"}} ) print(f"✓ Archived model: {version_id}") # Global registry instance registry = ModelRegistry()