""" Data loading and preprocessing for medical image classification. Converts medical images to graphs: 1. Divide image into patches (8x8 with stride 6) 2. Extract features per patch (color histogram + texture + spatial position) 3. Build k-NN graph for structure construction 4. Create SimplicialComplex + Hypergraph via build_topohyper_structure() """ import torch import numpy as np from .structures import build_topohyper_structure def extract_patch_features(image, patch_size=8, stride=6): """ Extract patch-level features from a medical image. Features per patch (38-dim): - Color histogram: 24 bins (8 per channel for RGB, or 24 for grayscale mapped to 3) - Texture: 8 values (mean, std of gradients in x/y for 2 scales) - Spatial position: 6 values (normalized x, y, x^2, y^2, xy, 1) Args: image: (C, H, W) tensor, values in [0, 1] patch_size: Size of each patch stride: Stride between patches Returns: features: (N_patches, 38) tensor positions: (N_patches, 2) center positions """ if image.dim() == 2: image = image.unsqueeze(0) C, H, W = image.shape # Ensure 3 channels if C == 1: image = image.expand(3, -1, -1) elif C > 3: image = image[:3] patches = [] positions = [] for y in range(0, H - patch_size + 1, stride): for x in range(0, W - patch_size + 1, stride): patch = image[:, y:y + patch_size, x:x + patch_size] # Color histogram (8 bins per channel) color_feats = [] for c in range(3): hist = torch.histc(patch[c].float(), bins=8, min=0.0, max=1.0) hist = hist / (hist.sum() + 1e-7) color_feats.append(hist) color_feat = torch.cat(color_feats) # 24-dim # Texture features (gradients) gray = patch.mean(dim=0) # (H, W) dx = gray[:, 1:] - gray[:, :-1] dy = gray[1:, :] - gray[:-1, :] texture_feat = torch.tensor([ dx.mean(), dx.std(), dy.mean(), dy.std(), dx.abs().mean(), dy.abs().mean(), (dx ** 2).mean().sqrt(), (dy ** 2).mean().sqrt() ]) # Spatial features cx = (x + patch_size / 2) / W cy = (y + patch_size / 2) / H spatial_feat = torch.tensor([cx, cy, cx**2, cy**2, cx * cy, 1.0]) feat = torch.cat([color_feat, texture_feat, spatial_feat]) patches.append(feat) positions.append(torch.tensor([cx, cy])) features = torch.stack(patches) positions = torch.stack(positions) return features, positions class MedicalImageGraphDataset: """ Dataset that converts medical images to graphs. Each image becomes a graph where: - Nodes = image patches - Edges = k-NN connections between patch features - Triangles = 3-cliques for simplicial complex - Hyperedges = k-NN neighborhoods + triangles for hypergraph Args: images: (B, C, H, W) tensor of images labels: (B,) tensor of labels patch_size: Patch extraction size stride: Patch extraction stride k: k-NN parameter for structure building max_samples: Maximum number of samples to use (None = all) """ def __init__(self, images, labels, patch_size=8, stride=6, k=6, max_samples=None): self.patch_size = patch_size self.stride = stride self.k = k if max_samples is not None: images = images[:max_samples] labels = labels[:max_samples] self.graphs = [] for i in range(len(images)): img = images[i].float() if img.max() > 1.0: img = img / 255.0 features, positions = extract_patch_features(img, patch_size, stride) sc, hg, edge_index = build_topohyper_structure(features, k=k) self.graphs.append({ 'features': features, 'sc': sc, 'hg': hg, 'edge_index': edge_index, 'label': labels[i].item() if isinstance(labels[i], torch.Tensor) else int(labels[i]) }) def __len__(self): return len(self.graphs) def __getitem__(self, idx): return self.graphs[idx] def load_medmnist_data(dataset_name='pathmnist', size=64, max_train=None, max_val=None, max_test=None): """ Load MedMNIST dataset and create graph datasets. Args: dataset_name: Name of MedMNIST subset (e.g., 'pathmnist', 'dermamnist') size: Image size (28 or 64) max_train: Max training samples max_val: Max validation samples max_test: Max test samples Returns: train_dataset, val_dataset, test_dataset, num_classes """ import medmnist from medmnist import INFO info = INFO[dataset_name] num_classes = len(info['label']) DataClass = getattr(medmnist, info['python_class']) train_data = DataClass(split='train', download=True, size=size) val_data = DataClass(split='val', download=True, size=size) test_data = DataClass(split='test', download=True, size=size) def to_tensors(data): images = torch.tensor(data.imgs).float() if images.dim() == 3: # (N, H, W) -> (N, 1, H, W) images = images.unsqueeze(1) elif images.dim() == 4 and images.shape[-1] in [1, 3]: # (N, H, W, C) -> (N, C, H, W) images = images.permute(0, 3, 1, 2) images = images / 255.0 labels = torch.tensor(data.labels).squeeze() return images, labels train_imgs, train_labels = to_tensors(train_data) val_imgs, val_labels = to_tensors(val_data) test_imgs, test_labels = to_tensors(test_data) print(f"Dataset: {dataset_name} ({num_classes} classes)") print(f" Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}") print(f" Image shape: {train_imgs.shape[1:]}") train_ds = MedicalImageGraphDataset(train_imgs, train_labels, max_samples=max_train) val_ds = MedicalImageGraphDataset(val_imgs, val_labels, max_samples=max_val) test_ds = MedicalImageGraphDataset(test_imgs, test_labels, max_samples=max_test) return train_ds, val_ds, test_ds, num_classes