Quintus / src /sequence_packing.py
iamrahulreddy's picture
release: publish Quintus project files
4fc1bb9 verified
Raw
History Blame Contribute Delete
7.97 kB
from __future__ import annotations
from bisect import bisect_left, insort
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from src.training_data import DistillationDataset
class SequencePackedDataset(Dataset):
def __init__(
self,
source: DistillationDataset,
source_indices: list[int],
pack_length: int,
eos_token_id: int,
pad_token_id: int,
mask_first_after_separator: bool = True,
):
if pack_length <= 0:
raise ValueError(f"pack_length must be positive, got {pack_length}.")
if not hasattr(source, "sample_lengths"):
raise ValueError("Packed training requires a source dataset with sample_lengths metadata.")
if not source_indices:
raise ValueError("Packed training requires at least one source row.")
self.source = source
self.source_indices = [int(index) for index in source_indices]
self.source_index_set = set(self.source_indices)
if len(self.source_index_set) != len(self.source_indices):
raise ValueError("Packed training source indices contain duplicates.")
self.pack_length = int(pack_length)
self.eos_token_id = int(eos_token_id)
self.pad_token_id = int(pad_token_id)
self.mask_first_after_separator = bool(mask_first_after_separator)
self._length_by_index: dict[int, int] = {}
self.plan: list[list[int]] = []
for source_index in self.source_indices:
try:
length = int(source.sample_lengths[source_index])
except IndexError as exc:
raise IndexError(f"Source index {source_index} is outside the tokenized dataset.") from exc
if length > self.pack_length:
raise ValueError(
f"Tokenized sample #{source_index} has length {length}, "
f"which exceeds pack_length={self.pack_length}."
)
self._length_by_index[source_index] = length
self._build_plan()
self._validate_plan()
self.source_sample_count = len(self.source_indices)
self.bin_count = len(self.plan)
self.original_token_count = sum(self._length_by_index.values())
self.separator_token_count = sum(max(0, len(bin_indices) - 1) for bin_indices in self.plan)
self.packed_token_count = self.original_token_count + self.separator_token_count
self.total_capacity = self.bin_count * self.pack_length
self.pad_token_count = self.total_capacity - self.packed_token_count
self.average_samples_per_bin = self.source_sample_count / max(self.bin_count, 1)
self.utilization = self.packed_token_count / max(self.total_capacity, 1)
def _build_plan(self) -> None:
items = sorted(
((self._length_by_index[source_index], source_index) for source_index in self.source_indices),
key=lambda item: (-item[0], item[1]),
)
available: list[tuple[int, int]] = []
for length, source_index in items:
required_existing = length + 1
insert_at = bisect_left(available, (required_existing, -1))
if insert_at == len(available):
bin_id = len(self.plan)
self.plan.append([source_index])
remaining = self.pack_length - length
insort(available, (remaining, bin_id))
continue
remaining, bin_id = available.pop(insert_at)
next_remaining = remaining - required_existing
if next_remaining < 0:
raise ValueError("Internal packing error: bin capacity became negative.")
self.plan[bin_id].append(source_index)
insort(available, (next_remaining, bin_id))
def _validate_plan(self) -> None:
seen: set[int] = set()
for bin_id, bin_indices in enumerate(self.plan):
if not bin_indices:
raise ValueError(f"Packed bin #{bin_id} is empty.")
real_length = sum(self._length_by_index[source_index] for source_index in bin_indices)
real_length += max(0, len(bin_indices) - 1)
if real_length > self.pack_length:
raise ValueError(
f"Packed bin #{bin_id} has real_length={real_length}, "
f"which exceeds pack_length={self.pack_length}."
)
for source_index in bin_indices:
if source_index in seen:
raise ValueError(f"Source sample #{source_index} appears in more than one packed bin.")
seen.add(source_index)
missing = self.source_index_set - seen
if missing:
first_missing = min(missing)
raise ValueError(f"Source sample #{first_missing} was not assigned to a packed bin.")
def __len__(self) -> int:
return len(self.plan)
def __getitem__(self, bin_idx: int) -> dict[str, torch.Tensor]:
bin_indices = self.plan[bin_idx]
input_parts: list[torch.Tensor] = []
mask_parts: list[torch.Tensor] = []
original_tokens = 0
separator_tokens = 0
for sample_offset, source_index in enumerate(bin_indices):
item = self.source[source_index]
input_ids = item["input_ids"].long()
loss_mask = item["loss_mask"].long()
original_tokens += int(input_ids.size(0))
if sample_offset > 0:
input_parts.append(torch.tensor([self.eos_token_id], dtype=torch.long))
mask_parts.append(torch.zeros(1, dtype=torch.long))
separator_tokens += 1
if self.mask_first_after_separator and loss_mask.numel() > 0:
loss_mask = loss_mask.clone()
loss_mask[0] = 0
input_parts.append(input_ids)
mask_parts.append(loss_mask)
input_ids = torch.cat(input_parts)
loss_mask = torch.cat(mask_parts)
real_length = int(input_ids.size(0))
if real_length > self.pack_length:
raise ValueError(
f"Packed bin #{bin_idx} has real_length={real_length}, "
f"which exceeds pack_length={self.pack_length}."
)
pad_len = self.pack_length - real_length
if pad_len:
input_ids = F.pad(input_ids, (0, pad_len), value=self.pad_token_id)
loss_mask = F.pad(loss_mask, (0, pad_len), value=0)
return {
"input_ids": input_ids,
"loss_mask": loss_mask,
"real_length": torch.tensor(real_length, dtype=torch.long),
"source_samples": torch.tensor(len(bin_indices), dtype=torch.long),
"original_tokens": torch.tensor(original_tokens, dtype=torch.long),
"separator_tokens": torch.tensor(separator_tokens, dtype=torch.long),
}
def collate_packed_fn(batch: list[dict], pad_token_id: int) -> dict:
del pad_token_id
input_ids = torch.stack([item["input_ids"] for item in batch])
loss_mask = torch.stack([item["loss_mask"] for item in batch]).long()
real_lengths = torch.stack([item["real_length"] for item in batch]).long()
seq_len = input_ids.size(1)
positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
attention_mask = (positions < real_lengths.unsqueeze(1)).long()
labels = input_ids.clone()
labels = labels.masked_fill(loss_mask == 0, -100)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"labels": labels,
"real_length": real_lengths,
"source_samples": torch.stack([item["source_samples"] for item in batch]).long(),
"original_tokens": torch.stack([item["original_tokens"] for item in batch]).long(),
"separator_tokens": torch.stack([item["separator_tokens"] for item in batch]).long(),
}