"""Camelyon17 dataset utilities via WILDS.""" from wilds import get_dataset def load_camelyon17(root_dir="data/wilds", download=False): """Load Camelyon17 from WILDS. Returns dataset object with subsets.""" return get_dataset("camelyon17", download=download, root_dir=root_dir) def get_camelyon_subsets(root_dir="data/wilds", download=True): """ Load Camelyon17 and return train/val/test subsets. Hospital split: - Train: H0, H1, H2 (3 hospitals) - ID Val: H3 (in-distribution validation, never seen in training) - OOD Test: H4 (out-of-distribution test, completely new hospital) Returns: tuple: (train_subset, id_val_subset, ood_test_subset, dataset) """ ds = load_camelyon17(root_dir=root_dir, download=download) train = ds.get_subset("train") # H0, H1, H2 id_val = ds.get_subset("id_val") # H3 ood_test = ds.get_subset("test") # H4 return train, id_val, ood_test, ds def camelyon_stats(root_dir="data/wilds"): """Print dataset statistics.""" train, id_val, ood_test, ds = get_camelyon_subsets(root_dir) print(f"Camelyon17 Dataset Statistics") print(f" Train (H0+H1+H2): {len(train):,} samples") print(f" ID Val (H3): {len(id_val):,} samples") print(f" OOD Test (H4): {len(ood_test):,} samples") print(f" Total: {len(train) + len(id_val) + len(ood_test):,} samples") print(f" Metadata fields: {ds.metadata_fields}") # Hospital distribution train_meta = train.metadata_array hospitals = train_meta[:, 0] # First column is hospital ID print(f"\n Hospital distribution (train):") for h in range(3): count = (hospitals == h).sum().item() print(f" Hospital {h}: {count:,} samples ({100*count/len(train):.1f}%)")