#!/usr/bin/env python3 # Copyright 2026 Xiaomi Corp. (authors: Han Zhu) # # See ../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataset and data-loading utilities for training and evaluation. Provides WebDataset-based iterable datasets, manifest parsing, and audio/token loading. Used by ``omnivoice.training.builder.build_dataloaders()`` to construct train and eval data loaders. Key functions: - ``prepare_data_manifests_from_json()``: Parses a data config JSON into train/dev manifests. Key classes: - ``WebDatasetReader``: Reads audio/text pairs from WebDataset tar shards as an iterable dataset. - ``MuxWebDatasetReader``: Multiplexes multiple WebDataset readers for multilingual data. - ``JsonlDatasetReader``: Reads audio/text pairs from a JSONL manifest file. Used by data processing scripts (e.g. ``omnivoice/scripts/``). - ``SampleDecoder``: Decodes individual samples (audio or tokens + labels). """ import io import json import logging import os import random from typing import Any, Dict, Iterator, List, Optional, Tuple import torch import torch.distributed as dist import torchaudio import webdataset as wds from torch.utils.data import IterableDataset def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"): """ Load audio from bytes data and resample to the target sample rate if needed. Return a tensor of shape (1, num_samples) """ audio, sr = torchaudio.load(io.BytesIO(data)) audio = audio.to(device) if audio.size(dim=0) > 1: audio = torch.mean(audio, dim=0) if sr != sample_rate: audio = torchaudio.functional.resample(audio, sr, sample_rate) return audio def prepare_data_manifests_from_json( data_config: str, ) -> Tuple[List[Tuple[str, str, int, float]], List[Tuple[str, str, int, float]]]: """ Prepare data manifests from a json file. A typical multilingual json file is in the following format: { "train": [ { "language_id": "en", "manifest_path": [ "/Emilia/EN/data.lst" ], "repeat": 1 }, { "language_id": "zh", "manifest_path": [ "/Emilia/ZH/data.lst" ], "repeat": 1 } ], "dev": [ { "language_id": "en", "manifest_path": [ "/Emilia/EN-dev/data.lst" ], "repeat": 1 }, { "language_id": "zh", "manifest_path": [ "/Emilia/ZH-dev/data.lst" ], "repeat": 1 } ] } "language_id" is not used, just for better organization of multilingual data. "repeat" is an optional field, default to 1, which indicates how many times the manifest should be repeated. The simplist format is like: { "train": [ { "manifest_path": [ "/Emilia/EN/data.lst", "/Emilia/ZH/data.lst" ], } ], "dev": [ { "manifest_path": [ "/Emilia/EN-dev/data.lst", "/Emilia/ZH-dev/data.lst" ], } ] data.lst format (items separated by space): /path/to/data.tar /path/to/label.jsonl num_items num_seconds """ train_manifests = [] dev_manifests = [] with open(data_config, "r", encoding="utf-8") as f: data = json.load(f) for item in data["train"]: manifest_paths = item["manifest_path"] repeat = item.get("repeat", 1) for manifest_path in manifest_paths: # assert manifest_path is a file assert os.path.isfile(manifest_path), f"{manifest_path} is not a file." train_manifests.extend( webdataset_manifest_reader(manifest_path) * repeat ) if "dev" in data: for item in data["dev"]: manifest_paths = item["manifest_path"] repeat = item.get("repeat", 1) for manifest_path in manifest_paths: dev_manifests.extend( webdataset_manifest_reader(manifest_path) * repeat ) return train_manifests, dev_manifests def webdataset_manifest_reader( manifest_path: str, ) -> List[Tuple[str, str]]: """ Read a manifest file containing webdataset tar paths and label jsonl paths. Each line in the manifest file is in the format of: /path/to/data.tar /path/to/label.jsonl num_items num_seconds """ manifests = [] with open(manifest_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue parts = line.split() if len(parts) != 4: raise ValueError( f"Invalid manifest line: {line}. " f"Each line must contain " "tar_path, label_jsonl_path, num_items, num_seconds." ) tar_path, label_jsonl_path, num_items, num_seconds = ( parts[0], parts[1], int(parts[2]), float(parts[3]), ) manifests.append((tar_path, label_jsonl_path, num_items, num_seconds)) return manifests class SampleDecoder: """ Decode a sample from webdataset, including loading audio/tokens and fetching label. """ def __init__( self, tar_to_label: Dict, sample_rate: int = 24000, audio_format: Optional[Tuple[str]] = None, normalize_audio: bool = True, ): """ Args: tar_to_label: A dict mapping from audio tar file to label tar file. sample_rate: Target sample rate for audio. Required if audio is loaded. audio_format: Tuple of audio file extensions to look for in the sample. """ self.tar_to_label = tar_to_label self.sample_rate = sample_rate self.label_dataset = None if audio_format is None: self.audio_format = ("flac", "wav", "mp3") else: self.audio_format = audio_format self.normalize_audio = normalize_audio def __call__(self, sample): return_dict = {} src = sample["__url__"] key = sample["__key__"] if ( self.label_dataset is None or self.label_dataset.path != self.tar_to_label[src] ): self.label_dataset = LabelDataset(self.tar_to_label[src]) audio = torch.empty(0) if "npy" in sample: audio_tokens = torch.from_numpy(sample["npy"]) return_dict["audio_tokens"] = audio_tokens else: for ext in self.audio_format: if ext in sample: # load audio (1, num_samples) audio = load_audio_webdataset( sample[ext], sample_rate=self.sample_rate ) if self.normalize_audio: audio = (audio / (audio.abs().max() + 1e-7)) * 0.9 break return_dict["audio"] = audio return_dict["audio_duration"] = audio.size(-1) / self.sample_rate label = self.label_dataset[key] return_dict["label"] = label return return_dict class LabelDataset: def __init__(self, jsonl_path: str): """ Load labels from a jsonl file. Args: jsonl_path: Path to the jsonl file containing labels. Each line in the manifest file is in the format of: {"idx": "idx", "text": "transcription text"} """ self._labels = {} self.path = jsonl_path if not os.path.exists(jsonl_path): raise FileNotFoundError(f"Label jsonl file {jsonl_path} does not exist.") with open(jsonl_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue item = json.loads(line) if "id" in item: self._labels[item["id"]] = item def __getitem__(self, key): return self._labels[key] class IterableDataReader: "Interfaces for classes reading data." sample_rate: int def set_epoch(self, epoch: int): raise NotImplementedError def __iter__(self) -> Iterator[Dict[str, Any]]: raise NotImplementedError def __len__(self) -> int: raise NotImplementedError class WrappedIterableDataset(IterableDataset): "IterableDataset interfaces in this project." def set_epoch(self, epoch: int): raise NotImplementedError def __iter__(self) -> Iterator[List[Dict[str, Any]]]: raise NotImplementedError class WebDatasetReader(IterableDataReader): def __init__( self, manifests: List[Tuple[str, str, int, float]], evaluation: bool = False, shuffle_buffer_size: int = 20000, sample_rate: int = 24000, ): self.shuffle_buffer_size = shuffle_buffer_size self.evaluation = evaluation self.epoch = 0 self.orig_urls = [] self.tar_to_label = {} self.num_items = 0 self.num_seconds = 0.0 for tar_path, label_jsonl_path, num_items, num_seconds in manifests: self.orig_urls.append(tar_path) self.tar_to_label[tar_path] = label_jsonl_path self.num_items += num_items self.num_seconds += num_seconds self.urls = self.orig_urls.copy() self.sample_decoder = SampleDecoder( tar_to_label=self.tar_to_label, sample_rate=sample_rate, ) self.sample_rate = sample_rate def set_epoch(self, epoch: int): """ Set the epoch for shuffling. """ self.epoch = epoch self.urls = self.orig_urls.copy() if not self.evaluation: random.Random(epoch).shuffle(self.urls) def __iter__(self) -> Iterator[Dict[str, Any]]: dataset = wds.WebDataset( self.urls, shardshuffle=False, workersplitter=wds.split_by_worker, nodesplitter=wds.split_by_node, ) pipeline = dataset.decode().map(self.sample_decoder) if not self.evaluation: pipeline = pipeline.shuffle(self.shuffle_buffer_size, seed=self.epoch) return iter(pipeline) def __len__(self) -> int: return self.num_items class JsonlDatasetReader(IterableDataReader): """Read raw JSONL and load audio files, matching WebDatasetReader output format. Each JSONL line should be a JSON object with at least: {"id": "...", "audio_path": "/path/to/audio.wav", ...} Yields dicts of the form: {"audio": Tensor(1, T), "label": dict} """ def __init__( self, jsonl_path: str, sample_rate: int = 24_000, shuffle: bool = True, shuffle_seed: int = 42, normalize_audio: bool = True, ): self.jsonl_path = jsonl_path self.sample_rate = sample_rate self.shuffle = shuffle self.shuffle_seed = shuffle_seed self.normalize_audio = normalize_audio def set_epoch(self, epoch: int): self.shuffle_seed = epoch def _read_lines(self) -> list[dict]: entries = [] with open(self.jsonl_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: entries.append(json.loads(line)) if self.shuffle: random.seed(self.shuffle_seed) random.shuffle(entries) logging.info( f"Shuffled {len(entries)} JSONL entries (seed={self.shuffle_seed})" ) return entries def _stream_lines(self): with open(self.jsonl_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: yield json.loads(line) def __iter__(self): source = self._read_lines() if self.shuffle else self._stream_lines() # Split data across distributed ranks (multi-GPU / DDP) if dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() source = [item for i, item in enumerate(source) if i % world_size == rank] # Split data across DataLoader workers to avoid duplication worker_info = torch.utils.data.get_worker_info() if worker_info is not None: source = ( item for i, item in enumerate(source) if i % worker_info.num_workers == worker_info.id ) for meta in source: audio_path = meta.get("audio_path") if not audio_path or not os.path.exists(audio_path): logging.warning( f"Skipping {meta.get('id', '?')}: audio_path missing or not found" ) continue try: waveform, sr = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sr != self.sample_rate: waveform = torchaudio.functional.resample( waveform, sr, self.sample_rate ) if self.normalize_audio: waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9 meta["audio_duration"] = waveform.shape[1] / self.sample_rate yield {"audio": waveform, "label": meta} except Exception as e: logging.warning(f"Skipping {meta.get('id', '?')}: {e}") class MuxWebDatasetReader(IterableDataReader): def __init__( self, readers: List[WebDatasetReader], weights: Optional[List[float]] = None, stop_early: bool = False, seed: int = 0, ): self.readers = readers self.stop_early = stop_early self.mux_iterator = LazyIteratorMultiplexer( *readers, stop_early=stop_early, weights=weights, seed=seed, ) def set_epoch(self, epoch: int): """ Set the epoch for shuffling. """ for reader in self.readers: reader.set_epoch(epoch) def __iter__(self) -> Iterator[Dict[str, Any]]: return iter(self.mux_iterator) class LazyIteratorMultiplexer: """ A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` at each step randomly selects the iterable used to yield an item. Since the iterables might be of different length, we provide a ``weights`` parameter to let the user decide which iterables should be sampled more frequently than others. When an iterable is exhausted, we will keep sampling from the other iterables, until we exhaust them all, unless ``stop_early`` is set to ``True``. """ def __init__( self, *iterators: IterableDataReader, stop_early: bool = False, weights: Optional[List[float]] = None, seed: int = 0, ) -> None: self.iterators = list(iterators) self.stop_early = stop_early self.seed = seed assert ( len(self.iterators) > 1 ), "There have to be at least two iterables to multiplex." if weights is None: if all(hasattr(it, "__len__") for it in self.iterators): lengths = [len(it) for it in self.iterators] total_length = sum(lengths) self.weights = [length / total_length for length in lengths] else: self.weights = [1] * len(self.iterators) else: self.weights = weights assert len(self.iterators) == len(self.weights) def __iter__(self): rng = random.Random(self.seed) iters = [iter(it) for it in self.iterators] exhausted = [False for _ in range(len(iters))] def should_continue(): if self.stop_early: return not any(exhausted) else: return not all(exhausted) while should_continue(): active_indexes, active_weights = zip( *[ (i, w) for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) if not is_exhausted ] ) idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] selected = iters[idx] try: item = next(selected) yield item except StopIteration: exhausted[idx] = True continue def __len__(self) -> int: return sum(len(iterator) for iterator in self.iterators)