""" utils.pseudo_envs — synthetic environment construction without site labels. Used to compute the IRMv1 penalty when the dataset has no real environment metadata. The default heuristic splits images by mean brightness — a proxy for scanner / acquisition differences which on real medical data correlates with site identity. Replace with metadata-driven splitting once a labelled multi-site dataset (e.g. WILDS-CheXpert) is plugged in. """ from __future__ import annotations from typing import List, Dict import torch from torch.utils.data import DataLoader def make_brightness_envs(dataset, n_envs: int, device: str) -> List[Dict[str, torch.Tensor]]: """ Return a list of `n_envs` dicts, each {"x": tensor [N,C,H,W], "y": tensor [N]}, obtained by sorting the dataset by mean per-image brightness and slicing into equal-sized quantile bins. """ all_imgs, all_labels = [], [] for imgs, labels in DataLoader(dataset, batch_size=256, shuffle=False): all_imgs.append(imgs) all_labels.append(labels.squeeze().long()) all_imgs = torch.cat(all_imgs) all_labels = torch.cat(all_labels) brightness = all_imgs.mean(dim=[1, 2, 3]) sorted_idx = torch.argsort(brightness) env_size = len(sorted_idx) // n_envs envs = [] for i in range(n_envs): start = i * env_size end = (i + 1) * env_size if i < n_envs - 1 else len(sorted_idx) idx = sorted_idx[start:end] envs.append({ "x": all_imgs[idx].to(device), "y": all_labels[idx].to(device), }) return envs