Spaces:
Running
Running
| import gradio as gr | |
| from test_lora import DanbooruTagTester | |
| import sys | |
| import io | |
| import spaces | |
| # For ZeroGPU, we combine loading and generation into a single, stateless function call. | |
| # 5-minute timeout for loading and generation | |
| def run_generation(use_4bit, prompt, max_new_tokens, temperature, top_k, top_p, do_sample, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| Loads the model and generates tags in a single call to be compatible with ZeroGPU. | |
| """ | |
| # Hardcoded model paths | |
| model_path = "kayfahaarukku/vtuber-tag-gen" | |
| base_model = "google/gemma-3-1b-it" | |
| # Redirect stdout to capture loading logs | |
| old_stdout = sys.stdout | |
| sys.stdout = captured_output = io.StringIO() | |
| status_message = "" | |
| generated_tags = "" | |
| final_output = "" | |
| try: | |
| # Step 1: Load the model | |
| progress(0, desc="Initializing and loading model...") | |
| tester = DanbooruTagTester( | |
| model_path=model_path, | |
| base_model_id=base_model, | |
| use_4bit=use_4bit, | |
| non_interactive=True | |
| ) | |
| status_message = "Model loaded successfully! Generating tags..." | |
| progress(0.7, desc=status_message) | |
| # Step 2: Generate tags | |
| generated_tags = tester.generate_tags( | |
| input_prompt=prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=temperature, | |
| top_k=int(top_k), | |
| top_p=top_p, | |
| do_sample=do_sample | |
| ) | |
| status_message += "\nGeneration complete." | |
| # Format the final combined output string. | |
| # The raw generated_tags will be displayed as-is. | |
| # For the final output, we process the raw tags to create a clean, combined list. | |
| # 1. Remove "completion:" and replace newlines with commas | |
| processed_tags = generated_tags.replace("completion:", "").replace("\n", ", ") | |
| # 2. Combine prompt and processed tags | |
| full_string = f"{prompt} {processed_tags}" | |
| # 3. Split by comma, strip whitespace from each tag, and filter out any empty tags that result | |
| tags_list = [tag.strip() for tag in full_string.split(',') if tag.strip()] | |
| # 4. Escape brackets in each tag for the final output | |
| escaped_tags_list = [tag.replace('(', r'\(').replace(')', r'\)') for tag in tags_list] | |
| # 5. Join back into a clean, single string for the final output | |
| final_output = ", ".join(escaped_tags_list) | |
| except Exception as e: | |
| status_message = f"An error occurred: {e}" | |
| generated_tags = "Generation failed." | |
| final_output = "Generation failed." | |
| finally: | |
| # Restore stdout and get the captured output | |
| sys.stdout = old_stdout | |
| log_output = captured_output.getvalue() | |
| final_log = log_output + "\n" + status_message | |
| return final_log, generated_tags, final_output | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# VTuber Tag Gen UI (ZeroGPU)") | |
| gr.Markdown("Enter a prompt and adjust settings, then click Generate. The model is loaded from scratch for each generation to comply with Hugging Face's stateless GPU environment.") | |
| gr.Markdown("**Model Info:** Based on Gemma 3 1B, fine-tuned using PEFT (LoRA) for 8 epochs on a 450k VTuber tag dataset.") | |
| gr.Markdown("This is a pilot project. If you like to see this with a bigger dataset, consider donating to my [Ko-fi](https://ko-fi.com/kayfahaarukku)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Configuration") | |
| use_4bit_checkbox = gr.Checkbox(label="Use 4-bit Quantization", value=True) | |
| gr.Markdown("### Generation Settings") | |
| with gr.Accordion("Adjust Settings", open=False): | |
| max_new_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens") | |
| temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") | |
| top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K") | |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") | |
| do_sample_checkbox = gr.Checkbox(label="Use Sampling", value=True) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Input & Output") | |
| prompt_input = gr.Textbox(label="Input Prompt", lines=3, placeholder="e.g., 1girl, inugami korone, hololive,") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| raw_completion_output = gr.Textbox(label="Generated Tags (Raw)", lines=5, interactive=False) | |
| final_output_textbox = gr.Textbox(label="Final Combined Tags", lines=5, interactive=False) | |
| status_output = gr.Textbox(label="Logs", lines=8, interactive=False) | |
| # --- Event Handler --- | |
| inputs = [ | |
| use_4bit_checkbox, | |
| prompt_input, max_new_tokens_slider, temperature_slider, | |
| top_k_slider, top_p_slider, do_sample_checkbox | |
| ] | |
| outputs = [status_output, raw_completion_output, final_output_textbox] | |
| generate_button.click( | |
| fn=run_generation, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") |