import gradio as gr from test_lora import DanbooruTagTester import sys import io import spaces # For ZeroGPU, we combine loading and generation into a single, stateless function call. @spaces.GPU(duration=300) # 5-minute timeout for loading and generation def run_generation(use_4bit, prompt, max_new_tokens, temperature, top_k, top_p, do_sample, progress=gr.Progress(track_tqdm=True)): """ Loads the model and generates tags in a single call to be compatible with ZeroGPU. """ # Hardcoded model paths model_path = "kayfahaarukku/vtuber-tag-gen" base_model = "google/gemma-3-1b-it" # Redirect stdout to capture loading logs old_stdout = sys.stdout sys.stdout = captured_output = io.StringIO() status_message = "" generated_tags = "" final_output = "" try: # Step 1: Load the model progress(0, desc="Initializing and loading model...") tester = DanbooruTagTester( model_path=model_path, base_model_id=base_model, use_4bit=use_4bit, non_interactive=True ) status_message = "Model loaded successfully! Generating tags..." progress(0.7, desc=status_message) # Step 2: Generate tags generated_tags = tester.generate_tags( input_prompt=prompt, max_new_tokens=int(max_new_tokens), temperature=temperature, top_k=int(top_k), top_p=top_p, do_sample=do_sample ) status_message += "\nGeneration complete." # Format the final combined output string. # The raw generated_tags will be displayed as-is. # For the final output, we process the raw tags to create a clean, combined list. # 1. Remove "completion:" and replace newlines with commas processed_tags = generated_tags.replace("completion:", "").replace("\n", ", ") # 2. Combine prompt and processed tags full_string = f"{prompt} {processed_tags}" # 3. Split by comma, strip whitespace from each tag, and filter out any empty tags that result tags_list = [tag.strip() for tag in full_string.split(',') if tag.strip()] # 4. Escape brackets in each tag for the final output escaped_tags_list = [tag.replace('(', r'\(').replace(')', r'\)') for tag in tags_list] # 5. Join back into a clean, single string for the final output final_output = ", ".join(escaped_tags_list) except Exception as e: status_message = f"An error occurred: {e}" generated_tags = "Generation failed." final_output = "Generation failed." finally: # Restore stdout and get the captured output sys.stdout = old_stdout log_output = captured_output.getvalue() final_log = log_output + "\n" + status_message return final_log, generated_tags, final_output # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# VTuber Tag Gen UI (ZeroGPU)") gr.Markdown("Enter a prompt and adjust settings, then click Generate. The model is loaded from scratch for each generation to comply with Hugging Face's stateless GPU environment.") gr.Markdown("**Model Info:** Based on Gemma 3 1B, fine-tuned using PEFT (LoRA) for 8 epochs on a 450k VTuber tag dataset.") gr.Markdown("This is a pilot project. If you like to see this with a bigger dataset, consider donating to my [Ko-fi](https://ko-fi.com/kayfahaarukku)") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Model Configuration") use_4bit_checkbox = gr.Checkbox(label="Use 4-bit Quantization", value=True) gr.Markdown("### Generation Settings") with gr.Accordion("Adjust Settings", open=False): max_new_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens") temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K") top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") do_sample_checkbox = gr.Checkbox(label="Use Sampling", value=True) with gr.Column(scale=2): gr.Markdown("### Input & Output") prompt_input = gr.Textbox(label="Input Prompt", lines=3, placeholder="e.g., 1girl, inugami korone, hololive,") generate_button = gr.Button("Generate", variant="primary") raw_completion_output = gr.Textbox(label="Generated Tags (Raw)", lines=5, interactive=False) final_output_textbox = gr.Textbox(label="Final Combined Tags", lines=5, interactive=False) status_output = gr.Textbox(label="Logs", lines=8, interactive=False) # --- Event Handler --- inputs = [ use_4bit_checkbox, prompt_input, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox ] outputs = [status_output, raw_completion_output, final_output_textbox] generate_button.click( fn=run_generation, inputs=inputs, outputs=outputs ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")