Spaces:
Runtime error
Runtime error
Jose Benitez commited on
Commit 路
b16249c
1
Parent(s): a1e077b
add credits function to training
Browse files- gradio_app.py +25 -4
- services/image_generation.py +1 -0
gradio_app.py
CHANGED
|
@@ -67,6 +67,12 @@ def compress_and_train(request: gr.Request, files, model_name, trigger_word, tra
|
|
| 67 |
return "No hay im谩genes. Sube algunas im谩genes para poder entrenar."
|
| 68 |
|
| 69 |
user = request.session.get('user')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
if not user:
|
| 71 |
raise gr.Error("User not authenticated. Please log in.")
|
| 72 |
|
|
@@ -98,7 +104,14 @@ def compress_and_train(request: gr.Request, files, model_name, trigger_word, tra
|
|
| 98 |
autocaption=True,
|
| 99 |
learning_rate=learning_rate)
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
| 104 |
user = request.session.get('user')
|
|
@@ -278,11 +291,19 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
|
|
| 278 |
batch_size = gr.Number(label='batch_size', value=1)
|
| 279 |
learning_rate = gr.Number(label='learning_rate', value=0.0004)
|
| 280 |
training_status = gr.Textbox(label="Training Status")
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
train_button.click(
|
| 283 |
-
compress_and_train,
|
|
|
|
| 284 |
inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
|
| 285 |
-
outputs=training_status
|
| 286 |
)
|
| 287 |
|
| 288 |
|
|
|
|
| 67 |
return "No hay im谩genes. Sube algunas im谩genes para poder entrenar."
|
| 68 |
|
| 69 |
user = request.session.get('user')
|
| 70 |
+
|
| 71 |
+
_, training_credits = get_user_credits(user['id'])
|
| 72 |
+
|
| 73 |
+
if training_credits <= 0:
|
| 74 |
+
raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
|
| 75 |
+
|
| 76 |
if not user:
|
| 77 |
raise gr.Error("User not authenticated. Please log in.")
|
| 78 |
|
|
|
|
| 104 |
autocaption=True,
|
| 105 |
learning_rate=learning_rate)
|
| 106 |
|
| 107 |
+
new_training_credits = training_credits - 1
|
| 108 |
+
update_user_credits(user['id'], user['generation_credits'], new_training_credits)
|
| 109 |
+
|
| 110 |
+
# Update session data
|
| 111 |
+
user['training_credits'] = new_training_credits
|
| 112 |
+
request.session['user'] = user
|
| 113 |
+
|
| 114 |
+
return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estar谩 listo para que lo pruebes en 'Generaci贸n'."), new_training_credits
|
| 115 |
|
| 116 |
def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
| 117 |
user = request.session.get('user')
|
|
|
|
| 291 |
batch_size = gr.Number(label='batch_size', value=1)
|
| 292 |
learning_rate = gr.Number(label='learning_rate', value=0.0004)
|
| 293 |
training_status = gr.Textbox(label="Training Status")
|
| 294 |
+
|
| 295 |
+
def fake_train(train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
|
| 296 |
+
print(f'fake training for test')
|
| 297 |
+
new_training_credits = 0
|
| 298 |
+
if new_training_credits <= 0:
|
| 299 |
+
raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
|
| 300 |
+
return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estar谩 listo para que lo pruebes en 'Generaci贸n'."), new_training_credits
|
| 301 |
+
|
| 302 |
train_button.click(
|
| 303 |
+
#compress_and_train,
|
| 304 |
+
fake_train,
|
| 305 |
inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
|
| 306 |
+
outputs=[training_status,train_credits_display]
|
| 307 |
)
|
| 308 |
|
| 309 |
|
services/image_generation.py
CHANGED
|
@@ -19,6 +19,7 @@ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_sca
|
|
| 19 |
}
|
| 20 |
)
|
| 21 |
else:
|
|
|
|
| 22 |
img_url = replicate.run(
|
| 23 |
model_name,
|
| 24 |
input={
|
|
|
|
| 19 |
}
|
| 20 |
)
|
| 21 |
else:
|
| 22 |
+
model_name = model_name.lower().replace(' ', '_')
|
| 23 |
img_url = replicate.run(
|
| 24 |
model_name,
|
| 25 |
input={
|