Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,189 +1,144 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
from
|
| 5 |
-
from
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
# The dataset contains movie information and is split to use 80% for training.
|
| 12 |
-
dataset = load_dataset("MongoDB/embedded_movies", split='train[:80%]') # AIatMongoDB/embedded_movies
|
| 13 |
dataset_df = pd.DataFrame(dataset)
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
# It's crucial to ensure that every movie entry has a complete plot description.
|
| 17 |
-
# The 'fullplot' column is necessary for generating embeddings, which are numerical representations of the text.
|
| 18 |
-
# If this data is missing, the embeddings will be incomplete or nonsensical, reducing the accuracy of recommendations.
|
| 19 |
dataset_df = dataset_df.dropna(subset=["fullplot"])
|
| 20 |
-
|
| 21 |
-
# Drop the 'plot_embedding' column as we will generate new embeddings
|
| 22 |
-
# We drop the existing 'plot_embedding' column to create new embeddings using a different, potentially more effective model.
|
| 23 |
-
# This step ensures consistency and accuracy in the embeddings used for similarity searches.
|
| 24 |
dataset_df = dataset_df.drop(columns=["plot_embedding"])
|
| 25 |
|
| 26 |
-
# Load a
|
| 27 |
-
|
| 28 |
-
# These embeddings capture the semantic content of the plots, allowing us to perform efficient and meaningful similarity searches.
|
| 29 |
-
embedding_model = SentenceTransformer("thenlper/gte-large")
|
| 30 |
|
| 31 |
-
# Define a function to generate embeddings for a given text
|
| 32 |
-
# Embeddings are numerical representations of text that capture its semantic meaning.
|
| 33 |
-
# This function checks if the text is not empty and then generates an embedding using the loaded model.
|
| 34 |
def get_embedding(text: str) -> list:
|
| 35 |
-
if not text.strip():
|
| 36 |
-
# If the text is empty, return an empty list as it does not make sense to generate embeddings for empty text.
|
| 37 |
-
# This ensures that we avoid errors and meaningless embeddings.
|
| 38 |
print("Attempted to get embedding for empty text.")
|
| 39 |
return []
|
| 40 |
-
embedding = embedding_model.encode(text)
|
| 41 |
-
return embedding.tolist()
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
# Function to connect to MongoDB
|
| 48 |
-
# MongoDB is a NoSQL database used to store and retrieve large datasets efficiently.
|
| 49 |
-
# This function attempts to create a MongoDB client to connect to the database.
|
| 50 |
def get_mongo_client(mongo_uri):
|
| 51 |
try:
|
| 52 |
-
client = pymongo.MongoClient(mongo_uri)
|
| 53 |
print("Connection to MongoDB successful")
|
| 54 |
return client
|
| 55 |
except pymongo.errors.ConnectionFailure as e:
|
| 56 |
-
# Handle potential connection failures to provide feedback in case of issues with the MongoDB URI or network problems.
|
| 57 |
print(f"Connection failed: {e}")
|
| 58 |
return None
|
| 59 |
|
| 60 |
-
# Get the MongoDB URI from environment variables
|
| 61 |
-
# The MongoDB URI is required to connect to the database. It should be stored securely in environment variables to protect sensitive information.
|
| 62 |
mongo_uri = os.getenv("MONGO_URI")
|
| 63 |
if not mongo_uri:
|
| 64 |
print("MONGO_URI not set in environment variables")
|
| 65 |
|
| 66 |
-
# Connect to MongoDB using the URI
|
| 67 |
-
# The client connects to the 'movies' database and accesses the 'movie_collection_2' collection.
|
| 68 |
-
# This collection will store the movie data with their respective embeddings.
|
| 69 |
mongo_client = get_mongo_client(mongo_uri)
|
| 70 |
-
db = mongo_client["movies"]
|
| 71 |
-
collection = db["movie_collection_2"]
|
| 72 |
-
|
| 73 |
-
# Clear the collection and insert
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
collection.
|
| 77 |
-
documents = dataset_df.to_dict("records") # Convert DataFrame to list of dictionaries
|
| 78 |
-
collection.insert_many(documents) # Insert documents into the collection
|
| 79 |
print("Data ingestion into MongoDB completed")
|
| 80 |
|
| 81 |
-
# Function to perform a vector search on the user query
|
| 82 |
-
# This function generates an embedding for the user's query and uses it to search for similar movie plots in the MongoDB collection.
|
| 83 |
-
# Vector search allows us to find movies with plots that are semantically similar to the query.
|
| 84 |
def vector_search(user_query, collection):
|
| 85 |
-
query_embedding = get_embedding(user_query)
|
| 86 |
if query_embedding is None:
|
| 87 |
-
# Return an error message if the embedding generation fails, ensuring graceful handling of invalid queries.
|
| 88 |
return "Invalid query or embedding generation failed."
|
| 89 |
|
| 90 |
-
# Define the MongoDB aggregation pipeline for vector search
|
| 91 |
-
# This pipeline uses the generated query embedding to search for similar embeddings in the collection.
|
| 92 |
pipeline = [
|
| 93 |
{
|
| 94 |
"$vectorSearch": {
|
| 95 |
-
"index": "vector_index",
|
| 96 |
-
"queryVector": query_embedding,
|
| 97 |
-
"path": "embedding",
|
| 98 |
-
"numCandidates":
|
| 99 |
-
"limit":
|
| 100 |
}
|
| 101 |
},
|
| 102 |
{
|
| 103 |
"$project": {
|
| 104 |
-
"_id": 0,
|
| 105 |
-
"fullplot": 1,
|
| 106 |
-
"title": 1,
|
| 107 |
-
"genres": 1,
|
| 108 |
-
"score": {"$meta": "vectorSearchScore"},
|
| 109 |
}
|
| 110 |
},
|
| 111 |
]
|
| 112 |
-
results = collection.aggregate(pipeline)
|
| 113 |
-
return list(results)
|
| 114 |
|
| 115 |
-
# Function to format search results
|
| 116 |
-
# This function formats the search results into a user-friendly format, making it easier for users to read and understand the recommendations.
|
| 117 |
def get_search_result(query):
|
| 118 |
-
get_knowledge = vector_search(query, collection)
|
| 119 |
search_result = ""
|
| 120 |
-
for result in get_knowledge:
|
| 121 |
-
|
| 122 |
-
search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:200]}...\n\n"
|
| 123 |
return search_result
|
| 124 |
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
# Function to generate a response based on the user's query
|
| 132 |
-
# This function combines the search results with the user's query and generates a response using the GPT-2 model.
|
| 133 |
def generate_response(query):
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Prepare input for the language model
|
| 140 |
-
# Ensures the input does not exceed the model's maximum token capacity.
|
| 141 |
-
max_length = tokenizer.model_max_length # Get the maximum token length to ensure input does not exceed model capacity
|
| 142 |
input_ids = tokenizer(combined_information, return_tensors="pt", max_length=max_length, truncation=True)
|
| 143 |
|
| 144 |
try:
|
| 145 |
response = model.generate(
|
| 146 |
**input_ids,
|
| 147 |
-
max_new_tokens=
|
| 148 |
-
num_return_sequences=1,
|
| 149 |
-
no_repeat_ngram_size=2,
|
| 150 |
-
top_k=
|
| 151 |
-
top_p=0.
|
| 152 |
-
temperature=0.7,
|
| 153 |
-
do_sample=True
|
| 154 |
)
|
| 155 |
-
return tokenizer.decode(response[0], skip_special_tokens=True)
|
| 156 |
except Exception as e:
|
| 157 |
-
# Handle potential errors during generation and provide a meaningful error message.
|
| 158 |
return f"An error occurred: {str(e)}"
|
| 159 |
|
| 160 |
-
# Function to handle user queries and generate responses
|
| 161 |
-
# This function ties together the query handling and response generation processes.
|
| 162 |
def query_movie_db(user_query):
|
| 163 |
return generate_response(user_query)
|
| 164 |
|
| 165 |
-
# Create the Gradio interface
|
| 166 |
-
# Gradio provides a simple interface to interact with the model, allowing users to enter queries and receive responses.
|
| 167 |
-
import gradio as gr
|
| 168 |
-
|
| 169 |
description_and_article = """
|
| 170 |
Ask this bot to recommend you a movie.
|
| 171 |
Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot.
|
| 172 |
-
|
| 173 |
Note that the bot truncates replies due to token limitations in the free tier of Hugging Face resources.
|
| 174 |
-
This is not a coding issue but a result of operating within the token limitations of the free tier of Hugging Face resources.
|
| 175 |
-
To enhance response quality, better models and more resources could be used, but these come with higher costs, which I want to avoid as this is a hobby project.
|
| 176 |
"""
|
| 177 |
|
| 178 |
iface = gr.Interface(
|
| 179 |
-
fn=query_movie_db,
|
| 180 |
-
inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."),
|
| 181 |
-
outputs="text",
|
| 182 |
-
title="Movie Recommendation Bot",
|
| 183 |
-
description=description_and_article,
|
| 184 |
-
examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]]
|
| 185 |
)
|
| 186 |
|
| 187 |
-
# Launch the interface
|
| 188 |
if __name__ == "__main__":
|
| 189 |
-
iface.launch()
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
import pymongo
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Load a smaller portion of the dataset
|
| 10 |
+
dataset = load_dataset("MongoDB/embedded_movies", split='train[:70%]')
|
|
|
|
|
|
|
| 11 |
dataset_df = pd.DataFrame(dataset)
|
| 12 |
|
| 13 |
+
# Data cleaning and preprocessing
|
|
|
|
|
|
|
|
|
|
| 14 |
dataset_df = dataset_df.dropna(subset=["fullplot"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
dataset_df = dataset_df.drop(columns=["plot_embedding"])
|
| 16 |
|
| 17 |
+
# Load a smaller embedding model
|
| 18 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
def get_embedding(text: str) -> list:
|
| 21 |
+
if not text.strip():
|
|
|
|
|
|
|
| 22 |
print("Attempted to get embedding for empty text.")
|
| 23 |
return []
|
| 24 |
+
embedding = embedding_model.encode(text)
|
| 25 |
+
return embedding.tolist()
|
| 26 |
|
| 27 |
+
# Process embeddings in batches
|
| 28 |
+
batch_size = 100
|
| 29 |
+
for i in range(0, len(dataset_df), batch_size):
|
| 30 |
+
batch = dataset_df.iloc[i:i+batch_size]
|
| 31 |
+
batch["embedding"] = batch["fullplot"].apply(get_embedding)
|
| 32 |
+
dataset_df.iloc[i:i+batch_size] = batch
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
def get_mongo_client(mongo_uri):
|
| 35 |
try:
|
| 36 |
+
client = pymongo.MongoClient(mongo_uri)
|
| 37 |
print("Connection to MongoDB successful")
|
| 38 |
return client
|
| 39 |
except pymongo.errors.ConnectionFailure as e:
|
|
|
|
| 40 |
print(f"Connection failed: {e}")
|
| 41 |
return None
|
| 42 |
|
|
|
|
|
|
|
| 43 |
mongo_uri = os.getenv("MONGO_URI")
|
| 44 |
if not mongo_uri:
|
| 45 |
print("MONGO_URI not set in environment variables")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
mongo_client = get_mongo_client(mongo_uri)
|
| 48 |
+
db = mongo_client["movies"]
|
| 49 |
+
collection = db["movie_collection_2"]
|
| 50 |
+
|
| 51 |
+
# Clear the collection and insert new data in bulk
|
| 52 |
+
collection.delete_many({})
|
| 53 |
+
documents = dataset_df.to_dict("records")
|
| 54 |
+
collection.insert_many(documents)
|
|
|
|
|
|
|
| 55 |
print("Data ingestion into MongoDB completed")
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
def vector_search(user_query, collection):
|
| 58 |
+
query_embedding = get_embedding(user_query)
|
| 59 |
if query_embedding is None:
|
|
|
|
| 60 |
return "Invalid query or embedding generation failed."
|
| 61 |
|
|
|
|
|
|
|
| 62 |
pipeline = [
|
| 63 |
{
|
| 64 |
"$vectorSearch": {
|
| 65 |
+
"index": "vector_index",
|
| 66 |
+
"queryVector": query_embedding,
|
| 67 |
+
"path": "embedding",
|
| 68 |
+
"numCandidates": 100,
|
| 69 |
+
"limit": 3,
|
| 70 |
}
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"$project": {
|
| 74 |
+
"_id": 0,
|
| 75 |
+
"fullplot": 1,
|
| 76 |
+
"title": 1,
|
| 77 |
+
"genres": 1,
|
| 78 |
+
"score": {"$meta": "vectorSearchScore"},
|
| 79 |
}
|
| 80 |
},
|
| 81 |
]
|
| 82 |
+
results = collection.aggregate(pipeline)
|
| 83 |
+
return list(results)
|
| 84 |
|
|
|
|
|
|
|
| 85 |
def get_search_result(query):
|
| 86 |
+
get_knowledge = vector_search(query, collection)
|
| 87 |
search_result = ""
|
| 88 |
+
for result in get_knowledge:
|
| 89 |
+
search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:150]}...\n\n"
|
|
|
|
| 90 |
return search_result
|
| 91 |
|
| 92 |
+
# Lazy loading of the language model
|
| 93 |
+
model = None
|
| 94 |
+
tokenizer = None
|
| 95 |
+
|
| 96 |
+
def load_language_model():
|
| 97 |
+
global model, tokenizer
|
| 98 |
+
if model is None or tokenizer is None:
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 100 |
+
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 101 |
|
|
|
|
|
|
|
| 102 |
def generate_response(query):
|
| 103 |
+
load_language_model()
|
| 104 |
+
source_information = get_search_result(query)
|
| 105 |
+
combined_information = f"Answer the question '{query}' based on these movie details:\n\n{source_information}"
|
| 106 |
+
|
| 107 |
+
max_length = tokenizer.model_max_length
|
|
|
|
|
|
|
|
|
|
| 108 |
input_ids = tokenizer(combined_information, return_tensors="pt", max_length=max_length, truncation=True)
|
| 109 |
|
| 110 |
try:
|
| 111 |
response = model.generate(
|
| 112 |
**input_ids,
|
| 113 |
+
max_new_tokens=100,
|
| 114 |
+
num_return_sequences=1,
|
| 115 |
+
no_repeat_ngram_size=2,
|
| 116 |
+
top_k=40,
|
| 117 |
+
top_p=0.9,
|
| 118 |
+
temperature=0.7,
|
| 119 |
+
do_sample=True
|
| 120 |
)
|
| 121 |
+
return tokenizer.decode(response[0], skip_special_tokens=True)
|
| 122 |
except Exception as e:
|
|
|
|
| 123 |
return f"An error occurred: {str(e)}"
|
| 124 |
|
|
|
|
|
|
|
| 125 |
def query_movie_db(user_query):
|
| 126 |
return generate_response(user_query)
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
description_and_article = """
|
| 129 |
Ask this bot to recommend you a movie.
|
| 130 |
Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot.
|
|
|
|
| 131 |
Note that the bot truncates replies due to token limitations in the free tier of Hugging Face resources.
|
|
|
|
|
|
|
| 132 |
"""
|
| 133 |
|
| 134 |
iface = gr.Interface(
|
| 135 |
+
fn=query_movie_db,
|
| 136 |
+
inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."),
|
| 137 |
+
outputs="text",
|
| 138 |
+
title="Movie Recommendation Bot",
|
| 139 |
+
description=description_and_article,
|
| 140 |
+
examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]]
|
| 141 |
)
|
| 142 |
|
|
|
|
| 143 |
if __name__ == "__main__":
|
| 144 |
+
iface.launch()
|