| """ |
| Data loading and preprocessing for medical image classification. |
| |
| Converts medical images to graphs: |
| 1. Divide image into patches (8x8 with stride 6) |
| 2. Extract features per patch (color histogram + texture + spatial position) |
| 3. Build k-NN graph for structure construction |
| 4. Create SimplicialComplex + Hypergraph via build_topohyper_structure() |
| """ |
|
|
| import torch |
| import numpy as np |
| from .structures import build_topohyper_structure |
|
|
|
|
| def extract_patch_features(image, patch_size=8, stride=6): |
| """ |
| Extract patch-level features from a medical image. |
| |
| Features per patch (38-dim): |
| - Color histogram: 24 bins (8 per channel for RGB, or 24 for grayscale mapped to 3) |
| - Texture: 8 values (mean, std of gradients in x/y for 2 scales) |
| - Spatial position: 6 values (normalized x, y, x^2, y^2, xy, 1) |
| |
| Args: |
| image: (C, H, W) tensor, values in [0, 1] |
| patch_size: Size of each patch |
| stride: Stride between patches |
| |
| Returns: |
| features: (N_patches, 38) tensor |
| positions: (N_patches, 2) center positions |
| """ |
| if image.dim() == 2: |
| image = image.unsqueeze(0) |
| |
| C, H, W = image.shape |
| |
| |
| if C == 1: |
| image = image.expand(3, -1, -1) |
| elif C > 3: |
| image = image[:3] |
| |
| patches = [] |
| positions = [] |
| |
| for y in range(0, H - patch_size + 1, stride): |
| for x in range(0, W - patch_size + 1, stride): |
| patch = image[:, y:y + patch_size, x:x + patch_size] |
| |
| |
| color_feats = [] |
| for c in range(3): |
| hist = torch.histc(patch[c].float(), bins=8, min=0.0, max=1.0) |
| hist = hist / (hist.sum() + 1e-7) |
| color_feats.append(hist) |
| color_feat = torch.cat(color_feats) |
| |
| |
| gray = patch.mean(dim=0) |
| dx = gray[:, 1:] - gray[:, :-1] |
| dy = gray[1:, :] - gray[:-1, :] |
| |
| texture_feat = torch.tensor([ |
| dx.mean(), dx.std(), dy.mean(), dy.std(), |
| dx.abs().mean(), dy.abs().mean(), |
| (dx ** 2).mean().sqrt(), (dy ** 2).mean().sqrt() |
| ]) |
| |
| |
| cx = (x + patch_size / 2) / W |
| cy = (y + patch_size / 2) / H |
| spatial_feat = torch.tensor([cx, cy, cx**2, cy**2, cx * cy, 1.0]) |
| |
| feat = torch.cat([color_feat, texture_feat, spatial_feat]) |
| patches.append(feat) |
| positions.append(torch.tensor([cx, cy])) |
| |
| features = torch.stack(patches) |
| positions = torch.stack(positions) |
| |
| return features, positions |
|
|
|
|
| class MedicalImageGraphDataset: |
| """ |
| Dataset that converts medical images to graphs. |
| |
| Each image becomes a graph where: |
| - Nodes = image patches |
| - Edges = k-NN connections between patch features |
| - Triangles = 3-cliques for simplicial complex |
| - Hyperedges = k-NN neighborhoods + triangles for hypergraph |
| |
| Args: |
| images: (B, C, H, W) tensor of images |
| labels: (B,) tensor of labels |
| patch_size: Patch extraction size |
| stride: Patch extraction stride |
| k: k-NN parameter for structure building |
| max_samples: Maximum number of samples to use (None = all) |
| """ |
| |
| def __init__(self, images, labels, patch_size=8, stride=6, k=6, max_samples=None): |
| self.patch_size = patch_size |
| self.stride = stride |
| self.k = k |
| |
| if max_samples is not None: |
| images = images[:max_samples] |
| labels = labels[:max_samples] |
| |
| self.graphs = [] |
| for i in range(len(images)): |
| img = images[i].float() |
| if img.max() > 1.0: |
| img = img / 255.0 |
| |
| features, positions = extract_patch_features(img, patch_size, stride) |
| sc, hg, edge_index = build_topohyper_structure(features, k=k) |
| |
| self.graphs.append({ |
| 'features': features, |
| 'sc': sc, |
| 'hg': hg, |
| 'edge_index': edge_index, |
| 'label': labels[i].item() if isinstance(labels[i], torch.Tensor) else int(labels[i]) |
| }) |
| |
| def __len__(self): |
| return len(self.graphs) |
| |
| def __getitem__(self, idx): |
| return self.graphs[idx] |
|
|
|
|
| def load_medmnist_data(dataset_name='pathmnist', size=64, max_train=None, max_val=None, max_test=None): |
| """ |
| Load MedMNIST dataset and create graph datasets. |
| |
| Args: |
| dataset_name: Name of MedMNIST subset (e.g., 'pathmnist', 'dermamnist') |
| size: Image size (28 or 64) |
| max_train: Max training samples |
| max_val: Max validation samples |
| max_test: Max test samples |
| |
| Returns: |
| train_dataset, val_dataset, test_dataset, num_classes |
| """ |
| import medmnist |
| from medmnist import INFO |
| |
| info = INFO[dataset_name] |
| num_classes = len(info['label']) |
| DataClass = getattr(medmnist, info['python_class']) |
| |
| train_data = DataClass(split='train', download=True, size=size) |
| val_data = DataClass(split='val', download=True, size=size) |
| test_data = DataClass(split='test', download=True, size=size) |
| |
| def to_tensors(data): |
| images = torch.tensor(data.imgs).float() |
| if images.dim() == 3: |
| images = images.unsqueeze(1) |
| elif images.dim() == 4 and images.shape[-1] in [1, 3]: |
| images = images.permute(0, 3, 1, 2) |
| images = images / 255.0 |
| labels = torch.tensor(data.labels).squeeze() |
| return images, labels |
| |
| train_imgs, train_labels = to_tensors(train_data) |
| val_imgs, val_labels = to_tensors(val_data) |
| test_imgs, test_labels = to_tensors(test_data) |
| |
| print(f"Dataset: {dataset_name} ({num_classes} classes)") |
| print(f" Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}") |
| print(f" Image shape: {train_imgs.shape[1:]}") |
| |
| train_ds = MedicalImageGraphDataset(train_imgs, train_labels, max_samples=max_train) |
| val_ds = MedicalImageGraphDataset(val_imgs, val_labels, max_samples=max_val) |
| test_ds = MedicalImageGraphDataset(test_imgs, test_labels, max_samples=max_test) |
| |
| return train_ds, val_ds, test_ds, num_classes |
|
|