"""Hot-swappable embedding model registry. Architecture: ┌─────────────────────────────────────────────────────────┐ │ EmbeddingRegistry (singleton) │ │ ┌───────────────┐ ┌──────────────┐ ┌─────────────┐ │ │ │ ONNX MiniLM │ │ Ollama embed │ │ Sentence-TF │ │ │ │ (built-in, │ │ (cloud/local │ │ (any HF │ │ │ │ 384-dim) │ │ 768-dim) │ │ model) │ │ │ └───────────────┘ └──────────────┘ └─────────────┘ │ │ │ │ Active model selected via EMBEDDING_MODEL env var │ │ Hot-swap via admin API without server restart │ └─────────────────────────────────────────────────────────┘ Supported model identifiers: "onnx-minilm" → built-in ChromaDB ONNX (default, 384-dim, local) "ollama:" → Ollama embed API (e.g. "ollama:nomic-embed-text") → loaded via sentence-transformers """ from __future__ import annotations import logging import threading from typing import Any, Protocol from chromadb import EmbeddingFunction, Embeddings, Documents from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2 from backend.app.core.config import settings logger = logging.getLogger(__name__) BUILTIN_ONNX_ALIAS = "onnx-minilm" OLLAMA_PREFIX = "ollama:" def _to_python_float_vectors(vectors: Any) -> list[list[float]]: """Normalize embedding outputs for ChromaDB's strict type validator.""" return [[float(value) for value in vec] for vec in vectors] class _EmbedderBackend(Protocol): def encode(self, texts: list[str]) -> list[list[float]]: ... @property def dimension(self) -> int: ... @property def model_name(self) -> str: ... class _OnnxMiniLMBackend: """Wraps ChromaDB's built-in ONNX all-MiniLM-L6-v2 (zero extra deps).""" def __init__(self) -> None: self._ef = ONNXMiniLM_L6_V2() def encode(self, texts: list[str]) -> list[list[float]]: return _to_python_float_vectors(self._ef(texts)) @property def dimension(self) -> int: return 384 @property def model_name(self) -> str: return BUILTIN_ONNX_ALIAS class _SentenceTransformerBackend: """Wraps any HuggingFace sentence-transformers model. Loads the model on first use. Works with medical-domain models like PubMedBERT that produce superior embeddings for clinical text. """ def __init__(self, hf_model_name: str) -> None: self._hf_name = hf_model_name self._model: Any = None self._dim: int | None = None def _ensure_loaded(self) -> None: if self._model is not None: return try: from sentence_transformers import SentenceTransformer except ImportError as exc: raise RuntimeError( "sentence-transformers is required for custom embedding models. " "Install it: pip install sentence-transformers" ) from exc logger.info("Loading embedding model: %s (first load may download weights)", self._hf_name) self._model = SentenceTransformer(self._hf_name) probe = self._model.encode(["probe"], convert_to_numpy=True) self._dim = int(probe.shape[1]) logger.info("Embedding model ready: %s (%d-dim)", self._hf_name, self._dim) def encode(self, texts: list[str]) -> list[list[float]]: self._ensure_loaded() embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False) return _to_python_float_vectors(embeddings) @property def dimension(self) -> int: self._ensure_loaded() assert self._dim is not None return self._dim @property def model_name(self) -> str: return self._hf_name class _OllamaEmbeddingBackend: """Uses the Ollama embed API (local or cloud) for embeddings.""" def __init__(self, ollama_model: str) -> None: self._ollama_model = ollama_model self._dim: int | None = None self._client: Any = None def _ensure_client(self) -> Any: if self._client is None: from ollama import Client headers = {} if settings.ollama_api_key: headers["Authorization"] = f"Bearer {settings.ollama_api_key}" self._client = Client(host=settings.resolved_ollama_host, headers=headers, timeout=None) return self._client def encode(self, texts: list[str]) -> list[list[float]]: client = self._ensure_client() response = client.embed(model=self._ollama_model, input=texts) embeddings = response.embeddings if self._dim is None and embeddings: self._dim = len(embeddings[0]) logger.info("Ollama embedding model %s: %d-dim", self._ollama_model, self._dim) return _to_python_float_vectors(embeddings) @property def dimension(self) -> int: if self._dim is None: self.encode(["probe"]) assert self._dim is not None return self._dim @property def model_name(self) -> str: return f"ollama:{self._ollama_model}" class _RegistryEmbeddingFunction(EmbeddingFunction[Documents]): """ChromaDB-compatible wrapper around the active backend.""" def __init__(self, backend: _EmbedderBackend) -> None: self._backend = backend def __call__(self, input: Documents) -> Embeddings: vecs = self._backend.encode(list(input)) return vecs # type: ignore[return-value] class EmbeddingRegistry: """Thread-safe singleton managing the active embedding model.""" _instance: EmbeddingRegistry | None = None _lock = threading.Lock() def __init__(self) -> None: self._backend: _EmbedderBackend | None = None self._swap_lock = threading.Lock() @classmethod def get(cls) -> EmbeddingRegistry: if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance @property def active_backend(self) -> _EmbedderBackend: if self._backend is None: self.load_model(settings.embedding_model) assert self._backend is not None return self._backend def load_model(self, model_name: str) -> None: with self._swap_lock: if model_name == BUILTIN_ONNX_ALIAS: self._backend = _OnnxMiniLMBackend() elif model_name.startswith(OLLAMA_PREFIX): ollama_model = model_name[len(OLLAMA_PREFIX):] self._backend = _OllamaEmbeddingBackend(ollama_model) else: self._backend = _SentenceTransformerBackend(model_name) logger.info("Active embedding model set to: %s", model_name) def swap_model(self, model_name: str) -> dict: """Hot-swap the active model. Returns status info. IMPORTANT: embeddings produced by different models live in incompatible vector spaces. After swapping, existing ChromaDB collections must be re-indexed. """ old_name = self._backend.model_name if self._backend else "none" self.load_model(model_name) new_backend = self.active_backend return { "previous_model": old_name, "active_model": new_backend.model_name, "dimension": new_backend.dimension, "reindex_required": old_name != new_backend.model_name, } @property def model_name(self) -> str: return self.active_backend.model_name @property def dimension(self) -> int: return self.active_backend.dimension # ── Public API (unchanged signatures for backward compatibility) ── def embed_text(text: str) -> list[float]: """Return an embedding for a single text string.""" return EmbeddingRegistry.get().active_backend.encode([text])[0] def embed_texts(texts: list[str]) -> list[list[float]]: """Return embeddings for a batch of text strings.""" return EmbeddingRegistry.get().active_backend.encode(texts) def get_embedding_function() -> EmbeddingFunction: """Return a ChromaDB-compatible embedding function using the active model.""" return _RegistryEmbeddingFunction(EmbeddingRegistry.get().active_backend) def get_model_info() -> dict: """Return metadata about the active embedding model.""" reg = EmbeddingRegistry.get() return { "model_name": reg.model_name, "dimension": reg.dimension, }