| from typing import List
|
| from langchain_huggingface import HuggingFaceEmbeddings
|
| from fastembed import SparseTextEmbedding
|
| from src.config import EMBEDDING_MODEL
|
|
|
| """
|
| embedding.py
|
|
|
| What it does:
|
| Generates vector embeddings using the HuggingFaceEmbeddings wrapper.
|
| """
|
|
|
| _model = None
|
| _sparse_model = None
|
|
|
| def get_model() -> HuggingFaceEmbeddings:
|
| """Returns the initialized HuggingFaceEmbeddings model."""
|
| global _model
|
| if _model is None:
|
| print(f"Loading embedding model: {EMBEDDING_MODEL}")
|
| _model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
| return _model
|
|
|
| def get_sparse_model() -> SparseTextEmbedding:
|
| """Returns the initialized fastembed SparseTextEmbedding model."""
|
| global _sparse_model
|
| if _sparse_model is None:
|
| print("Loading sparse embedding model: prithivida/Splade_PP_en_v1")
|
| _sparse_model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
|
| return _sparse_model
|
|
|
| def generate_embeddings(chunks: List[str]) -> List[List[float]]:
|
| """Generates embeddings for a list of text chunks."""
|
| if not chunks:
|
| return []
|
| model = get_model()
|
| return model.embed_documents(chunks)
|
|
|
| def embed_query(query: str) -> List[float]:
|
| """Generates an embedding for a single search query."""
|
| if not query:
|
| return []
|
| model = get_model()
|
| return model.embed_query(query)
|
|
|
| def embed_sparse_query(query: str) -> tuple[list[int], list[float]]:
|
| """Generates sparse embedding for a single query using fastembed.
|
| Returns a tuple of (indices, values)."""
|
| if not query:
|
| return ([], [])
|
| model = get_sparse_model()
|
|
|
|
|
| results = list(model.embed([query]))
|
| if not results:
|
| return ([], [])
|
|
|
|
|
| result = results[0]
|
| return (result.indices.tolist(), result.values.tolist())
|
|
|
| if __name__ == "__main__":
|
|
|
| print("Generating embeddings...")
|
| vector = embed_query("What is Pemetrexed?")
|
| print(f"Generated vector of dimension {len(vector)}.")
|
|
|