YejinGPT-space / app.py
YejinJ's picture
Update app.py
d01c9ab verified
Raw
History Blame Contribute Delete
2.34 kB
import pickle
import json
import torch
import pandas as pd
import gradio as gr
import spaces
from mario_gpt import MarioLM
# Load trained Random Forest model
with open("model.pkl", "rb") as f:
rf_model = pickle.load(f)
# Load YejinGPT model from Hugging Face
BASE = "shyamsn97/Mario-GPT2-700-context-length"
YEJIN_MODEL_PATH = "YejinJ/LLM"
mario_lm = MarioLM(
lm_path=YEJIN_MODEL_PATH,
tokenizer_path=BASE
)
# This function is called by Unity and the Gradio UI
@spaces.GPU
def gradio_generate(jump_count, damage_count, death_count, playtime, coin_count, item_count):
try:
X = pd.DataFrame([{
"jump_count": int(jump_count),
"damage_count": int(damage_count),
"death_count": int(death_count),
"playtime": float(playtime),
"coin_count": int(coin_count),
"item_count": int(item_count)
}])
prediction = rf_model.predict(X)[0]
pipes = prediction[0]
enemies = prediction[1]
blocks = prediction[2]
elevation = prediction[3]
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mario_lm.to(device)
generated_level = mario_lm.sample(
prompts=[prompt],
num_steps=1400,
temperature=1.0,
use_tqdm=False
)
level_data = generated_level.level
if isinstance(level_data, list):
level_string = "\n".join(level_data)
else:
level_string = str(level_data)
return prompt, level_string
except Exception as e:
return "ERROR", str(e)
demo = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Number(label="jump_count", value=73),
gr.Number(label="damage_count", value=5),
gr.Number(label="death_count", value=2),
gr.Number(label="playtime", value=91),
gr.Number(label="coin_count", value=13),
gr.Number(label="item_count", value=3),
],
outputs=[
gr.Textbox(label="Generated Prompt"),
gr.Textbox(label="Generated Level String", lines=15),
],
title="YejinGPT Level Generator",
api_name="generate"
)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)