Spaces:
Sleeping
Sleeping
Finetuner Studio GUI demo (planning mode)
Browse files- README.md +26 -6
- app.py +16 -0
- finetuner/__init__.py +3 -0
- finetuner/__main__.py +3 -0
- finetuner/app.py +45 -0
- finetuner/core/__init__.py +0 -0
- finetuner/core/codegen.py +105 -0
- finetuner/core/data.py +69 -0
- finetuner/core/detector.py +292 -0
- finetuner/core/engine.py +56 -0
- finetuner/core/export.py +28 -0
- finetuner/core/jobs.py +135 -0
- finetuner/core/models.py +69 -0
- finetuner/core/recipes.py +40 -0
- finetuner/core/registry.py +232 -0
- finetuner/core/state.py +35 -0
- finetuner/core/training.py +105 -0
- finetuner/ui/__init__.py +0 -0
- finetuner/ui/tab_dataset.py +117 -0
- finetuner/ui/tab_export.py +62 -0
- finetuner/ui/tab_model.py +122 -0
- finetuner/ui/tab_monitor.py +60 -0
- finetuner/ui/tab_playground.py +56 -0
- finetuner/ui/tab_train.py +140 -0
- requirements.txt +3 -0
README.md
CHANGED
|
@@ -1,13 +1,33 @@
|
|
| 1 |
---
|
| 2 |
title: Finetuner Studio
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.17.3
|
| 8 |
-
python_version: '3.13'
|
| 9 |
app_file: app.py
|
| 10 |
-
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Finetuner Studio
|
| 3 |
+
emoji: 🎛️
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.17.3
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
+
python_version: "3.12"
|
| 10 |
+
short_description: Low-code MLX fine-tuning studio (GUI demo)
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 🎛️ Finetuner Studio — GUI Demo
|
| 14 |
+
|
| 15 |
+
**Low-code fine-tuning on Apple Silicon**, powered by
|
| 16 |
+
[mlx-tune](https://github.com/ARahim3/mlx-tune). This Space runs in
|
| 17 |
+
**GUI-only mode** (Spaces have no Apple Silicon): explore the interface, load
|
| 18 |
+
any Hugging Face dataset and watch the **automatic format detection**, and
|
| 19 |
+
generate standalone mlx-tune training scripts with the code generator.
|
| 20 |
+
|
| 21 |
+
For actual training (12 paradigms: SFT, DPO, ORPO, SimPO, KTO, GRPO, CPT,
|
| 22 |
+
VLM, TTS, STT, Embedding, OCR), run it on a Mac:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
git clone https://github.com/aykutcayir34/finetuner
|
| 26 |
+
cd finetuner && pip install -e '.[mlx]' && finetuner
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
🇹🇷 Bu Space, Finetuner Studio arayüzünün canlı demosudur (Spaces'te Apple
|
| 30 |
+
Silicon olmadığı için eğitim devre dışı; format algılama ve kod üretici
|
| 31 |
+
çalışır). Örnek model: [Llama-3.2-1B Turkish-Alpaca](https://huggingface.co/acayir64/Llama-3.2-1B-Instruct-Turkish-Alpaca-mlx)
|
| 32 |
+
|
| 33 |
+
Source: https://github.com/aykutcayir34/finetuner
|
app.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Finetuner Studio — Hugging Face Space demo (GUI-only mode).
|
| 2 |
+
|
| 3 |
+
Spaces run on Linux, so MLX training is unavailable here; the Studio runs in
|
| 4 |
+
planning mode: dataset loading + automatic format detection, the code
|
| 5 |
+
generator, recipes and the full UI are live. Clone it on an Apple Silicon
|
| 6 |
+
Mac for actual training: https://github.com/aykutcayir34/finetuner
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
from finetuner.app import build_app
|
| 12 |
+
|
| 13 |
+
demo = build_app()
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
demo.launch(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="slate"))
|
finetuner/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Finetuner Studio — low-code fine-tuning on Apple Silicon, powered by mlx-tune."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
finetuner/__main__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .app import main
|
| 2 |
+
|
| 3 |
+
main()
|
finetuner/app.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Finetuner Studio — application entry point."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from .core.registry import TASKS, mlx_available
|
| 8 |
+
from .ui import tab_dataset, tab_export, tab_model, tab_monitor, tab_playground, tab_train
|
| 9 |
+
|
| 10 |
+
BANNER = """
|
| 11 |
+
# 🎛️ Finetuner Studio
|
| 12 |
+
**Low-code fine-tuning on Apple Silicon** · powered by
|
| 13 |
+
[mlx-tune](https://github.com/ARahim3/mlx-tune) · {n_tasks} training paradigms,
|
| 14 |
+
zero boilerplate
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_app() -> gr.Blocks:
|
| 19 |
+
ok, reason = mlx_available()
|
| 20 |
+
health = "🟢 mlx-tune ready" if ok else f"🟡 GUI-only mode — {reason}"
|
| 21 |
+
|
| 22 |
+
with gr.Blocks(title="Finetuner Studio") as app:
|
| 23 |
+
gr.Markdown(BANNER.format(n_tasks=len(TASKS)))
|
| 24 |
+
gr.Markdown(f"`{health}`")
|
| 25 |
+
with gr.Tabs():
|
| 26 |
+
tab_model.build(app)
|
| 27 |
+
tab_dataset.build(app)
|
| 28 |
+
tab_train.build(app)
|
| 29 |
+
tab_monitor.build(app)
|
| 30 |
+
tab_playground.build(app)
|
| 31 |
+
tab_export.build(app)
|
| 32 |
+
gr.Markdown(
|
| 33 |
+
"<center><small>Finetuner Studio · load a model → drop a dataset → "
|
| 34 |
+
"press train. The generated Python script is yours to keep.</small></center>"
|
| 35 |
+
)
|
| 36 |
+
return app
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
theme = gr.themes.Soft(primary_hue="orange", secondary_hue="slate")
|
| 41 |
+
build_app().launch(theme=theme)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
finetuner/core/__init__.py
ADDED
|
File without changes
|
finetuner/core/codegen.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Low-code → code: generate a standalone, editable mlx-tune training script
|
| 2 |
+
that reproduces exactly what the GUI is configured to do."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from .registry import get_task
|
| 7 |
+
from .training import RunConfig, build_trainer_args
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _fmt(v) -> str:
|
| 11 |
+
return repr(v)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_script(cfg: RunConfig, dataset_source: str = "",
|
| 15 |
+
dataset_is_local: bool = False) -> str:
|
| 16 |
+
spec = get_task(cfg.task)
|
| 17 |
+
args = build_trainer_args(cfg)
|
| 18 |
+
args_body = ",\n ".join(f"{k}={_fmt(v)}" for k, v in args.items())
|
| 19 |
+
|
| 20 |
+
imports = {spec.trainer}
|
| 21 |
+
if spec.config_module == "mlx_tune":
|
| 22 |
+
imports.add(spec.config)
|
| 23 |
+
if spec.collator:
|
| 24 |
+
imports.add(spec.collator)
|
| 25 |
+
import_line = f"from mlx_tune import {spec.loader}, {', '.join(sorted(imports))}"
|
| 26 |
+
extra_import = ""
|
| 27 |
+
if spec.config_module != "mlx_tune":
|
| 28 |
+
extra_import = f"\nfrom {spec.config_module} import {spec.config}"
|
| 29 |
+
|
| 30 |
+
# --- model loading -------------------------------------------------------
|
| 31 |
+
load_kwargs = ""
|
| 32 |
+
if spec.modality == "text":
|
| 33 |
+
load_kwargs = (f",\n max_seq_length={cfg.max_seq_length},"
|
| 34 |
+
f"\n load_in_4bit={cfg.load_in_4bit},")
|
| 35 |
+
handle = "processor" if spec.modality in ("vision", "image") else "tokenizer"
|
| 36 |
+
|
| 37 |
+
# --- LoRA ------------------------------------------------------------------
|
| 38 |
+
lora_block = ""
|
| 39 |
+
if cfg.use_lora and spec.peft_supported:
|
| 40 |
+
tm = ""
|
| 41 |
+
if spec.modality == "text":
|
| 42 |
+
tm = f"\n target_modules={cfg.target_modules!r},"
|
| 43 |
+
lora_block = f"""
|
| 44 |
+
# --- 2. Attach LoRA adapters -------------------------------------------------
|
| 45 |
+
model = {spec.loader}.get_peft_model(
|
| 46 |
+
model,
|
| 47 |
+
r={cfg.lora_r},
|
| 48 |
+
lora_alpha={cfg.lora_alpha},{tm}
|
| 49 |
+
)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# --- dataset ------------------------------------------------------------------
|
| 53 |
+
if dataset_is_local:
|
| 54 |
+
ds_block = f"""import json
|
| 55 |
+
|
| 56 |
+
with open({dataset_source!r}) as f:
|
| 57 |
+
train_dataset = [json.loads(line) for line in f if line.strip()]"""
|
| 58 |
+
elif dataset_source:
|
| 59 |
+
ds_block = f"""from datasets import load_dataset
|
| 60 |
+
|
| 61 |
+
train_dataset = load_dataset({dataset_source!r}, split="train")"""
|
| 62 |
+
else:
|
| 63 |
+
ds_block = """train_dataset = [
|
| 64 |
+
# TODO: fill with rows shaped like: %s
|
| 65 |
+
]""" % (dict.fromkeys(spec.dataset_schema, "..."),)
|
| 66 |
+
|
| 67 |
+
collator_line = ""
|
| 68 |
+
if spec.collator:
|
| 69 |
+
collator_line = f"\n data_collator={spec.collator}(model, {handle}),"
|
| 70 |
+
|
| 71 |
+
return f'''"""Auto-generated by Finetuner Studio — https://github.com/aykutcayir34/finetuner
|
| 72 |
+
|
| 73 |
+
Task : {spec.label}
|
| 74 |
+
Backend : mlx-tune (https://github.com/ARahim3/mlx-tune)
|
| 75 |
+
|
| 76 |
+
This script is fully standalone: edit it, version it, or move it to a cloud
|
| 77 |
+
GPU box — the mlx-tune API is Unsloth-compatible by design.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
{import_line}{extra_import}
|
| 81 |
+
|
| 82 |
+
# --- 1. Load the base model ----------------------------------------------------
|
| 83 |
+
model, {handle} = {spec.loader}.from_pretrained(
|
| 84 |
+
{cfg.model_name!r}{load_kwargs}
|
| 85 |
+
)
|
| 86 |
+
{lora_block}
|
| 87 |
+
# --- 3. Dataset ------------------------------------------------------------------
|
| 88 |
+
{ds_block}
|
| 89 |
+
|
| 90 |
+
# --- 4. Train ----------------------------------------------------------------------
|
| 91 |
+
trainer = {spec.trainer}(
|
| 92 |
+
model=model,
|
| 93 |
+
{handle}={handle},
|
| 94 |
+
train_dataset=train_dataset,{collator_line}
|
| 95 |
+
args={spec.config}(
|
| 96 |
+
{args_body},
|
| 97 |
+
),
|
| 98 |
+
)
|
| 99 |
+
trainer.train()
|
| 100 |
+
|
| 101 |
+
# --- 5. Save -----------------------------------------------------------------------
|
| 102 |
+
model.save_pretrained("lora_model") # adapters only
|
| 103 |
+
# model.save_pretrained_merged("merged", {handle}) # merged fp16 model
|
| 104 |
+
# model.push_to_hub("your-username/your-model") # upload to the Hub
|
| 105 |
+
'''
|
finetuner/core/data.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loading from the Hugging Face Hub or local files."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import csv
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import HfApi
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def search_hub_datasets(query: str, limit: int = 20) -> list[str]:
|
| 13 |
+
if not query.strip():
|
| 14 |
+
return []
|
| 15 |
+
api = HfApi()
|
| 16 |
+
return [d.id for d in api.list_datasets(search=query, limit=limit, sort="downloads")]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_hub_dataset(name: str, split: str = "train", config: str | None = None,
|
| 20 |
+
max_rows: int | None = None) -> list[dict]:
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
|
| 23 |
+
ds = load_dataset(name, config or None, split=split)
|
| 24 |
+
if max_rows:
|
| 25 |
+
ds = ds.select(range(min(max_rows, len(ds))))
|
| 26 |
+
return [dict(r) for r in ds]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_local_dataset(path: str, max_rows: int | None = None) -> list[dict]:
|
| 30 |
+
"""Load a local dataset file: .jsonl, .json, .csv, .tsv or .parquet."""
|
| 31 |
+
p = Path(path).expanduser()
|
| 32 |
+
if not p.exists():
|
| 33 |
+
raise FileNotFoundError(f"No such file: {p}")
|
| 34 |
+
suffix = p.suffix.lower()
|
| 35 |
+
|
| 36 |
+
rows: list[dict]
|
| 37 |
+
if suffix == ".jsonl":
|
| 38 |
+
rows = []
|
| 39 |
+
with p.open() as f:
|
| 40 |
+
for line in f:
|
| 41 |
+
line = line.strip()
|
| 42 |
+
if line:
|
| 43 |
+
rows.append(json.loads(line))
|
| 44 |
+
if max_rows and len(rows) >= max_rows:
|
| 45 |
+
break
|
| 46 |
+
elif suffix == ".json":
|
| 47 |
+
data = json.loads(p.read_text())
|
| 48 |
+
if isinstance(data, dict): # e.g. {"data": [...]} wrappers
|
| 49 |
+
for v in data.values():
|
| 50 |
+
if isinstance(v, list):
|
| 51 |
+
data = v
|
| 52 |
+
break
|
| 53 |
+
if not isinstance(data, list):
|
| 54 |
+
raise ValueError("JSON file must contain a list of records.")
|
| 55 |
+
rows = data
|
| 56 |
+
elif suffix in (".csv", ".tsv"):
|
| 57 |
+
delim = "\t" if suffix == ".tsv" else ","
|
| 58 |
+
with p.open(newline="") as f:
|
| 59 |
+
rows = list(csv.DictReader(f, delimiter=delim))
|
| 60 |
+
elif suffix == ".parquet":
|
| 61 |
+
import pandas as pd
|
| 62 |
+
rows = pd.read_parquet(p).to_dict("records")
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unsupported dataset format: {suffix} "
|
| 65 |
+
"(supported: .jsonl, .json, .csv, .tsv, .parquet)")
|
| 66 |
+
|
| 67 |
+
if max_rows:
|
| 68 |
+
rows = rows[:max_rows]
|
| 69 |
+
return rows
|
finetuner/core/detector.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Automatic dataset format detection and normalization.
|
| 2 |
+
|
| 3 |
+
Given a few sample rows, classify the dataset into one of the canonical
|
| 4 |
+
formats mlx-tune trainers understand, propose a column mapping, and convert
|
| 5 |
+
rows into the exact shape the selected trainer expects.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
# Column-name synonyms (English + Turkish), matched case-insensitively after
|
| 13 |
+
# stripping whitespace — real-world CSV headers are messy.
|
| 14 |
+
SYNONYMS = {
|
| 15 |
+
"instruction": {"instruction", "question", "query", "instruct", "task",
|
| 16 |
+
"talimat", "soru", "görev"},
|
| 17 |
+
"input": {"input", "context", "system", "giriş", "girdi", "bağlam"},
|
| 18 |
+
"output": {"output", "response", "answer", "completion", "target",
|
| 19 |
+
"çıktı", "cevap", "yanıt"},
|
| 20 |
+
"prompt": {"prompt", "question", "query", "instruction", "istem", "soru"},
|
| 21 |
+
"chosen": {"chosen", "preferred", "accepted", "good", "seçilen", "tercih"},
|
| 22 |
+
"rejected": {"rejected", "dispreferred", "bad", "reddedilen"},
|
| 23 |
+
"completion": {"completion", "response", "output", "answer", "cevap", "yanıt"},
|
| 24 |
+
"label": {"label", "thumbs_up", "is_good", "score", "etiket"},
|
| 25 |
+
"conversations": {"conversations", "messages", "dialogue", "dialog", "chat",
|
| 26 |
+
"turns", "konuşmalar", "mesajlar"},
|
| 27 |
+
"text": {"text", "content", "document", "body", "metin", "içerik"},
|
| 28 |
+
"anchor": {"anchor", "query", "sentence1", "question", "soru"},
|
| 29 |
+
"positive": {"positive", "passage", "sentence2", "answer", "document"},
|
| 30 |
+
"audio": {"audio", "audio_path", "audio_filepath", "wav", "file", "path", "ses"},
|
| 31 |
+
"transcription": {"text", "transcription", "transcript", "sentence", "caption",
|
| 32 |
+
"metin", "cümle"},
|
| 33 |
+
"image": {"image", "images", "image_path", "img", "picture", "görsel", "resim"},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
FORMAT_LABELS = {
|
| 37 |
+
"alpaca": "Alpaca (instruction / input / output)",
|
| 38 |
+
"sharegpt": "ShareGPT (conversations with from/value turns)",
|
| 39 |
+
"chatml": "ChatML / OpenAI messages (role/content turns)",
|
| 40 |
+
"prompt_completion": "Prompt–completion pairs",
|
| 41 |
+
"preference": "Preference pairs (prompt / chosen / rejected)",
|
| 42 |
+
"kto": "KTO binary feedback (prompt / completion / label)",
|
| 43 |
+
"grpo": "GRPO prompts (prompt, optional answer)",
|
| 44 |
+
"text": "Raw text corpus",
|
| 45 |
+
"embedding_pairs": "Embedding pairs (anchor / positive)",
|
| 46 |
+
"audio_text": "Audio + text (TTS / STT)",
|
| 47 |
+
"vision_chat": "Image + conversation (VLM)",
|
| 48 |
+
"image_text": "Image + ground-truth text (OCR)",
|
| 49 |
+
"unknown": "Unknown — manual mapping required",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class Detection:
|
| 55 |
+
format: str
|
| 56 |
+
confidence: float # 0..1
|
| 57 |
+
mapping: dict[str, str] = field(default_factory=dict) # canonical -> actual column
|
| 58 |
+
suggested_tasks: list[str] = field(default_factory=list)
|
| 59 |
+
notes: list[str] = field(default_factory=list)
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def label(self) -> str:
|
| 63 |
+
return FORMAT_LABELS.get(self.format, self.format)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _find(columns: list[str], canonical: str) -> str | None:
|
| 67 |
+
"""Find the actual column matching a canonical field name."""
|
| 68 |
+
lowered = {c.strip().lower(): c for c in columns}
|
| 69 |
+
if canonical in lowered:
|
| 70 |
+
return lowered[canonical]
|
| 71 |
+
for syn in SYNONYMS.get(canonical, ()):
|
| 72 |
+
if syn in lowered:
|
| 73 |
+
return lowered[syn]
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _is_turn_list(value) -> str | None:
|
| 78 |
+
"""Classify a list-of-dicts column as 'sharegpt' or 'chatml' turns."""
|
| 79 |
+
if not isinstance(value, list) or not value or not isinstance(value[0], dict):
|
| 80 |
+
return None
|
| 81 |
+
keys = set(value[0].keys())
|
| 82 |
+
if {"from", "value"} <= keys:
|
| 83 |
+
return "sharegpt"
|
| 84 |
+
if {"role", "content"} <= keys:
|
| 85 |
+
return "chatml"
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _looks_like_path(value, exts: tuple[str, ...]) -> bool:
|
| 90 |
+
return isinstance(value, str) and value.lower().endswith(exts)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
AUDIO_EXTS = (".wav", ".mp3", ".flac", ".m4a", ".ogg")
|
| 94 |
+
IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def detect(rows: list[dict]) -> Detection:
|
| 98 |
+
"""Detect the dataset format from sample rows (a handful is enough)."""
|
| 99 |
+
if not rows:
|
| 100 |
+
return Detection("unknown", 0.0, notes=["Dataset is empty."])
|
| 101 |
+
|
| 102 |
+
row = rows[0]
|
| 103 |
+
columns = list(row.keys())
|
| 104 |
+
from .registry import tasks_for_format # local import to avoid cycles
|
| 105 |
+
|
| 106 |
+
def done(fmt: str, conf: float, mapping: dict[str, str], notes: list[str] | None = None) -> Detection:
|
| 107 |
+
tasks = [t.id for t in tasks_for_format(fmt)]
|
| 108 |
+
return Detection(fmt, conf, mapping, tasks, notes or [])
|
| 109 |
+
|
| 110 |
+
# --- multimodal first: presence of media columns dominates -------------
|
| 111 |
+
img_col = _find(columns, "image")
|
| 112 |
+
audio_col = _find(columns, "audio")
|
| 113 |
+
conv_col = _find(columns, "conversations")
|
| 114 |
+
|
| 115 |
+
if img_col is not None:
|
| 116 |
+
sample = row.get(img_col)
|
| 117 |
+
media_like = _looks_like_path(sample, IMAGE_EXTS) or not isinstance(sample, str)
|
| 118 |
+
if conv_col is not None:
|
| 119 |
+
return done("vision_chat", 0.9 if media_like else 0.6,
|
| 120 |
+
{"images": img_col, "messages": conv_col})
|
| 121 |
+
text_col = _find([c for c in columns if c != img_col], "text")
|
| 122 |
+
if text_col is not None:
|
| 123 |
+
return done("image_text", 0.85 if media_like else 0.55,
|
| 124 |
+
{"image": img_col, "text": text_col},
|
| 125 |
+
["Image + text detected — suitable for OCR SFT or vision tasks."])
|
| 126 |
+
|
| 127 |
+
if audio_col is not None:
|
| 128 |
+
sample = row.get(audio_col)
|
| 129 |
+
if _looks_like_path(sample, AUDIO_EXTS) or isinstance(sample, dict):
|
| 130 |
+
text_col = _find([c for c in columns if c != audio_col], "transcription")
|
| 131 |
+
if text_col is not None:
|
| 132 |
+
return done("audio_text", 0.9, {"audio": audio_col, "text": text_col},
|
| 133 |
+
["Audio + text detected — choose TTS (synthesis) or STT (recognition)."])
|
| 134 |
+
|
| 135 |
+
# --- conversation formats ----------------------------------------------
|
| 136 |
+
if conv_col is not None:
|
| 137 |
+
kind = _is_turn_list(row.get(conv_col))
|
| 138 |
+
if kind == "sharegpt":
|
| 139 |
+
return done("sharegpt", 0.95, {"conversations": conv_col})
|
| 140 |
+
if kind == "chatml":
|
| 141 |
+
return done("chatml", 0.95, {"messages": conv_col})
|
| 142 |
+
|
| 143 |
+
# --- preference / feedback ----------------------------------------------
|
| 144 |
+
chosen = _find(columns, "chosen")
|
| 145 |
+
rejected = _find(columns, "rejected")
|
| 146 |
+
prompt = _find(columns, "prompt")
|
| 147 |
+
if chosen and rejected:
|
| 148 |
+
mapping = {"chosen": chosen, "rejected": rejected}
|
| 149 |
+
if prompt:
|
| 150 |
+
mapping["prompt"] = prompt
|
| 151 |
+
return done("preference", 0.95, mapping)
|
| 152 |
+
return done("preference", 0.75, mapping,
|
| 153 |
+
["No explicit prompt column; chosen/rejected may embed the prompt."])
|
| 154 |
+
|
| 155 |
+
label = _find(columns, "label")
|
| 156 |
+
completion = _find(columns, "completion")
|
| 157 |
+
if prompt and completion and label is not None:
|
| 158 |
+
if isinstance(row.get(label), (bool, int)):
|
| 159 |
+
return done("kto", 0.9, {"prompt": prompt, "completion": completion, "label": label})
|
| 160 |
+
|
| 161 |
+
# --- instruction tuning ---------------------------------------------------
|
| 162 |
+
instruction = _find(columns, "instruction")
|
| 163 |
+
output = _find(columns, "output")
|
| 164 |
+
if instruction and output:
|
| 165 |
+
mapping = {"instruction": instruction, "output": output}
|
| 166 |
+
inp = _find([c for c in columns if c not in (instruction, output)], "input")
|
| 167 |
+
if inp:
|
| 168 |
+
mapping["input"] = inp
|
| 169 |
+
return done("alpaca", 0.95, mapping)
|
| 170 |
+
|
| 171 |
+
if prompt and completion:
|
| 172 |
+
return done("prompt_completion", 0.9, {"prompt": prompt, "completion": completion})
|
| 173 |
+
|
| 174 |
+
# --- embeddings -----------------------------------------------------------
|
| 175 |
+
anchor = _find(columns, "anchor")
|
| 176 |
+
positive = _find(columns, "positive")
|
| 177 |
+
if anchor and positive and anchor != positive:
|
| 178 |
+
return done("embedding_pairs", 0.8, {"anchor": anchor, "positive": positive},
|
| 179 |
+
["Anchor/positive pair detected — embedding contrastive training."])
|
| 180 |
+
|
| 181 |
+
# --- GRPO: bare prompts -----------------------------------------------------
|
| 182 |
+
if prompt and len(columns) <= 2:
|
| 183 |
+
return done("grpo", 0.6, {"prompt": prompt},
|
| 184 |
+
["Bare prompts — usable for GRPO with a custom reward function."])
|
| 185 |
+
|
| 186 |
+
# --- raw text ----------------------------------------------------------------
|
| 187 |
+
text = _find(columns, "text")
|
| 188 |
+
if text:
|
| 189 |
+
return done("text", 0.85, {"text": text},
|
| 190 |
+
["Raw text — suitable for CPT or completion-style SFT."])
|
| 191 |
+
|
| 192 |
+
# --- single string column fallback ---------------------------------------------
|
| 193 |
+
str_cols = [c for c in columns if isinstance(row.get(c), str)]
|
| 194 |
+
if len(str_cols) == 1:
|
| 195 |
+
return done("text", 0.5, {"text": str_cols[0]},
|
| 196 |
+
[f"Single text column `{str_cols[0]}` assumed to be raw text."])
|
| 197 |
+
|
| 198 |
+
return Detection("unknown", 0.0, {},
|
| 199 |
+
notes=[f"Could not classify columns: {columns}. Map fields manually."])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
# Normalization: convert detected rows into trainer-ready shape
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
+
def _format_chat(turns: list[dict], tokenizer=None) -> str:
|
| 207 |
+
"""Render chat turns to text via the tokenizer's chat template when possible."""
|
| 208 |
+
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
|
| 209 |
+
try:
|
| 210 |
+
return tokenizer.apply_chat_template(turns, tokenize=False, add_generation_prompt=False)
|
| 211 |
+
except Exception:
|
| 212 |
+
pass
|
| 213 |
+
return "\n".join(f"<|{t['role']}|>\n{t['content']}" for t in turns) + "\n"
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
SHAREGPT_ROLES = {"human": "user", "user": "user", "gpt": "assistant",
|
| 217 |
+
"assistant": "assistant", "system": "system"}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def to_messages(row: dict, detection: Detection) -> list[dict]:
|
| 221 |
+
"""Convert a row of any chat-like format into ChatML messages."""
|
| 222 |
+
m = detection.mapping
|
| 223 |
+
fmt = detection.format
|
| 224 |
+
if fmt == "chatml":
|
| 225 |
+
return row[m["messages"]]
|
| 226 |
+
if fmt == "sharegpt":
|
| 227 |
+
return [{"role": SHAREGPT_ROLES.get(t["from"], "user"), "content": t["value"]}
|
| 228 |
+
for t in row[m["conversations"]]]
|
| 229 |
+
if fmt == "alpaca":
|
| 230 |
+
user = row[m["instruction"]]
|
| 231 |
+
if "input" in m and row.get(m["input"]):
|
| 232 |
+
user = f"{user}\n\n{row[m['input']]}"
|
| 233 |
+
return [{"role": "user", "content": user},
|
| 234 |
+
{"role": "assistant", "content": row[m["output"]]}]
|
| 235 |
+
if fmt == "prompt_completion":
|
| 236 |
+
return [{"role": "user", "content": row[m["prompt"]]},
|
| 237 |
+
{"role": "assistant", "content": row[m["completion"]]}]
|
| 238 |
+
raise ValueError(f"Cannot build messages from format {fmt!r}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def normalize(rows: list[dict], detection: Detection, task_id: str, tokenizer=None) -> list[dict]:
|
| 242 |
+
"""Convert raw rows into the schema the chosen task's trainer expects."""
|
| 243 |
+
m = detection.mapping
|
| 244 |
+
fmt = detection.format
|
| 245 |
+
|
| 246 |
+
if task_id in ("sft",):
|
| 247 |
+
if fmt == "text":
|
| 248 |
+
return [{"text": r[m["text"]]} for r in rows]
|
| 249 |
+
return [{"text": _format_chat(to_messages(r, detection), tokenizer)} for r in rows]
|
| 250 |
+
|
| 251 |
+
if task_id == "cpt":
|
| 252 |
+
col = m.get("text")
|
| 253 |
+
if col is None:
|
| 254 |
+
raise ValueError("CPT needs a raw text column.")
|
| 255 |
+
return [{"text": r[col]} for r in rows]
|
| 256 |
+
|
| 257 |
+
if task_id in ("dpo", "orpo", "simpo"):
|
| 258 |
+
out = []
|
| 259 |
+
for r in rows:
|
| 260 |
+
item = {"chosen": r[m["chosen"]], "rejected": r[m["rejected"]]}
|
| 261 |
+
item["prompt"] = r[m["prompt"]] if "prompt" in m else ""
|
| 262 |
+
out.append(item)
|
| 263 |
+
return out
|
| 264 |
+
|
| 265 |
+
if task_id == "kto":
|
| 266 |
+
return [{"prompt": r[m["prompt"]], "completion": r[m["completion"]],
|
| 267 |
+
"label": bool(r[m["label"]])} for r in rows]
|
| 268 |
+
|
| 269 |
+
if task_id == "grpo":
|
| 270 |
+
return [{"prompt": r[m["prompt"]]} for r in rows]
|
| 271 |
+
|
| 272 |
+
if task_id == "embedding":
|
| 273 |
+
return [{"anchor": r[m["anchor"]], "positive": r[m["positive"]]} for r in rows]
|
| 274 |
+
|
| 275 |
+
if task_id in ("tts_sft", "stt_sft"):
|
| 276 |
+
return [{"audio": r[m["audio"]], "text": r[m["text"]]} for r in rows]
|
| 277 |
+
|
| 278 |
+
if task_id == "ocr_sft":
|
| 279 |
+
return [{"image": r[m["image"]], "text": r[m["text"]]} for r in rows]
|
| 280 |
+
|
| 281 |
+
if task_id == "vlm_sft":
|
| 282 |
+
out = []
|
| 283 |
+
for r in rows:
|
| 284 |
+
sub = {"images": r[m["images"]], "messages": r[m["messages"]]}
|
| 285 |
+
if isinstance(sub["messages"], list) and sub["messages"] \
|
| 286 |
+
and "from" in (sub["messages"][0] or {}):
|
| 287 |
+
sub["messages"] = [{"role": SHAREGPT_ROLES.get(t["from"], "user"),
|
| 288 |
+
"content": t["value"]} for t in sub["messages"]]
|
| 289 |
+
out.append(sub)
|
| 290 |
+
return out
|
| 291 |
+
|
| 292 |
+
raise ValueError(f"Unknown task {task_id!r}")
|
finetuner/core/engine.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Single persistent MLX engine thread.
|
| 2 |
+
|
| 3 |
+
MLX streams are thread-local: a model loaded on one thread cannot reliably be
|
| 4 |
+
trained or sampled from another ("There is no Stream(gpu, N) in current
|
| 5 |
+
thread"). Gradio runs every event handler on a different worker thread, so all
|
| 6 |
+
MLX work — model loading, training, generation — is funneled through one
|
| 7 |
+
long-lived engine thread. This also serializes GPU work, which is what a
|
| 8 |
+
single-device machine wants anyway.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import queue
|
| 15 |
+
import sys
|
| 16 |
+
import threading
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
# mlx-tune's subprocess fallback shells out to `mlx_lm.lora`; make sure the
|
| 20 |
+
# interpreter's bin directory is on PATH even when the venv isn't activated.
|
| 21 |
+
_bin = str(Path(sys.executable).parent)
|
| 22 |
+
if _bin not in os.environ.get("PATH", "").split(os.pathsep):
|
| 23 |
+
os.environ["PATH"] = _bin + os.pathsep + os.environ.get("PATH", "")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _Engine:
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self._q: queue.Queue = queue.Queue()
|
| 29 |
+
self._thread = threading.Thread(target=self._loop, name="finetuner-mlx-engine",
|
| 30 |
+
daemon=True)
|
| 31 |
+
self._thread.start()
|
| 32 |
+
|
| 33 |
+
def _loop(self):
|
| 34 |
+
while True:
|
| 35 |
+
fn, args, kwargs, done, box = self._q.get()
|
| 36 |
+
try:
|
| 37 |
+
box["result"] = fn(*args, **kwargs)
|
| 38 |
+
except BaseException as exc: # noqa: BLE001 — re-raised on the caller thread
|
| 39 |
+
box["error"] = exc
|
| 40 |
+
finally:
|
| 41 |
+
done.set()
|
| 42 |
+
|
| 43 |
+
def call(self, fn, *args, **kwargs):
|
| 44 |
+
"""Run `fn` on the engine thread and block until it returns."""
|
| 45 |
+
if threading.current_thread() is self._thread:
|
| 46 |
+
return fn(*args, **kwargs)
|
| 47 |
+
done = threading.Event()
|
| 48 |
+
box: dict = {}
|
| 49 |
+
self._q.put((fn, args, kwargs, done, box))
|
| 50 |
+
done.wait()
|
| 51 |
+
if "error" in box:
|
| 52 |
+
raise box["error"]
|
| 53 |
+
return box["result"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
ENGINE = _Engine()
|
finetuner/core/export.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Export trained models: adapters, merged weights, GGUF, or push to the Hub."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def save_adapters(model, path: str) -> str:
|
| 7 |
+
model.save_pretrained(path)
|
| 8 |
+
return f"LoRA adapters saved to {path}"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def save_merged(model, tokenizer, path: str) -> str:
|
| 12 |
+
model.save_pretrained_merged(path, tokenizer)
|
| 13 |
+
return f"Merged 16-bit model saved to {path}"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def save_gguf(model, tokenizer, path: str) -> str:
|
| 17 |
+
# mlx-lm limitation: GGUF export requires a non-quantized base model.
|
| 18 |
+
model.save_pretrained_gguf(path, tokenizer)
|
| 19 |
+
return f"GGUF model saved to {path}"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def push_to_hub(model, repo_id: str, token: str | None = None) -> str:
|
| 23 |
+
kwargs = {"token": token} if token else {}
|
| 24 |
+
try:
|
| 25 |
+
model.push_to_hub(repo_id, **kwargs)
|
| 26 |
+
except TypeError:
|
| 27 |
+
model.push_to_hub(repo_id)
|
| 28 |
+
return f"Pushed to https://huggingface.co/{repo_id}"
|
finetuner/core/jobs.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Background training job manager.
|
| 2 |
+
|
| 3 |
+
Training runs on a worker thread so the GUI stays responsive. Trainer stdout
|
| 4 |
+
is captured into a ring buffer; loss values are parsed out of the log stream
|
| 5 |
+
for live charting.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import contextlib
|
| 11 |
+
import io
|
| 12 |
+
import re
|
| 13 |
+
import threading
|
| 14 |
+
import time
|
| 15 |
+
import traceback
|
| 16 |
+
from collections import deque
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
|
| 19 |
+
# Matches lines like "step 10: loss 1.2345", "{'loss': 1.23, 'step': 10}", "10/100 | loss: 1.23"
|
| 20 |
+
_LOSS_RE = re.compile(r"loss[\"']?[:=\s]+([0-9]*\.?[0-9]+(?:e-?\d+)?)", re.IGNORECASE)
|
| 21 |
+
_STEP_RE = re.compile(r"(?:step|it(?:er)?)[\"']?[:=\s/]+(\d+)", re.IGNORECASE)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _Tee(io.TextIOBase):
|
| 25 |
+
"""Write-through stream that feeds the job log and parses metrics."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, job: "Job", original):
|
| 28 |
+
self.job = job
|
| 29 |
+
self.original = original
|
| 30 |
+
self._buf = ""
|
| 31 |
+
|
| 32 |
+
def write(self, s: str) -> int:
|
| 33 |
+
self.original.write(s)
|
| 34 |
+
self._buf += s
|
| 35 |
+
parts = re.split(r"[\n\r]", self._buf)
|
| 36 |
+
self._buf = parts.pop() # keep the unterminated tail
|
| 37 |
+
for line in parts:
|
| 38 |
+
if line.strip():
|
| 39 |
+
self.job.add_log(line)
|
| 40 |
+
return len(s)
|
| 41 |
+
|
| 42 |
+
def flush(self):
|
| 43 |
+
self.original.flush()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class Job:
|
| 48 |
+
id: int
|
| 49 |
+
name: str
|
| 50 |
+
status: str = "pending" # pending | running | finished | failed | stopped
|
| 51 |
+
logs: deque = field(default_factory=lambda: deque(maxlen=2000))
|
| 52 |
+
metrics: list = field(default_factory=list) # [(step, loss)]
|
| 53 |
+
error: str | None = None
|
| 54 |
+
started_at: float | None = None
|
| 55 |
+
finished_at: float | None = None
|
| 56 |
+
stop_event: threading.Event = field(default_factory=threading.Event)
|
| 57 |
+
_step_guess: int = 0
|
| 58 |
+
|
| 59 |
+
def add_log(self, line: str):
|
| 60 |
+
self.logs.append(line)
|
| 61 |
+
loss = _LOSS_RE.search(line)
|
| 62 |
+
if loss:
|
| 63 |
+
step_m = _STEP_RE.search(line)
|
| 64 |
+
if step_m:
|
| 65 |
+
self._step_guess = int(step_m.group(1))
|
| 66 |
+
else:
|
| 67 |
+
self._step_guess += 1
|
| 68 |
+
try:
|
| 69 |
+
self.metrics.append((self._step_guess, float(loss.group(1))))
|
| 70 |
+
except ValueError:
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def log_text(self, last_n: int = 200) -> str:
|
| 74 |
+
return "\n".join(list(self.logs)[-last_n:])
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def elapsed(self) -> float:
|
| 78 |
+
if self.started_at is None:
|
| 79 |
+
return 0.0
|
| 80 |
+
end = self.finished_at or time.time()
|
| 81 |
+
return end - self.started_at
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class JobManager:
|
| 85 |
+
def __init__(self):
|
| 86 |
+
self._jobs: dict[int, Job] = {}
|
| 87 |
+
self._next_id = 1
|
| 88 |
+
self._lock = threading.Lock()
|
| 89 |
+
|
| 90 |
+
def submit(self, name: str, target, *args, **kwargs) -> Job:
|
| 91 |
+
"""Run `target(job, *args, **kwargs)` on a worker thread with log capture."""
|
| 92 |
+
with self._lock:
|
| 93 |
+
job = Job(id=self._next_id, name=name)
|
| 94 |
+
self._jobs[job.id] = job
|
| 95 |
+
self._next_id += 1
|
| 96 |
+
|
| 97 |
+
def runner():
|
| 98 |
+
job.status = "running"
|
| 99 |
+
job.started_at = time.time()
|
| 100 |
+
tee_out = _Tee(job, __import__("sys").stdout)
|
| 101 |
+
tee_err = _Tee(job, __import__("sys").stderr)
|
| 102 |
+
from .engine import ENGINE # local import: keep jobs importable standalone
|
| 103 |
+
try:
|
| 104 |
+
with contextlib.redirect_stdout(tee_out), contextlib.redirect_stderr(tee_err):
|
| 105 |
+
# All MLX work must run on the single engine thread; the
|
| 106 |
+
# stdout redirect is process-wide, so logs still reach us.
|
| 107 |
+
ENGINE.call(target, job, *args, **kwargs)
|
| 108 |
+
job.status = "stopped" if job.stop_event.is_set() else "finished"
|
| 109 |
+
except Exception:
|
| 110 |
+
job.error = traceback.format_exc()
|
| 111 |
+
job.add_log(job.error)
|
| 112 |
+
job.status = "failed"
|
| 113 |
+
finally:
|
| 114 |
+
job.finished_at = time.time()
|
| 115 |
+
|
| 116 |
+
threading.Thread(target=runner, name=f"finetuner-job-{job.id}", daemon=True).start()
|
| 117 |
+
return job
|
| 118 |
+
|
| 119 |
+
def get(self, job_id: int) -> Job | None:
|
| 120 |
+
return self._jobs.get(job_id)
|
| 121 |
+
|
| 122 |
+
def all(self) -> list[Job]:
|
| 123 |
+
return list(self._jobs.values())
|
| 124 |
+
|
| 125 |
+
def latest(self) -> Job | None:
|
| 126 |
+
return self._jobs[max(self._jobs)] if self._jobs else None
|
| 127 |
+
|
| 128 |
+
def stop(self, job_id: int):
|
| 129 |
+
job = self._jobs.get(job_id)
|
| 130 |
+
if job:
|
| 131 |
+
job.stop_event.set()
|
| 132 |
+
job.add_log("⏹ Stop requested — training will halt at the next step boundary.")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
MANAGER = JobManager()
|
finetuner/core/models.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model discovery (Hugging Face Hub) and loading via mlx-tune."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import HfApi
|
| 8 |
+
|
| 9 |
+
from .engine import ENGINE
|
| 10 |
+
from .registry import get_task, resolve
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def search_hub_models(query: str, limit: int = 20, mlx_only: bool = True) -> list[str]:
|
| 14 |
+
"""Search the Hub; by default biased to mlx-community / MLX-tagged models."""
|
| 15 |
+
if not query.strip():
|
| 16 |
+
return []
|
| 17 |
+
api = HfApi()
|
| 18 |
+
kwargs: dict = {"search": query, "limit": limit, "sort": "downloads"}
|
| 19 |
+
if mlx_only:
|
| 20 |
+
kwargs["filter"] = "mlx"
|
| 21 |
+
try:
|
| 22 |
+
results = [m.id for m in api.list_models(**kwargs)]
|
| 23 |
+
except TypeError: # huggingface_hub version drift on the filter kwarg
|
| 24 |
+
kwargs.pop("filter", None)
|
| 25 |
+
results = [m.id for m in api.list_models(**kwargs)]
|
| 26 |
+
if not results and mlx_only: # fall back to an unrestricted search
|
| 27 |
+
results = [m.id for m in api.list_models(search=query, limit=limit, sort="downloads")]
|
| 28 |
+
return results
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def validate_local_model(path: str) -> str:
|
| 32 |
+
p = Path(path).expanduser()
|
| 33 |
+
if not p.is_dir():
|
| 34 |
+
raise FileNotFoundError(f"Not a directory: {p}")
|
| 35 |
+
if not (p / "config.json").exists():
|
| 36 |
+
raise ValueError(f"{p} does not look like a model directory (missing config.json).")
|
| 37 |
+
return str(p)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_model(task_id: str, model_name: str, max_seq_length: int = 2048,
|
| 41 |
+
load_in_4bit: bool = True):
|
| 42 |
+
"""Load a model + tokenizer/processor through the task's mlx-tune loader."""
|
| 43 |
+
spec = get_task(task_id)
|
| 44 |
+
loader = resolve(spec.loader)
|
| 45 |
+
kwargs: dict = {}
|
| 46 |
+
if spec.modality == "text":
|
| 47 |
+
kwargs["max_seq_length"] = max_seq_length
|
| 48 |
+
kwargs["load_in_4bit"] = load_in_4bit
|
| 49 |
+
# MLX streams are thread-local — load on the engine thread (see engine.py).
|
| 50 |
+
return ENGINE.call(loader.from_pretrained, model_name, **kwargs)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def apply_lora(task_id: str, model, r: int = 16, lora_alpha: int = 16,
|
| 54 |
+
lora_dropout: float = 0.0, target_modules: list[str] | None = None,
|
| 55 |
+
**extra):
|
| 56 |
+
spec = get_task(task_id)
|
| 57 |
+
loader = resolve(spec.loader)
|
| 58 |
+
kwargs: dict = {"r": r, "lora_alpha": lora_alpha}
|
| 59 |
+
if lora_dropout:
|
| 60 |
+
kwargs["lora_dropout"] = lora_dropout
|
| 61 |
+
if spec.modality == "text" and target_modules:
|
| 62 |
+
kwargs["target_modules"] = list(target_modules)
|
| 63 |
+
kwargs.update(extra)
|
| 64 |
+
try:
|
| 65 |
+
return ENGINE.call(loader.get_peft_model, model, **kwargs)
|
| 66 |
+
except TypeError:
|
| 67 |
+
# Older mlx-tune versions may not accept every kwarg (e.g. lora_dropout).
|
| 68 |
+
kwargs.pop("lora_dropout", None)
|
| 69 |
+
return ENGINE.call(loader.get_peft_model, model, **kwargs)
|
finetuner/core/recipes.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Recipes: save/load complete run configurations as shareable YAML files."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from .training import RunConfig
|
| 11 |
+
|
| 12 |
+
RECIPE_DIR = Path("recipes")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_recipe(cfg: RunConfig, name: str = "", dataset_source: str = "",
|
| 16 |
+
dataset_is_local: bool = False) -> Path:
|
| 17 |
+
RECIPE_DIR.mkdir(exist_ok=True)
|
| 18 |
+
slug = (name.strip() or f"{cfg.task}-{time.strftime('%Y%m%d-%H%M%S')}").replace(" ", "-")
|
| 19 |
+
path = RECIPE_DIR / f"{slug}.yaml"
|
| 20 |
+
payload = {
|
| 21 |
+
"finetuner_recipe": 1,
|
| 22 |
+
"dataset": {"source": dataset_source, "local": dataset_is_local},
|
| 23 |
+
"run": cfg.to_dict(),
|
| 24 |
+
}
|
| 25 |
+
path.write_text(yaml.safe_dump(payload, sort_keys=False, allow_unicode=True))
|
| 26 |
+
return path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_recipe(path: str) -> tuple[RunConfig, str, bool]:
|
| 30 |
+
data = yaml.safe_load(Path(path).expanduser().read_text())
|
| 31 |
+
if not isinstance(data, dict) or "run" not in data:
|
| 32 |
+
raise ValueError("Not a Finetuner recipe (missing `run` section).")
|
| 33 |
+
ds = data.get("dataset", {})
|
| 34 |
+
return RunConfig.from_dict(data["run"]), ds.get("source", ""), bool(ds.get("local", False))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def list_recipes() -> list[str]:
|
| 38 |
+
if not RECIPE_DIR.exists():
|
| 39 |
+
return []
|
| 40 |
+
return sorted(str(p) for p in RECIPE_DIR.glob("*.yaml"))
|
finetuner/core/registry.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task registry: one entry per mlx-tune training paradigm.
|
| 2 |
+
|
| 3 |
+
Every public trainer interface that mlx-tune exposes is described here so the
|
| 4 |
+
GUI, the code generator and the recipe system all share a single source of
|
| 5 |
+
truth. mlx-tune itself is imported lazily — the Studio runs (for planning,
|
| 6 |
+
dataset inspection, recipe authoring and script generation) even on machines
|
| 7 |
+
where MLX is not installed.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import importlib
|
| 13 |
+
import importlib.util
|
| 14 |
+
import platform
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class TaskSpec:
|
| 20 |
+
id: str
|
| 21 |
+
label: str
|
| 22 |
+
description: str
|
| 23 |
+
loader: str # mlx_tune Fast*Model class name
|
| 24 |
+
trainer: str # mlx_tune trainer class name
|
| 25 |
+
config: str # mlx_tune config class name
|
| 26 |
+
config_module: str = "mlx_tune" # module to import the config from
|
| 27 |
+
collator: str | None = None # optional data collator class name
|
| 28 |
+
dataset_schema: tuple[str, ...] = () # canonical required fields
|
| 29 |
+
detector_formats: tuple[str, ...] = () # detector format ids that map to this task
|
| 30 |
+
default_model: str = ""
|
| 31 |
+
default_target_modules: tuple[str, ...] = (
|
| 32 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 33 |
+
)
|
| 34 |
+
peft_supported: bool = True
|
| 35 |
+
extra_config_defaults: dict = field(default_factory=dict)
|
| 36 |
+
modality: str = "text" # text | vision | audio | image
|
| 37 |
+
notes: str = ""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
FULL_TARGETS = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
|
| 41 |
+
|
| 42 |
+
TASKS: dict[str, TaskSpec] = {
|
| 43 |
+
"sft": TaskSpec(
|
| 44 |
+
id="sft",
|
| 45 |
+
label="SFT — Supervised Fine-Tuning",
|
| 46 |
+
description="Instruction tuning of chat/completion LLMs on text or conversation data.",
|
| 47 |
+
loader="FastLanguageModel",
|
| 48 |
+
trainer="SFTTrainer",
|
| 49 |
+
config="SFTConfig",
|
| 50 |
+
dataset_schema=("text",),
|
| 51 |
+
detector_formats=("alpaca", "sharegpt", "chatml", "prompt_completion", "text"),
|
| 52 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 53 |
+
),
|
| 54 |
+
"dpo": TaskSpec(
|
| 55 |
+
id="dpo",
|
| 56 |
+
label="DPO — Direct Preference Optimization",
|
| 57 |
+
description="Align a model with human preferences from chosen/rejected pairs.",
|
| 58 |
+
loader="FastLanguageModel",
|
| 59 |
+
trainer="DPOTrainer",
|
| 60 |
+
config="DPOConfig",
|
| 61 |
+
dataset_schema=("prompt", "chosen", "rejected"),
|
| 62 |
+
detector_formats=("preference",),
|
| 63 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 64 |
+
extra_config_defaults={"beta": 0.1},
|
| 65 |
+
),
|
| 66 |
+
"orpo": TaskSpec(
|
| 67 |
+
id="orpo",
|
| 68 |
+
label="ORPO — Odds Ratio Preference Optimization",
|
| 69 |
+
description="Reference-free preference optimization; combines SFT and alignment in one pass.",
|
| 70 |
+
loader="FastLanguageModel",
|
| 71 |
+
trainer="ORPOTrainer",
|
| 72 |
+
config="ORPOConfig",
|
| 73 |
+
dataset_schema=("prompt", "chosen", "rejected"),
|
| 74 |
+
detector_formats=("preference",),
|
| 75 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 76 |
+
extra_config_defaults={"beta": 0.1},
|
| 77 |
+
),
|
| 78 |
+
"simpo": TaskSpec(
|
| 79 |
+
id="simpo",
|
| 80 |
+
label="SimPO — Simple Preference Optimization",
|
| 81 |
+
description="Length-normalized, reference-free preference optimization.",
|
| 82 |
+
loader="FastLanguageModel",
|
| 83 |
+
trainer="SimPOTrainer",
|
| 84 |
+
config="SimPOConfig",
|
| 85 |
+
dataset_schema=("prompt", "chosen", "rejected"),
|
| 86 |
+
detector_formats=("preference",),
|
| 87 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 88 |
+
),
|
| 89 |
+
"kto": TaskSpec(
|
| 90 |
+
id="kto",
|
| 91 |
+
label="KTO — Kahneman-Tversky Optimization",
|
| 92 |
+
description="Alignment from simple binary thumbs-up/down feedback (no pairs needed).",
|
| 93 |
+
loader="FastLanguageModel",
|
| 94 |
+
trainer="KTOTrainer",
|
| 95 |
+
config="KTOConfig",
|
| 96 |
+
dataset_schema=("prompt", "completion", "label"),
|
| 97 |
+
detector_formats=("kto",),
|
| 98 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 99 |
+
),
|
| 100 |
+
"grpo": TaskSpec(
|
| 101 |
+
id="grpo",
|
| 102 |
+
label="GRPO — Group Relative Policy Optimization",
|
| 103 |
+
description="Online RL with programmable reward functions (reasoning, math, code).",
|
| 104 |
+
loader="FastLanguageModel",
|
| 105 |
+
trainer="GRPOTrainer",
|
| 106 |
+
config="GRPOConfig",
|
| 107 |
+
dataset_schema=("prompt",),
|
| 108 |
+
detector_formats=("grpo",),
|
| 109 |
+
default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
| 110 |
+
notes="Reward functions are plain Python callables; edit them in the generated script.",
|
| 111 |
+
),
|
| 112 |
+
"cpt": TaskSpec(
|
| 113 |
+
id="cpt",
|
| 114 |
+
label="CPT — Continual Pretraining",
|
| 115 |
+
description="Inject domain knowledge by continuing pretraining on raw text.",
|
| 116 |
+
loader="FastLanguageModel",
|
| 117 |
+
trainer="CPTTrainer",
|
| 118 |
+
config="CPTConfig",
|
| 119 |
+
dataset_schema=("text",),
|
| 120 |
+
detector_formats=("text",),
|
| 121 |
+
default_model="mlx-community/SmolLM2-360M-Instruct",
|
| 122 |
+
default_target_modules=FULL_TARGETS,
|
| 123 |
+
extra_config_defaults={"embedding_learning_rate": 5e-6, "include_embeddings": True},
|
| 124 |
+
),
|
| 125 |
+
"vlm_sft": TaskSpec(
|
| 126 |
+
id="vlm_sft",
|
| 127 |
+
label="Vision SFT — Vision-Language Models",
|
| 128 |
+
description="Fine-tune VLMs (Qwen-VL, LLaVA-style) on image + conversation data.",
|
| 129 |
+
loader="FastVisionModel",
|
| 130 |
+
trainer="VLMSFTTrainer",
|
| 131 |
+
config="VLMSFTConfig",
|
| 132 |
+
config_module="mlx_tune.vlm",
|
| 133 |
+
dataset_schema=("images", "messages"),
|
| 134 |
+
detector_formats=("vision_chat",),
|
| 135 |
+
default_model="mlx-community/Qwen2.5-VL-3B-Instruct-4bit",
|
| 136 |
+
modality="vision",
|
| 137 |
+
),
|
| 138 |
+
"tts_sft": TaskSpec(
|
| 139 |
+
id="tts_sft",
|
| 140 |
+
label="TTS SFT — Text-to-Speech",
|
| 141 |
+
description="Fine-tune speech synthesis models (Orpheus, OuteTTS, CSM…) on audio+text pairs.",
|
| 142 |
+
loader="FastTTSModel",
|
| 143 |
+
trainer="TTSSFTTrainer",
|
| 144 |
+
config="TTSSFTConfig",
|
| 145 |
+
collator="TTSDataCollator",
|
| 146 |
+
dataset_schema=("audio", "text"),
|
| 147 |
+
detector_formats=("audio_text",),
|
| 148 |
+
default_model="mlx-community/orpheus-3b-0.1-ft-bf16",
|
| 149 |
+
modality="audio",
|
| 150 |
+
notes="Audio training currently supports batch_size=1 (mlx-tune limitation).",
|
| 151 |
+
),
|
| 152 |
+
"stt_sft": TaskSpec(
|
| 153 |
+
id="stt_sft",
|
| 154 |
+
label="STT SFT — Speech-to-Text",
|
| 155 |
+
description="Fine-tune ASR models (Whisper, Parakeet, Canary…) on audio+transcription pairs.",
|
| 156 |
+
loader="FastSTTModel",
|
| 157 |
+
trainer="STTSFTTrainer",
|
| 158 |
+
config="STTSFTConfig",
|
| 159 |
+
collator="STTDataCollator",
|
| 160 |
+
dataset_schema=("audio", "text"),
|
| 161 |
+
detector_formats=("audio_text",),
|
| 162 |
+
default_model="mlx-community/whisper-tiny-asr-fp16",
|
| 163 |
+
modality="audio",
|
| 164 |
+
notes="Audio training currently supports batch_size=1 (mlx-tune limitation).",
|
| 165 |
+
),
|
| 166 |
+
"embedding": TaskSpec(
|
| 167 |
+
id="embedding",
|
| 168 |
+
label="Embedding SFT — Sentence Embeddings",
|
| 169 |
+
description="Contrastive fine-tuning of embedding models (anchor/positive pairs).",
|
| 170 |
+
loader="FastEmbeddingModel",
|
| 171 |
+
trainer="EmbeddingSFTTrainer",
|
| 172 |
+
config="EmbeddingSFTConfig",
|
| 173 |
+
dataset_schema=("anchor", "positive"),
|
| 174 |
+
detector_formats=("embedding_pairs",),
|
| 175 |
+
default_model="mlx-community/all-MiniLM-L6-v2-bf16",
|
| 176 |
+
extra_config_defaults={"loss_type": "infonce", "temperature": 0.05},
|
| 177 |
+
),
|
| 178 |
+
"ocr_sft": TaskSpec(
|
| 179 |
+
id="ocr_sft",
|
| 180 |
+
label="OCR SFT — Optical Character Recognition",
|
| 181 |
+
description="Fine-tune OCR models (DeepSeek-OCR, olmOCR…) on image + ground-truth text.",
|
| 182 |
+
loader="FastOCRModel",
|
| 183 |
+
trainer="OCRSFTTrainer",
|
| 184 |
+
config="OCRSFTConfig",
|
| 185 |
+
dataset_schema=("image", "text"),
|
| 186 |
+
detector_formats=("image_text",),
|
| 187 |
+
default_model="mlx-community/DeepSeek-OCR-8bit",
|
| 188 |
+
modality="image",
|
| 189 |
+
extra_config_defaults={"learning_rate": 5e-5},
|
| 190 |
+
),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_task(task_id: str) -> TaskSpec:
|
| 195 |
+
return TASKS[task_id]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def task_choices() -> list[tuple[str, str]]:
|
| 199 |
+
"""(label, id) pairs for a Gradio dropdown."""
|
| 200 |
+
return [(spec.label, spec.id) for spec in TASKS.values()]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def tasks_for_format(format_id: str) -> list[TaskSpec]:
|
| 204 |
+
return [spec for spec in TASKS.values() if format_id in spec.detector_formats]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
# Lazy mlx-tune access
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
def mlx_available() -> tuple[bool, str]:
|
| 212 |
+
"""Whether mlx-tune is importable on this machine, plus a human reason."""
|
| 213 |
+
if platform.machine() != "arm64" or platform.system() != "Darwin":
|
| 214 |
+
return False, "mlx-tune requires an Apple Silicon Mac (arm64/macOS)."
|
| 215 |
+
if importlib.util.find_spec("mlx_tune") is None:
|
| 216 |
+
return False, "mlx-tune is not installed. Run: pip install 'finetuner[mlx]'"
|
| 217 |
+
return True, "mlx-tune is available."
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def resolve(name: str, module: str = "mlx_tune"):
|
| 221 |
+
"""Import `name` from an mlx_tune module, raising a friendly error if missing."""
|
| 222 |
+
ok, reason = mlx_available()
|
| 223 |
+
if not ok:
|
| 224 |
+
raise RuntimeError(reason)
|
| 225 |
+
mod = importlib.import_module(module)
|
| 226 |
+
try:
|
| 227 |
+
return getattr(mod, name)
|
| 228 |
+
except AttributeError as exc:
|
| 229 |
+
raise RuntimeError(
|
| 230 |
+
f"`{name}` not found in `{module}`. Your mlx-tune version may be too old; "
|
| 231 |
+
"try: pip install -U mlx-tune"
|
| 232 |
+
) from exc
|
finetuner/core/state.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared application state for the (single-user, local) Studio session."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
from .detector import Detection
|
| 8 |
+
from .training import RunConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AppState:
|
| 13 |
+
# Model
|
| 14 |
+
model = None
|
| 15 |
+
tokenizer = None
|
| 16 |
+
model_name: str = ""
|
| 17 |
+
model_loaded_for_task: str = ""
|
| 18 |
+
lora_attached: bool = False
|
| 19 |
+
# Dataset
|
| 20 |
+
raw_rows: list[dict] = field(default_factory=list)
|
| 21 |
+
detection: Detection | None = None
|
| 22 |
+
dataset_source: str = ""
|
| 23 |
+
dataset_is_local: bool = False
|
| 24 |
+
# Run configuration
|
| 25 |
+
config: RunConfig = field(default_factory=RunConfig)
|
| 26 |
+
|
| 27 |
+
def reset_model(self):
|
| 28 |
+
self.model = None
|
| 29 |
+
self.tokenizer = None
|
| 30 |
+
self.model_name = ""
|
| 31 |
+
self.model_loaded_for_task = ""
|
| 32 |
+
self.lora_attached = False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
STATE = AppState()
|
finetuner/core/training.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build and run mlx-tune trainers from a flat GUI config dict."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
from .jobs import Job
|
| 8 |
+
from .registry import get_task, resolve
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RunConfig:
|
| 13 |
+
"""Everything the GUI collects, flattened into one serializable object."""
|
| 14 |
+
task: str = "sft"
|
| 15 |
+
model_name: str = ""
|
| 16 |
+
max_seq_length: int = 2048
|
| 17 |
+
load_in_4bit: bool = True
|
| 18 |
+
# LoRA
|
| 19 |
+
use_lora: bool = True
|
| 20 |
+
lora_r: int = 16
|
| 21 |
+
lora_alpha: int = 16
|
| 22 |
+
lora_dropout: float = 0.0
|
| 23 |
+
target_modules: list[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"])
|
| 24 |
+
# Training hyperparameters
|
| 25 |
+
output_dir: str = "outputs"
|
| 26 |
+
batch_size: int = 2
|
| 27 |
+
gradient_accumulation_steps: int = 1
|
| 28 |
+
learning_rate: float = 2e-4
|
| 29 |
+
max_steps: int = 100
|
| 30 |
+
num_train_epochs: float | None = None
|
| 31 |
+
warmup_steps: int = 5
|
| 32 |
+
gradient_checkpointing: bool = False
|
| 33 |
+
seed: int = 42
|
| 34 |
+
# Task-specific extras (beta for DPO, temperature for embeddings, ...)
|
| 35 |
+
extra: dict = field(default_factory=dict)
|
| 36 |
+
|
| 37 |
+
def to_dict(self) -> dict:
|
| 38 |
+
from dataclasses import asdict
|
| 39 |
+
return asdict(self)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_dict(cls, d: dict) -> "RunConfig":
|
| 43 |
+
known = {f for f in cls.__dataclass_fields__}
|
| 44 |
+
return cls(**{k: v for k, v in d.items() if k in known})
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _filtered_kwargs(config_cls, kwargs: dict) -> dict:
|
| 48 |
+
"""Drop kwargs the dataclass config doesn't accept (version drift safety)."""
|
| 49 |
+
fields = getattr(config_cls, "__dataclass_fields__", None)
|
| 50 |
+
if fields is None:
|
| 51 |
+
return kwargs
|
| 52 |
+
return {k: v for k, v in kwargs.items() if k in fields}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def build_trainer_args(cfg: RunConfig) -> dict:
|
| 56 |
+
spec = get_task(cfg.task)
|
| 57 |
+
args: dict = {
|
| 58 |
+
"output_dir": cfg.output_dir,
|
| 59 |
+
"per_device_train_batch_size": cfg.batch_size,
|
| 60 |
+
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
|
| 61 |
+
"learning_rate": cfg.learning_rate,
|
| 62 |
+
"warmup_steps": cfg.warmup_steps,
|
| 63 |
+
"seed": cfg.seed,
|
| 64 |
+
}
|
| 65 |
+
if cfg.num_train_epochs:
|
| 66 |
+
args["num_train_epochs"] = cfg.num_train_epochs
|
| 67 |
+
else:
|
| 68 |
+
args["max_steps"] = cfg.max_steps
|
| 69 |
+
if cfg.gradient_checkpointing:
|
| 70 |
+
args["gradient_checkpointing"] = True
|
| 71 |
+
args.update(spec.extra_config_defaults)
|
| 72 |
+
args.update(cfg.extra)
|
| 73 |
+
return args
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_training(job: Job, cfg: RunConfig, model, tokenizer, dataset: list[dict]):
|
| 77 |
+
"""Job target: construct the task's trainer and train. Runs on a worker thread."""
|
| 78 |
+
spec = get_task(cfg.task)
|
| 79 |
+
trainer_cls = resolve(spec.trainer)
|
| 80 |
+
config_cls = resolve(spec.config, spec.config_module)
|
| 81 |
+
|
| 82 |
+
args = config_cls(**_filtered_kwargs(config_cls, build_trainer_args(cfg)))
|
| 83 |
+
|
| 84 |
+
trainer_kwargs: dict = {"model": model, "train_dataset": dataset, "args": args}
|
| 85 |
+
if spec.modality == "vision" or spec.id == "ocr_sft":
|
| 86 |
+
trainer_kwargs["processor"] = tokenizer
|
| 87 |
+
else:
|
| 88 |
+
trainer_kwargs["tokenizer"] = tokenizer
|
| 89 |
+
if spec.collator:
|
| 90 |
+
collator_cls = resolve(spec.collator)
|
| 91 |
+
trainer_kwargs["data_collator"] = collator_cls(model, tokenizer)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
trainer = trainer_cls(**trainer_kwargs)
|
| 95 |
+
except TypeError:
|
| 96 |
+
# Some trainers take `processor` instead of `tokenizer` or vice versa.
|
| 97 |
+
if "tokenizer" in trainer_kwargs:
|
| 98 |
+
trainer_kwargs["processor"] = trainer_kwargs.pop("tokenizer")
|
| 99 |
+
else:
|
| 100 |
+
trainer_kwargs["tokenizer"] = trainer_kwargs.pop("processor")
|
| 101 |
+
trainer = trainer_cls(**trainer_kwargs)
|
| 102 |
+
|
| 103 |
+
job.add_log(f"▶ {spec.label} started — {len(dataset)} samples, output → {cfg.output_dir}")
|
| 104 |
+
trainer.train()
|
| 105 |
+
job.add_log("✅ Training finished.")
|
finetuner/ui/__init__.py
ADDED
|
File without changes
|
finetuner/ui/tab_dataset.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset tab: load from the Hub, a local path, or upload — then auto-detect the format."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from ..core import data as datalib
|
| 9 |
+
from ..core.detector import detect
|
| 10 |
+
from ..core.registry import get_task
|
| 11 |
+
from ..core.state import STATE
|
| 12 |
+
|
| 13 |
+
PREVIEW_ROWS = 8
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _preview_df(rows: list[dict]) -> pd.DataFrame:
|
| 17 |
+
if not rows:
|
| 18 |
+
return pd.DataFrame()
|
| 19 |
+
df = pd.DataFrame(rows[:PREVIEW_ROWS])
|
| 20 |
+
return df.map(lambda v: str(v)[:300] if not isinstance(v, (int, float, bool)) else v)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _detection_md() -> str:
|
| 24 |
+
det = STATE.detection
|
| 25 |
+
if det is None:
|
| 26 |
+
return ""
|
| 27 |
+
bar = "🟩" * round(det.confidence * 10) + "⬜" * (10 - round(det.confidence * 10))
|
| 28 |
+
lines = [
|
| 29 |
+
f"### 🔎 Detected format: **{det.label}**",
|
| 30 |
+
f"Confidence: {bar} **{det.confidence:.0%}**",
|
| 31 |
+
]
|
| 32 |
+
if det.mapping:
|
| 33 |
+
mapped = ", ".join(f"`{k}` ← `{v}`" for k, v in det.mapping.items())
|
| 34 |
+
lines.append(f"Column mapping: {mapped}")
|
| 35 |
+
if det.suggested_tasks:
|
| 36 |
+
tasks = ", ".join(f"**{get_task(t).label}**" for t in det.suggested_tasks)
|
| 37 |
+
lines.append(f"Compatible trainers: {tasks}")
|
| 38 |
+
for note in det.notes:
|
| 39 |
+
lines.append(f"> 💡 {note}")
|
| 40 |
+
return "\n\n".join(lines)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _ingest(rows: list[dict], source: str, is_local: bool) -> tuple[str, pd.DataFrame, str]:
|
| 44 |
+
STATE.raw_rows = rows
|
| 45 |
+
STATE.detection = detect(rows)
|
| 46 |
+
STATE.dataset_source = source
|
| 47 |
+
STATE.dataset_is_local = is_local
|
| 48 |
+
return (f"✅ Loaded **{len(rows)}** rows from `{source}`.",
|
| 49 |
+
_preview_df(rows), _detection_md())
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build(app):
|
| 53 |
+
with gr.Tab("📚 Dataset", id="dataset"):
|
| 54 |
+
gr.Markdown("### Load a dataset — the format is detected automatically")
|
| 55 |
+
source = gr.Radio(["Hugging Face Hub", "Local file", "Upload"],
|
| 56 |
+
value="Hugging Face Hub", label="Source")
|
| 57 |
+
|
| 58 |
+
with gr.Group() as hub_group:
|
| 59 |
+
with gr.Row():
|
| 60 |
+
query = gr.Textbox(label="Search Hub datasets", placeholder="e.g. alpaca turkish", scale=3)
|
| 61 |
+
search_btn = gr.Button("🔍 Search", scale=1)
|
| 62 |
+
with gr.Row():
|
| 63 |
+
ds_name = gr.Dropdown(label="Dataset", allow_custom_value=True, choices=[],
|
| 64 |
+
info="Pick a result or type any dataset id.", scale=3)
|
| 65 |
+
split = gr.Textbox(value="train", label="Split", scale=1)
|
| 66 |
+
subset = gr.Textbox(value="", label="Config (optional)", scale=1)
|
| 67 |
+
|
| 68 |
+
local_path = gr.Textbox(label="Local dataset path", visible=False,
|
| 69 |
+
placeholder="~/data/train.jsonl (.jsonl/.json/.csv/.tsv/.parquet)")
|
| 70 |
+
upload = gr.File(label="Upload dataset", visible=False,
|
| 71 |
+
file_types=[".jsonl", ".json", ".csv", ".tsv", ".parquet"])
|
| 72 |
+
|
| 73 |
+
with gr.Row():
|
| 74 |
+
max_rows = gr.Number(value=0, precision=0, label="Max rows (0 = all)")
|
| 75 |
+
load_btn = gr.Button("📥 Load dataset", variant="primary", scale=2)
|
| 76 |
+
|
| 77 |
+
status = gr.Markdown()
|
| 78 |
+
detection_panel = gr.Markdown()
|
| 79 |
+
preview = gr.Dataframe(label=f"Preview (first {PREVIEW_ROWS} rows)", interactive=False, wrap=True)
|
| 80 |
+
|
| 81 |
+
# ----- events -------------------------------------------------------
|
| 82 |
+
def on_source(src):
|
| 83 |
+
return (gr.update(visible=src == "Hugging Face Hub"),
|
| 84 |
+
gr.update(visible=src == "Local file"),
|
| 85 |
+
gr.update(visible=src == "Upload"))
|
| 86 |
+
|
| 87 |
+
source.change(on_source, source, [hub_group, local_path, upload])
|
| 88 |
+
|
| 89 |
+
def on_search(q):
|
| 90 |
+
results = datalib.search_hub_datasets(q)
|
| 91 |
+
if not results:
|
| 92 |
+
gr.Warning(f"No Hub datasets found for {q!r}")
|
| 93 |
+
return gr.update()
|
| 94 |
+
return gr.update(choices=results, value=results[0])
|
| 95 |
+
|
| 96 |
+
search_btn.click(on_search, query, ds_name)
|
| 97 |
+
query.submit(on_search, query, ds_name)
|
| 98 |
+
|
| 99 |
+
def on_load(src, name, split_v, subset_v, path, file, n, progress=gr.Progress()):
|
| 100 |
+
limit = int(n) or None
|
| 101 |
+
try:
|
| 102 |
+
if src == "Hugging Face Hub":
|
| 103 |
+
if not name:
|
| 104 |
+
return "❌ Choose a dataset first.", gr.update(), ""
|
| 105 |
+
progress(0.2, desc=f"Downloading {name} …")
|
| 106 |
+
rows = datalib.load_hub_dataset(name, split_v or "train", subset_v or None, limit)
|
| 107 |
+
return _ingest(rows, name, is_local=False)
|
| 108 |
+
target = path if src == "Local file" else (file.name if file else "")
|
| 109 |
+
if not target:
|
| 110 |
+
return "❌ Provide a file first.", gr.update(), ""
|
| 111 |
+
rows = datalib.load_local_dataset(target, limit)
|
| 112 |
+
return _ingest(rows, target, is_local=True)
|
| 113 |
+
except Exception as exc:
|
| 114 |
+
return f"❌ Failed to load dataset: {exc}", gr.update(), ""
|
| 115 |
+
|
| 116 |
+
load_btn.click(on_load, [source, ds_name, split, subset, local_path, upload, max_rows],
|
| 117 |
+
[status, preview, detection_panel])
|
finetuner/ui/tab_export.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Export tab: adapters, merged weights, GGUF, and Hugging Face Hub upload."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from ..core import export
|
| 8 |
+
from ..core.state import STATE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _guard():
|
| 12 |
+
if STATE.model is None:
|
| 13 |
+
raise gr.Error("No model in memory — load and train one first.")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build(app):
|
| 17 |
+
with gr.Tab("📦 Export", id="export"):
|
| 18 |
+
gr.Markdown("### Save or publish the fine-tuned model")
|
| 19 |
+
|
| 20 |
+
with gr.Group():
|
| 21 |
+
gr.Markdown("**LoRA adapters** — small, fast to share")
|
| 22 |
+
with gr.Row():
|
| 23 |
+
adapter_path = gr.Textbox(value="lora_model", label="Directory", scale=3)
|
| 24 |
+
adapter_btn = gr.Button("💾 Save adapters", scale=1)
|
| 25 |
+
|
| 26 |
+
with gr.Group():
|
| 27 |
+
gr.Markdown("**Merged model** — base weights + adapters fused to 16-bit")
|
| 28 |
+
with gr.Row():
|
| 29 |
+
merged_path = gr.Textbox(value="merged", label="Directory", scale=3)
|
| 30 |
+
merged_btn = gr.Button("🔗 Save merged", scale=1)
|
| 31 |
+
|
| 32 |
+
with gr.Group():
|
| 33 |
+
gr.Markdown("**GGUF** — for llama.cpp / Ollama. "
|
| 34 |
+
"⚠️ Requires a *non-quantized* base model (mlx-lm limitation).")
|
| 35 |
+
with gr.Row():
|
| 36 |
+
gguf_path = gr.Textbox(value="model_gguf", label="Directory", scale=3)
|
| 37 |
+
gguf_btn = gr.Button("🦙 Export GGUF", scale=1)
|
| 38 |
+
|
| 39 |
+
with gr.Group():
|
| 40 |
+
gr.Markdown("**Hugging Face Hub** — publish the model to your account")
|
| 41 |
+
with gr.Row():
|
| 42 |
+
repo_id = gr.Textbox(label="Repo id", placeholder="username/my-finetuned-model", scale=2)
|
| 43 |
+
hf_token = gr.Textbox(label="HF token (optional if logged in)", type="password", scale=2)
|
| 44 |
+
push_btn = gr.Button("🤗 Push to Hub", variant="primary", scale=1)
|
| 45 |
+
|
| 46 |
+
status = gr.Markdown()
|
| 47 |
+
|
| 48 |
+
def run(fn, *args):
|
| 49 |
+
_guard()
|
| 50 |
+
try:
|
| 51 |
+
return f"✅ {fn(*args)}"
|
| 52 |
+
except Exception as exc:
|
| 53 |
+
return f"❌ {exc}"
|
| 54 |
+
|
| 55 |
+
adapter_btn.click(lambda p: run(export.save_adapters, STATE.model, p),
|
| 56 |
+
adapter_path, status)
|
| 57 |
+
merged_btn.click(lambda p: run(export.save_merged, STATE.model, STATE.tokenizer, p),
|
| 58 |
+
merged_path, status)
|
| 59 |
+
gguf_btn.click(lambda p: run(export.save_gguf, STATE.model, STATE.tokenizer, p),
|
| 60 |
+
gguf_path, status)
|
| 61 |
+
push_btn.click(lambda r, t: run(export.push_to_hub, STATE.model, r, t or None),
|
| 62 |
+
[repo_id, hf_token], status)
|
finetuner/ui/tab_model.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model tab: pick a task, find a model (Hub search or local path), load it, attach LoRA."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from ..core import models
|
| 8 |
+
from ..core.registry import get_task, mlx_available, task_choices
|
| 9 |
+
from ..core.state import STATE
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _task_info(task_id: str) -> str:
|
| 13 |
+
spec = get_task(task_id)
|
| 14 |
+
lines = [f"**{spec.label}**", "", spec.description, "",
|
| 15 |
+
f"- Backend: `mlx_tune.{spec.trainer}` + `{spec.config}`",
|
| 16 |
+
f"- Dataset schema: `{', '.join(spec.dataset_schema)}`"]
|
| 17 |
+
if spec.notes:
|
| 18 |
+
lines.append(f"- ⚠️ {spec.notes}")
|
| 19 |
+
return "\n".join(lines)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build(app):
|
| 23 |
+
with gr.Tab("🧠 Model", id="model"):
|
| 24 |
+
gr.Markdown("### 1 · Choose a task and a base model")
|
| 25 |
+
with gr.Row():
|
| 26 |
+
with gr.Column(scale=1):
|
| 27 |
+
task = gr.Dropdown(choices=task_choices(), value="sft", label="Training task",
|
| 28 |
+
info="Every mlx-tune trainer is available here.")
|
| 29 |
+
task_info = gr.Markdown(_task_info("sft"))
|
| 30 |
+
with gr.Column(scale=2):
|
| 31 |
+
source = gr.Radio(["Hugging Face Hub", "Local path"], value="Hugging Face Hub",
|
| 32 |
+
label="Model source")
|
| 33 |
+
with gr.Group() as hub_group:
|
| 34 |
+
with gr.Row():
|
| 35 |
+
query = gr.Textbox(label="Search the Hub",
|
| 36 |
+
placeholder="e.g. llama 3.2 instruct 4bit", scale=3)
|
| 37 |
+
search_btn = gr.Button("🔍 Search", scale=1)
|
| 38 |
+
model_name = gr.Dropdown(label="Model", allow_custom_value=True,
|
| 39 |
+
value=get_task("sft").default_model,
|
| 40 |
+
choices=[get_task("sft").default_model],
|
| 41 |
+
info="Pick a search result or type any repo id.")
|
| 42 |
+
local_path = gr.Textbox(label="Local model directory", visible=False,
|
| 43 |
+
placeholder="/path/to/converted-mlx-model")
|
| 44 |
+
with gr.Row():
|
| 45 |
+
max_seq = gr.Slider(256, 32768, value=2048, step=256, label="Max sequence length")
|
| 46 |
+
four_bit = gr.Checkbox(value=True, label="Load in 4-bit")
|
| 47 |
+
|
| 48 |
+
gr.Markdown("### 2 · LoRA adapters")
|
| 49 |
+
with gr.Row():
|
| 50 |
+
use_lora = gr.Checkbox(value=True, label="Attach LoRA", scale=1)
|
| 51 |
+
lora_r = gr.Slider(1, 256, value=16, step=1, label="Rank (r)", scale=2)
|
| 52 |
+
lora_alpha = gr.Slider(1, 256, value=16, step=1, label="Alpha", scale=2)
|
| 53 |
+
lora_dropout = gr.Slider(0.0, 0.5, value=0.0, step=0.01, label="Dropout", scale=2)
|
| 54 |
+
target_modules = gr.CheckboxGroup(
|
| 55 |
+
choices=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 56 |
+
value=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 57 |
+
label="Target modules (text models)")
|
| 58 |
+
|
| 59 |
+
load_btn = gr.Button("⚡ Load model", variant="primary")
|
| 60 |
+
status = gr.Markdown()
|
| 61 |
+
|
| 62 |
+
# ----- events -------------------------------------------------------
|
| 63 |
+
def on_task(task_id):
|
| 64 |
+
spec = get_task(task_id)
|
| 65 |
+
return (_task_info(task_id),
|
| 66 |
+
gr.update(value=spec.default_model, choices=[spec.default_model]),
|
| 67 |
+
gr.update(value=list(spec.default_target_modules)))
|
| 68 |
+
|
| 69 |
+
task.change(on_task, task, [task_info, model_name, target_modules])
|
| 70 |
+
|
| 71 |
+
def on_source(src):
|
| 72 |
+
hub = src == "Hugging Face Hub"
|
| 73 |
+
return gr.update(visible=hub), gr.update(visible=not hub)
|
| 74 |
+
|
| 75 |
+
source.change(on_source, source, [hub_group, local_path])
|
| 76 |
+
|
| 77 |
+
def on_search(q):
|
| 78 |
+
results = models.search_hub_models(q)
|
| 79 |
+
if not results:
|
| 80 |
+
gr.Warning(f"No Hub models found for {q!r}")
|
| 81 |
+
return gr.update()
|
| 82 |
+
return gr.update(choices=results, value=results[0])
|
| 83 |
+
|
| 84 |
+
search_btn.click(on_search, query, model_name)
|
| 85 |
+
query.submit(on_search, query, model_name)
|
| 86 |
+
|
| 87 |
+
def on_load(task_id, src, name, path, seq, fourbit,
|
| 88 |
+
lora, r, alpha, dropout, targets, progress=gr.Progress()):
|
| 89 |
+
ok, reason = mlx_available()
|
| 90 |
+
if not ok:
|
| 91 |
+
return f"❌ {reason}"
|
| 92 |
+
resolved = name
|
| 93 |
+
try:
|
| 94 |
+
if src == "Local path":
|
| 95 |
+
resolved = models.validate_local_model(path)
|
| 96 |
+
progress(0.1, desc=f"Loading {resolved} …")
|
| 97 |
+
model, tok = models.load_model(task_id, resolved, int(seq), bool(fourbit))
|
| 98 |
+
if lora and get_task(task_id).peft_supported:
|
| 99 |
+
progress(0.7, desc="Attaching LoRA adapters …")
|
| 100 |
+
model = models.apply_lora(task_id, model, int(r), int(alpha),
|
| 101 |
+
float(dropout), list(targets))
|
| 102 |
+
STATE.model, STATE.tokenizer = model, tok
|
| 103 |
+
STATE.model_name = resolved
|
| 104 |
+
STATE.model_loaded_for_task = task_id
|
| 105 |
+
STATE.lora_attached = bool(lora)
|
| 106 |
+
cfg = STATE.config
|
| 107 |
+
cfg.task, cfg.model_name = task_id, resolved
|
| 108 |
+
cfg.max_seq_length, cfg.load_in_4bit = int(seq), bool(fourbit)
|
| 109 |
+
cfg.use_lora, cfg.lora_r, cfg.lora_alpha = bool(lora), int(r), int(alpha)
|
| 110 |
+
cfg.lora_dropout, cfg.target_modules = float(dropout), list(targets)
|
| 111 |
+
return (f"✅ **{resolved}** loaded for **{get_task(task_id).label}**"
|
| 112 |
+
+ (" with LoRA attached." if lora else "."))
|
| 113 |
+
except Exception as exc: # surfaced to the user, not crashed
|
| 114 |
+
return f"❌ Load failed: {exc}"
|
| 115 |
+
|
| 116 |
+
load_btn.click(
|
| 117 |
+
on_load,
|
| 118 |
+
[task, source, model_name, local_path, max_seq, four_bit,
|
| 119 |
+
use_lora, lora_r, lora_alpha, lora_dropout, target_modules],
|
| 120 |
+
status)
|
| 121 |
+
|
| 122 |
+
return {"task": task}
|
finetuner/ui/tab_monitor.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Monitor tab: live logs, loss curve and job control, refreshed by a timer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from ..core.jobs import MANAGER
|
| 9 |
+
|
| 10 |
+
STATUS_ICONS = {"pending": "⏳", "running": "🏃", "finished": "✅",
|
| 11 |
+
"failed": "❌", "stopped": "⏹"}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _job_choices() -> list[tuple[str, int]]:
|
| 15 |
+
return [(f"#{j.id} {STATUS_ICONS.get(j.status, '')} {j.name}", j.id)
|
| 16 |
+
for j in reversed(MANAGER.all())]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _snapshot(job_id):
|
| 20 |
+
job = MANAGER.get(int(job_id)) if job_id else MANAGER.latest()
|
| 21 |
+
if job is None:
|
| 22 |
+
return ("*No jobs yet — start one from the 🚀 Train tab.*", "",
|
| 23 |
+
pd.DataFrame({"step": [], "loss": []}))
|
| 24 |
+
header = (f"**Job #{job.id}** · {job.name} · "
|
| 25 |
+
f"{STATUS_ICONS.get(job.status, '')} **{job.status}** · "
|
| 26 |
+
f"⏱ {job.elapsed:.0f}s · {len(job.metrics)} loss points")
|
| 27 |
+
df = pd.DataFrame(job.metrics, columns=["step", "loss"]) if job.metrics \
|
| 28 |
+
else pd.DataFrame({"step": [], "loss": []})
|
| 29 |
+
return header, job.log_text(), df
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build(app):
|
| 33 |
+
with gr.Tab("📈 Monitor", id="monitor"):
|
| 34 |
+
with gr.Row():
|
| 35 |
+
job_pick = gr.Dropdown(label="Job", choices=_job_choices(), scale=3)
|
| 36 |
+
refresh_btn = gr.Button("🔄 Refresh list", scale=1)
|
| 37 |
+
stop_btn = gr.Button("⏹ Stop job", variant="stop", scale=1)
|
| 38 |
+
header = gr.Markdown("*No jobs yet — start one from the 🚀 Train tab.*")
|
| 39 |
+
with gr.Row():
|
| 40 |
+
with gr.Column(scale=1):
|
| 41 |
+
loss_plot = gr.LinePlot(x="step", y="loss", label="Training loss",
|
| 42 |
+
value=pd.DataFrame({"step": [], "loss": []}))
|
| 43 |
+
with gr.Column(scale=1):
|
| 44 |
+
logs = gr.Textbox(label="Live logs", lines=20, max_lines=20,
|
| 45 |
+
autoscroll=True, interactive=False)
|
| 46 |
+
|
| 47 |
+
timer = gr.Timer(2.0)
|
| 48 |
+
timer.tick(_snapshot, job_pick, [header, logs, loss_plot])
|
| 49 |
+
|
| 50 |
+
refresh_btn.click(lambda: gr.update(choices=_job_choices()), None, job_pick)
|
| 51 |
+
job_pick.change(_snapshot, job_pick, [header, logs, loss_plot])
|
| 52 |
+
|
| 53 |
+
def on_stop(job_id):
|
| 54 |
+
job = MANAGER.get(int(job_id)) if job_id else MANAGER.latest()
|
| 55 |
+
if job is None:
|
| 56 |
+
return gr.update()
|
| 57 |
+
MANAGER.stop(job.id)
|
| 58 |
+
return gr.update(choices=_job_choices())
|
| 59 |
+
|
| 60 |
+
stop_btn.click(on_stop, job_pick, job_pick)
|
finetuner/ui/tab_playground.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Playground tab: chat with the currently loaded (and freshly tuned) model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from ..core.engine import ENGINE
|
| 8 |
+
from ..core.state import STATE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _generate(message: str, history: list[dict], max_tokens: int, temperature: float) -> str:
|
| 12 |
+
# MLX generation must run on the engine thread (streams are thread-local).
|
| 13 |
+
return ENGINE.call(_generate_inner, message, history, max_tokens, temperature)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _generate_inner(message: str, history: list[dict], max_tokens: int, temperature: float) -> str:
|
| 17 |
+
if STATE.model is None or STATE.tokenizer is None:
|
| 18 |
+
return "⚠️ No model loaded — load one in the 🧠 Model tab first."
|
| 19 |
+
|
| 20 |
+
messages = [{"role": h["role"], "content": h["content"]} for h in history]
|
| 21 |
+
messages.append({"role": "user", "content": message})
|
| 22 |
+
|
| 23 |
+
tok = STATE.tokenizer
|
| 24 |
+
try:
|
| 25 |
+
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 26 |
+
except Exception:
|
| 27 |
+
prompt = "\n".join(m["content"] for m in messages)
|
| 28 |
+
|
| 29 |
+
# mlx-tune models are mlx-lm compatible; prefer its generate().
|
| 30 |
+
try:
|
| 31 |
+
from mlx_lm import generate as mlx_generate
|
| 32 |
+
from mlx_lm.sample_utils import make_sampler
|
| 33 |
+
return mlx_generate(STATE.model, tok, prompt=prompt, max_tokens=int(max_tokens),
|
| 34 |
+
sampler=make_sampler(temp=float(temperature)), verbose=False)
|
| 35 |
+
except Exception:
|
| 36 |
+
pass
|
| 37 |
+
try: # older mlx-lm signature
|
| 38 |
+
from mlx_lm import generate as mlx_generate
|
| 39 |
+
return mlx_generate(STATE.model, tok, prompt=prompt,
|
| 40 |
+
max_tokens=int(max_tokens), verbose=False)
|
| 41 |
+
except Exception as exc:
|
| 42 |
+
return f"⚠️ Generation failed: {exc}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build(app):
|
| 46 |
+
with gr.Tab("💬 Playground", id="playground"):
|
| 47 |
+
gr.Markdown("### Test the loaded model — before and after fine-tuning")
|
| 48 |
+
with gr.Row():
|
| 49 |
+
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max new tokens")
|
| 50 |
+
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
| 51 |
+
gr.ChatInterface(
|
| 52 |
+
fn=_generate,
|
| 53 |
+
additional_inputs=[max_tokens, temperature],
|
| 54 |
+
examples=[["Merhaba! Kendini tanıtır mısın?"],
|
| 55 |
+
["Explain LoRA fine-tuning in two sentences."]],
|
| 56 |
+
)
|
finetuner/ui/tab_train.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train tab: hyperparameters, recipes, the code generator, and the launch button."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from ..core import recipes
|
| 8 |
+
from ..core.codegen import generate_script
|
| 9 |
+
from ..core.detector import normalize
|
| 10 |
+
from ..core.jobs import MANAGER
|
| 11 |
+
from ..core.registry import get_task, mlx_available
|
| 12 |
+
from ..core.state import STATE
|
| 13 |
+
from ..core.training import run_training
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _collect(cfg_fields: dict) -> None:
|
| 17 |
+
cfg = STATE.config
|
| 18 |
+
cfg.output_dir = cfg_fields["output_dir"]
|
| 19 |
+
cfg.batch_size = int(cfg_fields["batch_size"])
|
| 20 |
+
cfg.gradient_accumulation_steps = int(cfg_fields["grad_accum"])
|
| 21 |
+
cfg.learning_rate = float(cfg_fields["lr"])
|
| 22 |
+
cfg.max_steps = int(cfg_fields["max_steps"])
|
| 23 |
+
cfg.num_train_epochs = float(cfg_fields["epochs"]) or None
|
| 24 |
+
cfg.warmup_steps = int(cfg_fields["warmup"])
|
| 25 |
+
cfg.gradient_checkpointing = bool(cfg_fields["grad_ckpt"])
|
| 26 |
+
cfg.seed = int(cfg_fields["seed"])
|
| 27 |
+
extra = {}
|
| 28 |
+
if cfg.task in ("dpo", "orpo") and cfg_fields["beta"]:
|
| 29 |
+
extra["beta"] = float(cfg_fields["beta"])
|
| 30 |
+
cfg.extra = extra
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build(app):
|
| 34 |
+
with gr.Tab("🚀 Train", id="train"):
|
| 35 |
+
gr.Markdown("### Hyperparameters")
|
| 36 |
+
with gr.Row():
|
| 37 |
+
output_dir = gr.Textbox(value="outputs", label="Output directory")
|
| 38 |
+
batch_size = gr.Slider(1, 32, value=2, step=1, label="Batch size")
|
| 39 |
+
grad_accum = gr.Slider(1, 64, value=1, step=1, label="Gradient accumulation")
|
| 40 |
+
with gr.Row():
|
| 41 |
+
lr = gr.Number(value=2e-4, label="Learning rate")
|
| 42 |
+
max_steps = gr.Slider(10, 10000, value=100, step=10, label="Max steps")
|
| 43 |
+
epochs = gr.Number(value=0, label="Epochs (0 → use max steps)")
|
| 44 |
+
with gr.Row():
|
| 45 |
+
warmup = gr.Slider(0, 500, value=5, step=5, label="Warmup steps")
|
| 46 |
+
seed = gr.Number(value=42, precision=0, label="Seed")
|
| 47 |
+
grad_ckpt = gr.Checkbox(value=False, label="Gradient checkpointing (saves memory)")
|
| 48 |
+
beta = gr.Number(value=0.1, label="β (DPO/ORPO only)")
|
| 49 |
+
|
| 50 |
+
with gr.Row():
|
| 51 |
+
start_btn = gr.Button("🏁 Start training", variant="primary", scale=2)
|
| 52 |
+
gen_btn = gr.Button("🧾 Generate Python script", scale=1)
|
| 53 |
+
status = gr.Markdown()
|
| 54 |
+
|
| 55 |
+
with gr.Accordion("Generated script (standalone mlx-tune code)", open=False):
|
| 56 |
+
script_out = gr.Code(language="python", label="train.py")
|
| 57 |
+
gr.Markdown("Copy this script anywhere — it reproduces this run without the GUI.")
|
| 58 |
+
|
| 59 |
+
with gr.Accordion("Recipes (save / load runs as YAML)", open=False):
|
| 60 |
+
with gr.Row():
|
| 61 |
+
recipe_name = gr.Textbox(label="Recipe name", placeholder="my-sft-run")
|
| 62 |
+
save_recipe_btn = gr.Button("💾 Save recipe")
|
| 63 |
+
with gr.Row():
|
| 64 |
+
recipe_pick = gr.Dropdown(label="Saved recipes", choices=recipes.list_recipes(),
|
| 65 |
+
allow_custom_value=True)
|
| 66 |
+
load_recipe_btn = gr.Button("📂 Load recipe")
|
| 67 |
+
recipe_status = gr.Markdown()
|
| 68 |
+
|
| 69 |
+
hp_inputs = [output_dir, batch_size, grad_accum, lr, max_steps, epochs,
|
| 70 |
+
warmup, seed, grad_ckpt, beta]
|
| 71 |
+
|
| 72 |
+
def _fields(*vals) -> dict:
|
| 73 |
+
keys = ["output_dir", "batch_size", "grad_accum", "lr", "max_steps",
|
| 74 |
+
"epochs", "warmup", "seed", "grad_ckpt", "beta"]
|
| 75 |
+
return dict(zip(keys, vals))
|
| 76 |
+
|
| 77 |
+
# ----- start training -------------------------------------------------
|
| 78 |
+
def on_start(*vals):
|
| 79 |
+
_collect(_fields(*vals))
|
| 80 |
+
cfg = STATE.config
|
| 81 |
+
ok, reason = mlx_available()
|
| 82 |
+
if not ok:
|
| 83 |
+
return f"❌ {reason}"
|
| 84 |
+
if STATE.model is None:
|
| 85 |
+
return "❌ Load a model first (🧠 Model tab)."
|
| 86 |
+
if not STATE.raw_rows:
|
| 87 |
+
return "❌ Load a dataset first (📚 Dataset tab)."
|
| 88 |
+
if STATE.model_loaded_for_task != cfg.task:
|
| 89 |
+
cfg.task = STATE.model_loaded_for_task
|
| 90 |
+
try:
|
| 91 |
+
dataset = normalize(STATE.raw_rows, STATE.detection, cfg.task, STATE.tokenizer)
|
| 92 |
+
except Exception as exc:
|
| 93 |
+
return (f"❌ Dataset incompatible with **{get_task(cfg.task).label}**: {exc}\n\n"
|
| 94 |
+
f"Detected format: {STATE.detection.label if STATE.detection else '—'}")
|
| 95 |
+
job = MANAGER.submit(f"{cfg.task} · {cfg.model_name}", run_training,
|
| 96 |
+
cfg, STATE.model, STATE.tokenizer, dataset)
|
| 97 |
+
return (f"🏃 **Job #{job.id}** started ({len(dataset)} samples). "
|
| 98 |
+
"Follow it in the **📈 Monitor** tab.")
|
| 99 |
+
|
| 100 |
+
start_btn.click(on_start, hp_inputs, status)
|
| 101 |
+
|
| 102 |
+
# ----- codegen --------------------------------------------------------
|
| 103 |
+
def on_generate(*vals):
|
| 104 |
+
_collect(_fields(*vals))
|
| 105 |
+
cfg = STATE.config
|
| 106 |
+
if STATE.model_loaded_for_task:
|
| 107 |
+
cfg.task = STATE.model_loaded_for_task
|
| 108 |
+
if not cfg.model_name:
|
| 109 |
+
cfg.model_name = get_task(cfg.task).default_model
|
| 110 |
+
return generate_script(cfg, STATE.dataset_source, STATE.dataset_is_local)
|
| 111 |
+
|
| 112 |
+
gen_btn.click(on_generate, hp_inputs, script_out)
|
| 113 |
+
|
| 114 |
+
# ----- recipes ---------------------------------------------------------
|
| 115 |
+
def on_save_recipe(name, *vals):
|
| 116 |
+
_collect(_fields(*vals))
|
| 117 |
+
path = recipes.save_recipe(STATE.config, name, STATE.dataset_source,
|
| 118 |
+
STATE.dataset_is_local)
|
| 119 |
+
return f"💾 Saved `{path}`", gr.update(choices=recipes.list_recipes())
|
| 120 |
+
|
| 121 |
+
save_recipe_btn.click(on_save_recipe, [recipe_name, *hp_inputs],
|
| 122 |
+
[recipe_status, recipe_pick])
|
| 123 |
+
|
| 124 |
+
def on_load_recipe(path):
|
| 125 |
+
if not path:
|
| 126 |
+
return ["❌ Pick a recipe."] + [gr.update()] * len(hp_inputs)
|
| 127 |
+
try:
|
| 128 |
+
cfg, src, is_local = recipes.load_recipe(path)
|
| 129 |
+
except Exception as exc:
|
| 130 |
+
return [f"❌ {exc}"] + [gr.update()] * len(hp_inputs)
|
| 131 |
+
STATE.config = cfg
|
| 132 |
+
STATE.dataset_source, STATE.dataset_is_local = src, is_local
|
| 133 |
+
return [f"📂 Loaded `{path}` — task **{cfg.task}**, model `{cfg.model_name}`. "
|
| 134 |
+
"Reload model/dataset to run it.",
|
| 135 |
+
cfg.output_dir, cfg.batch_size, cfg.gradient_accumulation_steps,
|
| 136 |
+
cfg.learning_rate, cfg.max_steps, cfg.num_train_epochs or 0,
|
| 137 |
+
cfg.warmup_steps, cfg.seed, cfg.gradient_checkpointing,
|
| 138 |
+
cfg.extra.get("beta", 0.1)]
|
| 139 |
+
|
| 140 |
+
load_recipe_btn.click(on_load_recipe, recipe_pick, [recipe_status, *hp_inputs])
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets>=3.0
|
| 2 |
+
pandas>=2.0
|
| 3 |
+
pyyaml>=6.0
|