ropedia-xperience-10m-task-baselines / scripts /omni /train_qwen3_omni_lora.py
cy0307's picture
Update final Qwen model scripts
627e5d7 verified
Raw
History Blame Contribute Delete
32 kB
#!/usr/bin/env python3
"""Conservative LoRA SFT for Qwen3-Omni action/subtask label generation."""
from __future__ import annotations
import argparse
import json
import math
import random
import time
from pathlib import Path
from types import MethodType
import torch
import torch.nn.functional as F
from qwen3_omni_dataset_utils import (
build_messages,
DEFAULT_MODEL_ID,
has_empty_audio_items,
is_empty_audio_exception,
load_jsonl,
sample_has_audio,
sample_without_audio,
)
def parse_args() -> argparse.Namespace:
workspace_default = Path(__file__).resolve().parents[2]
parser = argparse.ArgumentParser(description="Train Qwen3-Omni LoRA on exported Ropedia windows.")
parser.add_argument("--dataset-jsonl", type=Path, required=True)
parser.add_argument("--run-id", default="qwen_lora_text_video_audio")
parser.add_argument("--output-dir", type=Path)
parser.add_argument("--results-dir", type=Path)
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument(
"--backbone-config",
type=Path,
default=workspace_default / "configs" / "omni_backbones" / "qwen3_omni_lora.json",
help="Backbone contract JSON recorded with the run for model-extension tracking.",
)
parser.add_argument("--train-split", default="train")
parser.add_argument("--val-split", default="val")
parser.add_argument("--include-unspecified-in-train", action="store_true")
parser.add_argument("--max-train-samples", type=int, default=0)
parser.add_argument("--max-val-samples", type=int, default=64)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=0.0)
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--device-map", default="auto")
parser.add_argument("--dtype", default="bfloat16", choices=["auto", "bfloat16", "float16", "float32"])
parser.add_argument("--local-files-only", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--use-audio-in-video", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--gradient-checkpointing", action="store_true")
parser.add_argument("--progress-every", type=int, default=1)
parser.add_argument(
"--loss-logit-tail-only",
action=argparse.BooleanOptionalAction,
default=True,
help="For SFT, project only the supervised assistant-answer tail through lm_head before CE loss.",
)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.05)
parser.add_argument(
"--lora-target-modules",
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
help="Comma-separated module names passed to PEFT LoRAConfig.",
)
return parser.parse_args()
def dtype_arg(value: str):
if value == "auto":
return "auto"
return {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}[value]
def select_samples(samples: list[dict], split: str, include_unspecified: bool) -> list[dict]:
rows = [sample for sample in samples if sample.get("split") == split]
if include_unspecified:
rows.extend(sample for sample in samples if sample.get("split") == "unspecified")
return rows
def patch_rotary_position_device(model) -> bool:
"""Keep Qwen3-Omni rotary position ids aligned under model-parallel device maps."""
inner_model = getattr(model, "model", None)
rotary = getattr(inner_model, "rotary_emb", None)
if rotary is None or getattr(rotary, "_ropedia_position_device_patch", False):
return False
original_forward = rotary.forward
def forward_with_aligned_position_ids(self, x, position_ids, *args, **kwargs):
if hasattr(self, "inv_freq") and hasattr(x, "device") and self.inv_freq.device != x.device:
self._buffers["inv_freq"] = self.inv_freq.to(x.device)
if hasattr(position_ids, "to") and hasattr(x, "device") and position_ids.device != x.device:
position_ids = position_ids.to(x.device)
return original_forward(x, position_ids, *args, **kwargs)
rotary.forward = MethodType(forward_with_aligned_position_ids, rotary)
rotary._ropedia_position_device_patch = True
return True
def patch_qwen3_omni_rotary_classes() -> None:
"""Patch Qwen3-Omni MRoPE classes before Accelerate installs device hooks."""
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe as qwen3_omni_moe
def patch_mrope_class(class_name: str) -> None:
rotary_cls = getattr(qwen3_omni_moe, class_name, None)
if rotary_cls is None or getattr(rotary_cls, "_ropedia_class_device_patch", False):
return
@torch.no_grad()
def forward(self, x, position_ids):
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
target_device = x.device
inv_freq = self.inv_freq.to(target_device)
position_ids = position_ids.to(target_device)
inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float()
device_type = target_device.type if isinstance(target_device.type, str) and target_device.type != "mps" else "cpu"
with qwen3_omni_moe.maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
rotary_cls.forward = forward
rotary_cls._ropedia_class_device_patch = True
patch_mrope_class("Qwen3OmniMoeThinkerTextRotaryEmbedding")
patch_mrope_class("Qwen3OmniMoeTalkerRotaryEmbedding")
def patch_qwen3_omni_norm_classes() -> None:
"""Patch Qwen3-Omni RMSNorm classes for model-parallel device maps."""
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe as qwen3_omni_moe
def patch_norm_class(class_name: str) -> None:
norm_cls = getattr(qwen3_omni_moe, class_name, None)
if norm_cls is None or getattr(norm_cls, "_ropedia_class_device_patch", False):
return
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
norm_states = hidden_states.to(torch.float32)
variance = norm_states.pow(2).mean(-1, keepdim=True)
norm_states = norm_states * torch.rsqrt(variance + self.variance_epsilon)
weight = self.weight.to(hidden_states.device)
return weight * norm_states.to(input_dtype)
norm_cls.forward = forward
norm_cls._ropedia_class_device_patch = True
patch_norm_class("Qwen3OmniMoeRMSNorm")
patch_norm_class("Qwen3OmniMoeTextRMSNorm")
patch_norm_class("Qwen3OmniMoeThinkerTextRMSNorm")
patch_norm_class("Qwen3OmniMoeCode2WavRMSNorm")
def cast_floating_parameters(model, target_dtype) -> None:
if isinstance(target_dtype, str):
return
for param in model.parameters():
if param.is_floating_point() and param.dtype != target_dtype:
param.data = param.data.to(target_dtype)
def build_trainable_cpu_state_dict(model) -> dict[str, torch.Tensor]:
state_dict = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
clean_name = name
if clean_name.startswith("module."):
clean_name = clean_name[len("module.") :]
state_dict[clean_name] = param.detach().to("cpu", copy=True)
return state_dict
def strip_module_prefix(name: str) -> str:
while name.startswith("module."):
name = name[len("module.") :]
return name
def thinker_adapter_state_from_full_state(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
adapter_state: dict[str, torch.Tensor] = {}
for name, tensor in state_dict.items():
clean_name = strip_module_prefix(name)
if clean_name.startswith("thinker."):
clean_name = clean_name[len("thinker.") :]
elif clean_name.startswith("base_model.model.thinker."):
clean_name = clean_name[len("base_model.model.thinker.") :]
if "lora_" not in clean_name:
continue
if tensor.numel() == 0:
raise ValueError(f"Gathered LoRA state still has an empty tensor: {name}")
adapter_state[clean_name] = tensor.detach().to("cpu", copy=True)
if not adapter_state:
raise ValueError("Gathered state dict did not contain any LoRA tensors.")
return adapter_state
def adapter_shape_summary(state_dict: dict[str, torch.Tensor]) -> dict[str, object]:
prefixes = {}
for name, tensor in state_dict.items():
prefix = name.split(".")[2] if name.startswith("base_model.model.") and len(name.split(".")) > 2 else name.split(".")[0]
row = prefixes.setdefault(prefix, {"tensors": 0, "numel": 0})
row["tensors"] += 1
row["numel"] += int(tensor.numel())
return {
"adapter_tensors": len(state_dict),
"adapter_bytes": sum(t.numel() * t.element_size() for t in state_dict.values()),
"prefixes": prefixes,
}
class TailSlicingLmHead(torch.nn.Module):
"""Wrap lm_head so SFT can avoid full-prompt vocab logits."""
def __init__(self, base: torch.nn.Module) -> None:
super().__init__()
self.base = base
self.tail_start: int | None = None
self._ropedia_tail_slicing_lm_head = True
@property
def weight(self):
return self.base.weight
@property
def bias(self):
return getattr(self.base, "bias", None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.tail_start is not None:
hidden_states = hidden_states[:, self.tail_start :, :]
return self.base(hidden_states)
def install_tail_slicing_lm_head(model) -> bool:
base = model.get_base_model() if hasattr(model, "get_base_model") else model
lm_head = getattr(base, "lm_head", None)
if lm_head is None:
return False
if getattr(lm_head, "_ropedia_tail_slicing_lm_head", False):
return True
setattr(base, "lm_head", TailSlicingLmHead(lm_head))
return True
def set_tail_slicing_lm_head(model, tail_start: int | None) -> bool:
updated = False
for module in model.modules():
if getattr(module, "_ropedia_tail_slicing_lm_head", False):
module.tail_start = tail_start
updated = True
return updated
def first_supervised_label(labels: torch.Tensor) -> int | None:
active = labels[0].ne(-100).nonzero(as_tuple=False)
if active.numel() == 0:
return None
return int(active[0].item())
def load_backbone_profile(path: Path | None) -> dict:
if path is None:
return {
"id": "qwen3_omni_lora",
"display_name": "Qwen3-Omni LoRA",
"dataset_contract": "xperience10m_episode_json_qa_v1",
"training_objective": "structured_episode_understanding_json_qa",
"primary_metrics": [],
}
path = path.expanduser()
if not path.is_absolute():
path = Path(__file__).resolve().parents[2] / path
if not path.exists():
raise FileNotFoundError(f"Backbone config not found: {path}")
payload = json.loads(path.read_text(encoding="utf-8"))
return {
"id": payload.get("id"),
"display_name": payload.get("display_name"),
"status": payload.get("status"),
"model_family": payload.get("model_family"),
"dataset_contract": payload.get("dataset_contract"),
"training_objective": payload.get("training_objective"),
"split_policy": payload.get("split_policy", {}),
"modalities": payload.get("modalities", {}),
"primary_metrics": payload.get("primary_metrics", []),
"config_path": str(path),
}
def load_model_processor(args: argparse.Namespace):
from qwen3_omni_compat import patch_qwen3_omni_config
patch_qwen3_omni_config()
from peft import LoraConfig, get_peft_model
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
patch_qwen3_omni_rotary_classes()
patch_qwen3_omni_norm_classes()
model_kwargs = {
"dtype": dtype_arg(args.dtype),
"local_files_only": args.local_files_only,
}
if args.device_map and args.device_map.lower() != "none":
model_kwargs["device_map"] = args.device_map
if args.trust_remote_code:
model_kwargs["trust_remote_code"] = True
omni_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(args.model_id, **model_kwargs)
if hasattr(omni_model, "disable_talker"):
omni_model.disable_talker()
model = omni_model.thinker
if args.device_map and args.device_map.lower() != "none":
patch_rotary_position_device(model)
if args.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
processor_kwargs = {"local_files_only": args.local_files_only}
if args.trust_remote_code:
processor_kwargs["trust_remote_code"] = True
processor = Qwen3OmniMoeProcessor.from_pretrained(args.model_id, **processor_kwargs)
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
target_modules=[item.strip() for item in args.lora_target_modules.split(",") if item.strip()],
)
model = get_peft_model(model, config)
if args.loss_logit_tail_only:
install_tail_slicing_lm_head(model)
cast_floating_parameters(model, dtype_arg(args.dtype))
model.print_trainable_parameters()
return model, processor
def move_inputs(inputs, device, dtype=None):
for key, value in list(inputs.items()):
if hasattr(value, "to"):
if dtype is not None and getattr(value, "is_floating_point", lambda: False)():
inputs[key] = value.to(device=device, dtype=dtype)
else:
inputs[key] = value.to(device)
return inputs
def compute_answer_token_loss(model, inputs: dict, tail_only: bool = True) -> torch.Tensor:
"""Compute CE only on supervised answer tokens to avoid full-logit fp32 casts."""
labels = inputs.pop("labels")
tail_start = 0
if tail_only:
first_label = first_supervised_label(labels)
if first_label is None:
return model(**inputs).logits.sum() * 0.0
tail_start = max(first_label - 1, 0)
tail_only = set_tail_slicing_lm_head(model, tail_start)
try:
output = model(**inputs)
finally:
if tail_only:
set_tail_slicing_lm_head(model, None)
logits = output.logits
labels = labels.to(logits.device)
if tail_only:
if logits.shape[1] == labels.shape[1]:
logits = logits[:, tail_start:, :]
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., tail_start + 1 : tail_start + 1 + shift_logits.shape[1]]
else:
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
active = shift_labels != -100
if not active.any().item():
return logits.sum() * 0.0
active_logits = shift_logits[active]
active_labels = shift_labels[active]
return F.cross_entropy(active_logits.float(), active_labels, reduction="mean")
def prepare_sample(processor, sample: dict, use_audio_in_video: bool, device, dtype=None) -> dict:
from qwen_omni_utils import process_mm_info
active_sample = sample
for attempt in range(2):
full_messages = build_messages(active_sample, active_sample["label_options"], include_answer=True)
prompt_messages = build_messages(active_sample, active_sample["label_options"], include_answer=False)
full_text = processor.apply_chat_template(full_messages, tokenize=False)
prompt_text = processor.apply_chat_template(prompt_messages, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(full_messages, use_audio_in_video=use_audio_in_video)
if attempt == 0 and sample_has_audio(active_sample) and has_empty_audio_items(audios):
active_sample = sample_without_audio(active_sample)
continue
try:
inputs = processor(
text=full_text,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True,
use_audio_in_video=use_audio_in_video,
)
break
except RuntimeError as exc:
if attempt == 0 and sample_has_audio(active_sample) and is_empty_audio_exception(exc):
active_sample = sample_without_audio(active_sample)
continue
raise
else:
raise RuntimeError("Unable to prepare multimodal sample after dropping empty audio.")
labels = inputs["input_ids"].clone()
prompt_ids = processor.tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
prompt_len = min(prompt_ids.shape[1], labels.shape[1])
labels[:, :prompt_len] = -100
pad_id = processor.tokenizer.pad_token_id
if pad_id is not None:
labels[inputs["input_ids"] == pad_id] = -100
inputs["labels"] = labels
return move_inputs(inputs, device, dtype=dtype)
def write_progress(path: Path, row: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fp:
fp.write(json.dumps(row, ensure_ascii=False) + "\n")
def distributed_slice(samples: list[dict], process_index: int, num_processes: int) -> list[dict]:
if num_processes <= 1:
return list(samples)
shard = list(samples[process_index::num_processes])
max_len = math.ceil(len(samples) / num_processes)
if not samples:
return []
if not shard:
shard = [samples[process_index % len(samples)]]
while len(shard) < max_len:
shard.append(random.choice(shard))
return shard
def evaluate_loss(model, processor, samples: list[dict], args: argparse.Namespace, device, dtype=None, accelerator=None) -> float | None:
if not samples:
return None
losses = []
model.eval()
with torch.no_grad():
for sample in samples:
inputs = prepare_sample(processor, sample, args.use_audio_in_video, device, dtype=dtype)
loss = compute_answer_token_loss(model, inputs, tail_only=args.loss_logit_tail_only)
losses.append(float(loss.detach().cpu()))
model.train()
local = torch.tensor([sum(losses), len(losses)], dtype=torch.float32, device=device)
if accelerator is not None:
gathered = accelerator.gather(local)
total_loss = float(gathered[0::2].sum().detach().cpu())
total_count = float(gathered[1::2].sum().detach().cpu())
return total_loss / total_count if total_count else None
return sum(losses) / len(losses) if losses else None
def main() -> int:
args = parse_args()
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
workspace_default = Path(__file__).resolve().parents[2]
if args.output_dir is None:
args.output_dir = workspace_default / "checkpoints" / args.run_id / "adapter_lora"
if args.results_dir is None:
args.results_dir = workspace_default / "results" / "omni_finetune" / args.run_id
args.output_dir.mkdir(parents=True, exist_ok=True)
args.results_dir.mkdir(parents=True, exist_ok=True)
progress_path = args.results_dir / "progress.jsonl"
if accelerator.is_main_process and progress_path.exists():
progress_path.unlink()
backbone_profile = load_backbone_profile(args.backbone_config)
torch.manual_seed(args.seed + accelerator.process_index)
random.seed(args.seed + accelerator.process_index)
samples = load_jsonl(args.dataset_jsonl)
train_samples = select_samples(samples, args.train_split, args.include_unspecified_in_train)
val_samples = [sample for sample in samples if sample.get("split") == args.val_split]
if args.max_train_samples > 0:
train_samples = train_samples[: args.max_train_samples]
if args.max_val_samples > 0:
val_samples = val_samples[: args.max_val_samples]
if not train_samples:
raise ValueError("No training samples selected. Check --train-split or use --include-unspecified-in-train.")
rank_train_samples = distributed_slice(train_samples, accelerator.process_index, accelerator.num_processes)
rank_val_samples = distributed_slice(val_samples, accelerator.process_index, accelerator.num_processes) if val_samples else []
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "setup_done",
"run_id": args.run_id,
"dataset_jsonl": str(args.dataset_jsonl),
"num_processes": accelerator.num_processes,
"num_train_samples": len(train_samples),
"num_val_samples": len(val_samples),
"rank0_samples_per_epoch": len(rank_train_samples),
"backbone_id": backbone_profile.get("id"),
"dataset_contract": backbone_profile.get("dataset_contract"),
"training_objective": backbone_profile.get("training_objective"),
"loss_mode": "answer_token_ce",
"loss_logit_tail_only": args.loss_logit_tail_only,
"timestamp": time.time(),
})
if accelerator.num_processes > 1 and args.device_map == "auto":
args.device_map = "none"
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "model_load_start",
"run_id": args.run_id,
"model_id": args.model_id,
"backbone_id": backbone_profile.get("id"),
"device_map": args.device_map,
"dtype": args.dtype,
"timestamp": time.time(),
})
model, processor = load_model_processor(args)
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "model_load_done",
"run_id": args.run_id,
"timestamp": time.time(),
})
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay)
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "accelerator_prepare_start",
"run_id": args.run_id,
"timestamp": time.time(),
})
model, optimizer = accelerator.prepare(model, optimizer)
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "accelerator_prepare_done",
"run_id": args.run_id,
"timestamp": time.time(),
})
device = accelerator.device
model_dtype = next(model.parameters()).dtype
history = []
global_step = 0
optimizer.zero_grad(set_to_none=True)
model.train()
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "train_loop_start",
"run_id": args.run_id,
"model_id": args.model_id,
"dataset_jsonl": str(args.dataset_jsonl),
"num_processes": accelerator.num_processes,
"num_train_samples": len(train_samples),
"num_val_samples": len(val_samples),
"rank_samples_per_epoch": len(rank_train_samples),
"epochs": args.epochs,
"timestamp": time.time(),
})
for epoch in range(1, args.epochs + 1):
random.shuffle(rank_train_samples)
epoch_loss = 0.0
seen = 0
steps_in_epoch = math.ceil(len(rank_train_samples) / max(args.batch_size, 1))
for batch_start in range(0, len(rank_train_samples), args.batch_size):
batch = rank_train_samples[batch_start : batch_start + args.batch_size]
batch_loss = 0.0
for sample in batch:
with accelerator.accumulate(model):
inputs = prepare_sample(processor, sample, args.use_audio_in_video, device, dtype=model_dtype)
loss = compute_answer_token_loss(model, inputs, tail_only=args.loss_logit_tail_only)
accelerator.backward(loss)
batch_loss += float(loss.detach().cpu())
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
seen += len(batch)
epoch_loss += batch_loss
global_step += 1
if accelerator.is_main_process and (global_step % args.progress_every == 0 or batch_start // max(args.batch_size, 1) == steps_in_epoch - 1):
write_progress(progress_path, {
"event": "train_step",
"epoch": epoch,
"global_step": global_step,
"rank0_seen": seen,
"rank0_samples_per_epoch": len(rank_train_samples),
"rank0_batch_loss": batch_loss / max(len(batch), 1),
"timestamp": time.time(),
})
val_loss = evaluate_loss(model, processor, rank_val_samples, args, device, dtype=model_dtype, accelerator=accelerator)
epoch_row = {
"epoch": epoch,
"train_loss": epoch_loss / max(len(rank_train_samples), 1),
"val_loss": val_loss,
"global_step": global_step,
}
history.append(epoch_row)
if accelerator.is_main_process:
print(json.dumps(epoch_row, indent=2))
write_progress(progress_path, {"event": "epoch_end", **epoch_row, "timestamp": time.time()})
accelerator.wait_for_everyone()
if accelerator.is_main_process:
write_progress(progress_path, {
"event": "save_start",
"checkpoint_dir": str(args.output_dir),
"save_mode": "trainable_lora_state_dict",
"timestamp": time.time(),
})
accelerator.wait_for_everyone()
unwrapped = accelerator.unwrap_model(model)
gathered_state = accelerator.get_state_dict(model) if accelerator.num_processes > 1 else None
if accelerator.is_main_process:
if gathered_state is not None:
adapter_state = thinker_adapter_state_from_full_state(gathered_state)
state_source = "accelerator_full_state_dict"
else:
adapter_state = thinker_adapter_state_from_full_state(build_trainable_cpu_state_dict(unwrapped))
state_source = "local_trainable_state_dict"
shape_summary = adapter_shape_summary(adapter_state)
write_progress(progress_path, {
"event": "save_state_dict_built",
"checkpoint_dir": str(args.output_dir),
"state_source": state_source,
"trainable_tensors": shape_summary["adapter_tensors"],
"trainable_bytes": shape_summary["adapter_bytes"],
"shape_summary": shape_summary,
"timestamp": time.time(),
})
peft_model = getattr(unwrapped, "thinker", unwrapped)
peft_model.save_pretrained(args.output_dir, state_dict=adapter_state, is_main_process=True)
processor.save_pretrained(args.output_dir)
write_progress(progress_path, {
"event": "save_done",
"checkpoint_dir": str(args.output_dir),
"timestamp": time.time(),
})
metadata = {
"run_id": args.run_id,
"model_id": args.model_id,
"backbone": backbone_profile,
"dataset_jsonl": str(args.dataset_jsonl),
"checkpoint_dir": str(args.output_dir),
"num_processes": accelerator.num_processes,
"num_train_samples": len(train_samples),
"num_val_samples": len(val_samples),
"history": history,
"lora": {
"r": args.lora_r,
"alpha": args.lora_alpha,
"dropout": args.lora_dropout,
"target_modules": [item.strip() for item in args.lora_target_modules.split(",") if item.strip()],
},
"use_audio_in_video": args.use_audio_in_video,
"loss_mode": "answer_token_ce",
"loss_logit_tail_only": args.loss_logit_tail_only,
}
if accelerator.is_main_process:
(args.output_dir / "training_metadata.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8")
(args.results_dir / "config.yaml").write_text(
"\n".join([
f"run_id: {args.run_id}",
"stage: qwen_lora_text_video_audio",
f"backbone_id: {backbone_profile.get('id')}",
f"dataset_contract: {backbone_profile.get('dataset_contract')}",
f"model_id: {args.model_id}",
f"dataset_jsonl: {args.dataset_jsonl}",
f"checkpoint_dir: {args.output_dir}",
f"num_processes: {accelerator.num_processes}",
f"epochs: {args.epochs}",
f"learning_rate: {args.learning_rate}",
f"lora_r: {args.lora_r}",
f"lora_alpha: {args.lora_alpha}",
"loss_mode: answer_token_ce",
f"loss_logit_tail_only: {args.loss_logit_tail_only}",
]) + "\n",
encoding="utf-8",
)
(args.results_dir / "training_metadata.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8")
report = [
"# Qwen3-Omni LoRA Training",
"",
f"- Backbone profile: `{backbone_profile.get('display_name')}`",
f"- Dataset contract: `{backbone_profile.get('dataset_contract')}`",
f"- Training objective: `{backbone_profile.get('training_objective')}`",
f"- Base model: `{args.model_id}`",
f"- Dataset: `{args.dataset_jsonl}`",
f"- Train samples: `{len(train_samples)}`",
f"- Validation samples: `{len(val_samples)}`",
f"- Processes: `{accelerator.num_processes}`",
f"- Epochs: `{args.epochs}`",
"- Loss: answer-token cross entropy over supervised JSON tokens",
f"- Logit projection: `{'assistant-answer tail only' if args.loss_logit_tail_only else 'full sequence'}`",
f"- Final train loss: `{history[-1]['train_loss']:.6f}`",
"",
"Only LoRA parameters are trained; the base Qwen3-Omni weights remain frozen.",
]
if history[-1]["val_loss"] is not None:
report.append(f"- Final val loss: `{history[-1]['val_loss']:.6f}`")
(args.results_dir / "RUN_REPORT.md").write_text("\n".join(report) + "\n", encoding="utf-8")
write_progress(progress_path, {"event": "complete", "checkpoint_dir": str(args.output_dir), "timestamp": time.time()})
print(f"Wrote LoRA adapter to {args.output_dir}")
return 0
if __name__ == "__main__":
raise SystemExit(main())