Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import torch | |
| from torchvision import utils | |
| import math | |
| from medical_diffusion.models.pipelines import DiffusionPipeline | |
| import numpy as np | |
| from PIL import Image | |
| import time | |
| def chunks(lst, n): | |
| """Yield successive n-sized chunks from lst.""" | |
| for i in range(0, len(lst), n): | |
| yield lst[i:i + n] | |
| # ------------ Load Model ------------ | |
| device = torch.device('cuda') | |
| # pipeline = DiffusionPipeline.load_best_checkpoint(path_run_dir) | |
| pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_12_12_171357_chest_diffusion/last.ckpt') | |
| pipeline.to(device) | |
| if __name__ == "__main__": | |
| # {'NRG':0, 'RG':1} 3270, {'MSIH':0, 'nonMSIH':1} :9979 {'No_Cardiomegaly':0, 'Cardiomegaly':1} 7869 | |
| for steps in [50, 100, 150, 200, 250]: | |
| for name, label in {'No_Cardiomegaly':0, 'Cardiomegaly':1}.items(): | |
| n_samples = 7869 | |
| sample_batch = 200 | |
| cfg = 1 | |
| # path_out = Path(f'/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_{steps}/')/name | |
| path_out = Path(f'/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_{steps}')/name | |
| # path_out = Path('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion')/name | |
| path_out.mkdir(parents=True, exist_ok=True) | |
| # --------- Generate Samples ------------------- | |
| torch.manual_seed(0) | |
| counter = 0 | |
| for chunk in chunks(list(range(n_samples)), sample_batch): | |
| condition = torch.tensor([label]*len(chunk), device=device) if label is not None else None | |
| un_cond = torch.tensor([1-label]*len(chunk), device=device) if label is not None else None # Might be None, or 1-condition or specific label | |
| results = pipeline.sample(len(chunk), (8, 32, 32), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps) | |
| # results = pipeline.sample(len(chunk), (4, 64, 64), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps ) | |
| results = results.cpu().numpy() | |
| # --------- Save result ---------------- | |
| for image in results: | |
| image = image.clip(-1, 1) # or (image-image.min())/(image.max()-image.min()) | |
| image = (image+1)/2*255 # Transform from [-1, 1] to [0, 1] to [0, 255] | |
| image = np.moveaxis(image, 0, -1) | |
| image = image.astype(np.uint8) | |
| image = np.squeeze(image, axis=-1) if image.shape[-1]==1 else image | |
| Image.fromarray(image).convert("RGB").save(path_out/f'fake_{counter}.png') | |
| counter += 1 | |
| torch.cuda.empty_cache() | |
| time.sleep(3) | |