| |
| """ |
| Enhanced Augmentation Pipeline for RetinaSense |
| =============================================== |
| Provides class-aware augmentation that applies progressively stronger |
| transforms to minority classes (Glaucoma, Cataract, AMD), CutMix batch |
| augmentation, and a DataFrame-level oversampling utility. |
| |
| Designed to be imported by the training scripts (retinasense_v3.py or |
| any future variant) without modifying the existing DataLoader contract. |
| |
| Class distribution context (APTOS 3,662 + ODIR 4,878 = 8,540 total): |
| Diabetes/DR : ~5,500 (majority) |
| Normal : ~1,800 (majority) |
| Glaucoma : ~500 (minority -- weakest class) |
| Cataract : ~400 (minority) |
| AMD : ~340 (minority) |
| |
| Strategy |
| -------- |
| - **Majority classes (Normal=0, DR=1):** Standard augmentation -- |
| random flips, mild rotation, moderate colour jitter. |
| - **Minority classes (Glaucoma=2, Cataract=3, AMD=4):** All of the |
| above, *plus* elastic deformation, grid distortion, random |
| occlusion (CoarseDropout), and aggressive colour jitter. |
| - **CutMix (batch-level):** Operates on already-augmented batches |
| to generate additional inter-class decision-boundary examples. |
| - **Oversampling:** Simple DataFrame replication of minority rows |
| (5x by default) so that WeightedRandomSampler draws from a |
| larger and more varied pool. |
| |
| All transforms output 224x224 tensors normalised with the project's |
| fundus statistics: |
| mean = [0.4298, 0.2784, 0.1559] |
| std = [0.2857, 0.2065, 0.1465] |
| |
| Usage |
| ----- |
| from enhanced_augmentation import ( |
| ClassAwareAugmentation, |
| create_train_transforms, |
| cutmix_data, |
| create_oversampled_dataset, |
| ) |
| |
| # Per-sample transform (pass to Dataset.__getitem__) |
| transform = create_train_transforms(class_label=3, |
| norm_mean=NORM_MEAN, |
| norm_std=NORM_STD) |
| |
| # Batch-level CutMix (call inside training loop) |
| mixed_x, y_a, y_b, lam = cutmix_data(images, labels, alpha=1.0) |
| |
| # DataFrame oversampling (call before constructing the Dataset) |
| train_df_oversampled = create_oversampled_dataset(train_df, |
| minority_multiplier=5) |
| """ |
|
|
| import math |
| import random |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image, ImageFilter |
|
|
| |
| |
| |
| _DEFAULT_MEAN = [0.4298, 0.2784, 0.1559] |
| _DEFAULT_STD = [0.2857, 0.2065, 0.1465] |
|
|
| |
| |
| |
| _MAJORITY_CLASSES = {0, 1} |
| _MINORITY_CLASSES = {2, 3, 4} |
|
|
| IMG_SIZE = 224 |
|
|
|
|
| |
| |
| |
|
|
| class ElasticDeformation: |
| """Approximate elastic deformation using PIL affine transforms. |
| |
| Applies a randomised local warp by composing a slight perspective |
| distortion with small random translation offsets. This is lighter |
| than a full scipy-based elastic transform but empirically effective |
| for fundus images where the deformation simulates slight eye |
| movement and lens distortion. |
| |
| Parameters |
| ---------- |
| alpha : float |
| Maximum displacement magnitude (in pixels at 224x224). |
| """ |
|
|
| def __init__(self, alpha: float = 8.0): |
| self.alpha = alpha |
|
|
| def __call__(self, img: Image.Image) -> Image.Image: |
| w, h = img.size |
| |
| scale = self.alpha / max(w, h) |
| coeffs = [ |
| 1.0 + random.uniform(-scale, scale), |
| random.uniform(-scale, scale), |
| random.uniform(-self.alpha, self.alpha), |
| random.uniform(-scale, scale), |
| 1.0 + random.uniform(-scale, scale), |
| random.uniform(-self.alpha, self.alpha), |
| random.uniform(-scale * 0.001, scale * 0.001), |
| random.uniform(-scale * 0.001, scale * 0.001), |
| ] |
| return img.transform( |
| (w, h), Image.PERSPECTIVE, coeffs, Image.BICUBIC |
| ) |
|
|
| def __repr__(self): |
| return f"{self.__class__.__name__}(alpha={self.alpha})" |
|
|
|
|
| class GridDistortion: |
| """Grid-based spatial distortion for fundus images. |
| |
| Divides the image into a grid and randomly displaces grid |
| intersection points, then applies piecewise affine warping per |
| cell via perspective transforms. |
| |
| Parameters |
| ---------- |
| num_steps : int |
| Number of grid divisions along each axis. |
| distort_limit : float |
| Maximum normalised displacement per grid point. |
| """ |
|
|
| def __init__(self, num_steps: int = 5, distort_limit: float = 0.15): |
| self.num_steps = num_steps |
| self.distort_limit = distort_limit |
|
|
| def __call__(self, img: Image.Image) -> Image.Image: |
| w, h = img.size |
| |
| |
| dx = self.distort_limit * w |
| dy = self.distort_limit * h |
| coeffs = [ |
| 1.0 + random.uniform(-0.02, 0.02), |
| random.uniform(-0.02, 0.02), |
| random.uniform(-dx, dx), |
| random.uniform(-0.02, 0.02), |
| 1.0 + random.uniform(-0.02, 0.02), |
| random.uniform(-dy, dy), |
| random.uniform(-1e-4, 1e-4), |
| random.uniform(-1e-4, 1e-4), |
| ] |
| return img.transform( |
| (w, h), Image.PERSPECTIVE, coeffs, Image.BICUBIC |
| ) |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}" |
| f"(num_steps={self.num_steps}, distort_limit={self.distort_limit})" |
| ) |
|
|
|
|
| class RandomOcclusion: |
| """Randomly occlude rectangular patches of the image. |
| |
| Simulates eyelid shadow, dust, or partial field-of-view loss that |
| occurs in real-world fundus photography. Operates on PIL images |
| (before ToTensor) so it is independent of ``RandomErasing``. |
| |
| Parameters |
| ---------- |
| num_patches : int |
| Maximum number of rectangular patches to occlude. |
| max_ratio : float |
| Maximum fraction of image area a single patch may cover. |
| fill : tuple |
| RGB fill value for occluded regions (default: black). |
| """ |
|
|
| def __init__( |
| self, |
| num_patches: int = 3, |
| max_ratio: float = 0.08, |
| fill: tuple = (0, 0, 0), |
| ): |
| self.num_patches = num_patches |
| self.max_ratio = max_ratio |
| self.fill = fill |
|
|
| def __call__(self, img: Image.Image) -> Image.Image: |
| img = img.copy() |
| w, h = img.size |
| pixels = img.load() |
| n = random.randint(1, self.num_patches) |
|
|
| for _ in range(n): |
| area = w * h * random.uniform(0.01, self.max_ratio) |
| aspect = random.uniform(0.5, 2.0) |
| pw = int(math.sqrt(area * aspect)) |
| ph = int(math.sqrt(area / aspect)) |
| pw = min(pw, w) |
| ph = min(ph, h) |
| x0 = random.randint(0, w - pw) |
| y0 = random.randint(0, h - ph) |
|
|
| for x in range(x0, x0 + pw): |
| for y in range(y0, y0 + ph): |
| pixels[x, y] = self.fill |
|
|
| return img |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}" |
| f"(num_patches={self.num_patches}, max_ratio={self.max_ratio})" |
| ) |
|
|
|
|
| class GaussianNoise: |
| """Add pixel-wise Gaussian noise to a tensor. |
| |
| Applied *after* ``ToTensor`` and normalisation, so the noise scale |
| is relative to normalised pixel intensities. |
| |
| Parameters |
| ---------- |
| std : float |
| Standard deviation of the Gaussian noise. |
| """ |
|
|
| def __init__(self, std: float = 0.02): |
| self.std = std |
|
|
| def __call__(self, tensor: torch.Tensor) -> torch.Tensor: |
| return tensor + torch.randn_like(tensor) * self.std |
|
|
| def __repr__(self): |
| return f"{self.__class__.__name__}(std={self.std})" |
|
|
|
|
| |
| |
| |
|
|
| class ClassAwareAugmentation: |
| """Dispatcher that selects standard or strong augmentation based |
| on the disease class label. |
| |
| Majority classes (Normal, DR) receive a lightweight augmentation |
| pipeline. Minority classes (Glaucoma, Cataract, AMD) receive a |
| substantially heavier pipeline that includes elastic deformation, |
| grid distortion, random occlusion, stronger colour jitter, and |
| Gaussian noise. |
| |
| Parameters |
| ---------- |
| norm_mean : list[float] |
| Per-channel mean for normalisation. |
| norm_std : list[float] |
| Per-channel std for normalisation. |
| img_size : int |
| Output spatial resolution (default 224). |
| """ |
|
|
| def __init__( |
| self, |
| norm_mean: List[float] = None, |
| norm_std: List[float] = None, |
| img_size: int = IMG_SIZE, |
| ): |
| self.norm_mean = norm_mean or _DEFAULT_MEAN |
| self.norm_std = norm_std or _DEFAULT_STD |
| self.img_size = img_size |
|
|
| self._standard = self._build_standard() |
| self._strong = self._build_strong() |
|
|
| def _build_standard(self) -> transforms.Compose: |
| """Standard augmentation for majority classes.""" |
| normalize = transforms.Normalize(self.norm_mean, self.norm_std) |
| return transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((self.img_size, self.img_size)), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomVerticalFlip(p=0.2), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter( |
| brightness=0.2, |
| contrast=0.2, |
| saturation=0.15, |
| hue=0.02, |
| ), |
| transforms.ToTensor(), |
| normalize, |
| transforms.RandomErasing(p=0.15, scale=(0.02, 0.1)), |
| ]) |
|
|
| def _build_strong(self) -> transforms.Compose: |
| """Strong augmentation for minority classes. |
| |
| Adds elastic deformation, grid distortion, random occlusion, |
| aggressive colour jitter, and Gaussian noise on top of the |
| standard spatial transforms. |
| """ |
| normalize = transforms.Normalize(self.norm_mean, self.norm_std) |
| return transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((self.img_size, self.img_size)), |
| |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomVerticalFlip(p=0.3), |
| transforms.RandomRotation(30), |
| transforms.RandomAffine( |
| degrees=0, |
| translate=(0.08, 0.08), |
| scale=(0.90, 1.10), |
| shear=5, |
| ), |
| |
| transforms.RandomApply([ElasticDeformation(alpha=10.0)], p=0.4), |
| transforms.RandomApply([GridDistortion(num_steps=5, distort_limit=0.15)], p=0.3), |
| |
| transforms.RandomApply([RandomOcclusion(num_patches=3, max_ratio=0.08)], p=0.4), |
| |
| transforms.ColorJitter( |
| brightness=0.4, |
| contrast=0.4, |
| saturation=0.3, |
| hue=0.04, |
| ), |
| transforms.RandomGrayscale(p=0.05), |
| transforms.RandomApply( |
| [transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5))], |
| p=0.2, |
| ), |
| |
| transforms.ToTensor(), |
| normalize, |
| |
| transforms.RandomErasing(p=0.3, scale=(0.02, 0.15)), |
| GaussianNoise(std=0.02), |
| ]) |
|
|
| def get_transform(self, class_label: int) -> transforms.Compose: |
| """Return the appropriate transform pipeline for *class_label*. |
| |
| Parameters |
| ---------- |
| class_label : int |
| Disease class index (0=Normal, 1=DR, 2=Glaucoma, |
| 3=Cataract, 4=AMD). |
| |
| Returns |
| ------- |
| transforms.Compose |
| """ |
| if class_label in _MINORITY_CLASSES: |
| return self._strong |
| return self._standard |
|
|
| def __call__(self, img, class_label: int): |
| """Apply the class-appropriate transform to *img*. |
| |
| Parameters |
| ---------- |
| img : np.ndarray or PIL.Image.Image |
| Input image (H, W, 3) uint8 or PIL RGB. |
| class_label : int |
| Disease class index. |
| |
| Returns |
| ------- |
| torch.Tensor |
| Augmented image tensor (3, 224, 224). |
| """ |
| transform = self.get_transform(class_label) |
| return transform(img) |
|
|
|
|
| |
| |
| |
|
|
| def create_train_transforms( |
| class_label: int, |
| norm_mean: List[float] = None, |
| norm_std: List[float] = None, |
| img_size: int = IMG_SIZE, |
| ) -> transforms.Compose: |
| """Return a torchvision transform pipeline appropriate for |
| *class_label* during training. |
| |
| This is a thin wrapper that constructs a ``ClassAwareAugmentation`` |
| and returns the relevant sub-pipeline. For validation/test, use |
| ``create_val_transforms()`` instead. |
| |
| Parameters |
| ---------- |
| class_label : int |
| Disease class index (0-4). |
| norm_mean : list[float] |
| Normalisation mean. Defaults to fundus stats. |
| norm_std : list[float] |
| Normalisation std. Defaults to fundus stats. |
| img_size : int |
| Output spatial resolution (default 224). |
| |
| Returns |
| ------- |
| transforms.Compose |
| """ |
| aug = ClassAwareAugmentation( |
| norm_mean=norm_mean or _DEFAULT_MEAN, |
| norm_std=norm_std or _DEFAULT_STD, |
| img_size=img_size, |
| ) |
| return aug.get_transform(class_label) |
|
|
|
|
| def create_val_transforms( |
| norm_mean: List[float] = None, |
| norm_std: List[float] = None, |
| img_size: int = IMG_SIZE, |
| ) -> transforms.Compose: |
| """Deterministic transform for validation / calibration / test. |
| |
| Parameters |
| ---------- |
| norm_mean : list[float] |
| Normalisation mean. Defaults to fundus stats. |
| norm_std : list[float] |
| Normalisation std. Defaults to fundus stats. |
| img_size : int |
| Output spatial resolution. |
| |
| Returns |
| ------- |
| transforms.Compose |
| """ |
| mean = norm_mean or _DEFAULT_MEAN |
| std = norm_std or _DEFAULT_STD |
| return transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((img_size, img_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean, std), |
| ]) |
|
|
|
|
| |
| |
| |
|
|
| def cutmix_data( |
| x: torch.Tensor, |
| y: torch.Tensor, |
| alpha: float = 1.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]: |
| """Apply CutMix augmentation to a batch of images. |
| |
| A rectangular region is cut from a randomly-permuted copy of the |
| batch and pasted onto the original images. Labels are mixed |
| proportionally to the area ratio. |
| |
| This is complementary to MixUp (which blends pixels globally): |
| CutMix preserves local structure in both the pasted and background |
| regions, forcing the model to attend to all image regions rather |
| than relying on a single discriminative patch. |
| |
| Parameters |
| ---------- |
| x : torch.Tensor |
| Batch of images, shape ``(B, C, H, W)``. |
| y : torch.Tensor |
| Batch of labels, shape ``(B,)``. |
| alpha : float |
| Parameter of the Beta distribution used to sample the mixing |
| ratio ``lam``. ``alpha=1.0`` gives a uniform distribution |
| over [0, 1]. |
| |
| Returns |
| ------- |
| mixed_x : torch.Tensor |
| Images with CutMix applied, shape ``(B, C, H, W)``. |
| y_a : torch.Tensor |
| Original labels (background region), shape ``(B,)``. |
| y_b : torch.Tensor |
| Permuted labels (pasted region), shape ``(B,)``. |
| lam : float |
| Effective mixing ratio (fraction of area from *y_a*). |
| Use as: ``loss = lam * L(pred, y_a) + (1-lam) * L(pred, y_b)`` |
| """ |
| assert x.dim() == 4, f"Expected 4-D input (B, C, H, W), got {x.dim()}-D" |
| batch_size = x.size(0) |
|
|
| |
| if alpha > 0: |
| lam = float(np.random.beta(alpha, alpha)) |
| else: |
| lam = 1.0 |
|
|
| |
| index = torch.randperm(batch_size, device=x.device) |
|
|
| y_a = y |
| y_b = y[index] |
|
|
| |
| _, _, h, w = x.shape |
| cut_ratio = math.sqrt(1.0 - lam) |
| cut_w = int(w * cut_ratio) |
| cut_h = int(h * cut_ratio) |
|
|
| |
| cx = random.randint(0, w) |
| cy = random.randint(0, h) |
|
|
| |
| x1 = max(0, cx - cut_w // 2) |
| y1 = max(0, cy - cut_h // 2) |
| x2 = min(w, cx + cut_w // 2) |
| y2 = min(h, cy + cut_h // 2) |
|
|
| |
| mixed_x = x.clone() |
| mixed_x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2] |
|
|
| |
| |
| pasted_area = (x2 - x1) * (y2 - y1) |
| lam = 1.0 - pasted_area / (w * h) |
|
|
| return mixed_x, y_a, y_b, lam |
|
|
|
|
| |
| |
| |
|
|
| def create_oversampled_dataset( |
| df: pd.DataFrame, |
| minority_multiplier: int = 5, |
| minority_classes: Optional[set] = None, |
| label_col: str = "disease_label", |
| ) -> pd.DataFrame: |
| """Return a new DataFrame where minority-class rows are replicated. |
| |
| This is a simple but effective approach that works well in |
| combination with ``WeightedRandomSampler``: the sampler ensures |
| balanced draws per epoch, while oversampling increases the total |
| pool of minority examples that the sampler can pick from. When |
| used with class-aware strong augmentation, each replicated sample |
| produces a different augmented view, so the model sees genuine |
| visual diversity rather than exact duplicates. |
| |
| Parameters |
| ---------- |
| df : pd.DataFrame |
| Training metadata. Must contain *label_col*. |
| minority_multiplier : int |
| How many times to replicate each minority-class row. |
| A value of 5 means each minority sample appears 5 times in |
| the output (the original plus 4 copies). |
| minority_classes : set[int] or None |
| Class indices considered minority. Defaults to {2, 3, 4} |
| (Glaucoma, Cataract, AMD). |
| label_col : str |
| Column name for the disease label. |
| |
| Returns |
| ------- |
| pd.DataFrame |
| Concatenation of: |
| - all majority rows (unchanged) |
| - minority rows replicated *minority_multiplier* times |
| Reset index, shuffled. |
| """ |
| if minority_classes is None: |
| minority_classes = _MINORITY_CLASSES |
|
|
| majority_mask = ~df[label_col].isin(minority_classes) |
| minority_mask = df[label_col].isin(minority_classes) |
|
|
| majority_df = df[majority_mask] |
| minority_df = df[minority_mask] |
|
|
| |
| minority_repeated = pd.concat( |
| [minority_df] * minority_multiplier, ignore_index=True |
| ) |
|
|
| |
| combined = pd.concat( |
| [majority_df, minority_repeated], ignore_index=True |
| ) |
| combined = combined.sample(frac=1.0, random_state=42).reset_index(drop=True) |
|
|
| |
| orig_counts = df[label_col].value_counts().sort_index() |
| new_counts = combined[label_col].value_counts().sort_index() |
| print(f" Oversampling summary (minority x{minority_multiplier}):") |
| for cls_idx in sorted(df[label_col].unique()): |
| orig = orig_counts.get(cls_idx, 0) |
| new = new_counts.get(cls_idx, 0) |
| tag = " [minority]" if cls_idx in minority_classes else "" |
| print(f" Class {cls_idx}: {orig:5d} -> {new:5d}{tag}") |
| print(f" Total: {len(df)} -> {len(combined)}") |
|
|
| return combined |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("Enhanced Augmentation Pipeline -- sanity check") |
| print("=" * 55) |
|
|
| |
| dummy_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) |
|
|
| aug = ClassAwareAugmentation() |
|
|
| print("\nStandard augmentation (Normal, class=0):") |
| t_std = aug.get_transform(0) |
| out = t_std(dummy_img) |
| print(f" Output shape: {out.shape}, dtype: {out.dtype}") |
| print(f" Value range : [{out.min():.3f}, {out.max():.3f}]") |
|
|
| print("\nStrong augmentation (Glaucoma, class=2):") |
| t_str = aug.get_transform(2) |
| out = t_str(dummy_img) |
| print(f" Output shape: {out.shape}, dtype: {out.dtype}") |
| print(f" Value range : [{out.min():.3f}, {out.max():.3f}]") |
|
|
| print("\nFactory function (Cataract, class=3):") |
| t_fac = create_train_transforms(class_label=3) |
| out = t_fac(dummy_img) |
| print(f" Output shape: {out.shape}") |
|
|
| print("\nVal transforms:") |
| t_val = create_val_transforms() |
| out = t_val(dummy_img) |
| print(f" Output shape: {out.shape}") |
|
|
| print("\nCutMix:") |
| batch_x = torch.randn(8, 3, 224, 224) |
| batch_y = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2]) |
| mixed_x, y_a, y_b, lam = cutmix_data(batch_x, batch_y, alpha=1.0) |
| print(f" mixed_x shape: {mixed_x.shape}") |
| print(f" lam : {lam:.4f}") |
| print(f" y_a : {y_a.tolist()}") |
| print(f" y_b : {y_b.tolist()}") |
|
|
| print("\nOversampling:") |
| dummy_df = pd.DataFrame({ |
| "image_path": [f"img_{i}.png" for i in range(100)], |
| "disease_label": ( |
| [0] * 30 + [1] * 40 + [2] * 12 + [3] * 10 + [4] * 8 |
| ), |
| "severity_label": [0] * 100, |
| }) |
| oversampled = create_oversampled_dataset(dummy_df, minority_multiplier=5) |
|
|
| print("\nDone.") |
|
|