Spaces:
Runtime error
Runtime error
| import torch.utils.data as data | |
| import torch | |
| from torch import nn | |
| from pathlib import Path | |
| from torchvision import transforms as T | |
| import pandas as pd | |
| from PIL import Image | |
| from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit | |
| class SimpleDataset2D(data.Dataset): | |
| def __init__( | |
| self, | |
| path_root, | |
| item_pointers =[], | |
| crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'], | |
| transform = None, | |
| image_resize = None, | |
| augment_horizontal_flip = False, | |
| augment_vertical_flip = False, | |
| image_crop = None, | |
| ): | |
| super().__init__() | |
| self.path_root = Path(path_root) | |
| self.crawler_ext = crawler_ext | |
| if len(item_pointers): | |
| self.item_pointers = item_pointers | |
| else: | |
| self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext) | |
| if transform is None: | |
| self.transform = T.Compose([ | |
| T.Resize(image_resize) if image_resize is not None else nn.Identity(), | |
| T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), | |
| T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(), | |
| T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(), | |
| T.ToTensor(), | |
| # T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x), | |
| # ToTensor16bit(), | |
| # Normalize(), # [0, 1.0] | |
| # T.ConvertImageDtype(torch.float), | |
| T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1] | |
| ]) | |
| else: | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.item_pointers) | |
| def __getitem__(self, index): | |
| rel_path_item = self.item_pointers[index] | |
| path_item = self.path_root/rel_path_item | |
| # img = Image.open(path_item) | |
| img = self.load_item(path_item) | |
| return {'uid':rel_path_item.stem, 'source': self.transform(img)} | |
| def load_item(self, path_item): | |
| return Image.open(path_item).convert('RGB') | |
| # return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images | |
| def run_item_crawler(cls, path_root, extension, **kwargs): | |
| return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')] | |
| def get_weights(self): | |
| """Return list of class-weights for WeightedSampling""" | |
| return None | |
| class AIROGSDataset(SimpleDataset2D): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id') | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, index): | |
| uid = self.labels.index[index] | |
| path_item = self.path_root/f'{uid}.jpg' | |
| img = self.load_item(path_item) | |
| str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172 | |
| target = str_2_int[self.labels.loc[uid, 'class']] | |
| # return {'uid':uid, 'source': self.transform(img), 'target':target} | |
| return {'source': self.transform(img), 'target':target} | |
| def get_weights(self): | |
| n_samples = len(self) | |
| weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02} | |
| weights = [0] * n_samples | |
| for index in range(n_samples): | |
| target = self.labels.iloc[index]['class'] | |
| weights[index] = weight_per_class[target] | |
| return weights | |
| def run_item_crawler(cls, path_root, extension, **kwargs): | |
| """Overwrite to speed up as paths are determined by .csv file anyway""" | |
| return [] | |
| class MSIvsMSS_Dataset(SimpleDataset2D): | |
| # https://doi.org/10.5281/zenodo.2530835 | |
| def __getitem__(self, index): | |
| rel_path_item = self.item_pointers[index] | |
| path_item = self.path_root/rel_path_item | |
| img = self.load_item(path_item) | |
| uid = rel_path_item.stem | |
| str_2_int = {'MSIMUT':0, 'MSS':1} | |
| target = str_2_int[path_item.parent.name] # | |
| return {'uid':uid, 'source': self.transform(img), 'target':target} | |
| class MSIvsMSS_2_Dataset(SimpleDataset2D): | |
| # https://doi.org/10.5281/zenodo.3832231 | |
| def __getitem__(self, index): | |
| rel_path_item = self.item_pointers[index] | |
| path_item = self.path_root/rel_path_item | |
| img = self.load_item(path_item) | |
| uid = rel_path_item.stem | |
| str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH) | |
| target = str_2_int[path_item.parent.name] | |
| # return {'uid':uid, 'source': self.transform(img), 'target':target} | |
| return {'source': self.transform(img), 'target':target} | |
| class CheXpert_Dataset(SimpleDataset2D): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| mode = self.path_root.name | |
| labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path') | |
| self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy() | |
| self.labels.index = self.labels.index.str[20:] | |
| self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication | |
| self.labels.fillna(2, inplace=True) # TODO: Find better solution, | |
| str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}} | |
| self.labels.replace(str_2_int, inplace=True) | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, index): | |
| rel_path_item = self.labels.index[index] | |
| path_item = self.path_root/rel_path_item | |
| img = self.load_item(path_item) | |
| uid = str(rel_path_item) | |
| target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3 | |
| return {'uid':uid, 'source': self.transform(img), 'target':target} | |
| def run_item_crawler(cls, path_root, extension, **kwargs): | |
| """Overwrite to speed up as paths are determined by .csv file anyway""" | |
| return [] | |
| class CheXpert_2_Dataset(SimpleDataset2D): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0) | |
| labels = labels.loc[labels['fold']=='train'].copy() | |
| labels = labels.drop(labels='fold', axis=1) | |
| labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path') | |
| labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy() | |
| labels2 = labels2[['Cardiomegaly',]].copy() | |
| labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain | |
| labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true') | |
| # labels = labels[labels['Cardiomegaly_true']!=2] | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, index): | |
| path_index, image_index = self.labels.index[index] | |
| path_item = self.path_root/'data'/f'{image_index:06}.png' | |
| img = self.load_item(path_item) | |
| uid = image_index | |
| target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly']) | |
| # return {'uid':uid, 'source': self.transform(img), 'target':target} | |
| return {'source': self.transform(img), 'target':target} | |
| def run_item_crawler(cls, path_root, extension, **kwargs): | |
| """Overwrite to speed up as paths are determined by .csv file anyway""" | |
| return [] | |
| def get_weights(self): | |
| n_samples = len(self) | |
| weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True) | |
| # weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3} | |
| weights = [0] * n_samples | |
| for index in range(n_samples): | |
| target = self.labels.loc[self.labels.index[index], 'Cardiomegaly'] | |
| weights[index] = weight_per_class[target] | |
| return weights |