kanad13 commited on
Commit
88d0739
·
verified ·
1 Parent(s): b811610

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -37
app.py CHANGED
@@ -2,17 +2,15 @@ import gradio as gr
2
  import pandas as pd
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer
5
- import pymongo
6
- import os
7
- from sklearn.metrics.pairwise import cosine_similarity
8
  import numpy as np
 
9
 
10
  # Load a smaller portion of the dataset
11
- dataset = load_dataset("MongoDB/embedded_movies", split='train[:10%]')
12
  dataset_df = pd.DataFrame(dataset)
13
 
14
  # Data cleaning and preprocessing
15
- dataset_df = dataset_df.dropna(subset=["fullplot"])
16
  dataset_df = dataset_df.drop(columns=["plot_embedding"])
17
 
18
  # Load a smaller embedding model
@@ -25,52 +23,31 @@ def get_embedding(text: str) -> list:
25
  embedding = embedding_model.encode(text)
26
  return embedding.tolist()
27
 
28
- # Process embeddings in batches
29
- batch_size = 100
 
30
  for i in range(0, len(dataset_df), batch_size):
31
- batch = dataset_df.iloc[i:i+batch_size]
32
- batch["embedding"] = batch["fullplot"].apply(get_embedding)
33
- dataset_df.iloc[i:i+batch_size] = batch
34
-
35
- def get_mongo_client(mongo_uri):
36
- try:
37
- client = pymongo.MongoClient(mongo_uri)
38
- print("Connection to MongoDB successful")
39
- return client
40
- except pymongo.errors.ConnectionFailure as e:
41
- print(f"Connection failed: {e}")
42
- return None
43
-
44
- mongo_uri = os.getenv("MONGO_URI")
45
- if not mongo_uri:
46
- print("MONGO_URI not set in environment variables")
47
-
48
- mongo_client = get_mongo_client(mongo_uri)
49
- db = mongo_client["movies"]
50
- collection = db["movie_collection_2"]
51
 
52
- # Clear the collection and insert new data in bulk
53
- collection.delete_many({})
54
- documents = dataset_df.to_dict("records")
55
- collection.insert_many(documents)
56
- print("Data ingestion into MongoDB completed")
57
 
58
- # Load all embeddings into memory for faster similarity search
59
- all_embeddings = np.array(dataset_df["embedding"].tolist())
60
- all_titles = dataset_df["title"].tolist()
61
 
62
  def vector_search(user_query):
63
  query_embedding = get_embedding(user_query)
64
  if not query_embedding:
65
  return "Invalid query or embedding generation failed."
66
 
67
- similarities = cosine_similarity([query_embedding], all_embeddings)[0]
68
  top_indices = similarities.argsort()[-3:][::-1]
69
 
70
  results = []
71
  for idx in top_indices:
72
  results.append({
73
- "title": all_titles[idx],
74
  "fullplot": dataset_df.iloc[idx]["fullplot"],
75
  "genres": dataset_df.iloc[idx]["genres"],
76
  "score": similarities[idx]
 
2
  import pandas as pd
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer
 
 
 
5
  import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
 
8
  # Load a smaller portion of the dataset
9
+ dataset = load_dataset("MongoDB/embedded_movies", split='train[:5%]')
10
  dataset_df = pd.DataFrame(dataset)
11
 
12
  # Data cleaning and preprocessing
13
+ dataset_df = dataset_df.dropna(subset=["fullplot"]).reset_index(drop=True)
14
  dataset_df = dataset_df.drop(columns=["plot_embedding"])
15
 
16
  # Load a smaller embedding model
 
23
  embedding = embedding_model.encode(text)
24
  return embedding.tolist()
25
 
26
+ # Generate embeddings for all plots
27
+ all_embeddings = []
28
+ batch_size = 32
29
  for i in range(0, len(dataset_df), batch_size):
30
+ batch = dataset_df['fullplot'].iloc[i:i+batch_size].tolist()
31
+ batch_embeddings = embedding_model.encode(batch)
32
+ all_embeddings.extend(batch_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Add embeddings to the DataFrame
35
+ dataset_df['embedding'] = all_embeddings
 
 
 
36
 
37
+ print("Embeddings generated and added to DataFrame")
 
 
38
 
39
  def vector_search(user_query):
40
  query_embedding = get_embedding(user_query)
41
  if not query_embedding:
42
  return "Invalid query or embedding generation failed."
43
 
44
+ similarities = cosine_similarity([query_embedding], list(dataset_df['embedding']))[0]
45
  top_indices = similarities.argsort()[-3:][::-1]
46
 
47
  results = []
48
  for idx in top_indices:
49
  results.append({
50
+ "title": dataset_df.iloc[idx]["title"],
51
  "fullplot": dataset_df.iloc[idx]["fullplot"],
52
  "genres": dataset_df.iloc[idx]["genres"],
53
  "score": similarities[idx]