vtuber-tag-gen / gradio_app.py
kayfahaarukku's picture
Update gradio_app.py
680b7a9 verified
Raw
History Blame Contribute Delete
5.42 kB
import gradio as gr
from test_lora import DanbooruTagTester
import sys
import io
import spaces
# For ZeroGPU, we combine loading and generation into a single, stateless function call.
@spaces.GPU(duration=300) # 5-minute timeout for loading and generation
def run_generation(use_4bit, prompt, max_new_tokens, temperature, top_k, top_p, do_sample, progress=gr.Progress(track_tqdm=True)):
"""
Loads the model and generates tags in a single call to be compatible with ZeroGPU.
"""
# Hardcoded model paths
model_path = "kayfahaarukku/vtuber-tag-gen"
base_model = "google/gemma-3-1b-it"
# Redirect stdout to capture loading logs
old_stdout = sys.stdout
sys.stdout = captured_output = io.StringIO()
status_message = ""
generated_tags = ""
final_output = ""
try:
# Step 1: Load the model
progress(0, desc="Initializing and loading model...")
tester = DanbooruTagTester(
model_path=model_path,
base_model_id=base_model,
use_4bit=use_4bit,
non_interactive=True
)
status_message = "Model loaded successfully! Generating tags..."
progress(0.7, desc=status_message)
# Step 2: Generate tags
generated_tags = tester.generate_tags(
input_prompt=prompt,
max_new_tokens=int(max_new_tokens),
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
do_sample=do_sample
)
status_message += "\nGeneration complete."
# Format the final combined output string.
# The raw generated_tags will be displayed as-is.
# For the final output, we process the raw tags to create a clean, combined list.
# 1. Remove "completion:" and replace newlines with commas
processed_tags = generated_tags.replace("completion:", "").replace("\n", ", ")
# 2. Combine prompt and processed tags
full_string = f"{prompt} {processed_tags}"
# 3. Split by comma, strip whitespace from each tag, and filter out any empty tags that result
tags_list = [tag.strip() for tag in full_string.split(',') if tag.strip()]
# 4. Escape brackets in each tag for the final output
escaped_tags_list = [tag.replace('(', r'\(').replace(')', r'\)') for tag in tags_list]
# 5. Join back into a clean, single string for the final output
final_output = ", ".join(escaped_tags_list)
except Exception as e:
status_message = f"An error occurred: {e}"
generated_tags = "Generation failed."
final_output = "Generation failed."
finally:
# Restore stdout and get the captured output
sys.stdout = old_stdout
log_output = captured_output.getvalue()
final_log = log_output + "\n" + status_message
return final_log, generated_tags, final_output
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# VTuber Tag Gen UI (ZeroGPU)")
gr.Markdown("Enter a prompt and adjust settings, then click Generate. The model is loaded from scratch for each generation to comply with Hugging Face's stateless GPU environment.")
gr.Markdown("**Model Info:** Based on Gemma 3 1B, fine-tuned using PEFT (LoRA) for 8 epochs on a 450k VTuber tag dataset.")
gr.Markdown("This is a pilot project. If you like to see this with a bigger dataset, consider donating to my [Ko-fi](https://ko-fi.com/kayfahaarukku)")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Model Configuration")
use_4bit_checkbox = gr.Checkbox(label="Use 4-bit Quantization", value=True)
gr.Markdown("### Generation Settings")
with gr.Accordion("Adjust Settings", open=False):
max_new_tokens_slider = gr.Slider(minimum=10, maximum=500, value=150, step=10, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
do_sample_checkbox = gr.Checkbox(label="Use Sampling", value=True)
with gr.Column(scale=2):
gr.Markdown("### Input & Output")
prompt_input = gr.Textbox(label="Input Prompt", lines=3, placeholder="e.g., 1girl, inugami korone, hololive,")
generate_button = gr.Button("Generate", variant="primary")
raw_completion_output = gr.Textbox(label="Generated Tags (Raw)", lines=5, interactive=False)
final_output_textbox = gr.Textbox(label="Final Combined Tags", lines=5, interactive=False)
status_output = gr.Textbox(label="Logs", lines=8, interactive=False)
# --- Event Handler ---
inputs = [
use_4bit_checkbox,
prompt_input, max_new_tokens_slider, temperature_slider,
top_k_slider, top_p_slider, do_sample_checkbox
]
outputs = [status_output, raw_completion_output, final_output_textbox]
generate_button.click(
fn=run_generation,
inputs=inputs,
outputs=outputs
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")