chaoshengt's picture
Add data module: medical image to graph conversion pipeline
816502d verified
Raw
History Blame Contribute Delete
6.51 kB
"""
Data loading and preprocessing for medical image classification.
Converts medical images to graphs:
1. Divide image into patches (8x8 with stride 6)
2. Extract features per patch (color histogram + texture + spatial position)
3. Build k-NN graph for structure construction
4. Create SimplicialComplex + Hypergraph via build_topohyper_structure()
"""
import torch
import numpy as np
from .structures import build_topohyper_structure
def extract_patch_features(image, patch_size=8, stride=6):
"""
Extract patch-level features from a medical image.
Features per patch (38-dim):
- Color histogram: 24 bins (8 per channel for RGB, or 24 for grayscale mapped to 3)
- Texture: 8 values (mean, std of gradients in x/y for 2 scales)
- Spatial position: 6 values (normalized x, y, x^2, y^2, xy, 1)
Args:
image: (C, H, W) tensor, values in [0, 1]
patch_size: Size of each patch
stride: Stride between patches
Returns:
features: (N_patches, 38) tensor
positions: (N_patches, 2) center positions
"""
if image.dim() == 2:
image = image.unsqueeze(0)
C, H, W = image.shape
# Ensure 3 channels
if C == 1:
image = image.expand(3, -1, -1)
elif C > 3:
image = image[:3]
patches = []
positions = []
for y in range(0, H - patch_size + 1, stride):
for x in range(0, W - patch_size + 1, stride):
patch = image[:, y:y + patch_size, x:x + patch_size]
# Color histogram (8 bins per channel)
color_feats = []
for c in range(3):
hist = torch.histc(patch[c].float(), bins=8, min=0.0, max=1.0)
hist = hist / (hist.sum() + 1e-7)
color_feats.append(hist)
color_feat = torch.cat(color_feats) # 24-dim
# Texture features (gradients)
gray = patch.mean(dim=0) # (H, W)
dx = gray[:, 1:] - gray[:, :-1]
dy = gray[1:, :] - gray[:-1, :]
texture_feat = torch.tensor([
dx.mean(), dx.std(), dy.mean(), dy.std(),
dx.abs().mean(), dy.abs().mean(),
(dx ** 2).mean().sqrt(), (dy ** 2).mean().sqrt()
])
# Spatial features
cx = (x + patch_size / 2) / W
cy = (y + patch_size / 2) / H
spatial_feat = torch.tensor([cx, cy, cx**2, cy**2, cx * cy, 1.0])
feat = torch.cat([color_feat, texture_feat, spatial_feat])
patches.append(feat)
positions.append(torch.tensor([cx, cy]))
features = torch.stack(patches)
positions = torch.stack(positions)
return features, positions
class MedicalImageGraphDataset:
"""
Dataset that converts medical images to graphs.
Each image becomes a graph where:
- Nodes = image patches
- Edges = k-NN connections between patch features
- Triangles = 3-cliques for simplicial complex
- Hyperedges = k-NN neighborhoods + triangles for hypergraph
Args:
images: (B, C, H, W) tensor of images
labels: (B,) tensor of labels
patch_size: Patch extraction size
stride: Patch extraction stride
k: k-NN parameter for structure building
max_samples: Maximum number of samples to use (None = all)
"""
def __init__(self, images, labels, patch_size=8, stride=6, k=6, max_samples=None):
self.patch_size = patch_size
self.stride = stride
self.k = k
if max_samples is not None:
images = images[:max_samples]
labels = labels[:max_samples]
self.graphs = []
for i in range(len(images)):
img = images[i].float()
if img.max() > 1.0:
img = img / 255.0
features, positions = extract_patch_features(img, patch_size, stride)
sc, hg, edge_index = build_topohyper_structure(features, k=k)
self.graphs.append({
'features': features,
'sc': sc,
'hg': hg,
'edge_index': edge_index,
'label': labels[i].item() if isinstance(labels[i], torch.Tensor) else int(labels[i])
})
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
return self.graphs[idx]
def load_medmnist_data(dataset_name='pathmnist', size=64, max_train=None, max_val=None, max_test=None):
"""
Load MedMNIST dataset and create graph datasets.
Args:
dataset_name: Name of MedMNIST subset (e.g., 'pathmnist', 'dermamnist')
size: Image size (28 or 64)
max_train: Max training samples
max_val: Max validation samples
max_test: Max test samples
Returns:
train_dataset, val_dataset, test_dataset, num_classes
"""
import medmnist
from medmnist import INFO
info = INFO[dataset_name]
num_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, size=size)
val_data = DataClass(split='val', download=True, size=size)
test_data = DataClass(split='test', download=True, size=size)
def to_tensors(data):
images = torch.tensor(data.imgs).float()
if images.dim() == 3: # (N, H, W) -> (N, 1, H, W)
images = images.unsqueeze(1)
elif images.dim() == 4 and images.shape[-1] in [1, 3]: # (N, H, W, C) -> (N, C, H, W)
images = images.permute(0, 3, 1, 2)
images = images / 255.0
labels = torch.tensor(data.labels).squeeze()
return images, labels
train_imgs, train_labels = to_tensors(train_data)
val_imgs, val_labels = to_tensors(val_data)
test_imgs, test_labels = to_tensors(test_data)
print(f"Dataset: {dataset_name} ({num_classes} classes)")
print(f" Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}")
print(f" Image shape: {train_imgs.shape[1:]}")
train_ds = MedicalImageGraphDataset(train_imgs, train_labels, max_samples=max_train)
val_ds = MedicalImageGraphDataset(val_imgs, val_labels, max_samples=max_val)
test_ds = MedicalImageGraphDataset(test_imgs, test_labels, max_samples=max_test)
return train_ds, val_ds, test_ds, num_classes