| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Dataset and data-loading utilities for training and evaluation. |
| |
| Provides WebDataset-based iterable datasets, manifest parsing, and audio/token |
| loading. Used by ``omnivoice.training.builder.build_dataloaders()`` to construct |
| train and eval data loaders. |
| |
| Key functions: |
| - ``prepare_data_manifests_from_json()``: Parses a data config JSON into train/dev |
| manifests. |
| |
| Key classes: |
| - ``WebDatasetReader``: Reads audio/text pairs from WebDataset tar shards as an |
| iterable dataset. |
| - ``MuxWebDatasetReader``: Multiplexes multiple WebDataset readers for |
| multilingual data. |
| - ``JsonlDatasetReader``: Reads audio/text pairs from a JSONL manifest file. |
| Used by data processing scripts (e.g. ``omnivoice/scripts/``). |
| - ``SampleDecoder``: Decodes individual samples (audio or tokens + labels). |
| """ |
|
|
| import io |
| import json |
| import logging |
| import os |
| import random |
| from typing import Any, Dict, Iterator, List, Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| import torchaudio |
| import webdataset as wds |
| from torch.utils.data import IterableDataset |
|
|
|
|
| def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"): |
| """ |
| Load audio from bytes data and resample to the target sample rate if needed. |
| Return a tensor of shape (1, num_samples) |
| """ |
| audio, sr = torchaudio.load(io.BytesIO(data)) |
| audio = audio.to(device) |
| if audio.size(dim=0) > 1: |
| audio = torch.mean(audio, dim=0) |
| if sr != sample_rate: |
| audio = torchaudio.functional.resample(audio, sr, sample_rate) |
| return audio |
|
|
|
|
| def prepare_data_manifests_from_json( |
| data_config: str, |
| ) -> Tuple[List[Tuple[str, str, int, float]], List[Tuple[str, str, int, float]]]: |
| """ |
| Prepare data manifests from a json file. |
| A typical multilingual json file is in the following format: |
| { |
| "train": |
| [ |
| { |
| "language_id": "en", |
| "manifest_path": [ |
| "/Emilia/EN/data.lst" |
| ], |
| "repeat": 1 |
| }, |
| { |
| "language_id": "zh", |
| "manifest_path": [ |
| "/Emilia/ZH/data.lst" |
| ], |
| "repeat": 1 |
| } |
| ], |
| "dev": |
| [ |
| { |
| "language_id": "en", |
| "manifest_path": [ |
| "/Emilia/EN-dev/data.lst" |
| ], |
| "repeat": 1 |
| }, |
| { |
| "language_id": "zh", |
| "manifest_path": [ |
| "/Emilia/ZH-dev/data.lst" |
| ], |
| "repeat": 1 |
| } |
| ] |
| } |
| |
| "language_id" is not used, just for better organization of multilingual data. |
| "repeat" is an optional field, default to 1, which indicates how many times |
| the manifest should be repeated. |
| |
| The simplist format is like: |
| { |
| "train": |
| [ |
| { |
| "manifest_path": [ |
| "/Emilia/EN/data.lst", |
| "/Emilia/ZH/data.lst" |
| ], |
| } |
| ], |
| "dev": |
| [ |
| { |
| "manifest_path": [ |
| "/Emilia/EN-dev/data.lst", |
| "/Emilia/ZH-dev/data.lst" |
| ], |
| } |
| ] |
| |
| data.lst format (items separated by space): |
| /path/to/data.tar /path/to/label.jsonl num_items num_seconds |
| """ |
| train_manifests = [] |
| dev_manifests = [] |
| with open(data_config, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| for item in data["train"]: |
| manifest_paths = item["manifest_path"] |
| repeat = item.get("repeat", 1) |
| for manifest_path in manifest_paths: |
| |
| assert os.path.isfile(manifest_path), f"{manifest_path} is not a file." |
| train_manifests.extend( |
| webdataset_manifest_reader(manifest_path) * repeat |
| ) |
| if "dev" in data: |
| for item in data["dev"]: |
| manifest_paths = item["manifest_path"] |
| repeat = item.get("repeat", 1) |
| for manifest_path in manifest_paths: |
| dev_manifests.extend( |
| webdataset_manifest_reader(manifest_path) * repeat |
| ) |
| return train_manifests, dev_manifests |
|
|
|
|
| def webdataset_manifest_reader( |
| manifest_path: str, |
| ) -> List[Tuple[str, str]]: |
| """ |
| Read a manifest file containing webdataset tar paths and label jsonl paths. |
| Each line in the manifest file is in the format of: |
| /path/to/data.tar /path/to/label.jsonl num_items num_seconds |
| """ |
| manifests = [] |
| with open(manifest_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| parts = line.split() |
| if len(parts) != 4: |
| raise ValueError( |
| f"Invalid manifest line: {line}. " |
| f"Each line must contain " |
| "tar_path, label_jsonl_path, num_items, num_seconds." |
| ) |
| tar_path, label_jsonl_path, num_items, num_seconds = ( |
| parts[0], |
| parts[1], |
| int(parts[2]), |
| float(parts[3]), |
| ) |
| manifests.append((tar_path, label_jsonl_path, num_items, num_seconds)) |
| return manifests |
|
|
|
|
| class SampleDecoder: |
| """ |
| Decode a sample from webdataset, including loading audio/tokens and fetching label. |
| """ |
|
|
| def __init__( |
| self, |
| tar_to_label: Dict, |
| sample_rate: int = 24000, |
| audio_format: Optional[Tuple[str]] = None, |
| normalize_audio: bool = True, |
| ): |
| """ |
| Args: |
| tar_to_label: |
| A dict mapping from audio tar file to label tar file. |
| sample_rate: |
| Target sample rate for audio. Required if audio is loaded. |
| audio_format: |
| Tuple of audio file extensions to look for in the sample. |
| """ |
| self.tar_to_label = tar_to_label |
| self.sample_rate = sample_rate |
| self.label_dataset = None |
| if audio_format is None: |
| self.audio_format = ("flac", "wav", "mp3") |
| else: |
| self.audio_format = audio_format |
| self.normalize_audio = normalize_audio |
|
|
| def __call__(self, sample): |
| return_dict = {} |
| src = sample["__url__"] |
| key = sample["__key__"] |
| if ( |
| self.label_dataset is None |
| or self.label_dataset.path != self.tar_to_label[src] |
| ): |
| self.label_dataset = LabelDataset(self.tar_to_label[src]) |
|
|
| audio = torch.empty(0) |
| if "npy" in sample: |
| audio_tokens = torch.from_numpy(sample["npy"]) |
| return_dict["audio_tokens"] = audio_tokens |
| else: |
| for ext in self.audio_format: |
| if ext in sample: |
| |
| audio = load_audio_webdataset( |
| sample[ext], sample_rate=self.sample_rate |
| ) |
| if self.normalize_audio: |
| audio = (audio / (audio.abs().max() + 1e-7)) * 0.9 |
| break |
| return_dict["audio"] = audio |
| return_dict["audio_duration"] = audio.size(-1) / self.sample_rate |
|
|
| label = self.label_dataset[key] |
|
|
| return_dict["label"] = label |
| return return_dict |
|
|
|
|
| class LabelDataset: |
| def __init__(self, jsonl_path: str): |
| """ |
| Load labels from a jsonl file. |
| Args: |
| jsonl_path: |
| Path to the jsonl file containing labels. |
| Each line in the manifest file is in the format of: |
| {"idx": "idx", "text": "transcription text"} |
| """ |
| self._labels = {} |
| self.path = jsonl_path |
| if not os.path.exists(jsonl_path): |
| raise FileNotFoundError(f"Label jsonl file {jsonl_path} does not exist.") |
| with open(jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| item = json.loads(line) |
| if "id" in item: |
| self._labels[item["id"]] = item |
|
|
| def __getitem__(self, key): |
| return self._labels[key] |
|
|
|
|
| class IterableDataReader: |
| "Interfaces for classes reading data." |
|
|
| sample_rate: int |
|
|
| def set_epoch(self, epoch: int): |
| raise NotImplementedError |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| raise NotImplementedError |
|
|
| def __len__(self) -> int: |
| raise NotImplementedError |
|
|
|
|
| class WrappedIterableDataset(IterableDataset): |
| "IterableDataset interfaces in this project." |
|
|
| def set_epoch(self, epoch: int): |
| raise NotImplementedError |
|
|
| def __iter__(self) -> Iterator[List[Dict[str, Any]]]: |
| raise NotImplementedError |
|
|
|
|
| class WebDatasetReader(IterableDataReader): |
| def __init__( |
| self, |
| manifests: List[Tuple[str, str, int, float]], |
| evaluation: bool = False, |
| shuffle_buffer_size: int = 20000, |
| sample_rate: int = 24000, |
| ): |
| self.shuffle_buffer_size = shuffle_buffer_size |
| self.evaluation = evaluation |
| self.epoch = 0 |
|
|
| self.orig_urls = [] |
| self.tar_to_label = {} |
| self.num_items = 0 |
| self.num_seconds = 0.0 |
| for tar_path, label_jsonl_path, num_items, num_seconds in manifests: |
| self.orig_urls.append(tar_path) |
| self.tar_to_label[tar_path] = label_jsonl_path |
| self.num_items += num_items |
| self.num_seconds += num_seconds |
| self.urls = self.orig_urls.copy() |
| self.sample_decoder = SampleDecoder( |
| tar_to_label=self.tar_to_label, |
| sample_rate=sample_rate, |
| ) |
| self.sample_rate = sample_rate |
|
|
| def set_epoch(self, epoch: int): |
| """ |
| Set the epoch for shuffling. |
| """ |
| self.epoch = epoch |
| self.urls = self.orig_urls.copy() |
| if not self.evaluation: |
| random.Random(epoch).shuffle(self.urls) |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
|
|
| dataset = wds.WebDataset( |
| self.urls, |
| shardshuffle=False, |
| workersplitter=wds.split_by_worker, |
| nodesplitter=wds.split_by_node, |
| ) |
|
|
| pipeline = dataset.decode().map(self.sample_decoder) |
| if not self.evaluation: |
| pipeline = pipeline.shuffle(self.shuffle_buffer_size, seed=self.epoch) |
| return iter(pipeline) |
|
|
| def __len__(self) -> int: |
| return self.num_items |
|
|
|
|
| class JsonlDatasetReader(IterableDataReader): |
| """Read raw JSONL and load audio files, matching WebDatasetReader output format. |
| |
| Each JSONL line should be a JSON object with at least: |
| {"id": "...", "audio_path": "/path/to/audio.wav", ...} |
| |
| Yields dicts of the form: {"audio": Tensor(1, T), "label": dict} |
| """ |
|
|
| def __init__( |
| self, |
| jsonl_path: str, |
| sample_rate: int = 24_000, |
| shuffle: bool = True, |
| shuffle_seed: int = 42, |
| normalize_audio: bool = True, |
| ): |
| self.jsonl_path = jsonl_path |
| self.sample_rate = sample_rate |
| self.shuffle = shuffle |
| self.shuffle_seed = shuffle_seed |
| self.normalize_audio = normalize_audio |
|
|
| def set_epoch(self, epoch: int): |
| self.shuffle_seed = epoch |
|
|
| def _read_lines(self) -> list[dict]: |
| entries = [] |
| with open(self.jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| entries.append(json.loads(line)) |
| if self.shuffle: |
| random.seed(self.shuffle_seed) |
| random.shuffle(entries) |
| logging.info( |
| f"Shuffled {len(entries)} JSONL entries (seed={self.shuffle_seed})" |
| ) |
| return entries |
|
|
| def _stream_lines(self): |
| with open(self.jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| yield json.loads(line) |
|
|
| def __iter__(self): |
| source = self._read_lines() if self.shuffle else self._stream_lines() |
|
|
| |
| if dist.is_initialized(): |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| source = [item for i, item in enumerate(source) if i % world_size == rank] |
|
|
| |
| worker_info = torch.utils.data.get_worker_info() |
| if worker_info is not None: |
| source = ( |
| item |
| for i, item in enumerate(source) |
| if i % worker_info.num_workers == worker_info.id |
| ) |
|
|
| for meta in source: |
| audio_path = meta.get("audio_path") |
| if not audio_path or not os.path.exists(audio_path): |
| logging.warning( |
| f"Skipping {meta.get('id', '?')}: audio_path missing or not found" |
| ) |
| continue |
| try: |
| waveform, sr = torchaudio.load(audio_path) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| if sr != self.sample_rate: |
| waveform = torchaudio.functional.resample( |
| waveform, sr, self.sample_rate |
| ) |
| if self.normalize_audio: |
| waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9 |
| meta["audio_duration"] = waveform.shape[1] / self.sample_rate |
| yield {"audio": waveform, "label": meta} |
| except Exception as e: |
| logging.warning(f"Skipping {meta.get('id', '?')}: {e}") |
|
|
|
|
| class MuxWebDatasetReader(IterableDataReader): |
| def __init__( |
| self, |
| readers: List[WebDatasetReader], |
| weights: Optional[List[float]] = None, |
| stop_early: bool = False, |
| seed: int = 0, |
| ): |
| self.readers = readers |
| self.stop_early = stop_early |
| self.mux_iterator = LazyIteratorMultiplexer( |
| *readers, |
| stop_early=stop_early, |
| weights=weights, |
| seed=seed, |
| ) |
|
|
| def set_epoch(self, epoch: int): |
| """ |
| Set the epoch for shuffling. |
| """ |
| for reader in self.readers: |
| reader.set_epoch(epoch) |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| return iter(self.mux_iterator) |
|
|
|
|
| class LazyIteratorMultiplexer: |
| """ |
| A wrapper over multiple iterators that enables to combine |
| lazy manifests in Lhotse. During iteration, unlike |
| :class:`.LazyIteratorChain`, |
| :class:`.LazyIteratorMultiplexer` at each step randomly |
| selects the iterable used to yield an item. |
| |
| Since the iterables might be of different length, we provide |
| a ``weights`` parameter to let the user decide which iterables |
| should be sampled more frequently than others. |
| When an iterable is exhausted, we will keep sampling from the other iterables, until |
| we exhaust them all, unless ``stop_early`` is set to ``True``. |
| """ |
|
|
| def __init__( |
| self, |
| *iterators: IterableDataReader, |
| stop_early: bool = False, |
| weights: Optional[List[float]] = None, |
| seed: int = 0, |
| ) -> None: |
| self.iterators = list(iterators) |
| self.stop_early = stop_early |
| self.seed = seed |
|
|
| assert ( |
| len(self.iterators) > 1 |
| ), "There have to be at least two iterables to multiplex." |
|
|
| if weights is None: |
| if all(hasattr(it, "__len__") for it in self.iterators): |
| lengths = [len(it) for it in self.iterators] |
| total_length = sum(lengths) |
| self.weights = [length / total_length for length in lengths] |
| else: |
| self.weights = [1] * len(self.iterators) |
| else: |
| self.weights = weights |
|
|
| assert len(self.iterators) == len(self.weights) |
|
|
| def __iter__(self): |
|
|
| rng = random.Random(self.seed) |
| iters = [iter(it) for it in self.iterators] |
| exhausted = [False for _ in range(len(iters))] |
|
|
| def should_continue(): |
| if self.stop_early: |
| return not any(exhausted) |
| else: |
| return not all(exhausted) |
|
|
| while should_continue(): |
| active_indexes, active_weights = zip( |
| *[ |
| (i, w) |
| for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) |
| if not is_exhausted |
| ] |
| ) |
| idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] |
| selected = iters[idx] |
| try: |
| item = next(selected) |
| yield item |
| except StopIteration: |
| exhausted[idx] = True |
| continue |
|
|
| def __len__(self) -> int: |
| return sum(len(iterator) for iterator in self.iterators) |
|
|