kanad13 commited on
Commit
35e476e
·
verified ·
1 Parent(s): a965c00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -120
app.py CHANGED
@@ -1,189 +1,144 @@
1
- # Import necessary libraries
2
- import gradio as gr # Gradio is used to create a web interface to interact with the model.
3
- import pandas as pd # Pandas is used for data manipulation and analysis.
4
- from datasets import load_dataset # This function loads datasets from Hugging Face.
5
- from sentence_transformers import SentenceTransformer # Used for generating text embeddings.
6
- from transformers import AutoTokenizer, AutoModelForCausalLM # Transformers are used for natural language processing tasks.
7
- import pymongo # Pymongo is used to interact with MongoDB.
8
- import os # Used for accessing environment variables.
9
-
10
- # Load Dataset from Hugging Face and convert it to a pandas DataFrame
11
- # The dataset contains movie information and is split to use 80% for training.
12
- dataset = load_dataset("MongoDB/embedded_movies", split='train[:80%]') # AIatMongoDB/embedded_movies
13
  dataset_df = pd.DataFrame(dataset)
14
 
15
- # Remove rows where the 'fullplot' column is empty
16
- # It's crucial to ensure that every movie entry has a complete plot description.
17
- # The 'fullplot' column is necessary for generating embeddings, which are numerical representations of the text.
18
- # If this data is missing, the embeddings will be incomplete or nonsensical, reducing the accuracy of recommendations.
19
  dataset_df = dataset_df.dropna(subset=["fullplot"])
20
-
21
- # Drop the 'plot_embedding' column as we will generate new embeddings
22
- # We drop the existing 'plot_embedding' column to create new embeddings using a different, potentially more effective model.
23
- # This step ensures consistency and accuracy in the embeddings used for similarity searches.
24
  dataset_df = dataset_df.drop(columns=["plot_embedding"])
25
 
26
- # Load a pre-trained embedding model
27
- # We use a pre-trained model from Sentence Transformers to convert movie plots into numerical embeddings.
28
- # These embeddings capture the semantic content of the plots, allowing us to perform efficient and meaningful similarity searches.
29
- embedding_model = SentenceTransformer("thenlper/gte-large")
30
 
31
- # Define a function to generate embeddings for a given text
32
- # Embeddings are numerical representations of text that capture its semantic meaning.
33
- # This function checks if the text is not empty and then generates an embedding using the loaded model.
34
  def get_embedding(text: str) -> list:
35
- if not text.strip(): # Check if the text is not empty
36
- # If the text is empty, return an empty list as it does not make sense to generate embeddings for empty text.
37
- # This ensures that we avoid errors and meaningless embeddings.
38
  print("Attempted to get embedding for empty text.")
39
  return []
40
- embedding = embedding_model.encode(text) # Generate the embedding
41
- return embedding.tolist() # Convert embedding to a list for storage and manipulation
42
 
43
- # Apply the embedding function to the 'fullplot' column in the DataFrame
44
- # This step generates embeddings for each movie plot in the dataset, storing them in the DataFrame for later use in similarity searches.
45
- dataset_df["embedding"] = dataset_df["fullplot"].apply(get_embedding)
 
 
 
46
 
47
- # Function to connect to MongoDB
48
- # MongoDB is a NoSQL database used to store and retrieve large datasets efficiently.
49
- # This function attempts to create a MongoDB client to connect to the database.
50
  def get_mongo_client(mongo_uri):
51
  try:
52
- client = pymongo.MongoClient(mongo_uri) # Create a MongoDB client
53
  print("Connection to MongoDB successful")
54
  return client
55
  except pymongo.errors.ConnectionFailure as e:
56
- # Handle potential connection failures to provide feedback in case of issues with the MongoDB URI or network problems.
57
  print(f"Connection failed: {e}")
58
  return None
59
 
60
- # Get the MongoDB URI from environment variables
61
- # The MongoDB URI is required to connect to the database. It should be stored securely in environment variables to protect sensitive information.
62
  mongo_uri = os.getenv("MONGO_URI")
63
  if not mongo_uri:
64
  print("MONGO_URI not set in environment variables")
65
 
66
- # Connect to MongoDB using the URI
67
- # The client connects to the 'movies' database and accesses the 'movie_collection_2' collection.
68
- # This collection will store the movie data with their respective embeddings.
69
  mongo_client = get_mongo_client(mongo_uri)
70
- db = mongo_client["movies"] # Access the 'movies' database
71
- collection = db["movie_collection_2"] # Access the 'movie_collection_2' collection
72
-
73
- # Clear the collection and insert the new data
74
- # Clearing the collection to avoid duplication of records and ensure we start with a fresh set of data.
75
- # This step ensures that the collection only contains the most recent data with newly generated embeddings.
76
- collection.delete_many({}) # Delete any existing records in the collection
77
- documents = dataset_df.to_dict("records") # Convert DataFrame to list of dictionaries
78
- collection.insert_many(documents) # Insert documents into the collection
79
  print("Data ingestion into MongoDB completed")
80
 
81
- # Function to perform a vector search on the user query
82
- # This function generates an embedding for the user's query and uses it to search for similar movie plots in the MongoDB collection.
83
- # Vector search allows us to find movies with plots that are semantically similar to the query.
84
  def vector_search(user_query, collection):
85
- query_embedding = get_embedding(user_query) # Generate embedding for the user query
86
  if query_embedding is None:
87
- # Return an error message if the embedding generation fails, ensuring graceful handling of invalid queries.
88
  return "Invalid query or embedding generation failed."
89
 
90
- # Define the MongoDB aggregation pipeline for vector search
91
- # This pipeline uses the generated query embedding to search for similar embeddings in the collection.
92
  pipeline = [
93
  {
94
  "$vectorSearch": {
95
- "index": "vector_index", # Name of the vector index
96
- "queryVector": query_embedding, # Embedding of the user query
97
- "path": "embedding", # Path to the embedding field in the documents
98
- "numCandidates": 150, # Number of candidate matches to consider for broad retrieval
99
- "limit": 4, # Return top 4 matches to keep results concise and relevant
100
  }
101
  },
102
  {
103
  "$project": {
104
- "_id": 0, # Exclude the '_id' field from the results for cleaner output
105
- "fullplot": 1, # Include the 'fullplot' field in the results for detailed descriptions
106
- "title": 1, # Include the 'title' field in the results to identify movies
107
- "genres": 1, # Include the 'genres' field in the results for additional context
108
- "score": {"$meta": "vectorSearchScore"}, # Include the search score to assess relevance
109
  }
110
  },
111
  ]
112
- results = collection.aggregate(pipeline) # Execute the aggregation pipeline
113
- return list(results) # Return the results as a list
114
 
115
- # Function to format search results
116
- # This function formats the search results into a user-friendly format, making it easier for users to read and understand the recommendations.
117
  def get_search_result(query):
118
- get_knowledge = vector_search(query, collection) # Perform vector search on the query
119
  search_result = ""
120
- for result in get_knowledge: # Iterate through search results
121
- # Format the search results to be user-friendly, including only the first 200 characters of the plot for brevity.
122
- search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:200]}...\n\n"
123
  return search_result
124
 
125
- # Load a pre-trained language model for generating responses
126
- # Using GPT-2 to generate human-like responses based on the search results.
127
- # The tokenizer converts text to a format that the model can understand, and the model generates responses.
128
- tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
129
- model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
 
 
 
 
130
 
131
- # Function to generate a response based on the user's query
132
- # This function combines the search results with the user's query and generates a response using the GPT-2 model.
133
  def generate_response(query):
134
- source_information = get_search_result(query) # Get search results for the query
135
- combined_information = (
136
- f"Answer the question '{query}' based on these movie details:\n\n{source_information}"
137
- )
138
-
139
- # Prepare input for the language model
140
- # Ensures the input does not exceed the model's maximum token capacity.
141
- max_length = tokenizer.model_max_length # Get the maximum token length to ensure input does not exceed model capacity
142
  input_ids = tokenizer(combined_information, return_tensors="pt", max_length=max_length, truncation=True)
143
 
144
  try:
145
  response = model.generate(
146
  **input_ids,
147
- max_new_tokens=150, # Limit the number of tokens to generate to control response length
148
- num_return_sequences=1, # Generate a single response sequence
149
- no_repeat_ngram_size=2, # Avoid repeating n-grams to improve response quality
150
- top_k=50, # Use top-k sampling for diversity in responses
151
- top_p=0.95, # Use nucleus sampling to focus on high-probability words
152
- temperature=0.7, # Control the randomness of predictions to balance between creativity and coherence
153
- do_sample=True # Enable sampling
154
  )
155
- return tokenizer.decode(response[0], skip_special_tokens=True) # Decode and return the response
156
  except Exception as e:
157
- # Handle potential errors during generation and provide a meaningful error message.
158
  return f"An error occurred: {str(e)}"
159
 
160
- # Function to handle user queries and generate responses
161
- # This function ties together the query handling and response generation processes.
162
  def query_movie_db(user_query):
163
  return generate_response(user_query)
164
 
165
- # Create the Gradio interface
166
- # Gradio provides a simple interface to interact with the model, allowing users to enter queries and receive responses.
167
- import gradio as gr
168
-
169
  description_and_article = """
170
  Ask this bot to recommend you a movie.
171
  Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot.
172
-
173
  Note that the bot truncates replies due to token limitations in the free tier of Hugging Face resources.
174
- This is not a coding issue but a result of operating within the token limitations of the free tier of Hugging Face resources.
175
- To enhance response quality, better models and more resources could be used, but these come with higher costs, which I want to avoid as this is a hobby project.
176
  """
177
 
178
  iface = gr.Interface(
179
- fn=query_movie_db, # Function to handle user queries
180
- inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."), # Textbox input for user queries
181
- outputs="text", # Text output for responses
182
- title="Movie Recommendation Bot", # Title of the interface
183
- description=description_and_article, # Combined description and article content
184
- examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]] # Example queries
185
  )
186
 
187
- # Launch the interface
188
  if __name__ == "__main__":
189
- iface.launch() # This launches the Gradio interface.
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import pymongo
7
+ import os
8
+
9
+ # Load a smaller portion of the dataset
10
+ dataset = load_dataset("MongoDB/embedded_movies", split='train[:70%]')
 
 
11
  dataset_df = pd.DataFrame(dataset)
12
 
13
+ # Data cleaning and preprocessing
 
 
 
14
  dataset_df = dataset_df.dropna(subset=["fullplot"])
 
 
 
 
15
  dataset_df = dataset_df.drop(columns=["plot_embedding"])
16
 
17
+ # Load a smaller embedding model
18
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
19
 
 
 
 
20
  def get_embedding(text: str) -> list:
21
+ if not text.strip():
 
 
22
  print("Attempted to get embedding for empty text.")
23
  return []
24
+ embedding = embedding_model.encode(text)
25
+ return embedding.tolist()
26
 
27
+ # Process embeddings in batches
28
+ batch_size = 100
29
+ for i in range(0, len(dataset_df), batch_size):
30
+ batch = dataset_df.iloc[i:i+batch_size]
31
+ batch["embedding"] = batch["fullplot"].apply(get_embedding)
32
+ dataset_df.iloc[i:i+batch_size] = batch
33
 
 
 
 
34
  def get_mongo_client(mongo_uri):
35
  try:
36
+ client = pymongo.MongoClient(mongo_uri)
37
  print("Connection to MongoDB successful")
38
  return client
39
  except pymongo.errors.ConnectionFailure as e:
 
40
  print(f"Connection failed: {e}")
41
  return None
42
 
 
 
43
  mongo_uri = os.getenv("MONGO_URI")
44
  if not mongo_uri:
45
  print("MONGO_URI not set in environment variables")
46
 
 
 
 
47
  mongo_client = get_mongo_client(mongo_uri)
48
+ db = mongo_client["movies"]
49
+ collection = db["movie_collection_2"]
50
+
51
+ # Clear the collection and insert new data in bulk
52
+ collection.delete_many({})
53
+ documents = dataset_df.to_dict("records")
54
+ collection.insert_many(documents)
 
 
55
  print("Data ingestion into MongoDB completed")
56
 
 
 
 
57
  def vector_search(user_query, collection):
58
+ query_embedding = get_embedding(user_query)
59
  if query_embedding is None:
 
60
  return "Invalid query or embedding generation failed."
61
 
 
 
62
  pipeline = [
63
  {
64
  "$vectorSearch": {
65
+ "index": "vector_index",
66
+ "queryVector": query_embedding,
67
+ "path": "embedding",
68
+ "numCandidates": 100,
69
+ "limit": 3,
70
  }
71
  },
72
  {
73
  "$project": {
74
+ "_id": 0,
75
+ "fullplot": 1,
76
+ "title": 1,
77
+ "genres": 1,
78
+ "score": {"$meta": "vectorSearchScore"},
79
  }
80
  },
81
  ]
82
+ results = collection.aggregate(pipeline)
83
+ return list(results)
84
 
 
 
85
  def get_search_result(query):
86
+ get_knowledge = vector_search(query, collection)
87
  search_result = ""
88
+ for result in get_knowledge:
89
+ search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:150]}...\n\n"
 
90
  return search_result
91
 
92
+ # Lazy loading of the language model
93
+ model = None
94
+ tokenizer = None
95
+
96
+ def load_language_model():
97
+ global model, tokenizer
98
+ if model is None or tokenizer is None:
99
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
100
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
101
 
 
 
102
  def generate_response(query):
103
+ load_language_model()
104
+ source_information = get_search_result(query)
105
+ combined_information = f"Answer the question '{query}' based on these movie details:\n\n{source_information}"
106
+
107
+ max_length = tokenizer.model_max_length
 
 
 
108
  input_ids = tokenizer(combined_information, return_tensors="pt", max_length=max_length, truncation=True)
109
 
110
  try:
111
  response = model.generate(
112
  **input_ids,
113
+ max_new_tokens=100,
114
+ num_return_sequences=1,
115
+ no_repeat_ngram_size=2,
116
+ top_k=40,
117
+ top_p=0.9,
118
+ temperature=0.7,
119
+ do_sample=True
120
  )
121
+ return tokenizer.decode(response[0], skip_special_tokens=True)
122
  except Exception as e:
 
123
  return f"An error occurred: {str(e)}"
124
 
 
 
125
  def query_movie_db(user_query):
126
  return generate_response(user_query)
127
 
 
 
 
 
128
  description_and_article = """
129
  Ask this bot to recommend you a movie.
130
  Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot.
 
131
  Note that the bot truncates replies due to token limitations in the free tier of Hugging Face resources.
 
 
132
  """
133
 
134
  iface = gr.Interface(
135
+ fn=query_movie_db,
136
+ inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."),
137
+ outputs="text",
138
+ title="Movie Recommendation Bot",
139
+ description=description_and_article,
140
+ examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]]
141
  )
142
 
 
143
  if __name__ == "__main__":
144
+ iface.launch()