"""CPU offloading for optimizer states. Manages a pinned CPU memory pool and async CUDA streams to offload optimizer state tensors (momentum buffers, Adam moments) to CPU between optimizer steps, freeing GPU memory. All tracked tensors are packed into a single flat pinned CPU buffer (per dtype). D2H and H2D copies are performed per-tensor directly between individual GPU tensors and their slice of the CPU flat buffer — no GPU staging buffer is allocated, so there is **no temporary GPU memory spike** during offload or reload. Individual tensor storages are freed after offload via ``untyped_storage().resize_(0)``, preserving tensor identity so downstream caches remain valid. """ import logging from collections import defaultdict import torch from torch.distributed.tensor import DTensor logger = logging.getLogger(__name__) class CPUOffloadPool: """Pinned CPU memory pool for async optimizer state offloading. Tracked tensors are grouped by dtype. Each group gets a single flat pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of the flat buffer) to avoid allocating a GPU staging buffer. """ def __init__(self): self._managed: list[torch.Tensor] = [] self._storage_nbytes: dict[int, int] = {} # id(t) → bytes # Per-dtype group: populated on first offload. # dtype → dict with keys: # "indices" : list[int] managed-list indices # "offsets" : list[tuple[int,int]] (start, numel) in flat buf # "total" : int total numel # "cpu_flat" : Tensor pinned CPU buffer self._groups: dict[torch.dtype, dict] = {} self._offload_stream: torch.cuda.Stream | None = None self._device: torch.device | None = None self._initialized: bool = False self._logged: bool = False # ------------------------------------------------------------------ @staticmethod def _local(t: torch.Tensor) -> torch.Tensor: """Unwrap DTensor to its local CUDA tensor.""" return t._local_tensor if isinstance(t, DTensor) else t def _ensure_stream(self): if self._offload_stream is None: self._offload_stream = torch.cuda.Stream(device=self._device) # ------------------------------------------------------------------ def track(self, tensor: torch.Tensor): """Register a GPU tensor for CPU offloading. Idempotent.""" tid = id(tensor) if tid in self._storage_nbytes: return local = self._local(tensor) if self._device is None: self._device = local.device storage = local.untyped_storage() # Skip tensors with empty storage (e.g. empty FSDP shards) if storage.size() == 0: return self._storage_nbytes[tid] = storage.size() self._managed.append(tensor) # ------------------------------------------------------------------ def _init_buffers(self): """Build per-dtype flat buffers on first offload.""" # Group managed tensors by dtype. dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) for idx, t in enumerate(self._managed): local = self._local(t) dtype_map[local.dtype].append((idx, local.numel())) total_cpu_bytes = 0 for dtype, entries in dtype_map.items(): offsets: list[tuple[int, int]] = [] indices: list[int] = [] off = 0 for idx, n in entries: indices.append(idx) offsets.append((off, n)) off += n cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, "total": off, "cpu_flat": cpu_flat, } total_cpu_bytes += off * cpu_flat.element_size() self._initialized = True logger.info( "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " "%.2f MB pinned CPU memory", len(self._managed), len(self._groups), total_cpu_bytes / (1024**2), ) # ------------------------------------------------------------------ def offload(self): """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" if not self._managed: return if not self._initialized: self._init_buffers() self._ensure_stream() # Offload stream waits for compute to finish. compute_event = torch.cuda.current_stream(self._device).record_event() self._offload_stream.wait_event(compute_event) offloaded_bytes = 0 # Per-tensor D2H copies directly into CPU flat buffer slices. # No GPU staging buffer → no temporary GPU memory spike. with torch.cuda.stream(self._offload_stream): for dtype, grp in self._groups.items(): indices = grp["indices"] offsets = grp["offsets"] cpu_flat = grp["cpu_flat"] for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() # Wait for all D2H copies to land, then free GPU storage. self._offload_stream.synchronize() for t in self._managed: storage = self._local(t).untyped_storage() if storage.size() != 0: storage.resize_(0) else: raise RuntimeError( f"Tensor storage is already freed (size=0) before offload. " f"This indicates a double-free or external interference. " f"Tensor shape: {t.shape}, dtype: {t.dtype}" ) if not self._logged: logger.info( "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", offloaded_bytes / (1024**2), ) # ------------------------------------------------------------------ def reload(self): """Per-tensor H2D from CPU flat buffer on the default stream. Runs on the current (default) CUDA stream to avoid stream interaction issues with the parallel Muon pipeline. Since pinned CPU memory is the source, the copies overlap with GPU idle time between steps. """ if not self._managed or not self._initialized: return reloaded_bytes = 0 # Re-allocate all GPU storages first. for t in self._managed: local = self._local(t) storage = local.untyped_storage() if storage.size() != 0: raise RuntimeError( f"Storage should have been freed (size=0) before reload, " f"but got size={storage.size()}. " f"Tensor shape: {t.shape}, dtype: {t.dtype}" ) storage.resize_(self._storage_nbytes[id(t)]) # Per-tensor H2D copies from CPU flat buffer slices. # non_blocking=True with pinned source allows DMA overlap. for dtype, grp in self._groups.items(): indices = grp["indices"] offsets = grp["offsets"] cpu_flat = grp["cpu_flat"] for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: logger.info( "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) )