File size: 2,242 Bytes
08fd094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import List
from langchain_huggingface import HuggingFaceEmbeddings
from fastembed import SparseTextEmbedding
from src.config import EMBEDDING_MODEL

"""

embedding.py



What it does:

Generates vector embeddings using the HuggingFaceEmbeddings wrapper.

"""

_model = None
_sparse_model = None

def get_model() -> HuggingFaceEmbeddings:
    """Returns the initialized HuggingFaceEmbeddings model."""
    global _model
    if _model is None:
        print(f"Loading embedding model: {EMBEDDING_MODEL}")
        _model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
    return _model

def get_sparse_model() -> SparseTextEmbedding:
    """Returns the initialized fastembed SparseTextEmbedding model."""
    global _sparse_model
    if _sparse_model is None:
        print("Loading sparse embedding model: prithivida/Splade_PP_en_v1")
        _sparse_model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
    return _sparse_model

def generate_embeddings(chunks: List[str]) -> List[List[float]]:
    """Generates embeddings for a list of text chunks."""
    if not chunks:
        return []
    model = get_model()
    return model.embed_documents(chunks)

def embed_query(query: str) -> List[float]:
    """Generates an embedding for a single search query."""
    if not query:
        return []
    model = get_model()
    return model.embed_query(query)

def embed_sparse_query(query: str) -> tuple[list[int], list[float]]:
    """Generates sparse embedding for a single query using fastembed.

    Returns a tuple of (indices, values)."""
    if not query:
        return ([], [])
    model = get_sparse_model()
    # fastembed returns an iterator of SparseEmbedding objects
    # For a single query, we just take the first one
    results = list(model.embed([query]))
    if not results:
        return ([], [])
    
    # SparseEmbedding object has .indices and .values
    result = results[0]
    return (result.indices.tolist(), result.values.tolist())

if __name__ == "__main__":
    # Test block
    print("Generating embeddings...")
    vector = embed_query("What is Pemetrexed?")
    print(f"Generated vector of dimension {len(vector)}.")