pharmaspine-backend / src /embedding.py
ashish1265659565's picture
Upload folder using huggingface_hub
08fd094 verified
Raw
History Blame Contribute Delete
2.24 kB
from typing import List
from langchain_huggingface import HuggingFaceEmbeddings
from fastembed import SparseTextEmbedding
from src.config import EMBEDDING_MODEL
"""
embedding.py
What it does:
Generates vector embeddings using the HuggingFaceEmbeddings wrapper.
"""
_model = None
_sparse_model = None
def get_model() -> HuggingFaceEmbeddings:
"""Returns the initialized HuggingFaceEmbeddings model."""
global _model
if _model is None:
print(f"Loading embedding model: {EMBEDDING_MODEL}")
_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
return _model
def get_sparse_model() -> SparseTextEmbedding:
"""Returns the initialized fastembed SparseTextEmbedding model."""
global _sparse_model
if _sparse_model is None:
print("Loading sparse embedding model: prithivida/Splade_PP_en_v1")
_sparse_model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
return _sparse_model
def generate_embeddings(chunks: List[str]) -> List[List[float]]:
"""Generates embeddings for a list of text chunks."""
if not chunks:
return []
model = get_model()
return model.embed_documents(chunks)
def embed_query(query: str) -> List[float]:
"""Generates an embedding for a single search query."""
if not query:
return []
model = get_model()
return model.embed_query(query)
def embed_sparse_query(query: str) -> tuple[list[int], list[float]]:
"""Generates sparse embedding for a single query using fastembed.
Returns a tuple of (indices, values)."""
if not query:
return ([], [])
model = get_sparse_model()
# fastembed returns an iterator of SparseEmbedding objects
# For a single query, we just take the first one
results = list(model.embed([query]))
if not results:
return ([], [])
# SparseEmbedding object has .indices and .values
result = results[0]
return (result.indices.tolist(), result.values.tolist())
if __name__ == "__main__":
# Test block
print("Generating embeddings...")
vector = embed_query("What is Pemetrexed?")
print(f"Generated vector of dimension {len(vector)}.")