Spaces:
Sleeping
Sleeping
github-actions[bot]
Sync backend to Hugging Face Space (commit: 39b5c807918249fa80049d49f4b6a74d6a0ed1fc)
6d86412 | """Hot-swappable embedding model registry. | |
| Architecture: | |
| ┌─────────────────────────────────────────────────────────┐ | |
| │ EmbeddingRegistry (singleton) │ | |
| │ ┌───────────────┐ ┌──────────────┐ ┌─────────────┐ │ | |
| │ │ ONNX MiniLM │ │ Ollama embed │ │ Sentence-TF │ │ | |
| │ │ (built-in, │ │ (cloud/local │ │ (any HF │ │ | |
| │ │ 384-dim) │ │ 768-dim) │ │ model) │ │ | |
| │ └───────────────┘ └──────────────┘ └─────────────┘ │ | |
| │ │ | |
| │ Active model selected via EMBEDDING_MODEL env var │ | |
| │ Hot-swap via admin API without server restart │ | |
| └─────────────────────────────────────────────────────────┘ | |
| Supported model identifiers: | |
| "onnx-minilm" → built-in ChromaDB ONNX (default, 384-dim, local) | |
| "ollama:<model>" → Ollama embed API (e.g. "ollama:nomic-embed-text") | |
| <any HuggingFace model> → loaded via sentence-transformers | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import threading | |
| from typing import Any, Protocol | |
| from chromadb import EmbeddingFunction, Embeddings, Documents | |
| from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2 | |
| from backend.app.core.config import settings | |
| logger = logging.getLogger(__name__) | |
| BUILTIN_ONNX_ALIAS = "onnx-minilm" | |
| OLLAMA_PREFIX = "ollama:" | |
| def _to_python_float_vectors(vectors: Any) -> list[list[float]]: | |
| """Normalize embedding outputs for ChromaDB's strict type validator.""" | |
| return [[float(value) for value in vec] for vec in vectors] | |
| class _EmbedderBackend(Protocol): | |
| def encode(self, texts: list[str]) -> list[list[float]]: ... | |
| def dimension(self) -> int: ... | |
| def model_name(self) -> str: ... | |
| class _OnnxMiniLMBackend: | |
| """Wraps ChromaDB's built-in ONNX all-MiniLM-L6-v2 (zero extra deps).""" | |
| def __init__(self) -> None: | |
| self._ef = ONNXMiniLM_L6_V2() | |
| def encode(self, texts: list[str]) -> list[list[float]]: | |
| return _to_python_float_vectors(self._ef(texts)) | |
| def dimension(self) -> int: | |
| return 384 | |
| def model_name(self) -> str: | |
| return BUILTIN_ONNX_ALIAS | |
| class _SentenceTransformerBackend: | |
| """Wraps any HuggingFace sentence-transformers model. | |
| Loads the model on first use. Works with medical-domain models | |
| like PubMedBERT that produce superior embeddings for clinical text. | |
| """ | |
| def __init__(self, hf_model_name: str) -> None: | |
| self._hf_name = hf_model_name | |
| self._model: Any = None | |
| self._dim: int | None = None | |
| def _ensure_loaded(self) -> None: | |
| if self._model is not None: | |
| return | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| except ImportError as exc: | |
| raise RuntimeError( | |
| "sentence-transformers is required for custom embedding models. " | |
| "Install it: pip install sentence-transformers" | |
| ) from exc | |
| logger.info("Loading embedding model: %s (first load may download weights)", self._hf_name) | |
| self._model = SentenceTransformer(self._hf_name) | |
| probe = self._model.encode(["probe"], convert_to_numpy=True) | |
| self._dim = int(probe.shape[1]) | |
| logger.info("Embedding model ready: %s (%d-dim)", self._hf_name, self._dim) | |
| def encode(self, texts: list[str]) -> list[list[float]]: | |
| self._ensure_loaded() | |
| embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False) | |
| return _to_python_float_vectors(embeddings) | |
| def dimension(self) -> int: | |
| self._ensure_loaded() | |
| assert self._dim is not None | |
| return self._dim | |
| def model_name(self) -> str: | |
| return self._hf_name | |
| class _OllamaEmbeddingBackend: | |
| """Uses the Ollama embed API (local or cloud) for embeddings.""" | |
| def __init__(self, ollama_model: str) -> None: | |
| self._ollama_model = ollama_model | |
| self._dim: int | None = None | |
| self._client: Any = None | |
| def _ensure_client(self) -> Any: | |
| if self._client is None: | |
| from ollama import Client | |
| headers = {} | |
| if settings.ollama_api_key: | |
| headers["Authorization"] = f"Bearer {settings.ollama_api_key}" | |
| self._client = Client(host=settings.resolved_ollama_host, headers=headers, timeout=None) | |
| return self._client | |
| def encode(self, texts: list[str]) -> list[list[float]]: | |
| client = self._ensure_client() | |
| response = client.embed(model=self._ollama_model, input=texts) | |
| embeddings = response.embeddings | |
| if self._dim is None and embeddings: | |
| self._dim = len(embeddings[0]) | |
| logger.info("Ollama embedding model %s: %d-dim", self._ollama_model, self._dim) | |
| return _to_python_float_vectors(embeddings) | |
| def dimension(self) -> int: | |
| if self._dim is None: | |
| self.encode(["probe"]) | |
| assert self._dim is not None | |
| return self._dim | |
| def model_name(self) -> str: | |
| return f"ollama:{self._ollama_model}" | |
| class _RegistryEmbeddingFunction(EmbeddingFunction[Documents]): | |
| """ChromaDB-compatible wrapper around the active backend.""" | |
| def __init__(self, backend: _EmbedderBackend) -> None: | |
| self._backend = backend | |
| def __call__(self, input: Documents) -> Embeddings: | |
| vecs = self._backend.encode(list(input)) | |
| return vecs # type: ignore[return-value] | |
| class EmbeddingRegistry: | |
| """Thread-safe singleton managing the active embedding model.""" | |
| _instance: EmbeddingRegistry | None = None | |
| _lock = threading.Lock() | |
| def __init__(self) -> None: | |
| self._backend: _EmbedderBackend | None = None | |
| self._swap_lock = threading.Lock() | |
| def get(cls) -> EmbeddingRegistry: | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def active_backend(self) -> _EmbedderBackend: | |
| if self._backend is None: | |
| self.load_model(settings.embedding_model) | |
| assert self._backend is not None | |
| return self._backend | |
| def load_model(self, model_name: str) -> None: | |
| with self._swap_lock: | |
| if model_name == BUILTIN_ONNX_ALIAS: | |
| self._backend = _OnnxMiniLMBackend() | |
| elif model_name.startswith(OLLAMA_PREFIX): | |
| ollama_model = model_name[len(OLLAMA_PREFIX):] | |
| self._backend = _OllamaEmbeddingBackend(ollama_model) | |
| else: | |
| self._backend = _SentenceTransformerBackend(model_name) | |
| logger.info("Active embedding model set to: %s", model_name) | |
| def swap_model(self, model_name: str) -> dict: | |
| """Hot-swap the active model. Returns status info. | |
| IMPORTANT: embeddings produced by different models live in | |
| incompatible vector spaces. After swapping, existing ChromaDB | |
| collections must be re-indexed. | |
| """ | |
| old_name = self._backend.model_name if self._backend else "none" | |
| self.load_model(model_name) | |
| new_backend = self.active_backend | |
| return { | |
| "previous_model": old_name, | |
| "active_model": new_backend.model_name, | |
| "dimension": new_backend.dimension, | |
| "reindex_required": old_name != new_backend.model_name, | |
| } | |
| def model_name(self) -> str: | |
| return self.active_backend.model_name | |
| def dimension(self) -> int: | |
| return self.active_backend.dimension | |
| # ── Public API (unchanged signatures for backward compatibility) ── | |
| def embed_text(text: str) -> list[float]: | |
| """Return an embedding for a single text string.""" | |
| return EmbeddingRegistry.get().active_backend.encode([text])[0] | |
| def embed_texts(texts: list[str]) -> list[list[float]]: | |
| """Return embeddings for a batch of text strings.""" | |
| return EmbeddingRegistry.get().active_backend.encode(texts) | |
| def get_embedding_function() -> EmbeddingFunction: | |
| """Return a ChromaDB-compatible embedding function using the active model.""" | |
| return _RegistryEmbeddingFunction(EmbeddingRegistry.get().active_backend) | |
| def get_model_info() -> dict: | |
| """Return metadata about the active embedding model.""" | |
| reg = EmbeddingRegistry.get() | |
| return { | |
| "model_name": reg.model_name, | |
| "dimension": reg.dimension, | |
| } | |