File size: 2,798 Bytes
da0fb08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d37781
da0fb08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d37781
 
da0fb08
2d37781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Visual Search API - HuggingFace Space
"""

import os
import gradio as gr
import torch
import numpy as np
from PIL import Image

# Pinecone config from HF Secrets
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
PINECONE_HOST = os.environ.get('PINECONE_HOST')

# Model (loaded on first use)
model = None


def load_model():
    """Load Jina CLIP v2 model."""
    global model
    if model is None:
        print("Loading Jina CLIP v2...")
        from transformers import AutoModel
        model = AutoModel.from_pretrained(
            "jinaai/jina-clip-v2",
            trust_remote_code=True
        )
        model.eval()
        print("Model loaded!")
    return model


def get_embedding(image: Image.Image) -> list:
    """Generate 512-dim embedding for an image."""
    m = load_model()

    with torch.no_grad():
        emb = m.encode_image(image)
        if hasattr(emb, 'cpu'):
            emb = emb.cpu().numpy()
        emb = emb.flatten()
        emb = emb / np.linalg.norm(emb)
        if len(emb) > 512:
            emb = emb[:512]
        return emb.tolist()


def query_pinecone(embedding: list, top_k: int = 12) -> list:
    """Query Pinecone for similar products."""
    if not PINECONE_API_KEY or not PINECONE_HOST:
        return []

    import requests

    resp = requests.post(
        f"https://{PINECONE_HOST}/query",
        headers={
            "Api-Key": PINECONE_API_KEY,
            "Content-Type": "application/json"
        },
        json={
            "vector": embedding,
            "topK": top_k,
            "includeMetadata": True
        },
        timeout=15
    )

    if resp.status_code != 200:
        return []

    matches = resp.json().get('matches', [])
    return [
        {
            'handle': m.get('metadata', {}).get('handle', m.get('id')),
            'title': m.get('metadata', {}).get('title', ''),
            'score': m.get('score', 0),
        }
        for m in matches
    ]


def search(image):
    """Main search function."""
    if image is None:
        return "No image provided"

    try:
        embedding = get_embedding(image)
        products = query_pinecone(embedding)

        if not products:
            return "No similar products found"

        result = "\n".join([
            f"{i+1}. {p['title']} ({p['handle']}) - score: {p['score']:.3f}"
            for i, p in enumerate(products)
        ])
        return result
    except Exception as e:
        return f"Error: {str(e)}"


# Simple Gradio interface
demo = gr.Interface(
    fn=search,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Textbox(label="Similar Products", lines=15),
    title="Visual Product Search",
    description="Upload an image to find similar products."
)

if __name__ == "__main__":
    demo.launch()