wheattoast11 commited on
Commit
1b9ca25
·
verified ·
1 Parent(s): 1482784

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +681 -56
app.py CHANGED
@@ -1,81 +1,706 @@
1
  """
2
- Unsloth Training Hub - LLM Fine-tuning & RL Platform
3
  Supports: SFT, GRPO, GSPO, DPO, Dr-GRPO, DAPO, BNPO
 
4
  """
 
5
  import gradio as gr
6
  import os
7
  import json
8
  from datetime import datetime
9
 
10
- MODELS = [
11
- "unsloth/Qwen2.5-7B-Instruct",
12
- "unsloth/Qwen2.5-3B-Instruct",
13
- "unsloth/Qwen2.5-14B-Instruct",
14
- "unsloth/Meta-Llama-3.1-8B-Instruct",
15
- "unsloth/DeepSeek-R1-Distill-Qwen-7B",
16
- "unsloth/gemma-3-4b-it",
17
- "unsloth/Phi-4-mini-instruct",
18
- ]
19
-
20
- RL_METHODS = ["grpo", "gspo", "dr_grpo", "dapo", "bnpo", "dpo"]
21
- PRESETS = ["test_run", "small_run", "medium_run", "large_run", "grokking_run"]
22
-
23
- def get_status():
24
- s = {"cuda": False, "gpu": "None", "unsloth": False, "vllm": False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
  import torch
27
- s["cuda"] = torch.cuda.is_available()
28
- if s["cuda"]: s["gpu"] = torch.cuda.get_device_name(0)
29
- except: pass
 
 
 
 
30
  try:
31
  import unsloth
32
- s["unsloth"] = True
33
- except: pass
 
 
34
  try:
35
  import vllm
36
- s["vllm"] = True
37
- except: pass
38
- return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def create_ui():
 
 
41
  with gr.Blocks(title="Unsloth Training Hub", theme=gr.themes.Soft()) as demo:
42
  gr.Markdown("# Unsloth Training Hub")
43
  gr.Markdown("Comprehensive LLM Fine-tuning & RL Platform")
44
-
45
- status = get_status()
46
- gr.Markdown(f"**CUDA**: {status['cuda']} | **GPU**: {status['gpu']} | **Unsloth**: {status['unsloth']} | **vLLM**: {status['vllm']}")
47
-
 
48
  with gr.Tabs():
49
- with gr.Tab("Model & Mode"):
50
- model = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Model")
51
- mode = gr.Radio(choices=["sft", "rl"], value="sft", label="Training Mode")
52
- rl_method = gr.Dropdown(choices=RL_METHODS, value="grpo", label="RL Method", visible=False)
53
- mode.change(lambda m: gr.Dropdown(visible=m=="rl"), mode, rl_method)
54
-
55
- with gr.Tab("Training Config"):
56
- preset = gr.Radio(choices=PRESETS, value="small_run", label="Preset")
57
- lora_rank = gr.Dropdown(choices=[8,16,32,64,128], value=32, label="LoRA Rank")
58
- lr = gr.Number(value=5e-6, label="Learning Rate")
59
-
60
- with gr.Tab("Output"):
61
- hub_id = gr.Textbox(value="wheattoast11/trained-model", label="Hub Model ID")
62
- push = gr.Checkbox(value=True, label="Push to Hub")
63
-
64
- output = gr.Markdown("Configure and click Generate")
65
- btn = gr.Button("Generate Training Script", variant="primary")
66
-
67
- def generate(model, mode, rl_method, preset, lora_rank, lr, hub_id, push):
68
- return f"**Model**: {model}
69
- **Mode**: {mode}
70
- **Preset**: {preset}
71
- **LoRA**: {lora_rank}
72
- **LR**: {lr}"
73
-
74
- btn.click(generate, [model, mode, rl_method, preset, lora_rank, lr, hub_id, push], output)
75
- gr.Markdown("---
76
- **Intuition Labs** | L40S ~$1.80/hr - PAUSE when not training!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return demo
78
 
 
79
  if __name__ == "__main__":
80
  demo = create_ui()
81
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
+ Unsloth Training Hub - Comprehensive LLM Fine-tuning & RL Platform
3
  Supports: SFT, GRPO, GSPO, DPO, Dr-GRPO, DAPO, BNPO
4
+ Models: All Unsloth-optimized models (LLM, VLM, Embedding, Multimodal)
5
  """
6
+
7
  import gradio as gr
8
  import os
9
  import json
10
  from datetime import datetime
11
 
12
+ # ============================================================================
13
+ # MODEL CATALOG - All Unsloth Pre-optimized Models
14
+ # ============================================================================
15
+
16
+ UNSLOTH_MODELS = {
17
+ "text_llm": {
18
+ "Qwen3": [
19
+ "unsloth/Qwen3-0.6B",
20
+ "unsloth/Qwen3-1.7B",
21
+ "unsloth/Qwen3-4B",
22
+ "unsloth/Qwen3-8B",
23
+ "unsloth/Qwen3-14B",
24
+ "unsloth/Qwen3-32B",
25
+ "unsloth/Qwen3-30B-A3B",
26
+ "unsloth/Qwen3-235B-A22B",
27
+ ],
28
+ "Qwen2.5": [
29
+ "unsloth/Qwen2.5-0.5B-Instruct",
30
+ "unsloth/Qwen2.5-1.5B-Instruct",
31
+ "unsloth/Qwen2.5-3B-Instruct",
32
+ "unsloth/Qwen2.5-7B-Instruct",
33
+ "unsloth/Qwen2.5-14B-Instruct",
34
+ "unsloth/Qwen2.5-32B-Instruct",
35
+ "unsloth/Qwen2.5-72B-Instruct",
36
+ ],
37
+ "Qwen2.5-Coder": [
38
+ "unsloth/Qwen2.5-Coder-0.5B-Instruct",
39
+ "unsloth/Qwen2.5-Coder-1.5B-Instruct",
40
+ "unsloth/Qwen2.5-Coder-3B-Instruct",
41
+ "unsloth/Qwen2.5-Coder-7B-Instruct",
42
+ "unsloth/Qwen2.5-Coder-14B-Instruct",
43
+ "unsloth/Qwen2.5-Coder-32B-Instruct",
44
+ ],
45
+ "Llama-4": [
46
+ "unsloth/Llama-4-Scout-17B-16E-Instruct",
47
+ "unsloth/Llama-4-Maverick-17B-128E-Instruct",
48
+ ],
49
+ "Llama-3.3": [
50
+ "unsloth/Llama-3.3-70B-Instruct",
51
+ ],
52
+ "Llama-3.1": [
53
+ "unsloth/Meta-Llama-3.1-8B-Instruct",
54
+ "unsloth/Meta-Llama-3.1-70B-Instruct",
55
+ "unsloth/Meta-Llama-3.1-405B-Instruct",
56
+ ],
57
+ "Llama-3.2": [
58
+ "unsloth/Llama-3.2-1B-Instruct",
59
+ "unsloth/Llama-3.2-3B-Instruct",
60
+ ],
61
+ "DeepSeek-R1": [
62
+ "unsloth/DeepSeek-R1-Distill-Qwen-1.5B",
63
+ "unsloth/DeepSeek-R1-Distill-Qwen-7B",
64
+ "unsloth/DeepSeek-R1-Distill-Qwen-14B",
65
+ "unsloth/DeepSeek-R1-Distill-Qwen-32B",
66
+ "unsloth/DeepSeek-R1-Distill-Llama-8B",
67
+ "unsloth/DeepSeek-R1-Distill-Llama-70B",
68
+ ],
69
+ "Gemma-3": [
70
+ "unsloth/gemma-3-1b-it",
71
+ "unsloth/gemma-3-4b-it",
72
+ "unsloth/gemma-3-12b-it",
73
+ "unsloth/gemma-3-27b-it",
74
+ ],
75
+ "Mistral": [
76
+ "unsloth/Mistral-Small-3.2-24B-Instruct-2506",
77
+ "unsloth/mistral-7b-instruct-v0.3",
78
+ "unsloth/Mistral-Nemo-Instruct-2407",
79
+ ],
80
+ "Phi-4": [
81
+ "unsloth/Phi-4-mini-instruct",
82
+ "unsloth/Phi-4-Instruct",
83
+ ],
84
+ "GLM": [
85
+ "unsloth/GLM-4.7-Flash",
86
+ "unsloth/GLM-4.5-Air",
87
+ ],
88
+ "Nemotron": [
89
+ "unsloth/Nemotron-3-Nano-30B-A3B",
90
+ ],
91
+ },
92
+ "vision_vlm": {
93
+ "Qwen3-VL": [
94
+ "unsloth/Qwen3-VL-2B-Instruct",
95
+ "unsloth/Qwen3-VL-4B-Instruct",
96
+ "unsloth/Qwen3-VL-8B-Instruct",
97
+ "unsloth/Qwen3-VL-32B-Instruct",
98
+ ],
99
+ "Qwen2.5-VL": [
100
+ "unsloth/Qwen2.5-VL-3B-Instruct",
101
+ "unsloth/Qwen2.5-VL-7B-Instruct",
102
+ "unsloth/Qwen2.5-VL-32B-Instruct",
103
+ "unsloth/Qwen2.5-VL-72B-Instruct",
104
+ ],
105
+ "Llama-Vision": [
106
+ "unsloth/Llama-3.2-11B-Vision-Instruct",
107
+ "unsloth/Llama-3.2-90B-Vision-Instruct",
108
+ ],
109
+ "Pixtral": [
110
+ "unsloth/Pixtral-12B-2409",
111
+ ],
112
+ "Gemma-3-Vision": [
113
+ "unsloth/gemma-3-4b-it", # Vision capable
114
+ "unsloth/gemma-3-12b-it",
115
+ "unsloth/gemma-3-27b-it",
116
+ ],
117
+ },
118
+ "embedding": [
119
+ "unsloth/Qwen3-Embedding-0.6B",
120
+ "unsloth/Qwen3-Embedding-4B",
121
+ "unsloth/Qwen3-Embedding-8B",
122
+ "unsloth/embeddinggemma-300m",
123
+ "unsloth/bge-m3",
124
+ "unsloth/ModernBERT-base",
125
+ "unsloth/ModernBERT-large",
126
+ ],
127
+ "multimodal_omni": [
128
+ "unsloth/Qwen2.5-Omni-3B",
129
+ "unsloth/Qwen2.5-Omni-7B",
130
+ ],
131
+ }
132
+
133
+ # ============================================================================
134
+ # RL METHODS CONFIGURATION
135
+ # ============================================================================
136
+
137
+ RL_METHODS = {
138
+ "grpo": {
139
+ "name": "GRPO (Group Relative Policy Optimization)",
140
+ "description": "Token-level importance sampling. Default DeepSeek method.",
141
+ "config": {"loss_type": "grpo", "importance_sampling_level": "token"},
142
+ },
143
+ "gspo": {
144
+ "name": "GSPO (Group Sequence Policy Optimization)",
145
+ "description": "Sequence-level importance sampling. Qwen team variant.",
146
+ "config": {"loss_type": "grpo", "importance_sampling_level": "sequence"},
147
+ },
148
+ "dr_grpo": {
149
+ "name": "Dr-GRPO (Difficulty-Resilient GRPO)",
150
+ "description": "Avoids difficulty bias in training.",
151
+ "config": {"loss_type": "dr_grpo", "scale_rewards": False},
152
+ },
153
+ "dapo": {
154
+ "name": "DAPO (Direct Advantage Policy Optimization)",
155
+ "description": "Token-level normalization for long chain-of-thought.",
156
+ "config": {"loss_type": "dapo", "mask_truncated_completions": True},
157
+ },
158
+ "bnpo": {
159
+ "name": "BNPO (Bounded Natural Policy Optimization)",
160
+ "description": "Asymmetric clipping for better exploration.",
161
+ "config": {"loss_type": "bnpo", "epsilon": 0.2, "epsilon_high": 0.28, "delta": 1.5},
162
+ },
163
+ "dpo": {
164
+ "name": "DPO (Direct Preference Optimization)",
165
+ "description": "Preference-based training without reward model.",
166
+ "config": {"method": "dpo"},
167
+ },
168
+ }
169
+
170
+ # ============================================================================
171
+ # SAMPLE SIZE PRESETS
172
+ # ============================================================================
173
+
174
+ SAMPLE_PRESETS = {
175
+ "test_run": {"samples": 100, "max_steps": 50, "description": "Quick test (5-10 min)"},
176
+ "small_run": {"samples": 1000, "max_steps": 250, "description": "Small training (30-60 min)"},
177
+ "medium_run": {"samples": 5000, "max_steps": 1000, "description": "Medium training (2-4 hours)"},
178
+ "large_run": {"samples": 25000, "max_steps": 5000, "description": "Large training (8-12 hours)"},
179
+ "grokking_run": {"samples": 100000, "max_steps": 50000, "description": "Grokking/extended (24+ hours)"},
180
+ }
181
+
182
+ # ============================================================================
183
+ # REWARD FUNCTION PACKS
184
+ # ============================================================================
185
+
186
+ REWARD_PACKS = {
187
+ "reasoning_xml": {
188
+ "name": "XML Reasoning Format",
189
+ "description": "Rewards <reasoning>...</reasoning><answer>...</answer> format",
190
+ "functions": ["xmlcount_reward", "soft_format_reward", "strict_format_reward"],
191
+ },
192
+ "code_quality": {
193
+ "name": "Code Quality",
194
+ "description": "Rewards syntactically correct, well-formatted code",
195
+ "functions": ["syntax_reward", "docstring_reward", "type_hint_reward"],
196
+ },
197
+ "math_accuracy": {
198
+ "name": "Math Accuracy",
199
+ "description": "Rewards correct numerical answers with step verification",
200
+ "functions": ["correctness_reward", "int_reward", "step_count_reward"],
201
+ },
202
+ "instruction_following": {
203
+ "name": "Instruction Following",
204
+ "description": "Rewards adherence to specific output formats",
205
+ "functions": ["format_reward", "length_reward", "keyword_reward"],
206
+ },
207
+ "safety_alignment": {
208
+ "name": "Safety & Alignment",
209
+ "description": "Rewards helpful, harmless, honest outputs",
210
+ "functions": ["helpfulness_reward", "safety_reward", "factuality_reward"],
211
+ },
212
+ }
213
+
214
+
215
+ def get_environment_status():
216
+ """Check environment and return status."""
217
+ import subprocess
218
+
219
+ status = {
220
+ "cuda_available": False,
221
+ "gpu_name": "Not detected",
222
+ "gpu_memory": "Unknown",
223
+ "unsloth_installed": False,
224
+ "vllm_installed": False,
225
+ "trl_installed": False,
226
+ "anthropic_key": bool(os.environ.get("ANTHROPIC_API_KEY")),
227
+ "hf_token": bool(os.environ.get("HF_TOKEN")),
228
+ }
229
+
230
  try:
231
  import torch
232
+ status["cuda_available"] = torch.cuda.is_available()
233
+ if status["cuda_available"]:
234
+ status["gpu_name"] = torch.cuda.get_device_name(0)
235
+ status["gpu_memory"] = f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
236
+ except:
237
+ pass
238
+
239
  try:
240
  import unsloth
241
+ status["unsloth_installed"] = True
242
+ except:
243
+ pass
244
+
245
  try:
246
  import vllm
247
+ status["vllm_installed"] = True
248
+ except:
249
+ pass
250
+
251
+ try:
252
+ import trl
253
+ status["trl_installed"] = True
254
+ except:
255
+ pass
256
+
257
+ return status
258
+
259
+
260
+ def format_status_markdown(status):
261
+ """Format status as markdown."""
262
+ lines = [
263
+ "## Environment Status\n",
264
+ f"- **CUDA**: {'Available' if status['cuda_available'] else 'Not available'}",
265
+ f"- **GPU**: {status['gpu_name']} ({status['gpu_memory']})",
266
+ f"- **Unsloth**: {'Installed' if status['unsloth_installed'] else 'Not installed'}",
267
+ f"- **vLLM**: {'Installed' if status['vllm_installed'] else 'Not installed'}",
268
+ f"- **TRL**: {'Installed' if status['trl_installed'] else 'Not installed'}",
269
+ f"- **ANTHROPIC_API_KEY**: {'Set' if status['anthropic_key'] else 'Not set'}",
270
+ f"- **HF_TOKEN**: {'Set' if status['hf_token'] else 'Not set'}",
271
+ ]
272
+ return "\n".join(lines)
273
+
274
+
275
+ def get_model_list(model_type):
276
+ """Get flat list of models for given type."""
277
+ if model_type == "text_llm":
278
+ models = []
279
+ for family, family_models in UNSLOTH_MODELS["text_llm"].items():
280
+ models.extend(family_models)
281
+ return models
282
+ elif model_type == "vision_vlm":
283
+ models = []
284
+ for family, family_models in UNSLOTH_MODELS["vision_vlm"].items():
285
+ models.extend(family_models)
286
+ return models
287
+ elif model_type == "embedding":
288
+ return UNSLOTH_MODELS["embedding"]
289
+ elif model_type == "multimodal":
290
+ return UNSLOTH_MODELS["multimodal_omni"]
291
+ return []
292
+
293
+
294
+ def start_training(
295
+ model_name,
296
+ model_type,
297
+ training_mode,
298
+ rl_method,
299
+ sample_preset,
300
+ reward_pack,
301
+ custom_reward_code,
302
+ lora_rank,
303
+ learning_rate,
304
+ num_generations,
305
+ temperature,
306
+ max_seq_length,
307
+ hub_model_id,
308
+ push_to_hub,
309
+ ):
310
+ """Start training with selected configuration."""
311
+
312
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
313
+ run_dir = f"/app/runs/{training_mode}_{timestamp}"
314
+
315
+ config = {
316
+ "model_name": model_name,
317
+ "model_type": model_type,
318
+ "training_mode": training_mode,
319
+ "rl_method": rl_method if training_mode == "rl" else None,
320
+ "sample_preset": sample_preset,
321
+ "reward_pack": reward_pack if training_mode == "rl" else None,
322
+ "lora_rank": lora_rank,
323
+ "learning_rate": learning_rate,
324
+ "num_generations": num_generations if training_mode == "rl" else None,
325
+ "temperature": temperature,
326
+ "max_seq_length": max_seq_length,
327
+ "hub_model_id": hub_model_id,
328
+ "push_to_hub": push_to_hub,
329
+ "run_dir": run_dir,
330
+ "timestamp": timestamp,
331
+ }
332
+
333
+ # Generate training script
334
+ if training_mode == "sft":
335
+ script = generate_sft_script(config)
336
+ else:
337
+ script = generate_rl_script(config)
338
+
339
+ return f"""
340
+ ## Training Configuration Saved
341
+
342
+ **Run Directory**: `{run_dir}`
343
+ **Timestamp**: {timestamp}
344
+
345
+ ### Configuration:
346
+ ```json
347
+ {json.dumps(config, indent=2)}
348
+ ```
349
+
350
+ ### Generated Training Script:
351
+ ```python
352
+ {script[:2000]}...
353
+ ```
354
+
355
+ **Status**: Ready to execute. Click 'Execute Training' to start.
356
+ """
357
+
358
+
359
+ def generate_sft_script(config):
360
+ """Generate SFT training script."""
361
+ preset = SAMPLE_PRESETS[config["sample_preset"]]
362
+
363
+ return f'''
364
+ # Unsloth SFT Training Script
365
+ # Generated: {config["timestamp"]}
366
+
367
+ from unsloth import FastLanguageModel
368
+ from trl import SFTTrainer, SFTConfig
369
+ from datasets import load_dataset
370
+
371
+ max_seq_length = {config["max_seq_length"]}
372
+ lora_rank = {config["lora_rank"]}
373
+
374
+ # Load model with Unsloth optimizations
375
+ model, tokenizer = FastLanguageModel.from_pretrained(
376
+ model_name="{config["model_name"]}",
377
+ max_seq_length=max_seq_length,
378
+ load_in_4bit=True,
379
+ dtype=None,
380
+ )
381
+
382
+ # Add LoRA adapters
383
+ model = FastLanguageModel.get_peft_model(
384
+ model,
385
+ r=lora_rank,
386
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
387
+ "gate_proj", "up_proj", "down_proj"],
388
+ lora_alpha=lora_rank,
389
+ lora_dropout=0,
390
+ bias="none",
391
+ use_gradient_checkpointing="unsloth",
392
+ random_state=3407,
393
+ )
394
+
395
+ # Load and prepare dataset
396
+ dataset = load_dataset("your_dataset", split="train")
397
+
398
+ # Configure trainer
399
+ trainer = SFTTrainer(
400
+ model=model,
401
+ tokenizer=tokenizer,
402
+ train_dataset=dataset,
403
+ args=SFTConfig(
404
+ per_device_train_batch_size=2,
405
+ gradient_accumulation_steps=4,
406
+ warmup_steps=10,
407
+ max_steps={preset["max_steps"]},
408
+ learning_rate={config["learning_rate"]},
409
+ optim="adamw_8bit",
410
+ packing=True,
411
+ max_length=max_seq_length,
412
+ output_dir="{config["run_dir"]}",
413
+ report_to="none",
414
+ ),
415
+ )
416
+
417
+ # Train
418
+ trainer.train()
419
+
420
+ # Save
421
+ model.save_pretrained_merged("{config["run_dir"]}/merged", tokenizer, save_method="merged_16bit")
422
+ '''
423
+
424
+
425
+ def generate_rl_script(config):
426
+ """Generate RL training script."""
427
+ preset = SAMPLE_PRESETS[config["sample_preset"]]
428
+ rl_config = RL_METHODS[config["rl_method"]]["config"]
429
+
430
+ return f'''
431
+ # Unsloth RL Training Script ({config["rl_method"].upper()})
432
+ # Generated: {config["timestamp"]}
433
+
434
+ from unsloth import FastLanguageModel, PatchFastRL
435
+ PatchFastRL("GRPO", FastLanguageModel) # CRITICAL: Must be BEFORE trl import
436
+
437
+ from trl import GRPOConfig, GRPOTrainer
438
+ from datasets import load_dataset
439
+
440
+ max_seq_length = {config["max_seq_length"]}
441
+ lora_rank = {config["lora_rank"]}
442
+
443
+ # Load model with Unsloth optimizations + vLLM fast inference
444
+ model, tokenizer = FastLanguageModel.from_pretrained(
445
+ model_name="{config["model_name"]}",
446
+ max_seq_length=max_seq_length,
447
+ load_in_4bit=True,
448
+ fast_inference=True,
449
+ max_lora_rank=lora_rank,
450
+ gpu_memory_utilization=0.6,
451
+ )
452
+
453
+ # Add LoRA adapters
454
+ model = FastLanguageModel.get_peft_model(
455
+ model,
456
+ r=lora_rank,
457
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
458
+ "gate_proj", "up_proj", "down_proj"],
459
+ lora_alpha=lora_rank,
460
+ use_gradient_checkpointing="unsloth",
461
+ random_state=3407,
462
+ )
463
+
464
+ # Reward functions from pack: {config["reward_pack"]}
465
+ def xmlcount_reward_func(completions, **kwargs):
466
+ def count_xml(text):
467
+ count = 0.0
468
+ if text.count("<reasoning>\\n") == 1: count += 0.125
469
+ if text.count("\\n</reasoning>\\n") == 1: count += 0.125
470
+ if text.count("\\n<answer>\\n") == 1: count += 0.125
471
+ if text.count("\\n</answer>") == 1: count += 0.125
472
+ return count
473
+ return [count_xml(c[0]["content"]) for c in completions]
474
+
475
+ def correctness_reward_func(prompts, completions, answer, **kwargs):
476
+ def extract_answer(text):
477
+ if "<answer>" in text and "</answer>" in text:
478
+ return text.split("<answer>")[-1].split("</answer>")[0].strip()
479
+ return text.strip()
480
+ responses = [c[0]["content"] for c in completions]
481
+ extracted = [extract_answer(r) for r in responses]
482
+ return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]
483
+
484
+ # Load dataset
485
+ dataset = load_dataset("openai/gsm8k", "main", split="train")
486
+
487
+ # Configure GRPO trainer
488
+ training_args = GRPOConfig(
489
+ output_dir="{config["run_dir"]}",
490
+ learning_rate={config["learning_rate"]},
491
+ per_device_train_batch_size=1,
492
+ gradient_accumulation_steps=4,
493
+ num_generations={config["num_generations"]},
494
+ max_prompt_length=256,
495
+ max_completion_length={config["max_seq_length"]} - 256,
496
+ max_steps={preset["max_steps"]},
497
+ temperature={config["temperature"]},
498
+ loss_type="{rl_config.get("loss_type", "grpo")}",
499
+ importance_sampling_level="{rl_config.get("importance_sampling_level", "token")}",
500
+ optim="adamw_8bit",
501
+ warmup_ratio=0.1,
502
+ lr_scheduler_type="cosine",
503
+ max_grad_norm=0.1,
504
+ report_to="none",
505
+ )
506
+
507
+ # Initialize trainer
508
+ trainer = GRPOTrainer(
509
+ model=model,
510
+ processing_class=tokenizer,
511
+ reward_funcs=[xmlcount_reward_func, correctness_reward_func],
512
+ args=training_args,
513
+ train_dataset=dataset,
514
+ )
515
+
516
+ # Train
517
+ trainer.train()
518
+
519
+ # Save
520
+ model.save_pretrained("{config["run_dir"]}/lora")
521
+ tokenizer.save_pretrained("{config["run_dir"]}/lora")
522
+ '''
523
+
524
+
525
+ # ============================================================================
526
+ # GRADIO UI
527
+ # ============================================================================
528
 
529
  def create_ui():
530
+ """Create Gradio interface."""
531
+
532
  with gr.Blocks(title="Unsloth Training Hub", theme=gr.themes.Soft()) as demo:
533
  gr.Markdown("# Unsloth Training Hub")
534
  gr.Markdown("Comprehensive LLM Fine-tuning & RL Platform")
535
+
536
+ # Status
537
+ status = get_environment_status()
538
+ gr.Markdown(format_status_markdown(status))
539
+
540
  with gr.Tabs():
541
+ # Tab 1: Model Selection
542
+ with gr.Tab("1. Model Selection"):
543
+ model_type = gr.Radio(
544
+ choices=["text_llm", "vision_vlm", "embedding", "multimodal"],
545
+ value="text_llm",
546
+ label="Model Type",
547
+ )
548
+
549
+ model_dropdown = gr.Dropdown(
550
+ choices=get_model_list("text_llm"),
551
+ value="unsloth/Qwen2.5-7B-Instruct",
552
+ label="Select Model",
553
+ filterable=True,
554
+ )
555
+
556
+ def update_models(model_type):
557
+ models = get_model_list(model_type)
558
+ return gr.Dropdown(choices=models, value=models[0] if models else None)
559
+
560
+ model_type.change(update_models, model_type, model_dropdown)
561
+
562
+ # Tab 2: Training Mode
563
+ with gr.Tab("2. Training Mode"):
564
+ training_mode = gr.Radio(
565
+ choices=["sft", "rl"],
566
+ value="sft",
567
+ label="Training Mode",
568
+ info="SFT: Supervised Fine-Tuning | RL: Reinforcement Learning"
569
+ )
570
+
571
+ with gr.Group(visible=False) as rl_options:
572
+ rl_method = gr.Dropdown(
573
+ choices=list(RL_METHODS.keys()),
574
+ value="grpo",
575
+ label="RL Method",
576
+ )
577
+
578
+ rl_info = gr.Markdown(RL_METHODS["grpo"]["description"])
579
+
580
+ def update_rl_info(method):
581
+ return RL_METHODS[method]["description"]
582
+
583
+ rl_method.change(update_rl_info, rl_method, rl_info)
584
+
585
+ reward_pack = gr.Dropdown(
586
+ choices=list(REWARD_PACKS.keys()),
587
+ value="reasoning_xml",
588
+ label="Reward Pack",
589
+ )
590
+
591
+ custom_reward = gr.Code(
592
+ label="Custom Reward Function (Optional)",
593
+ language="python",
594
+ value="# def custom_reward(completions, **kwargs):\n# return [1.0 for _ in completions]",
595
+ )
596
+
597
+ num_generations = gr.Slider(
598
+ minimum=2, maximum=16, value=8, step=2,
599
+ label="Generations per Prompt",
600
+ )
601
+
602
+ temperature = gr.Slider(
603
+ minimum=0.1, maximum=2.0, value=1.0, step=0.1,
604
+ label="Generation Temperature",
605
+ )
606
+
607
+ def toggle_rl_options(mode):
608
+ return gr.Group(visible=(mode == "rl"))
609
+
610
+ training_mode.change(toggle_rl_options, training_mode, rl_options)
611
+
612
+ # Tab 3: Training Config
613
+ with gr.Tab("3. Training Config"):
614
+ sample_preset = gr.Radio(
615
+ choices=list(SAMPLE_PRESETS.keys()),
616
+ value="small_run",
617
+ label="Sample Size Preset",
618
+ )
619
+
620
+ preset_info = gr.Markdown(
621
+ f"**{SAMPLE_PRESETS['small_run']['description']}** - "
622
+ f"{SAMPLE_PRESETS['small_run']['samples']} samples, "
623
+ f"{SAMPLE_PRESETS['small_run']['max_steps']} steps"
624
+ )
625
+
626
+ def update_preset_info(preset):
627
+ p = SAMPLE_PRESETS[preset]
628
+ return f"**{p['description']}** - {p['samples']} samples, {p['max_steps']} steps"
629
+
630
+ sample_preset.change(update_preset_info, sample_preset, preset_info)
631
+
632
+ with gr.Row():
633
+ lora_rank = gr.Dropdown(
634
+ choices=[8, 16, 32, 64, 128],
635
+ value=32,
636
+ label="LoRA Rank",
637
+ )
638
+
639
+ learning_rate = gr.Number(
640
+ value=5e-6,
641
+ label="Learning Rate",
642
+ )
643
+
644
+ max_seq_length = gr.Dropdown(
645
+ choices=[512, 1024, 2048, 4096, 8192, 16384, 32768],
646
+ value=2048,
647
+ label="Max Sequence Length",
648
+ )
649
+
650
+ # Tab 4: Output & Hub
651
+ with gr.Tab("4. Output & Hub"):
652
+ hub_model_id = gr.Textbox(
653
+ value="wheattoast11/unsloth-trained-model",
654
+ label="HuggingFace Hub Model ID",
655
+ )
656
+
657
+ push_to_hub = gr.Checkbox(
658
+ value=True,
659
+ label="Push to HuggingFace Hub after training",
660
+ )
661
+
662
+ output_format = gr.CheckboxGroup(
663
+ choices=["merged_16bit", "merged_4bit", "lora", "gguf_q4_k_m", "gguf_q8_0"],
664
+ value=["merged_16bit", "lora"],
665
+ label="Output Formats",
666
+ )
667
+
668
+ # Start Training
669
+ gr.Markdown("---")
670
+
671
+ with gr.Row():
672
+ start_btn = gr.Button("Generate Training Config", variant="primary", scale=2)
673
+ execute_btn = gr.Button("Execute Training", variant="secondary", scale=1)
674
+
675
+ output = gr.Markdown("Configure your training and click 'Generate Training Config'")
676
+
677
+ start_btn.click(
678
+ start_training,
679
+ inputs=[
680
+ model_dropdown,
681
+ model_type,
682
+ training_mode,
683
+ rl_method,
684
+ sample_preset,
685
+ reward_pack,
686
+ custom_reward,
687
+ lora_rank,
688
+ learning_rate,
689
+ num_generations,
690
+ temperature,
691
+ max_seq_length,
692
+ hub_model_id,
693
+ push_to_hub,
694
+ ],
695
+ outputs=output,
696
+ )
697
+
698
+ gr.Markdown("---")
699
+ gr.Markdown("**Intuition Labs** | Unsloth Training Hub | L40S ~$1.80/hr - PAUSE when not training!")
700
+
701
  return demo
702
 
703
+
704
  if __name__ == "__main__":
705
  demo = create_ui()
706
  demo.launch(server_name="0.0.0.0", server_port=7860)