Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import logging | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as tF | |
| from torch.utils.data.dataloader import DataLoader | |
| from torchvision.datasets import ImageFolder | |
| from torch.utils.data import TensorDataset, Subset | |
| from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS | |
| from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim | |
| from medical_diffusion.models.embedders.latent_embedders import VAE | |
| # ----------------Settings -------------- | |
| batch_size = 100 | |
| max_samples = None # set to None for all | |
| target_class = None # None for no specific class | |
| # path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics' | |
| # path_out = Path.cwd()/'results'/'AIROGS'/'metrics' | |
| path_out = Path.cwd()/'results'/'CheXpert'/'metrics' | |
| path_out.mkdir(parents=True, exist_ok=True) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # ----------------- Logging ----------- | |
| current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") | |
| logger = logging.getLogger() | |
| logging.basicConfig(level=logging.INFO) | |
| logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w')) | |
| # -------------- Helpers --------------------- | |
| pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later) | |
| # ---------------- Dataset/Dataloader ---------------- | |
| ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', transform=pil2torch) | |
| # ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch) | |
| # ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference_test/', transform=pil2torch) | |
| # ---------- Limit Sample Size | |
| ds_real.samples = ds_real.samples[slice(max_samples)] | |
| # --------- Select specific class ------------ | |
| if target_class is not None: | |
| ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]]) | |
| dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False) | |
| logger.info(f"Samples Real: {len(ds_real)}") | |
| # --------------- Load Model ------------------ | |
| model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt') | |
| model.to(device) | |
| # from diffusers import StableDiffusionPipeline | |
| # with open('auth_token.txt', 'r') as file: | |
| # auth_token = file.read() | |
| # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token) | |
| # model = pipe.vae | |
| # model.to(device) | |
| # ------------- Init Metrics ---------------------- | |
| calc_lpips = LPIPS().to(device) | |
| # --------------- Start Calculation ----------------- | |
| mmssim_list, mse_list = [], [] | |
| for real_batch in tqdm(dm_real): | |
| imgs_real_batch = real_batch[0].to(device) | |
| imgs_real_batch = tF.normalize(imgs_real_batch/255, 0.5, 0.5) # [0, 255] -> [-1, 1] | |
| with torch.no_grad(): | |
| imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1) | |
| # -------------- LPIP ------------------- | |
| calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1] | |
| # -------------- MS-SSIM + MSE ------------------- | |
| for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch): | |
| img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1] | |
| mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu')) | |
| mse_list.append(torch.mean(torch.square(img_real-img_fake))) | |
| # -------------- Summary ------------------- | |
| mmssim_list = torch.stack(mmssim_list) | |
| mse_list = torch.stack(mse_list) | |
| lpips = 1-calc_lpips.compute() | |
| logger.info(f"LPIPS Score: {lpips}") | |
| logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}") | |
| logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}") |