import pickle import json import torch import pandas as pd import gradio as gr import spaces from mario_gpt import MarioLM # Load trained Random Forest model with open("model.pkl", "rb") as f: rf_model = pickle.load(f) # Load YejinGPT model from Hugging Face BASE = "shyamsn97/Mario-GPT2-700-context-length" YEJIN_MODEL_PATH = "YejinJ/LLM" mario_lm = MarioLM( lm_path=YEJIN_MODEL_PATH, tokenizer_path=BASE ) # This function is called by Unity and the Gradio UI @spaces.GPU def gradio_generate(jump_count, damage_count, death_count, playtime, coin_count, item_count): try: X = pd.DataFrame([{ "jump_count": int(jump_count), "damage_count": int(damage_count), "death_count": int(death_count), "playtime": float(playtime), "coin_count": int(coin_count), "item_count": int(item_count) }]) prediction = rf_model.predict(X)[0] pipes = prediction[0] enemies = prediction[1] blocks = prediction[2] elevation = prediction[3] prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mario_lm.to(device) generated_level = mario_lm.sample( prompts=[prompt], num_steps=1400, temperature=1.0, use_tqdm=False ) level_data = generated_level.level if isinstance(level_data, list): level_string = "\n".join(level_data) else: level_string = str(level_data) return prompt, level_string except Exception as e: return "ERROR", str(e) demo = gr.Interface( fn=gradio_generate, inputs=[ gr.Number(label="jump_count", value=73), gr.Number(label="damage_count", value=5), gr.Number(label="death_count", value=2), gr.Number(label="playtime", value=91), gr.Number(label="coin_count", value=13), gr.Number(label="item_count", value=3), ], outputs=[ gr.Textbox(label="Generated Prompt"), gr.Textbox(label="Generated Level String", lines=15), ], title="YejinGPT Level Generator", api_name="generate" ) demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860)