Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,17 +2,15 @@ import gradio as gr
|
|
| 2 |
import pandas as pd
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
-
import pymongo
|
| 6 |
-
import os
|
| 7 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
|
| 10 |
# Load a smaller portion of the dataset
|
| 11 |
-
dataset = load_dataset("MongoDB/embedded_movies", split='train[:
|
| 12 |
dataset_df = pd.DataFrame(dataset)
|
| 13 |
|
| 14 |
# Data cleaning and preprocessing
|
| 15 |
-
dataset_df = dataset_df.dropna(subset=["fullplot"])
|
| 16 |
dataset_df = dataset_df.drop(columns=["plot_embedding"])
|
| 17 |
|
| 18 |
# Load a smaller embedding model
|
|
@@ -25,52 +23,31 @@ def get_embedding(text: str) -> list:
|
|
| 25 |
embedding = embedding_model.encode(text)
|
| 26 |
return embedding.tolist()
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
| 30 |
for i in range(0, len(dataset_df), batch_size):
|
| 31 |
-
batch = dataset_df.iloc[i:i+batch_size]
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def get_mongo_client(mongo_uri):
|
| 36 |
-
try:
|
| 37 |
-
client = pymongo.MongoClient(mongo_uri)
|
| 38 |
-
print("Connection to MongoDB successful")
|
| 39 |
-
return client
|
| 40 |
-
except pymongo.errors.ConnectionFailure as e:
|
| 41 |
-
print(f"Connection failed: {e}")
|
| 42 |
-
return None
|
| 43 |
-
|
| 44 |
-
mongo_uri = os.getenv("MONGO_URI")
|
| 45 |
-
if not mongo_uri:
|
| 46 |
-
print("MONGO_URI not set in environment variables")
|
| 47 |
-
|
| 48 |
-
mongo_client = get_mongo_client(mongo_uri)
|
| 49 |
-
db = mongo_client["movies"]
|
| 50 |
-
collection = db["movie_collection_2"]
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
documents = dataset_df.to_dict("records")
|
| 55 |
-
collection.insert_many(documents)
|
| 56 |
-
print("Data ingestion into MongoDB completed")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
all_embeddings = np.array(dataset_df["embedding"].tolist())
|
| 60 |
-
all_titles = dataset_df["title"].tolist()
|
| 61 |
|
| 62 |
def vector_search(user_query):
|
| 63 |
query_embedding = get_embedding(user_query)
|
| 64 |
if not query_embedding:
|
| 65 |
return "Invalid query or embedding generation failed."
|
| 66 |
|
| 67 |
-
similarities = cosine_similarity([query_embedding],
|
| 68 |
top_indices = similarities.argsort()[-3:][::-1]
|
| 69 |
|
| 70 |
results = []
|
| 71 |
for idx in top_indices:
|
| 72 |
results.append({
|
| 73 |
-
"title":
|
| 74 |
"fullplot": dataset_df.iloc[idx]["fullplot"],
|
| 75 |
"genres": dataset_df.iloc[idx]["genres"],
|
| 76 |
"score": similarities[idx]
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
|
| 8 |
# Load a smaller portion of the dataset
|
| 9 |
+
dataset = load_dataset("MongoDB/embedded_movies", split='train[:5%]')
|
| 10 |
dataset_df = pd.DataFrame(dataset)
|
| 11 |
|
| 12 |
# Data cleaning and preprocessing
|
| 13 |
+
dataset_df = dataset_df.dropna(subset=["fullplot"]).reset_index(drop=True)
|
| 14 |
dataset_df = dataset_df.drop(columns=["plot_embedding"])
|
| 15 |
|
| 16 |
# Load a smaller embedding model
|
|
|
|
| 23 |
embedding = embedding_model.encode(text)
|
| 24 |
return embedding.tolist()
|
| 25 |
|
| 26 |
+
# Generate embeddings for all plots
|
| 27 |
+
all_embeddings = []
|
| 28 |
+
batch_size = 32
|
| 29 |
for i in range(0, len(dataset_df), batch_size):
|
| 30 |
+
batch = dataset_df['fullplot'].iloc[i:i+batch_size].tolist()
|
| 31 |
+
batch_embeddings = embedding_model.encode(batch)
|
| 32 |
+
all_embeddings.extend(batch_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# Add embeddings to the DataFrame
|
| 35 |
+
dataset_df['embedding'] = all_embeddings
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
print("Embeddings generated and added to DataFrame")
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def vector_search(user_query):
|
| 40 |
query_embedding = get_embedding(user_query)
|
| 41 |
if not query_embedding:
|
| 42 |
return "Invalid query or embedding generation failed."
|
| 43 |
|
| 44 |
+
similarities = cosine_similarity([query_embedding], list(dataset_df['embedding']))[0]
|
| 45 |
top_indices = similarities.argsort()[-3:][::-1]
|
| 46 |
|
| 47 |
results = []
|
| 48 |
for idx in top_indices:
|
| 49 |
results.append({
|
| 50 |
+
"title": dataset_df.iloc[idx]["title"],
|
| 51 |
"fullplot": dataset_df.iloc[idx]["fullplot"],
|
| 52 |
"genres": dataset_df.iloc[idx]["genres"],
|
| 53 |
"score": similarities[idx]
|