Spaces:
Runtime error
Runtime error
| import pickle | |
| import json | |
| import torch | |
| import pandas as pd | |
| import gradio as gr | |
| import spaces | |
| from mario_gpt import MarioLM | |
| # Load trained Random Forest model | |
| with open("model.pkl", "rb") as f: | |
| rf_model = pickle.load(f) | |
| # Load YejinGPT model from Hugging Face | |
| BASE = "shyamsn97/Mario-GPT2-700-context-length" | |
| YEJIN_MODEL_PATH = "YejinJ/LLM" | |
| mario_lm = MarioLM( | |
| lm_path=YEJIN_MODEL_PATH, | |
| tokenizer_path=BASE | |
| ) | |
| # This function is called by Unity and the Gradio UI | |
| def gradio_generate(jump_count, damage_count, death_count, playtime, coin_count, item_count): | |
| try: | |
| X = pd.DataFrame([{ | |
| "jump_count": int(jump_count), | |
| "damage_count": int(damage_count), | |
| "death_count": int(death_count), | |
| "playtime": float(playtime), | |
| "coin_count": int(coin_count), | |
| "item_count": int(item_count) | |
| }]) | |
| prediction = rf_model.predict(X)[0] | |
| pipes = prediction[0] | |
| enemies = prediction[1] | |
| blocks = prediction[2] | |
| elevation = prediction[3] | |
| prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mario_lm.to(device) | |
| generated_level = mario_lm.sample( | |
| prompts=[prompt], | |
| num_steps=1400, | |
| temperature=1.0, | |
| use_tqdm=False | |
| ) | |
| level_data = generated_level.level | |
| if isinstance(level_data, list): | |
| level_string = "\n".join(level_data) | |
| else: | |
| level_string = str(level_data) | |
| return prompt, level_string | |
| except Exception as e: | |
| return "ERROR", str(e) | |
| demo = gr.Interface( | |
| fn=gradio_generate, | |
| inputs=[ | |
| gr.Number(label="jump_count", value=73), | |
| gr.Number(label="damage_count", value=5), | |
| gr.Number(label="death_count", value=2), | |
| gr.Number(label="playtime", value=91), | |
| gr.Number(label="coin_count", value=13), | |
| gr.Number(label="item_count", value=3), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Generated Prompt"), | |
| gr.Textbox(label="Generated Level String", lines=15), | |
| ], | |
| title="YejinGPT Level Generator", | |
| api_name="generate" | |
| ) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |