acayir64 commited on
Commit
e187c2c
·
verified ·
1 Parent(s): 269350e

Finetuner Studio GUI demo (planning mode)

Browse files
README.md CHANGED
@@ -1,13 +1,33 @@
1
  ---
2
  title: Finetuner Studio
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 6.17.3
8
- python_version: '3.13'
9
  app_file: app.py
10
- pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Finetuner Studio
3
+ emoji: 🎛️
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.17.3
 
8
  app_file: app.py
9
+ python_version: "3.12"
10
+ short_description: Low-code MLX fine-tuning studio (GUI demo)
11
  ---
12
 
13
+ # 🎛️ Finetuner Studio GUI Demo
14
+
15
+ **Low-code fine-tuning on Apple Silicon**, powered by
16
+ [mlx-tune](https://github.com/ARahim3/mlx-tune). This Space runs in
17
+ **GUI-only mode** (Spaces have no Apple Silicon): explore the interface, load
18
+ any Hugging Face dataset and watch the **automatic format detection**, and
19
+ generate standalone mlx-tune training scripts with the code generator.
20
+
21
+ For actual training (12 paradigms: SFT, DPO, ORPO, SimPO, KTO, GRPO, CPT,
22
+ VLM, TTS, STT, Embedding, OCR), run it on a Mac:
23
+
24
+ ```bash
25
+ git clone https://github.com/aykutcayir34/finetuner
26
+ cd finetuner && pip install -e '.[mlx]' && finetuner
27
+ ```
28
+
29
+ 🇹🇷 Bu Space, Finetuner Studio arayüzünün canlı demosudur (Spaces'te Apple
30
+ Silicon olmadığı için eğitim devre dışı; format algılama ve kod üretici
31
+ çalışır). Örnek model: [Llama-3.2-1B Turkish-Alpaca](https://huggingface.co/acayir64/Llama-3.2-1B-Instruct-Turkish-Alpaca-mlx)
32
+
33
+ Source: https://github.com/aykutcayir34/finetuner
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Finetuner Studio — Hugging Face Space demo (GUI-only mode).
2
+
3
+ Spaces run on Linux, so MLX training is unavailable here; the Studio runs in
4
+ planning mode: dataset loading + automatic format detection, the code
5
+ generator, recipes and the full UI are live. Clone it on an Apple Silicon
6
+ Mac for actual training: https://github.com/aykutcayir34/finetuner
7
+ """
8
+
9
+ import gradio as gr
10
+
11
+ from finetuner.app import build_app
12
+
13
+ demo = build_app()
14
+
15
+ if __name__ == "__main__":
16
+ demo.launch(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="slate"))
finetuner/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Finetuner Studio — low-code fine-tuning on Apple Silicon, powered by mlx-tune."""
2
+
3
+ __version__ = "0.1.0"
finetuner/__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .app import main
2
+
3
+ main()
finetuner/app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Finetuner Studio — application entry point."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from .core.registry import TASKS, mlx_available
8
+ from .ui import tab_dataset, tab_export, tab_model, tab_monitor, tab_playground, tab_train
9
+
10
+ BANNER = """
11
+ # 🎛️ Finetuner Studio
12
+ **Low-code fine-tuning on Apple Silicon** · powered by
13
+ [mlx-tune](https://github.com/ARahim3/mlx-tune) · {n_tasks} training paradigms,
14
+ zero boilerplate
15
+ """
16
+
17
+
18
+ def build_app() -> gr.Blocks:
19
+ ok, reason = mlx_available()
20
+ health = "🟢 mlx-tune ready" if ok else f"🟡 GUI-only mode — {reason}"
21
+
22
+ with gr.Blocks(title="Finetuner Studio") as app:
23
+ gr.Markdown(BANNER.format(n_tasks=len(TASKS)))
24
+ gr.Markdown(f"`{health}`")
25
+ with gr.Tabs():
26
+ tab_model.build(app)
27
+ tab_dataset.build(app)
28
+ tab_train.build(app)
29
+ tab_monitor.build(app)
30
+ tab_playground.build(app)
31
+ tab_export.build(app)
32
+ gr.Markdown(
33
+ "<center><small>Finetuner Studio · load a model → drop a dataset → "
34
+ "press train. The generated Python script is yours to keep.</small></center>"
35
+ )
36
+ return app
37
+
38
+
39
+ def main():
40
+ theme = gr.themes.Soft(primary_hue="orange", secondary_hue="slate")
41
+ build_app().launch(theme=theme)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
finetuner/core/__init__.py ADDED
File without changes
finetuner/core/codegen.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Low-code → code: generate a standalone, editable mlx-tune training script
2
+ that reproduces exactly what the GUI is configured to do."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from .registry import get_task
7
+ from .training import RunConfig, build_trainer_args
8
+
9
+
10
+ def _fmt(v) -> str:
11
+ return repr(v)
12
+
13
+
14
+ def generate_script(cfg: RunConfig, dataset_source: str = "",
15
+ dataset_is_local: bool = False) -> str:
16
+ spec = get_task(cfg.task)
17
+ args = build_trainer_args(cfg)
18
+ args_body = ",\n ".join(f"{k}={_fmt(v)}" for k, v in args.items())
19
+
20
+ imports = {spec.trainer}
21
+ if spec.config_module == "mlx_tune":
22
+ imports.add(spec.config)
23
+ if spec.collator:
24
+ imports.add(spec.collator)
25
+ import_line = f"from mlx_tune import {spec.loader}, {', '.join(sorted(imports))}"
26
+ extra_import = ""
27
+ if spec.config_module != "mlx_tune":
28
+ extra_import = f"\nfrom {spec.config_module} import {spec.config}"
29
+
30
+ # --- model loading -------------------------------------------------------
31
+ load_kwargs = ""
32
+ if spec.modality == "text":
33
+ load_kwargs = (f",\n max_seq_length={cfg.max_seq_length},"
34
+ f"\n load_in_4bit={cfg.load_in_4bit},")
35
+ handle = "processor" if spec.modality in ("vision", "image") else "tokenizer"
36
+
37
+ # --- LoRA ------------------------------------------------------------------
38
+ lora_block = ""
39
+ if cfg.use_lora and spec.peft_supported:
40
+ tm = ""
41
+ if spec.modality == "text":
42
+ tm = f"\n target_modules={cfg.target_modules!r},"
43
+ lora_block = f"""
44
+ # --- 2. Attach LoRA adapters -------------------------------------------------
45
+ model = {spec.loader}.get_peft_model(
46
+ model,
47
+ r={cfg.lora_r},
48
+ lora_alpha={cfg.lora_alpha},{tm}
49
+ )
50
+ """
51
+
52
+ # --- dataset ------------------------------------------------------------------
53
+ if dataset_is_local:
54
+ ds_block = f"""import json
55
+
56
+ with open({dataset_source!r}) as f:
57
+ train_dataset = [json.loads(line) for line in f if line.strip()]"""
58
+ elif dataset_source:
59
+ ds_block = f"""from datasets import load_dataset
60
+
61
+ train_dataset = load_dataset({dataset_source!r}, split="train")"""
62
+ else:
63
+ ds_block = """train_dataset = [
64
+ # TODO: fill with rows shaped like: %s
65
+ ]""" % (dict.fromkeys(spec.dataset_schema, "..."),)
66
+
67
+ collator_line = ""
68
+ if spec.collator:
69
+ collator_line = f"\n data_collator={spec.collator}(model, {handle}),"
70
+
71
+ return f'''"""Auto-generated by Finetuner Studio — https://github.com/aykutcayir34/finetuner
72
+
73
+ Task : {spec.label}
74
+ Backend : mlx-tune (https://github.com/ARahim3/mlx-tune)
75
+
76
+ This script is fully standalone: edit it, version it, or move it to a cloud
77
+ GPU box — the mlx-tune API is Unsloth-compatible by design.
78
+ """
79
+
80
+ {import_line}{extra_import}
81
+
82
+ # --- 1. Load the base model ----------------------------------------------------
83
+ model, {handle} = {spec.loader}.from_pretrained(
84
+ {cfg.model_name!r}{load_kwargs}
85
+ )
86
+ {lora_block}
87
+ # --- 3. Dataset ------------------------------------------------------------------
88
+ {ds_block}
89
+
90
+ # --- 4. Train ----------------------------------------------------------------------
91
+ trainer = {spec.trainer}(
92
+ model=model,
93
+ {handle}={handle},
94
+ train_dataset=train_dataset,{collator_line}
95
+ args={spec.config}(
96
+ {args_body},
97
+ ),
98
+ )
99
+ trainer.train()
100
+
101
+ # --- 5. Save -----------------------------------------------------------------------
102
+ model.save_pretrained("lora_model") # adapters only
103
+ # model.save_pretrained_merged("merged", {handle}) # merged fp16 model
104
+ # model.push_to_hub("your-username/your-model") # upload to the Hub
105
+ '''
finetuner/core/data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loading from the Hugging Face Hub or local files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ from pathlib import Path
8
+
9
+ from huggingface_hub import HfApi
10
+
11
+
12
+ def search_hub_datasets(query: str, limit: int = 20) -> list[str]:
13
+ if not query.strip():
14
+ return []
15
+ api = HfApi()
16
+ return [d.id for d in api.list_datasets(search=query, limit=limit, sort="downloads")]
17
+
18
+
19
+ def load_hub_dataset(name: str, split: str = "train", config: str | None = None,
20
+ max_rows: int | None = None) -> list[dict]:
21
+ from datasets import load_dataset
22
+
23
+ ds = load_dataset(name, config or None, split=split)
24
+ if max_rows:
25
+ ds = ds.select(range(min(max_rows, len(ds))))
26
+ return [dict(r) for r in ds]
27
+
28
+
29
+ def load_local_dataset(path: str, max_rows: int | None = None) -> list[dict]:
30
+ """Load a local dataset file: .jsonl, .json, .csv, .tsv or .parquet."""
31
+ p = Path(path).expanduser()
32
+ if not p.exists():
33
+ raise FileNotFoundError(f"No such file: {p}")
34
+ suffix = p.suffix.lower()
35
+
36
+ rows: list[dict]
37
+ if suffix == ".jsonl":
38
+ rows = []
39
+ with p.open() as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if line:
43
+ rows.append(json.loads(line))
44
+ if max_rows and len(rows) >= max_rows:
45
+ break
46
+ elif suffix == ".json":
47
+ data = json.loads(p.read_text())
48
+ if isinstance(data, dict): # e.g. {"data": [...]} wrappers
49
+ for v in data.values():
50
+ if isinstance(v, list):
51
+ data = v
52
+ break
53
+ if not isinstance(data, list):
54
+ raise ValueError("JSON file must contain a list of records.")
55
+ rows = data
56
+ elif suffix in (".csv", ".tsv"):
57
+ delim = "\t" if suffix == ".tsv" else ","
58
+ with p.open(newline="") as f:
59
+ rows = list(csv.DictReader(f, delimiter=delim))
60
+ elif suffix == ".parquet":
61
+ import pandas as pd
62
+ rows = pd.read_parquet(p).to_dict("records")
63
+ else:
64
+ raise ValueError(f"Unsupported dataset format: {suffix} "
65
+ "(supported: .jsonl, .json, .csv, .tsv, .parquet)")
66
+
67
+ if max_rows:
68
+ rows = rows[:max_rows]
69
+ return rows
finetuner/core/detector.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Automatic dataset format detection and normalization.
2
+
3
+ Given a few sample rows, classify the dataset into one of the canonical
4
+ formats mlx-tune trainers understand, propose a column mapping, and convert
5
+ rows into the exact shape the selected trainer expects.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+
12
+ # Column-name synonyms (English + Turkish), matched case-insensitively after
13
+ # stripping whitespace — real-world CSV headers are messy.
14
+ SYNONYMS = {
15
+ "instruction": {"instruction", "question", "query", "instruct", "task",
16
+ "talimat", "soru", "görev"},
17
+ "input": {"input", "context", "system", "giriş", "girdi", "bağlam"},
18
+ "output": {"output", "response", "answer", "completion", "target",
19
+ "çıktı", "cevap", "yanıt"},
20
+ "prompt": {"prompt", "question", "query", "instruction", "istem", "soru"},
21
+ "chosen": {"chosen", "preferred", "accepted", "good", "seçilen", "tercih"},
22
+ "rejected": {"rejected", "dispreferred", "bad", "reddedilen"},
23
+ "completion": {"completion", "response", "output", "answer", "cevap", "yanıt"},
24
+ "label": {"label", "thumbs_up", "is_good", "score", "etiket"},
25
+ "conversations": {"conversations", "messages", "dialogue", "dialog", "chat",
26
+ "turns", "konuşmalar", "mesajlar"},
27
+ "text": {"text", "content", "document", "body", "metin", "içerik"},
28
+ "anchor": {"anchor", "query", "sentence1", "question", "soru"},
29
+ "positive": {"positive", "passage", "sentence2", "answer", "document"},
30
+ "audio": {"audio", "audio_path", "audio_filepath", "wav", "file", "path", "ses"},
31
+ "transcription": {"text", "transcription", "transcript", "sentence", "caption",
32
+ "metin", "cümle"},
33
+ "image": {"image", "images", "image_path", "img", "picture", "görsel", "resim"},
34
+ }
35
+
36
+ FORMAT_LABELS = {
37
+ "alpaca": "Alpaca (instruction / input / output)",
38
+ "sharegpt": "ShareGPT (conversations with from/value turns)",
39
+ "chatml": "ChatML / OpenAI messages (role/content turns)",
40
+ "prompt_completion": "Prompt–completion pairs",
41
+ "preference": "Preference pairs (prompt / chosen / rejected)",
42
+ "kto": "KTO binary feedback (prompt / completion / label)",
43
+ "grpo": "GRPO prompts (prompt, optional answer)",
44
+ "text": "Raw text corpus",
45
+ "embedding_pairs": "Embedding pairs (anchor / positive)",
46
+ "audio_text": "Audio + text (TTS / STT)",
47
+ "vision_chat": "Image + conversation (VLM)",
48
+ "image_text": "Image + ground-truth text (OCR)",
49
+ "unknown": "Unknown — manual mapping required",
50
+ }
51
+
52
+
53
+ @dataclass
54
+ class Detection:
55
+ format: str
56
+ confidence: float # 0..1
57
+ mapping: dict[str, str] = field(default_factory=dict) # canonical -> actual column
58
+ suggested_tasks: list[str] = field(default_factory=list)
59
+ notes: list[str] = field(default_factory=list)
60
+
61
+ @property
62
+ def label(self) -> str:
63
+ return FORMAT_LABELS.get(self.format, self.format)
64
+
65
+
66
+ def _find(columns: list[str], canonical: str) -> str | None:
67
+ """Find the actual column matching a canonical field name."""
68
+ lowered = {c.strip().lower(): c for c in columns}
69
+ if canonical in lowered:
70
+ return lowered[canonical]
71
+ for syn in SYNONYMS.get(canonical, ()):
72
+ if syn in lowered:
73
+ return lowered[syn]
74
+ return None
75
+
76
+
77
+ def _is_turn_list(value) -> str | None:
78
+ """Classify a list-of-dicts column as 'sharegpt' or 'chatml' turns."""
79
+ if not isinstance(value, list) or not value or not isinstance(value[0], dict):
80
+ return None
81
+ keys = set(value[0].keys())
82
+ if {"from", "value"} <= keys:
83
+ return "sharegpt"
84
+ if {"role", "content"} <= keys:
85
+ return "chatml"
86
+ return None
87
+
88
+
89
+ def _looks_like_path(value, exts: tuple[str, ...]) -> bool:
90
+ return isinstance(value, str) and value.lower().endswith(exts)
91
+
92
+
93
+ AUDIO_EXTS = (".wav", ".mp3", ".flac", ".m4a", ".ogg")
94
+ IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff")
95
+
96
+
97
+ def detect(rows: list[dict]) -> Detection:
98
+ """Detect the dataset format from sample rows (a handful is enough)."""
99
+ if not rows:
100
+ return Detection("unknown", 0.0, notes=["Dataset is empty."])
101
+
102
+ row = rows[0]
103
+ columns = list(row.keys())
104
+ from .registry import tasks_for_format # local import to avoid cycles
105
+
106
+ def done(fmt: str, conf: float, mapping: dict[str, str], notes: list[str] | None = None) -> Detection:
107
+ tasks = [t.id for t in tasks_for_format(fmt)]
108
+ return Detection(fmt, conf, mapping, tasks, notes or [])
109
+
110
+ # --- multimodal first: presence of media columns dominates -------------
111
+ img_col = _find(columns, "image")
112
+ audio_col = _find(columns, "audio")
113
+ conv_col = _find(columns, "conversations")
114
+
115
+ if img_col is not None:
116
+ sample = row.get(img_col)
117
+ media_like = _looks_like_path(sample, IMAGE_EXTS) or not isinstance(sample, str)
118
+ if conv_col is not None:
119
+ return done("vision_chat", 0.9 if media_like else 0.6,
120
+ {"images": img_col, "messages": conv_col})
121
+ text_col = _find([c for c in columns if c != img_col], "text")
122
+ if text_col is not None:
123
+ return done("image_text", 0.85 if media_like else 0.55,
124
+ {"image": img_col, "text": text_col},
125
+ ["Image + text detected — suitable for OCR SFT or vision tasks."])
126
+
127
+ if audio_col is not None:
128
+ sample = row.get(audio_col)
129
+ if _looks_like_path(sample, AUDIO_EXTS) or isinstance(sample, dict):
130
+ text_col = _find([c for c in columns if c != audio_col], "transcription")
131
+ if text_col is not None:
132
+ return done("audio_text", 0.9, {"audio": audio_col, "text": text_col},
133
+ ["Audio + text detected — choose TTS (synthesis) or STT (recognition)."])
134
+
135
+ # --- conversation formats ----------------------------------------------
136
+ if conv_col is not None:
137
+ kind = _is_turn_list(row.get(conv_col))
138
+ if kind == "sharegpt":
139
+ return done("sharegpt", 0.95, {"conversations": conv_col})
140
+ if kind == "chatml":
141
+ return done("chatml", 0.95, {"messages": conv_col})
142
+
143
+ # --- preference / feedback ----------------------------------------------
144
+ chosen = _find(columns, "chosen")
145
+ rejected = _find(columns, "rejected")
146
+ prompt = _find(columns, "prompt")
147
+ if chosen and rejected:
148
+ mapping = {"chosen": chosen, "rejected": rejected}
149
+ if prompt:
150
+ mapping["prompt"] = prompt
151
+ return done("preference", 0.95, mapping)
152
+ return done("preference", 0.75, mapping,
153
+ ["No explicit prompt column; chosen/rejected may embed the prompt."])
154
+
155
+ label = _find(columns, "label")
156
+ completion = _find(columns, "completion")
157
+ if prompt and completion and label is not None:
158
+ if isinstance(row.get(label), (bool, int)):
159
+ return done("kto", 0.9, {"prompt": prompt, "completion": completion, "label": label})
160
+
161
+ # --- instruction tuning ---------------------------------------------------
162
+ instruction = _find(columns, "instruction")
163
+ output = _find(columns, "output")
164
+ if instruction and output:
165
+ mapping = {"instruction": instruction, "output": output}
166
+ inp = _find([c for c in columns if c not in (instruction, output)], "input")
167
+ if inp:
168
+ mapping["input"] = inp
169
+ return done("alpaca", 0.95, mapping)
170
+
171
+ if prompt and completion:
172
+ return done("prompt_completion", 0.9, {"prompt": prompt, "completion": completion})
173
+
174
+ # --- embeddings -----------------------------------------------------------
175
+ anchor = _find(columns, "anchor")
176
+ positive = _find(columns, "positive")
177
+ if anchor and positive and anchor != positive:
178
+ return done("embedding_pairs", 0.8, {"anchor": anchor, "positive": positive},
179
+ ["Anchor/positive pair detected — embedding contrastive training."])
180
+
181
+ # --- GRPO: bare prompts -----------------------------------------------------
182
+ if prompt and len(columns) <= 2:
183
+ return done("grpo", 0.6, {"prompt": prompt},
184
+ ["Bare prompts — usable for GRPO with a custom reward function."])
185
+
186
+ # --- raw text ----------------------------------------------------------------
187
+ text = _find(columns, "text")
188
+ if text:
189
+ return done("text", 0.85, {"text": text},
190
+ ["Raw text — suitable for CPT or completion-style SFT."])
191
+
192
+ # --- single string column fallback ---------------------------------------------
193
+ str_cols = [c for c in columns if isinstance(row.get(c), str)]
194
+ if len(str_cols) == 1:
195
+ return done("text", 0.5, {"text": str_cols[0]},
196
+ [f"Single text column `{str_cols[0]}` assumed to be raw text."])
197
+
198
+ return Detection("unknown", 0.0, {},
199
+ notes=[f"Could not classify columns: {columns}. Map fields manually."])
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # Normalization: convert detected rows into trainer-ready shape
204
+ # ---------------------------------------------------------------------------
205
+
206
+ def _format_chat(turns: list[dict], tokenizer=None) -> str:
207
+ """Render chat turns to text via the tokenizer's chat template when possible."""
208
+ if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
209
+ try:
210
+ return tokenizer.apply_chat_template(turns, tokenize=False, add_generation_prompt=False)
211
+ except Exception:
212
+ pass
213
+ return "\n".join(f"<|{t['role']}|>\n{t['content']}" for t in turns) + "\n"
214
+
215
+
216
+ SHAREGPT_ROLES = {"human": "user", "user": "user", "gpt": "assistant",
217
+ "assistant": "assistant", "system": "system"}
218
+
219
+
220
+ def to_messages(row: dict, detection: Detection) -> list[dict]:
221
+ """Convert a row of any chat-like format into ChatML messages."""
222
+ m = detection.mapping
223
+ fmt = detection.format
224
+ if fmt == "chatml":
225
+ return row[m["messages"]]
226
+ if fmt == "sharegpt":
227
+ return [{"role": SHAREGPT_ROLES.get(t["from"], "user"), "content": t["value"]}
228
+ for t in row[m["conversations"]]]
229
+ if fmt == "alpaca":
230
+ user = row[m["instruction"]]
231
+ if "input" in m and row.get(m["input"]):
232
+ user = f"{user}\n\n{row[m['input']]}"
233
+ return [{"role": "user", "content": user},
234
+ {"role": "assistant", "content": row[m["output"]]}]
235
+ if fmt == "prompt_completion":
236
+ return [{"role": "user", "content": row[m["prompt"]]},
237
+ {"role": "assistant", "content": row[m["completion"]]}]
238
+ raise ValueError(f"Cannot build messages from format {fmt!r}")
239
+
240
+
241
+ def normalize(rows: list[dict], detection: Detection, task_id: str, tokenizer=None) -> list[dict]:
242
+ """Convert raw rows into the schema the chosen task's trainer expects."""
243
+ m = detection.mapping
244
+ fmt = detection.format
245
+
246
+ if task_id in ("sft",):
247
+ if fmt == "text":
248
+ return [{"text": r[m["text"]]} for r in rows]
249
+ return [{"text": _format_chat(to_messages(r, detection), tokenizer)} for r in rows]
250
+
251
+ if task_id == "cpt":
252
+ col = m.get("text")
253
+ if col is None:
254
+ raise ValueError("CPT needs a raw text column.")
255
+ return [{"text": r[col]} for r in rows]
256
+
257
+ if task_id in ("dpo", "orpo", "simpo"):
258
+ out = []
259
+ for r in rows:
260
+ item = {"chosen": r[m["chosen"]], "rejected": r[m["rejected"]]}
261
+ item["prompt"] = r[m["prompt"]] if "prompt" in m else ""
262
+ out.append(item)
263
+ return out
264
+
265
+ if task_id == "kto":
266
+ return [{"prompt": r[m["prompt"]], "completion": r[m["completion"]],
267
+ "label": bool(r[m["label"]])} for r in rows]
268
+
269
+ if task_id == "grpo":
270
+ return [{"prompt": r[m["prompt"]]} for r in rows]
271
+
272
+ if task_id == "embedding":
273
+ return [{"anchor": r[m["anchor"]], "positive": r[m["positive"]]} for r in rows]
274
+
275
+ if task_id in ("tts_sft", "stt_sft"):
276
+ return [{"audio": r[m["audio"]], "text": r[m["text"]]} for r in rows]
277
+
278
+ if task_id == "ocr_sft":
279
+ return [{"image": r[m["image"]], "text": r[m["text"]]} for r in rows]
280
+
281
+ if task_id == "vlm_sft":
282
+ out = []
283
+ for r in rows:
284
+ sub = {"images": r[m["images"]], "messages": r[m["messages"]]}
285
+ if isinstance(sub["messages"], list) and sub["messages"] \
286
+ and "from" in (sub["messages"][0] or {}):
287
+ sub["messages"] = [{"role": SHAREGPT_ROLES.get(t["from"], "user"),
288
+ "content": t["value"]} for t in sub["messages"]]
289
+ out.append(sub)
290
+ return out
291
+
292
+ raise ValueError(f"Unknown task {task_id!r}")
finetuner/core/engine.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Single persistent MLX engine thread.
2
+
3
+ MLX streams are thread-local: a model loaded on one thread cannot reliably be
4
+ trained or sampled from another ("There is no Stream(gpu, N) in current
5
+ thread"). Gradio runs every event handler on a different worker thread, so all
6
+ MLX work — model loading, training, generation — is funneled through one
7
+ long-lived engine thread. This also serializes GPU work, which is what a
8
+ single-device machine wants anyway.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import queue
15
+ import sys
16
+ import threading
17
+ from pathlib import Path
18
+
19
+ # mlx-tune's subprocess fallback shells out to `mlx_lm.lora`; make sure the
20
+ # interpreter's bin directory is on PATH even when the venv isn't activated.
21
+ _bin = str(Path(sys.executable).parent)
22
+ if _bin not in os.environ.get("PATH", "").split(os.pathsep):
23
+ os.environ["PATH"] = _bin + os.pathsep + os.environ.get("PATH", "")
24
+
25
+
26
+ class _Engine:
27
+ def __init__(self):
28
+ self._q: queue.Queue = queue.Queue()
29
+ self._thread = threading.Thread(target=self._loop, name="finetuner-mlx-engine",
30
+ daemon=True)
31
+ self._thread.start()
32
+
33
+ def _loop(self):
34
+ while True:
35
+ fn, args, kwargs, done, box = self._q.get()
36
+ try:
37
+ box["result"] = fn(*args, **kwargs)
38
+ except BaseException as exc: # noqa: BLE001 — re-raised on the caller thread
39
+ box["error"] = exc
40
+ finally:
41
+ done.set()
42
+
43
+ def call(self, fn, *args, **kwargs):
44
+ """Run `fn` on the engine thread and block until it returns."""
45
+ if threading.current_thread() is self._thread:
46
+ return fn(*args, **kwargs)
47
+ done = threading.Event()
48
+ box: dict = {}
49
+ self._q.put((fn, args, kwargs, done, box))
50
+ done.wait()
51
+ if "error" in box:
52
+ raise box["error"]
53
+ return box["result"]
54
+
55
+
56
+ ENGINE = _Engine()
finetuner/core/export.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Export trained models: adapters, merged weights, GGUF, or push to the Hub."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def save_adapters(model, path: str) -> str:
7
+ model.save_pretrained(path)
8
+ return f"LoRA adapters saved to {path}"
9
+
10
+
11
+ def save_merged(model, tokenizer, path: str) -> str:
12
+ model.save_pretrained_merged(path, tokenizer)
13
+ return f"Merged 16-bit model saved to {path}"
14
+
15
+
16
+ def save_gguf(model, tokenizer, path: str) -> str:
17
+ # mlx-lm limitation: GGUF export requires a non-quantized base model.
18
+ model.save_pretrained_gguf(path, tokenizer)
19
+ return f"GGUF model saved to {path}"
20
+
21
+
22
+ def push_to_hub(model, repo_id: str, token: str | None = None) -> str:
23
+ kwargs = {"token": token} if token else {}
24
+ try:
25
+ model.push_to_hub(repo_id, **kwargs)
26
+ except TypeError:
27
+ model.push_to_hub(repo_id)
28
+ return f"Pushed to https://huggingface.co/{repo_id}"
finetuner/core/jobs.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Background training job manager.
2
+
3
+ Training runs on a worker thread so the GUI stays responsive. Trainer stdout
4
+ is captured into a ring buffer; loss values are parsed out of the log stream
5
+ for live charting.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import contextlib
11
+ import io
12
+ import re
13
+ import threading
14
+ import time
15
+ import traceback
16
+ from collections import deque
17
+ from dataclasses import dataclass, field
18
+
19
+ # Matches lines like "step 10: loss 1.2345", "{'loss': 1.23, 'step': 10}", "10/100 | loss: 1.23"
20
+ _LOSS_RE = re.compile(r"loss[\"']?[:=\s]+([0-9]*\.?[0-9]+(?:e-?\d+)?)", re.IGNORECASE)
21
+ _STEP_RE = re.compile(r"(?:step|it(?:er)?)[\"']?[:=\s/]+(\d+)", re.IGNORECASE)
22
+
23
+
24
+ class _Tee(io.TextIOBase):
25
+ """Write-through stream that feeds the job log and parses metrics."""
26
+
27
+ def __init__(self, job: "Job", original):
28
+ self.job = job
29
+ self.original = original
30
+ self._buf = ""
31
+
32
+ def write(self, s: str) -> int:
33
+ self.original.write(s)
34
+ self._buf += s
35
+ parts = re.split(r"[\n\r]", self._buf)
36
+ self._buf = parts.pop() # keep the unterminated tail
37
+ for line in parts:
38
+ if line.strip():
39
+ self.job.add_log(line)
40
+ return len(s)
41
+
42
+ def flush(self):
43
+ self.original.flush()
44
+
45
+
46
+ @dataclass
47
+ class Job:
48
+ id: int
49
+ name: str
50
+ status: str = "pending" # pending | running | finished | failed | stopped
51
+ logs: deque = field(default_factory=lambda: deque(maxlen=2000))
52
+ metrics: list = field(default_factory=list) # [(step, loss)]
53
+ error: str | None = None
54
+ started_at: float | None = None
55
+ finished_at: float | None = None
56
+ stop_event: threading.Event = field(default_factory=threading.Event)
57
+ _step_guess: int = 0
58
+
59
+ def add_log(self, line: str):
60
+ self.logs.append(line)
61
+ loss = _LOSS_RE.search(line)
62
+ if loss:
63
+ step_m = _STEP_RE.search(line)
64
+ if step_m:
65
+ self._step_guess = int(step_m.group(1))
66
+ else:
67
+ self._step_guess += 1
68
+ try:
69
+ self.metrics.append((self._step_guess, float(loss.group(1))))
70
+ except ValueError:
71
+ pass
72
+
73
+ def log_text(self, last_n: int = 200) -> str:
74
+ return "\n".join(list(self.logs)[-last_n:])
75
+
76
+ @property
77
+ def elapsed(self) -> float:
78
+ if self.started_at is None:
79
+ return 0.0
80
+ end = self.finished_at or time.time()
81
+ return end - self.started_at
82
+
83
+
84
+ class JobManager:
85
+ def __init__(self):
86
+ self._jobs: dict[int, Job] = {}
87
+ self._next_id = 1
88
+ self._lock = threading.Lock()
89
+
90
+ def submit(self, name: str, target, *args, **kwargs) -> Job:
91
+ """Run `target(job, *args, **kwargs)` on a worker thread with log capture."""
92
+ with self._lock:
93
+ job = Job(id=self._next_id, name=name)
94
+ self._jobs[job.id] = job
95
+ self._next_id += 1
96
+
97
+ def runner():
98
+ job.status = "running"
99
+ job.started_at = time.time()
100
+ tee_out = _Tee(job, __import__("sys").stdout)
101
+ tee_err = _Tee(job, __import__("sys").stderr)
102
+ from .engine import ENGINE # local import: keep jobs importable standalone
103
+ try:
104
+ with contextlib.redirect_stdout(tee_out), contextlib.redirect_stderr(tee_err):
105
+ # All MLX work must run on the single engine thread; the
106
+ # stdout redirect is process-wide, so logs still reach us.
107
+ ENGINE.call(target, job, *args, **kwargs)
108
+ job.status = "stopped" if job.stop_event.is_set() else "finished"
109
+ except Exception:
110
+ job.error = traceback.format_exc()
111
+ job.add_log(job.error)
112
+ job.status = "failed"
113
+ finally:
114
+ job.finished_at = time.time()
115
+
116
+ threading.Thread(target=runner, name=f"finetuner-job-{job.id}", daemon=True).start()
117
+ return job
118
+
119
+ def get(self, job_id: int) -> Job | None:
120
+ return self._jobs.get(job_id)
121
+
122
+ def all(self) -> list[Job]:
123
+ return list(self._jobs.values())
124
+
125
+ def latest(self) -> Job | None:
126
+ return self._jobs[max(self._jobs)] if self._jobs else None
127
+
128
+ def stop(self, job_id: int):
129
+ job = self._jobs.get(job_id)
130
+ if job:
131
+ job.stop_event.set()
132
+ job.add_log("⏹ Stop requested — training will halt at the next step boundary.")
133
+
134
+
135
+ MANAGER = JobManager()
finetuner/core/models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model discovery (Hugging Face Hub) and loading via mlx-tune."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from huggingface_hub import HfApi
8
+
9
+ from .engine import ENGINE
10
+ from .registry import get_task, resolve
11
+
12
+
13
+ def search_hub_models(query: str, limit: int = 20, mlx_only: bool = True) -> list[str]:
14
+ """Search the Hub; by default biased to mlx-community / MLX-tagged models."""
15
+ if not query.strip():
16
+ return []
17
+ api = HfApi()
18
+ kwargs: dict = {"search": query, "limit": limit, "sort": "downloads"}
19
+ if mlx_only:
20
+ kwargs["filter"] = "mlx"
21
+ try:
22
+ results = [m.id for m in api.list_models(**kwargs)]
23
+ except TypeError: # huggingface_hub version drift on the filter kwarg
24
+ kwargs.pop("filter", None)
25
+ results = [m.id for m in api.list_models(**kwargs)]
26
+ if not results and mlx_only: # fall back to an unrestricted search
27
+ results = [m.id for m in api.list_models(search=query, limit=limit, sort="downloads")]
28
+ return results
29
+
30
+
31
+ def validate_local_model(path: str) -> str:
32
+ p = Path(path).expanduser()
33
+ if not p.is_dir():
34
+ raise FileNotFoundError(f"Not a directory: {p}")
35
+ if not (p / "config.json").exists():
36
+ raise ValueError(f"{p} does not look like a model directory (missing config.json).")
37
+ return str(p)
38
+
39
+
40
+ def load_model(task_id: str, model_name: str, max_seq_length: int = 2048,
41
+ load_in_4bit: bool = True):
42
+ """Load a model + tokenizer/processor through the task's mlx-tune loader."""
43
+ spec = get_task(task_id)
44
+ loader = resolve(spec.loader)
45
+ kwargs: dict = {}
46
+ if spec.modality == "text":
47
+ kwargs["max_seq_length"] = max_seq_length
48
+ kwargs["load_in_4bit"] = load_in_4bit
49
+ # MLX streams are thread-local — load on the engine thread (see engine.py).
50
+ return ENGINE.call(loader.from_pretrained, model_name, **kwargs)
51
+
52
+
53
+ def apply_lora(task_id: str, model, r: int = 16, lora_alpha: int = 16,
54
+ lora_dropout: float = 0.0, target_modules: list[str] | None = None,
55
+ **extra):
56
+ spec = get_task(task_id)
57
+ loader = resolve(spec.loader)
58
+ kwargs: dict = {"r": r, "lora_alpha": lora_alpha}
59
+ if lora_dropout:
60
+ kwargs["lora_dropout"] = lora_dropout
61
+ if spec.modality == "text" and target_modules:
62
+ kwargs["target_modules"] = list(target_modules)
63
+ kwargs.update(extra)
64
+ try:
65
+ return ENGINE.call(loader.get_peft_model, model, **kwargs)
66
+ except TypeError:
67
+ # Older mlx-tune versions may not accept every kwarg (e.g. lora_dropout).
68
+ kwargs.pop("lora_dropout", None)
69
+ return ENGINE.call(loader.get_peft_model, model, **kwargs)
finetuner/core/recipes.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Recipes: save/load complete run configurations as shareable YAML files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import yaml
9
+
10
+ from .training import RunConfig
11
+
12
+ RECIPE_DIR = Path("recipes")
13
+
14
+
15
+ def save_recipe(cfg: RunConfig, name: str = "", dataset_source: str = "",
16
+ dataset_is_local: bool = False) -> Path:
17
+ RECIPE_DIR.mkdir(exist_ok=True)
18
+ slug = (name.strip() or f"{cfg.task}-{time.strftime('%Y%m%d-%H%M%S')}").replace(" ", "-")
19
+ path = RECIPE_DIR / f"{slug}.yaml"
20
+ payload = {
21
+ "finetuner_recipe": 1,
22
+ "dataset": {"source": dataset_source, "local": dataset_is_local},
23
+ "run": cfg.to_dict(),
24
+ }
25
+ path.write_text(yaml.safe_dump(payload, sort_keys=False, allow_unicode=True))
26
+ return path
27
+
28
+
29
+ def load_recipe(path: str) -> tuple[RunConfig, str, bool]:
30
+ data = yaml.safe_load(Path(path).expanduser().read_text())
31
+ if not isinstance(data, dict) or "run" not in data:
32
+ raise ValueError("Not a Finetuner recipe (missing `run` section).")
33
+ ds = data.get("dataset", {})
34
+ return RunConfig.from_dict(data["run"]), ds.get("source", ""), bool(ds.get("local", False))
35
+
36
+
37
+ def list_recipes() -> list[str]:
38
+ if not RECIPE_DIR.exists():
39
+ return []
40
+ return sorted(str(p) for p in RECIPE_DIR.glob("*.yaml"))
finetuner/core/registry.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task registry: one entry per mlx-tune training paradigm.
2
+
3
+ Every public trainer interface that mlx-tune exposes is described here so the
4
+ GUI, the code generator and the recipe system all share a single source of
5
+ truth. mlx-tune itself is imported lazily — the Studio runs (for planning,
6
+ dataset inspection, recipe authoring and script generation) even on machines
7
+ where MLX is not installed.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import importlib
13
+ import importlib.util
14
+ import platform
15
+ from dataclasses import dataclass, field
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class TaskSpec:
20
+ id: str
21
+ label: str
22
+ description: str
23
+ loader: str # mlx_tune Fast*Model class name
24
+ trainer: str # mlx_tune trainer class name
25
+ config: str # mlx_tune config class name
26
+ config_module: str = "mlx_tune" # module to import the config from
27
+ collator: str | None = None # optional data collator class name
28
+ dataset_schema: tuple[str, ...] = () # canonical required fields
29
+ detector_formats: tuple[str, ...] = () # detector format ids that map to this task
30
+ default_model: str = ""
31
+ default_target_modules: tuple[str, ...] = (
32
+ "q_proj", "k_proj", "v_proj", "o_proj",
33
+ )
34
+ peft_supported: bool = True
35
+ extra_config_defaults: dict = field(default_factory=dict)
36
+ modality: str = "text" # text | vision | audio | image
37
+ notes: str = ""
38
+
39
+
40
+ FULL_TARGETS = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
41
+
42
+ TASKS: dict[str, TaskSpec] = {
43
+ "sft": TaskSpec(
44
+ id="sft",
45
+ label="SFT — Supervised Fine-Tuning",
46
+ description="Instruction tuning of chat/completion LLMs on text or conversation data.",
47
+ loader="FastLanguageModel",
48
+ trainer="SFTTrainer",
49
+ config="SFTConfig",
50
+ dataset_schema=("text",),
51
+ detector_formats=("alpaca", "sharegpt", "chatml", "prompt_completion", "text"),
52
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
53
+ ),
54
+ "dpo": TaskSpec(
55
+ id="dpo",
56
+ label="DPO — Direct Preference Optimization",
57
+ description="Align a model with human preferences from chosen/rejected pairs.",
58
+ loader="FastLanguageModel",
59
+ trainer="DPOTrainer",
60
+ config="DPOConfig",
61
+ dataset_schema=("prompt", "chosen", "rejected"),
62
+ detector_formats=("preference",),
63
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
64
+ extra_config_defaults={"beta": 0.1},
65
+ ),
66
+ "orpo": TaskSpec(
67
+ id="orpo",
68
+ label="ORPO — Odds Ratio Preference Optimization",
69
+ description="Reference-free preference optimization; combines SFT and alignment in one pass.",
70
+ loader="FastLanguageModel",
71
+ trainer="ORPOTrainer",
72
+ config="ORPOConfig",
73
+ dataset_schema=("prompt", "chosen", "rejected"),
74
+ detector_formats=("preference",),
75
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
76
+ extra_config_defaults={"beta": 0.1},
77
+ ),
78
+ "simpo": TaskSpec(
79
+ id="simpo",
80
+ label="SimPO — Simple Preference Optimization",
81
+ description="Length-normalized, reference-free preference optimization.",
82
+ loader="FastLanguageModel",
83
+ trainer="SimPOTrainer",
84
+ config="SimPOConfig",
85
+ dataset_schema=("prompt", "chosen", "rejected"),
86
+ detector_formats=("preference",),
87
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
88
+ ),
89
+ "kto": TaskSpec(
90
+ id="kto",
91
+ label="KTO — Kahneman-Tversky Optimization",
92
+ description="Alignment from simple binary thumbs-up/down feedback (no pairs needed).",
93
+ loader="FastLanguageModel",
94
+ trainer="KTOTrainer",
95
+ config="KTOConfig",
96
+ dataset_schema=("prompt", "completion", "label"),
97
+ detector_formats=("kto",),
98
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
99
+ ),
100
+ "grpo": TaskSpec(
101
+ id="grpo",
102
+ label="GRPO — Group Relative Policy Optimization",
103
+ description="Online RL with programmable reward functions (reasoning, math, code).",
104
+ loader="FastLanguageModel",
105
+ trainer="GRPOTrainer",
106
+ config="GRPOConfig",
107
+ dataset_schema=("prompt",),
108
+ detector_formats=("grpo",),
109
+ default_model="mlx-community/Llama-3.2-1B-Instruct-4bit",
110
+ notes="Reward functions are plain Python callables; edit them in the generated script.",
111
+ ),
112
+ "cpt": TaskSpec(
113
+ id="cpt",
114
+ label="CPT — Continual Pretraining",
115
+ description="Inject domain knowledge by continuing pretraining on raw text.",
116
+ loader="FastLanguageModel",
117
+ trainer="CPTTrainer",
118
+ config="CPTConfig",
119
+ dataset_schema=("text",),
120
+ detector_formats=("text",),
121
+ default_model="mlx-community/SmolLM2-360M-Instruct",
122
+ default_target_modules=FULL_TARGETS,
123
+ extra_config_defaults={"embedding_learning_rate": 5e-6, "include_embeddings": True},
124
+ ),
125
+ "vlm_sft": TaskSpec(
126
+ id="vlm_sft",
127
+ label="Vision SFT — Vision-Language Models",
128
+ description="Fine-tune VLMs (Qwen-VL, LLaVA-style) on image + conversation data.",
129
+ loader="FastVisionModel",
130
+ trainer="VLMSFTTrainer",
131
+ config="VLMSFTConfig",
132
+ config_module="mlx_tune.vlm",
133
+ dataset_schema=("images", "messages"),
134
+ detector_formats=("vision_chat",),
135
+ default_model="mlx-community/Qwen2.5-VL-3B-Instruct-4bit",
136
+ modality="vision",
137
+ ),
138
+ "tts_sft": TaskSpec(
139
+ id="tts_sft",
140
+ label="TTS SFT — Text-to-Speech",
141
+ description="Fine-tune speech synthesis models (Orpheus, OuteTTS, CSM…) on audio+text pairs.",
142
+ loader="FastTTSModel",
143
+ trainer="TTSSFTTrainer",
144
+ config="TTSSFTConfig",
145
+ collator="TTSDataCollator",
146
+ dataset_schema=("audio", "text"),
147
+ detector_formats=("audio_text",),
148
+ default_model="mlx-community/orpheus-3b-0.1-ft-bf16",
149
+ modality="audio",
150
+ notes="Audio training currently supports batch_size=1 (mlx-tune limitation).",
151
+ ),
152
+ "stt_sft": TaskSpec(
153
+ id="stt_sft",
154
+ label="STT SFT — Speech-to-Text",
155
+ description="Fine-tune ASR models (Whisper, Parakeet, Canary…) on audio+transcription pairs.",
156
+ loader="FastSTTModel",
157
+ trainer="STTSFTTrainer",
158
+ config="STTSFTConfig",
159
+ collator="STTDataCollator",
160
+ dataset_schema=("audio", "text"),
161
+ detector_formats=("audio_text",),
162
+ default_model="mlx-community/whisper-tiny-asr-fp16",
163
+ modality="audio",
164
+ notes="Audio training currently supports batch_size=1 (mlx-tune limitation).",
165
+ ),
166
+ "embedding": TaskSpec(
167
+ id="embedding",
168
+ label="Embedding SFT — Sentence Embeddings",
169
+ description="Contrastive fine-tuning of embedding models (anchor/positive pairs).",
170
+ loader="FastEmbeddingModel",
171
+ trainer="EmbeddingSFTTrainer",
172
+ config="EmbeddingSFTConfig",
173
+ dataset_schema=("anchor", "positive"),
174
+ detector_formats=("embedding_pairs",),
175
+ default_model="mlx-community/all-MiniLM-L6-v2-bf16",
176
+ extra_config_defaults={"loss_type": "infonce", "temperature": 0.05},
177
+ ),
178
+ "ocr_sft": TaskSpec(
179
+ id="ocr_sft",
180
+ label="OCR SFT — Optical Character Recognition",
181
+ description="Fine-tune OCR models (DeepSeek-OCR, olmOCR…) on image + ground-truth text.",
182
+ loader="FastOCRModel",
183
+ trainer="OCRSFTTrainer",
184
+ config="OCRSFTConfig",
185
+ dataset_schema=("image", "text"),
186
+ detector_formats=("image_text",),
187
+ default_model="mlx-community/DeepSeek-OCR-8bit",
188
+ modality="image",
189
+ extra_config_defaults={"learning_rate": 5e-5},
190
+ ),
191
+ }
192
+
193
+
194
+ def get_task(task_id: str) -> TaskSpec:
195
+ return TASKS[task_id]
196
+
197
+
198
+ def task_choices() -> list[tuple[str, str]]:
199
+ """(label, id) pairs for a Gradio dropdown."""
200
+ return [(spec.label, spec.id) for spec in TASKS.values()]
201
+
202
+
203
+ def tasks_for_format(format_id: str) -> list[TaskSpec]:
204
+ return [spec for spec in TASKS.values() if format_id in spec.detector_formats]
205
+
206
+
207
+ # ---------------------------------------------------------------------------
208
+ # Lazy mlx-tune access
209
+ # ---------------------------------------------------------------------------
210
+
211
+ def mlx_available() -> tuple[bool, str]:
212
+ """Whether mlx-tune is importable on this machine, plus a human reason."""
213
+ if platform.machine() != "arm64" or platform.system() != "Darwin":
214
+ return False, "mlx-tune requires an Apple Silicon Mac (arm64/macOS)."
215
+ if importlib.util.find_spec("mlx_tune") is None:
216
+ return False, "mlx-tune is not installed. Run: pip install 'finetuner[mlx]'"
217
+ return True, "mlx-tune is available."
218
+
219
+
220
+ def resolve(name: str, module: str = "mlx_tune"):
221
+ """Import `name` from an mlx_tune module, raising a friendly error if missing."""
222
+ ok, reason = mlx_available()
223
+ if not ok:
224
+ raise RuntimeError(reason)
225
+ mod = importlib.import_module(module)
226
+ try:
227
+ return getattr(mod, name)
228
+ except AttributeError as exc:
229
+ raise RuntimeError(
230
+ f"`{name}` not found in `{module}`. Your mlx-tune version may be too old; "
231
+ "try: pip install -U mlx-tune"
232
+ ) from exc
finetuner/core/state.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared application state for the (single-user, local) Studio session."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ from .detector import Detection
8
+ from .training import RunConfig
9
+
10
+
11
+ @dataclass
12
+ class AppState:
13
+ # Model
14
+ model = None
15
+ tokenizer = None
16
+ model_name: str = ""
17
+ model_loaded_for_task: str = ""
18
+ lora_attached: bool = False
19
+ # Dataset
20
+ raw_rows: list[dict] = field(default_factory=list)
21
+ detection: Detection | None = None
22
+ dataset_source: str = ""
23
+ dataset_is_local: bool = False
24
+ # Run configuration
25
+ config: RunConfig = field(default_factory=RunConfig)
26
+
27
+ def reset_model(self):
28
+ self.model = None
29
+ self.tokenizer = None
30
+ self.model_name = ""
31
+ self.model_loaded_for_task = ""
32
+ self.lora_attached = False
33
+
34
+
35
+ STATE = AppState()
finetuner/core/training.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build and run mlx-tune trainers from a flat GUI config dict."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ from .jobs import Job
8
+ from .registry import get_task, resolve
9
+
10
+
11
+ @dataclass
12
+ class RunConfig:
13
+ """Everything the GUI collects, flattened into one serializable object."""
14
+ task: str = "sft"
15
+ model_name: str = ""
16
+ max_seq_length: int = 2048
17
+ load_in_4bit: bool = True
18
+ # LoRA
19
+ use_lora: bool = True
20
+ lora_r: int = 16
21
+ lora_alpha: int = 16
22
+ lora_dropout: float = 0.0
23
+ target_modules: list[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"])
24
+ # Training hyperparameters
25
+ output_dir: str = "outputs"
26
+ batch_size: int = 2
27
+ gradient_accumulation_steps: int = 1
28
+ learning_rate: float = 2e-4
29
+ max_steps: int = 100
30
+ num_train_epochs: float | None = None
31
+ warmup_steps: int = 5
32
+ gradient_checkpointing: bool = False
33
+ seed: int = 42
34
+ # Task-specific extras (beta for DPO, temperature for embeddings, ...)
35
+ extra: dict = field(default_factory=dict)
36
+
37
+ def to_dict(self) -> dict:
38
+ from dataclasses import asdict
39
+ return asdict(self)
40
+
41
+ @classmethod
42
+ def from_dict(cls, d: dict) -> "RunConfig":
43
+ known = {f for f in cls.__dataclass_fields__}
44
+ return cls(**{k: v for k, v in d.items() if k in known})
45
+
46
+
47
+ def _filtered_kwargs(config_cls, kwargs: dict) -> dict:
48
+ """Drop kwargs the dataclass config doesn't accept (version drift safety)."""
49
+ fields = getattr(config_cls, "__dataclass_fields__", None)
50
+ if fields is None:
51
+ return kwargs
52
+ return {k: v for k, v in kwargs.items() if k in fields}
53
+
54
+
55
+ def build_trainer_args(cfg: RunConfig) -> dict:
56
+ spec = get_task(cfg.task)
57
+ args: dict = {
58
+ "output_dir": cfg.output_dir,
59
+ "per_device_train_batch_size": cfg.batch_size,
60
+ "gradient_accumulation_steps": cfg.gradient_accumulation_steps,
61
+ "learning_rate": cfg.learning_rate,
62
+ "warmup_steps": cfg.warmup_steps,
63
+ "seed": cfg.seed,
64
+ }
65
+ if cfg.num_train_epochs:
66
+ args["num_train_epochs"] = cfg.num_train_epochs
67
+ else:
68
+ args["max_steps"] = cfg.max_steps
69
+ if cfg.gradient_checkpointing:
70
+ args["gradient_checkpointing"] = True
71
+ args.update(spec.extra_config_defaults)
72
+ args.update(cfg.extra)
73
+ return args
74
+
75
+
76
+ def run_training(job: Job, cfg: RunConfig, model, tokenizer, dataset: list[dict]):
77
+ """Job target: construct the task's trainer and train. Runs on a worker thread."""
78
+ spec = get_task(cfg.task)
79
+ trainer_cls = resolve(spec.trainer)
80
+ config_cls = resolve(spec.config, spec.config_module)
81
+
82
+ args = config_cls(**_filtered_kwargs(config_cls, build_trainer_args(cfg)))
83
+
84
+ trainer_kwargs: dict = {"model": model, "train_dataset": dataset, "args": args}
85
+ if spec.modality == "vision" or spec.id == "ocr_sft":
86
+ trainer_kwargs["processor"] = tokenizer
87
+ else:
88
+ trainer_kwargs["tokenizer"] = tokenizer
89
+ if spec.collator:
90
+ collator_cls = resolve(spec.collator)
91
+ trainer_kwargs["data_collator"] = collator_cls(model, tokenizer)
92
+
93
+ try:
94
+ trainer = trainer_cls(**trainer_kwargs)
95
+ except TypeError:
96
+ # Some trainers take `processor` instead of `tokenizer` or vice versa.
97
+ if "tokenizer" in trainer_kwargs:
98
+ trainer_kwargs["processor"] = trainer_kwargs.pop("tokenizer")
99
+ else:
100
+ trainer_kwargs["tokenizer"] = trainer_kwargs.pop("processor")
101
+ trainer = trainer_cls(**trainer_kwargs)
102
+
103
+ job.add_log(f"▶ {spec.label} started — {len(dataset)} samples, output → {cfg.output_dir}")
104
+ trainer.train()
105
+ job.add_log("✅ Training finished.")
finetuner/ui/__init__.py ADDED
File without changes
finetuner/ui/tab_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset tab: load from the Hub, a local path, or upload — then auto-detect the format."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pandas as pd
6
+ import gradio as gr
7
+
8
+ from ..core import data as datalib
9
+ from ..core.detector import detect
10
+ from ..core.registry import get_task
11
+ from ..core.state import STATE
12
+
13
+ PREVIEW_ROWS = 8
14
+
15
+
16
+ def _preview_df(rows: list[dict]) -> pd.DataFrame:
17
+ if not rows:
18
+ return pd.DataFrame()
19
+ df = pd.DataFrame(rows[:PREVIEW_ROWS])
20
+ return df.map(lambda v: str(v)[:300] if not isinstance(v, (int, float, bool)) else v)
21
+
22
+
23
+ def _detection_md() -> str:
24
+ det = STATE.detection
25
+ if det is None:
26
+ return ""
27
+ bar = "🟩" * round(det.confidence * 10) + "⬜" * (10 - round(det.confidence * 10))
28
+ lines = [
29
+ f"### 🔎 Detected format: **{det.label}**",
30
+ f"Confidence: {bar} **{det.confidence:.0%}**",
31
+ ]
32
+ if det.mapping:
33
+ mapped = ", ".join(f"`{k}` ← `{v}`" for k, v in det.mapping.items())
34
+ lines.append(f"Column mapping: {mapped}")
35
+ if det.suggested_tasks:
36
+ tasks = ", ".join(f"**{get_task(t).label}**" for t in det.suggested_tasks)
37
+ lines.append(f"Compatible trainers: {tasks}")
38
+ for note in det.notes:
39
+ lines.append(f"> 💡 {note}")
40
+ return "\n\n".join(lines)
41
+
42
+
43
+ def _ingest(rows: list[dict], source: str, is_local: bool) -> tuple[str, pd.DataFrame, str]:
44
+ STATE.raw_rows = rows
45
+ STATE.detection = detect(rows)
46
+ STATE.dataset_source = source
47
+ STATE.dataset_is_local = is_local
48
+ return (f"✅ Loaded **{len(rows)}** rows from `{source}`.",
49
+ _preview_df(rows), _detection_md())
50
+
51
+
52
+ def build(app):
53
+ with gr.Tab("📚 Dataset", id="dataset"):
54
+ gr.Markdown("### Load a dataset — the format is detected automatically")
55
+ source = gr.Radio(["Hugging Face Hub", "Local file", "Upload"],
56
+ value="Hugging Face Hub", label="Source")
57
+
58
+ with gr.Group() as hub_group:
59
+ with gr.Row():
60
+ query = gr.Textbox(label="Search Hub datasets", placeholder="e.g. alpaca turkish", scale=3)
61
+ search_btn = gr.Button("🔍 Search", scale=1)
62
+ with gr.Row():
63
+ ds_name = gr.Dropdown(label="Dataset", allow_custom_value=True, choices=[],
64
+ info="Pick a result or type any dataset id.", scale=3)
65
+ split = gr.Textbox(value="train", label="Split", scale=1)
66
+ subset = gr.Textbox(value="", label="Config (optional)", scale=1)
67
+
68
+ local_path = gr.Textbox(label="Local dataset path", visible=False,
69
+ placeholder="~/data/train.jsonl (.jsonl/.json/.csv/.tsv/.parquet)")
70
+ upload = gr.File(label="Upload dataset", visible=False,
71
+ file_types=[".jsonl", ".json", ".csv", ".tsv", ".parquet"])
72
+
73
+ with gr.Row():
74
+ max_rows = gr.Number(value=0, precision=0, label="Max rows (0 = all)")
75
+ load_btn = gr.Button("📥 Load dataset", variant="primary", scale=2)
76
+
77
+ status = gr.Markdown()
78
+ detection_panel = gr.Markdown()
79
+ preview = gr.Dataframe(label=f"Preview (first {PREVIEW_ROWS} rows)", interactive=False, wrap=True)
80
+
81
+ # ----- events -------------------------------------------------------
82
+ def on_source(src):
83
+ return (gr.update(visible=src == "Hugging Face Hub"),
84
+ gr.update(visible=src == "Local file"),
85
+ gr.update(visible=src == "Upload"))
86
+
87
+ source.change(on_source, source, [hub_group, local_path, upload])
88
+
89
+ def on_search(q):
90
+ results = datalib.search_hub_datasets(q)
91
+ if not results:
92
+ gr.Warning(f"No Hub datasets found for {q!r}")
93
+ return gr.update()
94
+ return gr.update(choices=results, value=results[0])
95
+
96
+ search_btn.click(on_search, query, ds_name)
97
+ query.submit(on_search, query, ds_name)
98
+
99
+ def on_load(src, name, split_v, subset_v, path, file, n, progress=gr.Progress()):
100
+ limit = int(n) or None
101
+ try:
102
+ if src == "Hugging Face Hub":
103
+ if not name:
104
+ return "❌ Choose a dataset first.", gr.update(), ""
105
+ progress(0.2, desc=f"Downloading {name} …")
106
+ rows = datalib.load_hub_dataset(name, split_v or "train", subset_v or None, limit)
107
+ return _ingest(rows, name, is_local=False)
108
+ target = path if src == "Local file" else (file.name if file else "")
109
+ if not target:
110
+ return "❌ Provide a file first.", gr.update(), ""
111
+ rows = datalib.load_local_dataset(target, limit)
112
+ return _ingest(rows, target, is_local=True)
113
+ except Exception as exc:
114
+ return f"❌ Failed to load dataset: {exc}", gr.update(), ""
115
+
116
+ load_btn.click(on_load, [source, ds_name, split, subset, local_path, upload, max_rows],
117
+ [status, preview, detection_panel])
finetuner/ui/tab_export.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Export tab: adapters, merged weights, GGUF, and Hugging Face Hub upload."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from ..core import export
8
+ from ..core.state import STATE
9
+
10
+
11
+ def _guard():
12
+ if STATE.model is None:
13
+ raise gr.Error("No model in memory — load and train one first.")
14
+
15
+
16
+ def build(app):
17
+ with gr.Tab("📦 Export", id="export"):
18
+ gr.Markdown("### Save or publish the fine-tuned model")
19
+
20
+ with gr.Group():
21
+ gr.Markdown("**LoRA adapters** — small, fast to share")
22
+ with gr.Row():
23
+ adapter_path = gr.Textbox(value="lora_model", label="Directory", scale=3)
24
+ adapter_btn = gr.Button("💾 Save adapters", scale=1)
25
+
26
+ with gr.Group():
27
+ gr.Markdown("**Merged model** — base weights + adapters fused to 16-bit")
28
+ with gr.Row():
29
+ merged_path = gr.Textbox(value="merged", label="Directory", scale=3)
30
+ merged_btn = gr.Button("🔗 Save merged", scale=1)
31
+
32
+ with gr.Group():
33
+ gr.Markdown("**GGUF** — for llama.cpp / Ollama. "
34
+ "⚠️ Requires a *non-quantized* base model (mlx-lm limitation).")
35
+ with gr.Row():
36
+ gguf_path = gr.Textbox(value="model_gguf", label="Directory", scale=3)
37
+ gguf_btn = gr.Button("🦙 Export GGUF", scale=1)
38
+
39
+ with gr.Group():
40
+ gr.Markdown("**Hugging Face Hub** — publish the model to your account")
41
+ with gr.Row():
42
+ repo_id = gr.Textbox(label="Repo id", placeholder="username/my-finetuned-model", scale=2)
43
+ hf_token = gr.Textbox(label="HF token (optional if logged in)", type="password", scale=2)
44
+ push_btn = gr.Button("🤗 Push to Hub", variant="primary", scale=1)
45
+
46
+ status = gr.Markdown()
47
+
48
+ def run(fn, *args):
49
+ _guard()
50
+ try:
51
+ return f"✅ {fn(*args)}"
52
+ except Exception as exc:
53
+ return f"❌ {exc}"
54
+
55
+ adapter_btn.click(lambda p: run(export.save_adapters, STATE.model, p),
56
+ adapter_path, status)
57
+ merged_btn.click(lambda p: run(export.save_merged, STATE.model, STATE.tokenizer, p),
58
+ merged_path, status)
59
+ gguf_btn.click(lambda p: run(export.save_gguf, STATE.model, STATE.tokenizer, p),
60
+ gguf_path, status)
61
+ push_btn.click(lambda r, t: run(export.push_to_hub, STATE.model, r, t or None),
62
+ [repo_id, hf_token], status)
finetuner/ui/tab_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model tab: pick a task, find a model (Hub search or local path), load it, attach LoRA."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from ..core import models
8
+ from ..core.registry import get_task, mlx_available, task_choices
9
+ from ..core.state import STATE
10
+
11
+
12
+ def _task_info(task_id: str) -> str:
13
+ spec = get_task(task_id)
14
+ lines = [f"**{spec.label}**", "", spec.description, "",
15
+ f"- Backend: `mlx_tune.{spec.trainer}` + `{spec.config}`",
16
+ f"- Dataset schema: `{', '.join(spec.dataset_schema)}`"]
17
+ if spec.notes:
18
+ lines.append(f"- ⚠️ {spec.notes}")
19
+ return "\n".join(lines)
20
+
21
+
22
+ def build(app):
23
+ with gr.Tab("🧠 Model", id="model"):
24
+ gr.Markdown("### 1 · Choose a task and a base model")
25
+ with gr.Row():
26
+ with gr.Column(scale=1):
27
+ task = gr.Dropdown(choices=task_choices(), value="sft", label="Training task",
28
+ info="Every mlx-tune trainer is available here.")
29
+ task_info = gr.Markdown(_task_info("sft"))
30
+ with gr.Column(scale=2):
31
+ source = gr.Radio(["Hugging Face Hub", "Local path"], value="Hugging Face Hub",
32
+ label="Model source")
33
+ with gr.Group() as hub_group:
34
+ with gr.Row():
35
+ query = gr.Textbox(label="Search the Hub",
36
+ placeholder="e.g. llama 3.2 instruct 4bit", scale=3)
37
+ search_btn = gr.Button("🔍 Search", scale=1)
38
+ model_name = gr.Dropdown(label="Model", allow_custom_value=True,
39
+ value=get_task("sft").default_model,
40
+ choices=[get_task("sft").default_model],
41
+ info="Pick a search result or type any repo id.")
42
+ local_path = gr.Textbox(label="Local model directory", visible=False,
43
+ placeholder="/path/to/converted-mlx-model")
44
+ with gr.Row():
45
+ max_seq = gr.Slider(256, 32768, value=2048, step=256, label="Max sequence length")
46
+ four_bit = gr.Checkbox(value=True, label="Load in 4-bit")
47
+
48
+ gr.Markdown("### 2 · LoRA adapters")
49
+ with gr.Row():
50
+ use_lora = gr.Checkbox(value=True, label="Attach LoRA", scale=1)
51
+ lora_r = gr.Slider(1, 256, value=16, step=1, label="Rank (r)", scale=2)
52
+ lora_alpha = gr.Slider(1, 256, value=16, step=1, label="Alpha", scale=2)
53
+ lora_dropout = gr.Slider(0.0, 0.5, value=0.0, step=0.01, label="Dropout", scale=2)
54
+ target_modules = gr.CheckboxGroup(
55
+ choices=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
56
+ value=["q_proj", "k_proj", "v_proj", "o_proj"],
57
+ label="Target modules (text models)")
58
+
59
+ load_btn = gr.Button("⚡ Load model", variant="primary")
60
+ status = gr.Markdown()
61
+
62
+ # ----- events -------------------------------------------------------
63
+ def on_task(task_id):
64
+ spec = get_task(task_id)
65
+ return (_task_info(task_id),
66
+ gr.update(value=spec.default_model, choices=[spec.default_model]),
67
+ gr.update(value=list(spec.default_target_modules)))
68
+
69
+ task.change(on_task, task, [task_info, model_name, target_modules])
70
+
71
+ def on_source(src):
72
+ hub = src == "Hugging Face Hub"
73
+ return gr.update(visible=hub), gr.update(visible=not hub)
74
+
75
+ source.change(on_source, source, [hub_group, local_path])
76
+
77
+ def on_search(q):
78
+ results = models.search_hub_models(q)
79
+ if not results:
80
+ gr.Warning(f"No Hub models found for {q!r}")
81
+ return gr.update()
82
+ return gr.update(choices=results, value=results[0])
83
+
84
+ search_btn.click(on_search, query, model_name)
85
+ query.submit(on_search, query, model_name)
86
+
87
+ def on_load(task_id, src, name, path, seq, fourbit,
88
+ lora, r, alpha, dropout, targets, progress=gr.Progress()):
89
+ ok, reason = mlx_available()
90
+ if not ok:
91
+ return f"❌ {reason}"
92
+ resolved = name
93
+ try:
94
+ if src == "Local path":
95
+ resolved = models.validate_local_model(path)
96
+ progress(0.1, desc=f"Loading {resolved} …")
97
+ model, tok = models.load_model(task_id, resolved, int(seq), bool(fourbit))
98
+ if lora and get_task(task_id).peft_supported:
99
+ progress(0.7, desc="Attaching LoRA adapters …")
100
+ model = models.apply_lora(task_id, model, int(r), int(alpha),
101
+ float(dropout), list(targets))
102
+ STATE.model, STATE.tokenizer = model, tok
103
+ STATE.model_name = resolved
104
+ STATE.model_loaded_for_task = task_id
105
+ STATE.lora_attached = bool(lora)
106
+ cfg = STATE.config
107
+ cfg.task, cfg.model_name = task_id, resolved
108
+ cfg.max_seq_length, cfg.load_in_4bit = int(seq), bool(fourbit)
109
+ cfg.use_lora, cfg.lora_r, cfg.lora_alpha = bool(lora), int(r), int(alpha)
110
+ cfg.lora_dropout, cfg.target_modules = float(dropout), list(targets)
111
+ return (f"✅ **{resolved}** loaded for **{get_task(task_id).label}**"
112
+ + (" with LoRA attached." if lora else "."))
113
+ except Exception as exc: # surfaced to the user, not crashed
114
+ return f"❌ Load failed: {exc}"
115
+
116
+ load_btn.click(
117
+ on_load,
118
+ [task, source, model_name, local_path, max_seq, four_bit,
119
+ use_lora, lora_r, lora_alpha, lora_dropout, target_modules],
120
+ status)
121
+
122
+ return {"task": task}
finetuner/ui/tab_monitor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Monitor tab: live logs, loss curve and job control, refreshed by a timer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pandas as pd
6
+ import gradio as gr
7
+
8
+ from ..core.jobs import MANAGER
9
+
10
+ STATUS_ICONS = {"pending": "⏳", "running": "🏃", "finished": "✅",
11
+ "failed": "❌", "stopped": "⏹"}
12
+
13
+
14
+ def _job_choices() -> list[tuple[str, int]]:
15
+ return [(f"#{j.id} {STATUS_ICONS.get(j.status, '')} {j.name}", j.id)
16
+ for j in reversed(MANAGER.all())]
17
+
18
+
19
+ def _snapshot(job_id):
20
+ job = MANAGER.get(int(job_id)) if job_id else MANAGER.latest()
21
+ if job is None:
22
+ return ("*No jobs yet — start one from the 🚀 Train tab.*", "",
23
+ pd.DataFrame({"step": [], "loss": []}))
24
+ header = (f"**Job #{job.id}** · {job.name} · "
25
+ f"{STATUS_ICONS.get(job.status, '')} **{job.status}** · "
26
+ f"⏱ {job.elapsed:.0f}s · {len(job.metrics)} loss points")
27
+ df = pd.DataFrame(job.metrics, columns=["step", "loss"]) if job.metrics \
28
+ else pd.DataFrame({"step": [], "loss": []})
29
+ return header, job.log_text(), df
30
+
31
+
32
+ def build(app):
33
+ with gr.Tab("📈 Monitor", id="monitor"):
34
+ with gr.Row():
35
+ job_pick = gr.Dropdown(label="Job", choices=_job_choices(), scale=3)
36
+ refresh_btn = gr.Button("🔄 Refresh list", scale=1)
37
+ stop_btn = gr.Button("⏹ Stop job", variant="stop", scale=1)
38
+ header = gr.Markdown("*No jobs yet — start one from the 🚀 Train tab.*")
39
+ with gr.Row():
40
+ with gr.Column(scale=1):
41
+ loss_plot = gr.LinePlot(x="step", y="loss", label="Training loss",
42
+ value=pd.DataFrame({"step": [], "loss": []}))
43
+ with gr.Column(scale=1):
44
+ logs = gr.Textbox(label="Live logs", lines=20, max_lines=20,
45
+ autoscroll=True, interactive=False)
46
+
47
+ timer = gr.Timer(2.0)
48
+ timer.tick(_snapshot, job_pick, [header, logs, loss_plot])
49
+
50
+ refresh_btn.click(lambda: gr.update(choices=_job_choices()), None, job_pick)
51
+ job_pick.change(_snapshot, job_pick, [header, logs, loss_plot])
52
+
53
+ def on_stop(job_id):
54
+ job = MANAGER.get(int(job_id)) if job_id else MANAGER.latest()
55
+ if job is None:
56
+ return gr.update()
57
+ MANAGER.stop(job.id)
58
+ return gr.update(choices=_job_choices())
59
+
60
+ stop_btn.click(on_stop, job_pick, job_pick)
finetuner/ui/tab_playground.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Playground tab: chat with the currently loaded (and freshly tuned) model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from ..core.engine import ENGINE
8
+ from ..core.state import STATE
9
+
10
+
11
+ def _generate(message: str, history: list[dict], max_tokens: int, temperature: float) -> str:
12
+ # MLX generation must run on the engine thread (streams are thread-local).
13
+ return ENGINE.call(_generate_inner, message, history, max_tokens, temperature)
14
+
15
+
16
+ def _generate_inner(message: str, history: list[dict], max_tokens: int, temperature: float) -> str:
17
+ if STATE.model is None or STATE.tokenizer is None:
18
+ return "⚠️ No model loaded — load one in the 🧠 Model tab first."
19
+
20
+ messages = [{"role": h["role"], "content": h["content"]} for h in history]
21
+ messages.append({"role": "user", "content": message})
22
+
23
+ tok = STATE.tokenizer
24
+ try:
25
+ prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
26
+ except Exception:
27
+ prompt = "\n".join(m["content"] for m in messages)
28
+
29
+ # mlx-tune models are mlx-lm compatible; prefer its generate().
30
+ try:
31
+ from mlx_lm import generate as mlx_generate
32
+ from mlx_lm.sample_utils import make_sampler
33
+ return mlx_generate(STATE.model, tok, prompt=prompt, max_tokens=int(max_tokens),
34
+ sampler=make_sampler(temp=float(temperature)), verbose=False)
35
+ except Exception:
36
+ pass
37
+ try: # older mlx-lm signature
38
+ from mlx_lm import generate as mlx_generate
39
+ return mlx_generate(STATE.model, tok, prompt=prompt,
40
+ max_tokens=int(max_tokens), verbose=False)
41
+ except Exception as exc:
42
+ return f"⚠️ Generation failed: {exc}"
43
+
44
+
45
+ def build(app):
46
+ with gr.Tab("💬 Playground", id="playground"):
47
+ gr.Markdown("### Test the loaded model — before and after fine-tuning")
48
+ with gr.Row():
49
+ max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max new tokens")
50
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
51
+ gr.ChatInterface(
52
+ fn=_generate,
53
+ additional_inputs=[max_tokens, temperature],
54
+ examples=[["Merhaba! Kendini tanıtır mısın?"],
55
+ ["Explain LoRA fine-tuning in two sentences."]],
56
+ )
finetuner/ui/tab_train.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train tab: hyperparameters, recipes, the code generator, and the launch button."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from ..core import recipes
8
+ from ..core.codegen import generate_script
9
+ from ..core.detector import normalize
10
+ from ..core.jobs import MANAGER
11
+ from ..core.registry import get_task, mlx_available
12
+ from ..core.state import STATE
13
+ from ..core.training import run_training
14
+
15
+
16
+ def _collect(cfg_fields: dict) -> None:
17
+ cfg = STATE.config
18
+ cfg.output_dir = cfg_fields["output_dir"]
19
+ cfg.batch_size = int(cfg_fields["batch_size"])
20
+ cfg.gradient_accumulation_steps = int(cfg_fields["grad_accum"])
21
+ cfg.learning_rate = float(cfg_fields["lr"])
22
+ cfg.max_steps = int(cfg_fields["max_steps"])
23
+ cfg.num_train_epochs = float(cfg_fields["epochs"]) or None
24
+ cfg.warmup_steps = int(cfg_fields["warmup"])
25
+ cfg.gradient_checkpointing = bool(cfg_fields["grad_ckpt"])
26
+ cfg.seed = int(cfg_fields["seed"])
27
+ extra = {}
28
+ if cfg.task in ("dpo", "orpo") and cfg_fields["beta"]:
29
+ extra["beta"] = float(cfg_fields["beta"])
30
+ cfg.extra = extra
31
+
32
+
33
+ def build(app):
34
+ with gr.Tab("🚀 Train", id="train"):
35
+ gr.Markdown("### Hyperparameters")
36
+ with gr.Row():
37
+ output_dir = gr.Textbox(value="outputs", label="Output directory")
38
+ batch_size = gr.Slider(1, 32, value=2, step=1, label="Batch size")
39
+ grad_accum = gr.Slider(1, 64, value=1, step=1, label="Gradient accumulation")
40
+ with gr.Row():
41
+ lr = gr.Number(value=2e-4, label="Learning rate")
42
+ max_steps = gr.Slider(10, 10000, value=100, step=10, label="Max steps")
43
+ epochs = gr.Number(value=0, label="Epochs (0 → use max steps)")
44
+ with gr.Row():
45
+ warmup = gr.Slider(0, 500, value=5, step=5, label="Warmup steps")
46
+ seed = gr.Number(value=42, precision=0, label="Seed")
47
+ grad_ckpt = gr.Checkbox(value=False, label="Gradient checkpointing (saves memory)")
48
+ beta = gr.Number(value=0.1, label="β (DPO/ORPO only)")
49
+
50
+ with gr.Row():
51
+ start_btn = gr.Button("🏁 Start training", variant="primary", scale=2)
52
+ gen_btn = gr.Button("🧾 Generate Python script", scale=1)
53
+ status = gr.Markdown()
54
+
55
+ with gr.Accordion("Generated script (standalone mlx-tune code)", open=False):
56
+ script_out = gr.Code(language="python", label="train.py")
57
+ gr.Markdown("Copy this script anywhere — it reproduces this run without the GUI.")
58
+
59
+ with gr.Accordion("Recipes (save / load runs as YAML)", open=False):
60
+ with gr.Row():
61
+ recipe_name = gr.Textbox(label="Recipe name", placeholder="my-sft-run")
62
+ save_recipe_btn = gr.Button("💾 Save recipe")
63
+ with gr.Row():
64
+ recipe_pick = gr.Dropdown(label="Saved recipes", choices=recipes.list_recipes(),
65
+ allow_custom_value=True)
66
+ load_recipe_btn = gr.Button("📂 Load recipe")
67
+ recipe_status = gr.Markdown()
68
+
69
+ hp_inputs = [output_dir, batch_size, grad_accum, lr, max_steps, epochs,
70
+ warmup, seed, grad_ckpt, beta]
71
+
72
+ def _fields(*vals) -> dict:
73
+ keys = ["output_dir", "batch_size", "grad_accum", "lr", "max_steps",
74
+ "epochs", "warmup", "seed", "grad_ckpt", "beta"]
75
+ return dict(zip(keys, vals))
76
+
77
+ # ----- start training -------------------------------------------------
78
+ def on_start(*vals):
79
+ _collect(_fields(*vals))
80
+ cfg = STATE.config
81
+ ok, reason = mlx_available()
82
+ if not ok:
83
+ return f"❌ {reason}"
84
+ if STATE.model is None:
85
+ return "❌ Load a model first (🧠 Model tab)."
86
+ if not STATE.raw_rows:
87
+ return "❌ Load a dataset first (📚 Dataset tab)."
88
+ if STATE.model_loaded_for_task != cfg.task:
89
+ cfg.task = STATE.model_loaded_for_task
90
+ try:
91
+ dataset = normalize(STATE.raw_rows, STATE.detection, cfg.task, STATE.tokenizer)
92
+ except Exception as exc:
93
+ return (f"❌ Dataset incompatible with **{get_task(cfg.task).label}**: {exc}\n\n"
94
+ f"Detected format: {STATE.detection.label if STATE.detection else '—'}")
95
+ job = MANAGER.submit(f"{cfg.task} · {cfg.model_name}", run_training,
96
+ cfg, STATE.model, STATE.tokenizer, dataset)
97
+ return (f"🏃 **Job #{job.id}** started ({len(dataset)} samples). "
98
+ "Follow it in the **📈 Monitor** tab.")
99
+
100
+ start_btn.click(on_start, hp_inputs, status)
101
+
102
+ # ----- codegen --------------------------------------------------------
103
+ def on_generate(*vals):
104
+ _collect(_fields(*vals))
105
+ cfg = STATE.config
106
+ if STATE.model_loaded_for_task:
107
+ cfg.task = STATE.model_loaded_for_task
108
+ if not cfg.model_name:
109
+ cfg.model_name = get_task(cfg.task).default_model
110
+ return generate_script(cfg, STATE.dataset_source, STATE.dataset_is_local)
111
+
112
+ gen_btn.click(on_generate, hp_inputs, script_out)
113
+
114
+ # ----- recipes ---------------------------------------------------------
115
+ def on_save_recipe(name, *vals):
116
+ _collect(_fields(*vals))
117
+ path = recipes.save_recipe(STATE.config, name, STATE.dataset_source,
118
+ STATE.dataset_is_local)
119
+ return f"💾 Saved `{path}`", gr.update(choices=recipes.list_recipes())
120
+
121
+ save_recipe_btn.click(on_save_recipe, [recipe_name, *hp_inputs],
122
+ [recipe_status, recipe_pick])
123
+
124
+ def on_load_recipe(path):
125
+ if not path:
126
+ return ["❌ Pick a recipe."] + [gr.update()] * len(hp_inputs)
127
+ try:
128
+ cfg, src, is_local = recipes.load_recipe(path)
129
+ except Exception as exc:
130
+ return [f"❌ {exc}"] + [gr.update()] * len(hp_inputs)
131
+ STATE.config = cfg
132
+ STATE.dataset_source, STATE.dataset_is_local = src, is_local
133
+ return [f"📂 Loaded `{path}` — task **{cfg.task}**, model `{cfg.model_name}`. "
134
+ "Reload model/dataset to run it.",
135
+ cfg.output_dir, cfg.batch_size, cfg.gradient_accumulation_steps,
136
+ cfg.learning_rate, cfg.max_steps, cfg.num_train_epochs or 0,
137
+ cfg.warmup_steps, cfg.seed, cfg.gradient_checkpointing,
138
+ cfg.extra.get("beta", 0.1)]
139
+
140
+ load_recipe_btn.click(on_load_recipe, recipe_pick, [recipe_status, *hp_inputs])
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets>=3.0
2
+ pandas>=2.0
3
+ pyyaml>=6.0