import logging from typing import Generator logger = logging.getLogger(__name__) class _Task: """Internal: wraps a generator, advances one yield at a time.""" def __init__(self, generator: Generator[None, None, None], index: int): self._generator = generator self._index = index self._steps_completed = 0 self.step() # run to first yield def step(self) -> bool: try: next(self._generator) self._steps_completed += 1 logger.debug("pipeline[%d] completed stage %d", self._index, self._steps_completed) return True except StopIteration: logger.debug("pipeline[%d] finished after %d stages", self._index, self._steps_completed) return False def close(self): self._generator.close() def run_pipeline( pipelines: Generator[Generator[None, None, None], None, None], max_concurrent: int, ) -> None: """Run generator-based pipelines with bounded concurrency. Each pipeline is a generator that yields at stage boundaries. The runtime interleaves pipelines so communication and computation overlap across chunks. """ if max_concurrent <= 0: raise ValueError(f"max_concurrent must be > 0, got {max_concurrent}") have_new = True task_index = 0 previous_tasks: list[_Task] = [] try: while have_new or previous_tasks: running_tasks: list[_Task] = [] # Admit one new pipeline per iteration (staggered admission). # Admitting one at a time ensures that while chunk N does NS # compute on the default stream, chunk N+1's NCCL all-to-all # runs concurrently on the NCCL stream — creating real # communication/computation overlap on the GPU. if have_new and len(previous_tasks) < max_concurrent: try: gen = next(pipelines) task = _Task(gen, task_index) task_index += 1 running_tasks.append(task) except StopIteration: have_new = False # Advance every previously-yielded task by one step. for task in previous_tasks: if task.step(): running_tasks.append(task) previous_tasks = running_tasks except BaseException: # Clean up all in-flight generators to release GPU resources. for task in previous_tasks: task.close() raise