kanad13's picture
Update app.py
35e476e verified
Raw
History Blame
4.68 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import pymongo
import os
# Load a smaller portion of the dataset
dataset = load_dataset("MongoDB/embedded_movies", split='train[:70%]')
dataset_df = pd.DataFrame(dataset)
# Data cleaning and preprocessing
dataset_df = dataset_df.dropna(subset=["fullplot"])
dataset_df = dataset_df.drop(columns=["plot_embedding"])
# Load a smaller embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
def get_embedding(text: str) -> list:
if not text.strip():
print("Attempted to get embedding for empty text.")
return []
embedding = embedding_model.encode(text)
return embedding.tolist()
# Process embeddings in batches
batch_size = 100
for i in range(0, len(dataset_df), batch_size):
batch = dataset_df.iloc[i:i+batch_size]
batch["embedding"] = batch["fullplot"].apply(get_embedding)
dataset_df.iloc[i:i+batch_size] = batch
def get_mongo_client(mongo_uri):
try:
client = pymongo.MongoClient(mongo_uri)
print("Connection to MongoDB successful")
return client
except pymongo.errors.ConnectionFailure as e:
print(f"Connection failed: {e}")
return None
mongo_uri = os.getenv("MONGO_URI")
if not mongo_uri:
print("MONGO_URI not set in environment variables")
mongo_client = get_mongo_client(mongo_uri)
db = mongo_client["movies"]
collection = db["movie_collection_2"]
# Clear the collection and insert new data in bulk
collection.delete_many({})
documents = dataset_df.to_dict("records")
collection.insert_many(documents)
print("Data ingestion into MongoDB completed")
def vector_search(user_query, collection):
query_embedding = get_embedding(user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
pipeline = [
{
"$vectorSearch": {
"index": "vector_index",
"queryVector": query_embedding,
"path": "embedding",
"numCandidates": 100,
"limit": 3,
}
},
{
"$project": {
"_id": 0,
"fullplot": 1,
"title": 1,
"genres": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]
results = collection.aggregate(pipeline)
return list(results)
def get_search_result(query):
get_knowledge = vector_search(query, collection)
search_result = ""
for result in get_knowledge:
search_result += f"Title: {result.get('title', 'N/A')}\nGenres: {', '.join(result.get('genres', ['N/A']))}\nPlot: {result.get('fullplot', 'N/A')[:150]}...\n\n"
return search_result
# Lazy loading of the language model
model = None
tokenizer = None
def load_language_model():
global model, tokenizer
if model is None or tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
def generate_response(query):
load_language_model()
source_information = get_search_result(query)
combined_information = f"Answer the question '{query}' based on these movie details:\n\n{source_information}"
max_length = tokenizer.model_max_length
input_ids = tokenizer(combined_information, return_tensors="pt", max_length=max_length, truncation=True)
try:
response = model.generate(
**input_ids,
max_new_tokens=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
top_k=40,
top_p=0.9,
temperature=0.7,
do_sample=True
)
return tokenizer.decode(response[0], skip_special_tokens=True)
except Exception as e:
return f"An error occurred: {str(e)}"
def query_movie_db(user_query):
return generate_response(user_query)
description_and_article = """
Ask this bot to recommend you a movie.
Checkout [my github repo](https://github.com/kanad13/Movie-Recommendation-Bot) to look at the code that powers this bot.
Note that the bot truncates replies due to token limitations in the free tier of Hugging Face resources.
"""
iface = gr.Interface(
fn=query_movie_db,
inputs=gr.Textbox(lines=2, placeholder="Enter your movie query here..."),
outputs="text",
title="Movie Recommendation Bot",
description=description_and_article,
examples=[["Suggest me a scary movie?"], ["What action movie can I watch?"]]
)
if __name__ == "__main__":
iface.launch()