scenarist / app /core /embedder.py
github-actions[bot]
Sync backend to Hugging Face Space (commit: 39b5c807918249fa80049d49f4b6a74d6a0ed1fc)
6d86412
Raw
History Blame Contribute Delete
9.14 kB
"""Hot-swappable embedding model registry.
Architecture:
┌─────────────────────────────────────────────────────────┐
│ EmbeddingRegistry (singleton) │
│ ┌───────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │ ONNX MiniLM │ │ Ollama embed │ │ Sentence-TF │ │
│ │ (built-in, │ │ (cloud/local │ │ (any HF │ │
│ │ 384-dim) │ │ 768-dim) │ │ model) │ │
│ └───────────────┘ └──────────────┘ └─────────────┘ │
│ │
│ Active model selected via EMBEDDING_MODEL env var │
│ Hot-swap via admin API without server restart │
└─────────────────────────────────────────────────────────┘
Supported model identifiers:
"onnx-minilm" → built-in ChromaDB ONNX (default, 384-dim, local)
"ollama:<model>" → Ollama embed API (e.g. "ollama:nomic-embed-text")
<any HuggingFace model> → loaded via sentence-transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Protocol
from chromadb import EmbeddingFunction, Embeddings, Documents
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
from backend.app.core.config import settings
logger = logging.getLogger(__name__)
BUILTIN_ONNX_ALIAS = "onnx-minilm"
OLLAMA_PREFIX = "ollama:"
def _to_python_float_vectors(vectors: Any) -> list[list[float]]:
"""Normalize embedding outputs for ChromaDB's strict type validator."""
return [[float(value) for value in vec] for vec in vectors]
class _EmbedderBackend(Protocol):
def encode(self, texts: list[str]) -> list[list[float]]: ...
@property
def dimension(self) -> int: ...
@property
def model_name(self) -> str: ...
class _OnnxMiniLMBackend:
"""Wraps ChromaDB's built-in ONNX all-MiniLM-L6-v2 (zero extra deps)."""
def __init__(self) -> None:
self._ef = ONNXMiniLM_L6_V2()
def encode(self, texts: list[str]) -> list[list[float]]:
return _to_python_float_vectors(self._ef(texts))
@property
def dimension(self) -> int:
return 384
@property
def model_name(self) -> str:
return BUILTIN_ONNX_ALIAS
class _SentenceTransformerBackend:
"""Wraps any HuggingFace sentence-transformers model.
Loads the model on first use. Works with medical-domain models
like PubMedBERT that produce superior embeddings for clinical text.
"""
def __init__(self, hf_model_name: str) -> None:
self._hf_name = hf_model_name
self._model: Any = None
self._dim: int | None = None
def _ensure_loaded(self) -> None:
if self._model is not None:
return
try:
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise RuntimeError(
"sentence-transformers is required for custom embedding models. "
"Install it: pip install sentence-transformers"
) from exc
logger.info("Loading embedding model: %s (first load may download weights)", self._hf_name)
self._model = SentenceTransformer(self._hf_name)
probe = self._model.encode(["probe"], convert_to_numpy=True)
self._dim = int(probe.shape[1])
logger.info("Embedding model ready: %s (%d-dim)", self._hf_name, self._dim)
def encode(self, texts: list[str]) -> list[list[float]]:
self._ensure_loaded()
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
return _to_python_float_vectors(embeddings)
@property
def dimension(self) -> int:
self._ensure_loaded()
assert self._dim is not None
return self._dim
@property
def model_name(self) -> str:
return self._hf_name
class _OllamaEmbeddingBackend:
"""Uses the Ollama embed API (local or cloud) for embeddings."""
def __init__(self, ollama_model: str) -> None:
self._ollama_model = ollama_model
self._dim: int | None = None
self._client: Any = None
def _ensure_client(self) -> Any:
if self._client is None:
from ollama import Client
headers = {}
if settings.ollama_api_key:
headers["Authorization"] = f"Bearer {settings.ollama_api_key}"
self._client = Client(host=settings.resolved_ollama_host, headers=headers, timeout=None)
return self._client
def encode(self, texts: list[str]) -> list[list[float]]:
client = self._ensure_client()
response = client.embed(model=self._ollama_model, input=texts)
embeddings = response.embeddings
if self._dim is None and embeddings:
self._dim = len(embeddings[0])
logger.info("Ollama embedding model %s: %d-dim", self._ollama_model, self._dim)
return _to_python_float_vectors(embeddings)
@property
def dimension(self) -> int:
if self._dim is None:
self.encode(["probe"])
assert self._dim is not None
return self._dim
@property
def model_name(self) -> str:
return f"ollama:{self._ollama_model}"
class _RegistryEmbeddingFunction(EmbeddingFunction[Documents]):
"""ChromaDB-compatible wrapper around the active backend."""
def __init__(self, backend: _EmbedderBackend) -> None:
self._backend = backend
def __call__(self, input: Documents) -> Embeddings:
vecs = self._backend.encode(list(input))
return vecs # type: ignore[return-value]
class EmbeddingRegistry:
"""Thread-safe singleton managing the active embedding model."""
_instance: EmbeddingRegistry | None = None
_lock = threading.Lock()
def __init__(self) -> None:
self._backend: _EmbedderBackend | None = None
self._swap_lock = threading.Lock()
@classmethod
def get(cls) -> EmbeddingRegistry:
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@property
def active_backend(self) -> _EmbedderBackend:
if self._backend is None:
self.load_model(settings.embedding_model)
assert self._backend is not None
return self._backend
def load_model(self, model_name: str) -> None:
with self._swap_lock:
if model_name == BUILTIN_ONNX_ALIAS:
self._backend = _OnnxMiniLMBackend()
elif model_name.startswith(OLLAMA_PREFIX):
ollama_model = model_name[len(OLLAMA_PREFIX):]
self._backend = _OllamaEmbeddingBackend(ollama_model)
else:
self._backend = _SentenceTransformerBackend(model_name)
logger.info("Active embedding model set to: %s", model_name)
def swap_model(self, model_name: str) -> dict:
"""Hot-swap the active model. Returns status info.
IMPORTANT: embeddings produced by different models live in
incompatible vector spaces. After swapping, existing ChromaDB
collections must be re-indexed.
"""
old_name = self._backend.model_name if self._backend else "none"
self.load_model(model_name)
new_backend = self.active_backend
return {
"previous_model": old_name,
"active_model": new_backend.model_name,
"dimension": new_backend.dimension,
"reindex_required": old_name != new_backend.model_name,
}
@property
def model_name(self) -> str:
return self.active_backend.model_name
@property
def dimension(self) -> int:
return self.active_backend.dimension
# ── Public API (unchanged signatures for backward compatibility) ──
def embed_text(text: str) -> list[float]:
"""Return an embedding for a single text string."""
return EmbeddingRegistry.get().active_backend.encode([text])[0]
def embed_texts(texts: list[str]) -> list[list[float]]:
"""Return embeddings for a batch of text strings."""
return EmbeddingRegistry.get().active_backend.encode(texts)
def get_embedding_function() -> EmbeddingFunction:
"""Return a ChromaDB-compatible embedding function using the active model."""
return _RegistryEmbeddingFunction(EmbeddingRegistry.get().active_backend)
def get_model_info() -> dict:
"""Return metadata about the active embedding model."""
reg = EmbeddingRegistry.get()
return {
"model_name": reg.model_name,
"dimension": reg.dimension,
}