Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from datetime import datetime | |
| import torch | |
| from torch.utils.data import ConcatDataset | |
| from pytorch_lightning.trainer import Trainer | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
| from medical_diffusion.data.datamodules import SimpleDataModule | |
| from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset | |
| from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN | |
| import torch.multiprocessing | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| if __name__ == "__main__": | |
| # --------------- Settings -------------------- | |
| current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") | |
| path_run_dir = Path.cwd() / 'runs' / str(current_time) | |
| path_run_dir.mkdir(parents=True, exist_ok=True) | |
| gpus = [0] if torch.cuda.is_available() else None | |
| # ------------ Load Data ---------------- | |
| # ds_1 = AIROGSDataset( # 256x256 | |
| # crawler_ext='jpg', | |
| # augment_horizontal_flip=True, | |
| # augment_vertical_flip=True, | |
| # # path_root='/home/gustav/Documents/datasets/AIROGS/dataset', | |
| # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', | |
| # ) | |
| # ds_2 = MSIvsMSS_2_Dataset( # 512x512 | |
| # # image_resize=256, | |
| # crawler_ext='jpg', | |
| # augment_horizontal_flip=True, | |
| # augment_vertical_flip=True, | |
| # # path_root='/home/gustav/Documents/datasets/Kather_2/train' | |
| # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' | |
| # ) | |
| ds_3 = CheXpert_2_Dataset( # 256x256 | |
| # image_resize=128, | |
| augment_horizontal_flip=False, | |
| augment_vertical_flip=False, | |
| # path_root = '/home/gustav/Documents/datasets/CheXpert/preprocessed_tianyu' | |
| path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' | |
| ) | |
| # ds = ConcatDataset([ds_1, ds_2, ds_3]) | |
| dm = SimpleDataModule( | |
| ds_train = ds_3, | |
| batch_size=8, | |
| # num_workers=0, | |
| pin_memory=True | |
| ) | |
| # ------------ Initialize Model ------------ | |
| model = VAE( | |
| in_channels=3, | |
| out_channels=3, | |
| emb_channels=8, | |
| spatial_dims=2, | |
| hid_chs = [ 64, 128, 256, 512], | |
| kernel_sizes=[ 3, 3, 3, 3], | |
| strides = [ 1, 2, 2, 2], | |
| deep_supervision=1, | |
| use_attention= 'none', | |
| loss = torch.nn.MSELoss, | |
| # optimizer_kwargs={'lr':1e-6}, | |
| embedding_loss_weight=1e-6 | |
| ) | |
| # model.load_pretrained(Path.cwd()/'runs/2022_12_01_183752_patho_vae/last.ckpt', strict=True) | |
| # model = VAEGAN( | |
| # in_channels=3, | |
| # out_channels=3, | |
| # emb_channels=8, | |
| # spatial_dims=2, | |
| # hid_chs = [ 64, 128, 256, 512], | |
| # deep_supervision=1, | |
| # use_attention= 'none', | |
| # start_gan_train_step=-1, | |
| # embedding_loss_weight=1e-6 | |
| # ) | |
| # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_11_25_082209_chest_vae/last.ckpt') | |
| # model.load_pretrained(Path.cwd()/'runs/2022_11_25_232957_patho_vaegan/last.ckpt') | |
| # model = VQVAE( | |
| # in_channels=3, | |
| # out_channels=3, | |
| # emb_channels=4, | |
| # num_embeddings = 8192, | |
| # spatial_dims=2, | |
| # hid_chs = [64, 128, 256, 512], | |
| # embedding_loss_weight=1, | |
| # beta=1, | |
| # loss = torch.nn.L1Loss, | |
| # deep_supervision=1, | |
| # use_attention = 'none', | |
| # ) | |
| # model = VQGAN( | |
| # in_channels=3, | |
| # out_channels=3, | |
| # emb_channels=4, | |
| # num_embeddings = 8192, | |
| # spatial_dims=2, | |
| # hid_chs = [64, 128, 256, 512], | |
| # embedding_loss_weight=1, | |
| # beta=1, | |
| # start_gan_train_step=-1, | |
| # pixel_loss = torch.nn.L1Loss, | |
| # deep_supervision=1, | |
| # use_attention='none', | |
| # ) | |
| # model.vqvae.load_pretrained(Path.cwd()/'runs/2022_12_13_093727_patho_vqvae/last.ckpt') | |
| # -------------- Training Initialization --------------- | |
| to_monitor = "train/L1" # "val/loss" | |
| min_max = "min" | |
| save_and_sample_every = 50 | |
| early_stopping = EarlyStopping( | |
| monitor=to_monitor, | |
| min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement | |
| patience=30, # number of checks with no improvement | |
| mode=min_max | |
| ) | |
| checkpointing = ModelCheckpoint( | |
| dirpath=str(path_run_dir), # dirpath | |
| monitor=to_monitor, | |
| every_n_train_steps=save_and_sample_every, | |
| save_last=True, | |
| save_top_k=5, | |
| mode=min_max, | |
| ) | |
| trainer = Trainer( | |
| accelerator='gpu', | |
| devices=[0], | |
| # precision=16, | |
| # amp_backend='apex', | |
| # amp_level='O2', | |
| # gradient_clip_val=0.5, | |
| default_root_dir=str(path_run_dir), | |
| callbacks=[checkpointing], | |
| # callbacks=[checkpointing, early_stopping], | |
| enable_checkpointing=True, | |
| check_val_every_n_epoch=1, | |
| log_every_n_steps=save_and_sample_every, | |
| auto_lr_find=False, | |
| # limit_train_batches=1000, | |
| limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available | |
| min_epochs=100, | |
| max_epochs=1001, | |
| num_sanity_val_steps=2, | |
| ) | |
| # ---------------- Execute Training ---------------- | |
| trainer.fit(model, datamodule=dm) | |
| # ------------- Save path to best model ------------- | |
| model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) | |