from __future__ import annotations from bisect import bisect_left, insort import torch import torch.nn.functional as F from torch.utils.data import Dataset from src.training_data import DistillationDataset class SequencePackedDataset(Dataset): def __init__( self, source: DistillationDataset, source_indices: list[int], pack_length: int, eos_token_id: int, pad_token_id: int, mask_first_after_separator: bool = True, ): if pack_length <= 0: raise ValueError(f"pack_length must be positive, got {pack_length}.") if not hasattr(source, "sample_lengths"): raise ValueError("Packed training requires a source dataset with sample_lengths metadata.") if not source_indices: raise ValueError("Packed training requires at least one source row.") self.source = source self.source_indices = [int(index) for index in source_indices] self.source_index_set = set(self.source_indices) if len(self.source_index_set) != len(self.source_indices): raise ValueError("Packed training source indices contain duplicates.") self.pack_length = int(pack_length) self.eos_token_id = int(eos_token_id) self.pad_token_id = int(pad_token_id) self.mask_first_after_separator = bool(mask_first_after_separator) self._length_by_index: dict[int, int] = {} self.plan: list[list[int]] = [] for source_index in self.source_indices: try: length = int(source.sample_lengths[source_index]) except IndexError as exc: raise IndexError(f"Source index {source_index} is outside the tokenized dataset.") from exc if length > self.pack_length: raise ValueError( f"Tokenized sample #{source_index} has length {length}, " f"which exceeds pack_length={self.pack_length}." ) self._length_by_index[source_index] = length self._build_plan() self._validate_plan() self.source_sample_count = len(self.source_indices) self.bin_count = len(self.plan) self.original_token_count = sum(self._length_by_index.values()) self.separator_token_count = sum(max(0, len(bin_indices) - 1) for bin_indices in self.plan) self.packed_token_count = self.original_token_count + self.separator_token_count self.total_capacity = self.bin_count * self.pack_length self.pad_token_count = self.total_capacity - self.packed_token_count self.average_samples_per_bin = self.source_sample_count / max(self.bin_count, 1) self.utilization = self.packed_token_count / max(self.total_capacity, 1) def _build_plan(self) -> None: items = sorted( ((self._length_by_index[source_index], source_index) for source_index in self.source_indices), key=lambda item: (-item[0], item[1]), ) available: list[tuple[int, int]] = [] for length, source_index in items: required_existing = length + 1 insert_at = bisect_left(available, (required_existing, -1)) if insert_at == len(available): bin_id = len(self.plan) self.plan.append([source_index]) remaining = self.pack_length - length insort(available, (remaining, bin_id)) continue remaining, bin_id = available.pop(insert_at) next_remaining = remaining - required_existing if next_remaining < 0: raise ValueError("Internal packing error: bin capacity became negative.") self.plan[bin_id].append(source_index) insort(available, (next_remaining, bin_id)) def _validate_plan(self) -> None: seen: set[int] = set() for bin_id, bin_indices in enumerate(self.plan): if not bin_indices: raise ValueError(f"Packed bin #{bin_id} is empty.") real_length = sum(self._length_by_index[source_index] for source_index in bin_indices) real_length += max(0, len(bin_indices) - 1) if real_length > self.pack_length: raise ValueError( f"Packed bin #{bin_id} has real_length={real_length}, " f"which exceeds pack_length={self.pack_length}." ) for source_index in bin_indices: if source_index in seen: raise ValueError(f"Source sample #{source_index} appears in more than one packed bin.") seen.add(source_index) missing = self.source_index_set - seen if missing: first_missing = min(missing) raise ValueError(f"Source sample #{first_missing} was not assigned to a packed bin.") def __len__(self) -> int: return len(self.plan) def __getitem__(self, bin_idx: int) -> dict[str, torch.Tensor]: bin_indices = self.plan[bin_idx] input_parts: list[torch.Tensor] = [] mask_parts: list[torch.Tensor] = [] original_tokens = 0 separator_tokens = 0 for sample_offset, source_index in enumerate(bin_indices): item = self.source[source_index] input_ids = item["input_ids"].long() loss_mask = item["loss_mask"].long() original_tokens += int(input_ids.size(0)) if sample_offset > 0: input_parts.append(torch.tensor([self.eos_token_id], dtype=torch.long)) mask_parts.append(torch.zeros(1, dtype=torch.long)) separator_tokens += 1 if self.mask_first_after_separator and loss_mask.numel() > 0: loss_mask = loss_mask.clone() loss_mask[0] = 0 input_parts.append(input_ids) mask_parts.append(loss_mask) input_ids = torch.cat(input_parts) loss_mask = torch.cat(mask_parts) real_length = int(input_ids.size(0)) if real_length > self.pack_length: raise ValueError( f"Packed bin #{bin_idx} has real_length={real_length}, " f"which exceeds pack_length={self.pack_length}." ) pad_len = self.pack_length - real_length if pad_len: input_ids = F.pad(input_ids, (0, pad_len), value=self.pad_token_id) loss_mask = F.pad(loss_mask, (0, pad_len), value=0) return { "input_ids": input_ids, "loss_mask": loss_mask, "real_length": torch.tensor(real_length, dtype=torch.long), "source_samples": torch.tensor(len(bin_indices), dtype=torch.long), "original_tokens": torch.tensor(original_tokens, dtype=torch.long), "separator_tokens": torch.tensor(separator_tokens, dtype=torch.long), } def collate_packed_fn(batch: list[dict], pad_token_id: int) -> dict: del pad_token_id input_ids = torch.stack([item["input_ids"] for item in batch]) loss_mask = torch.stack([item["loss_mask"] for item in batch]).long() real_lengths = torch.stack([item["real_length"] for item in batch]).long() seq_len = input_ids.size(1) positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) attention_mask = (positions < real_lengths.unsqueeze(1)).long() labels = input_ids.clone() labels = labels.masked_fill(loss_mask == 0, -100) return { "input_ids": input_ids, "attention_mask": attention_mask, "loss_mask": loss_mask, "labels": labels, "real_length": real_lengths, "source_samples": torch.stack([item["source_samples"] for item in batch]).long(), "original_tokens": torch.stack([item["original_tokens"] for item in batch]).long(), "separator_tokens": torch.stack([item["separator_tokens"] for item in batch]).long(), }