diff --git a/README.md b/README.md deleted file mode 100644 index 8b650b68b7bb174814384ec0cd18615f21d9966b..0000000000000000000000000000000000000000 --- a/README.md +++ /dev/null @@ -1,13 +0,0 @@ ---- -title: Deep Learning Model for Pediatric Bone Age -emoji: 💻 -colorFrom: red -colorTo: blue -sdk: gradio -sdk_version: 3.8.2 -app_file: app.py -pinned: false -license: apache-2.0 ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py index 02b282ae0b4f1468d2279df3f7d6f260b1a0de1b..42167d3624c463847a3fe861181e85a35463dd5f 100644 --- a/app.py +++ b/app.py @@ -1,106 +1,168 @@ +import cv2 import gradio as gr -import timm -import torch +import json +import numpy as np +import torch import torch.nn as nn +from einops import rearrange +from importlib import import_module +from pytorch_grad_cam import GradCAM +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from skimage.exposure import match_histograms +from skp.utils import load_model_from_config, load_kfold_ensemble_as_list -def change_num_input_channels(model, in_channels=1): - """ - Assumes number of input channels in model is 3. - """ - for i, m in enumerate(model.modules()): - if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: - m.in_channels = in_channels - # First, sum across channels - W = m.weight.sum(1, keepdim=True) - # Then, divide by number of channels - W = W / in_channels - # Then, repeat by number of channels - size = [1] * W.ndim - size[1] = in_channels - W = W.repeat(size) - m.weight = nn.Parameter(W) - break - return model - - -class Net2D(nn.Module): - - def __init__(self, weights): - super().__init__() - self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0) - self.backbone = change_num_input_channels(self.backbone, 2) - self.pool_layer = nn.AdaptiveAvgPool2d(1) - self.dropout = nn.Dropout(0.2) - self.classifier = nn.Linear(1280, 1) - self.load_state_dict(weights) - - def forward(self, x): - x = self.backbone(x) - x = self.pool_layer(x).view(x.size(0), -1) - x = self.dropout(x) - x = self.classifier(x) - return x[:, 0] if x.size(1) == 1 else x - -class Ensemble(nn.Module): - - def __init__(self, model_list): +class ModelForGradCAM(nn.Module): + def __init__(self, model): super().__init__() - self.model_list = nn.ModuleList(model_list) - - def forward(self, x): - return torch.stack([model(x) for model in self.model_list]).mean(0) + self.model = model + def forward(self, x): + return self.model({"x": x})["logits1"] -checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"] -weights = [torch.load(ckpt, map_location=torch.device("cpu"))["state_dict"] for ckpt in checkpoints] -weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights] -models = [Net2D(wt) for wt in weights] -ensemble = Ensemble(models).eval() -def predict_bone_age(Radiograph, Sex): - img = torch.from_numpy(Radiograph) - img = img.unsqueeze(0).unsqueeze(0) - img = img / img.max() - img = img - 0.5 - img = img * 2.0 - if Sex == 1: - img = torch.cat([img, torch.zeros_like(img) + 1], dim=1) - else: - img = torch.cat([img, torch.zeros_like(img) - 1], dim=1) - with torch.no_grad(): - bone_age = ensemble(img.float())[0].item() - total_months = bone_age * 12 - years = int(total_months // 12) - months = total_months - years * 12 +def convert_bone_age_to_string(bone_age: float): + # bone_age in months + years = round(bone_age // 12) + months = bone_age - (years * 12) months = round(months) - if months == 12: + + if months == 12: years += 1 months = 0 + if years == 0: str_output = f"{months} months" if months != 1 else "1 month" else: - months = round(months) if months == 0: str_output = f"{years} years" if years != 1 else "1 year" else: - str_output = f"{years} years, {months} months" if months != 1 else f"{years} years, 1 month" - return f"Estimated Bone Age: {str_output}" + str_output = ( + f"{years} years, {months} months" + if months != 1 + else f"{years} years, 1 month" + ) + return str_output + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +cfg_crop = import_module("skp.configs.boneage.cfg_crop_simple_resize").cfg +crop_model = load_model_from_config( + cfg_crop, weights_path="crop.pt", device=device, eval_mode=True +) +cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg +cfg.backbone = "convnextv2_tiny" -image = gr.Image(shape=(512, 512), image_mode="L") +model_list = load_kfold_ensemble_as_list( + cfg, [f"net{i}.pt" for i in range(3)], device=device, eval_mode=True +) + +ref_img = rearrange(cv2.imread("ref_img.png", 0), "h w -> h w 1 ") + +with open("greulich_and_pyle_ages.json", "r") as f: + greulich_and_pyle_ages = json.load(f)["bone_ages"] + +greulich_and_pyle_ages = {k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()} + +model_grad_cam = ModelForGradCAM(model_list[0]) +target_layers = [model_grad_cam.model.backbone.stages[-1]] + +def predict_bone_age(Radiograph, Sex): + x0 = rearrange(Radiograph, "h w -> h w 1") + x = cfg_crop.val_transforms(image=x0)["image"] + x = torch.from_numpy(x) + x = rearrange(x, "h w c -> 1 c h w") + # crop + with torch.inference_mode(): + box = crop_model({"x": x.to(device).float()}, return_loss=False)["logits"][ + 0 + ].cpu() + box[[0, 2]] = box[[0, 2]] * x0.shape[1] + box[[1, 3]] = box[[1, 3]] * x0.shape[0] + box = box.numpy().astype("int") + x, y, w, h = box + x0 = x0[y : y + h, x : x + w] + # histogram matching + x0 = match_histograms(x0, ref_img) + x = cfg.val_transforms(image=x0)["image"] + # create image channel for female/male + ch = np.zeros_like(x) + if Sex: # 0- male, 1- female + ch[...] = 255 + x = np.concatenate([x, ch], axis=-1) + x = torch.from_numpy(x) + x = rearrange(x, "h w c -> 1 c h w") + with torch.inference_mode(): + bone_age = [] + for each_model in model_list: + pred = each_model({"x": x.to(device).float()}, return_loss=False)[ + "logits1" + ][0].cpu() + pred = (pred.softmax(0) * torch.arange(240)).sum().numpy() + bone_age.append(pred) + bone_age = np.mean(bone_age) + + gp_ages = greulich_and_pyle_ages["female" if Sex else "male"] + diffs_gp = np.abs(bone_age - gp_ages) + diffs_gp = np.argsort(diffs_gp) + closest1 = gp_ages[diffs_gp[0]] + closest2 = gp_ages[diffs_gp[1]] + + bone_age_str = convert_bone_age_to_string(bone_age) + closest1 = convert_bone_age_to_string(closest1) + closest2 = convert_bone_age_to_string(closest2) + + targets = [ClassifierOutputTarget(round(bone_age))] + with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam: + grayscale_cam = cam(input_tensor=x.to(device).float(), targets=targets, eigen_smooth=True) + + heatmap = cv2.applyColorMap((grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET) + image = cv2.cvtColor(x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB) + image_weight = 0.6 + grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image + grad_cam_image = grad_cam_image.astype("uint8") + + return f"Predicted bone age: {bone_age_str}\n\nThe closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}", grad_cam_image + + +image = gr.Image(image_mode="L") sex = gr.Radio(["Male", "Female"], type="index") -label = gr.Label(show_label=True, label="Result") +textbox = gr.Textbox(show_label=True, label="Result") +grad_cam_image = gr.Image(image_mode="RGB", label="Heatmap") -demo = gr.Interface( - fn=predict_bone_age, - inputs=[image, sex], - outputs=label, - ) +with gr.Blocks() as demo: + gr.Markdown( + """ + # Deep Learning Model for Pediatric Bone Age + This model predicts the bone age from a single frontal view hand radiograph. + The model was trained on the publicly available + [RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset. + The model achieves a mean absolute error of 4.26 months on the original test set comprising 200 multi-annotated hand radiographs, + which is competitive with [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge. -if __name__ == "__main__": - demo.launch() + This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes + any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected + health information, as this demonstration is not compliant with patient privacy laws. + Created by: Ian Pan, + + Last updated: December 15, 2024 + """ + ) + gr.Interface( + fn=predict_bone_age, + inputs=[image, sex], + outputs=[textbox, grad_cam_image], + examples=[ + ["examples/2639.png", "Female"], + ["examples/10043.png", "Female"], + ["examples/8888.png", "Female"], + ], + ) +if __name__ == "__main__": + demo.launch(share=True) diff --git a/crop.pt b/crop.pt new file mode 120000 index 0000000000000000000000000000000000000000..6bfa1e73669739e9ea355f0d2d2d5beeaec589a8 --- /dev/null +++ b/crop.pt @@ -0,0 +1 @@ +../../experiments/boneage/boneage.cfg_crop_simple_resize/8b59fed7/fold0/checkpoints/last.ckpt \ No newline at end of file diff --git a/examples/10043.png b/examples/10043.png new file mode 100644 index 0000000000000000000000000000000000000000..3e643eff1c943a9250bf8d710bed2f87d6ea875e Binary files /dev/null and b/examples/10043.png differ diff --git a/examples/2639.png b/examples/2639.png new file mode 100644 index 0000000000000000000000000000000000000000..5db8b3708a1baeb42f6379c86b815cb69296ce8f Binary files /dev/null and b/examples/2639.png differ diff --git a/examples/8888.png b/examples/8888.png new file mode 100644 index 0000000000000000000000000000000000000000..f4e481f67626cccfec6ab1ba8224e74e520f5087 Binary files /dev/null and b/examples/8888.png differ diff --git a/fold0.ckpt b/fold0.ckpt deleted file mode 100644 index 2ae3bfe9981c5bbcef069ebd8987abd23051d056..0000000000000000000000000000000000000000 --- a/fold0.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2db6d3fb26a05b916341574c83683017e4a04a1c0df8fda4a97ad2314b33f109 -size 81642981 diff --git a/fold1.ckpt b/fold1.ckpt deleted file mode 100644 index ccf676658ed55303ac5115ff70d598afc584fac2..0000000000000000000000000000000000000000 --- a/fold1.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8c806c2ccd21cb4f1d1102e86d8716ed67583f561d4eea6a1761ac4f9bf6a60b -size 81642981 diff --git a/fold2.ckpt b/fold2.ckpt deleted file mode 100644 index e3637fe5070531f25d0ab9fc857771482fed0689..0000000000000000000000000000000000000000 --- a/fold2.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cabdc105bb4c3239d1a57ceaaca4306096a017763c1ec1d23adacf6d8c0713ab -size 81642981 diff --git a/greulich_and_pyle_ages.json b/greulich_and_pyle_ages.json new file mode 100644 index 0000000000000000000000000000000000000000..f69e01efdf79554c594a41fc2bec886dd29a6dd0 --- /dev/null +++ b/greulich_and_pyle_ages.json @@ -0,0 +1,65 @@ +{ + "bone_ages": { + "female": [ + 0, + 3, + 6, + 9, + 12, + 15, + 18, + 24, + 36, + 42, + 50, + 60, + 69, + 82, + 94, + 106, + 120, + 132, + 144, + 156, + 162, + 168, + 180, + 192, + 204, + 216 + ], + "male": [ + 0, + 3, + 6, + 9, + 12, + 15, + 18, + 24, + 30, + 32, + 36, + 42, + 48, + 54, + 60, + 72, + 84, + 96, + 108, + 120, + 132, + 138, + 150, + 156, + 162, + 168, + 180, + 192, + 204, + 216, + 228 + ] + } +} \ No newline at end of file diff --git a/net0.pt b/net0.pt new file mode 120000 index 0000000000000000000000000000000000000000..b6512f34f772baeefb7939dabc2e878481132901 --- /dev/null +++ b/net0.pt @@ -0,0 +1 @@ +../../experiments/boneage/boneage.cfg_female_channel_reg_cls_match_hist/fa77ff59/fold0/checkpoints/last.ckpt \ No newline at end of file diff --git a/net1.pt b/net1.pt new file mode 120000 index 0000000000000000000000000000000000000000..651fb29138ff7dd283ef08a13181a550f87eb1d6 --- /dev/null +++ b/net1.pt @@ -0,0 +1 @@ +../../experiments/boneage/boneage.cfg_female_channel_reg_cls_match_hist/fa77ff59/fold1/checkpoints/last.ckpt \ No newline at end of file diff --git a/net2.pt b/net2.pt new file mode 120000 index 0000000000000000000000000000000000000000..3872ac864077d596e20ca7f97a2b22f0714935f1 --- /dev/null +++ b/net2.pt @@ -0,0 +1 @@ +../../experiments/boneage/boneage.cfg_female_channel_reg_cls_match_hist/fa77ff59/fold2/checkpoints/last.ckpt \ No newline at end of file diff --git a/ref_img.png b/ref_img.png new file mode 100644 index 0000000000000000000000000000000000000000..1293dd0c8decdb5253a8d6c3cba1146974a919ae Binary files /dev/null and b/ref_img.png differ diff --git a/requirements.txt b/requirements.txt index d718df33065b4f002e4a996d27cfadf96b3aa242..1da3f5b3cb4bab4d7824b4dcb5bdef54e973649d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -gradio==3.45.0 -numpy -omegaconf +albumentations +einops +grad-cam +gradio +scikit-image timm -torch - +torch \ No newline at end of file diff --git a/skp/__pycache__/utils.cpython-312.pyc b/skp/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395a48a9fbe53be6715c4680be0f47d1fb373449 Binary files /dev/null and b/skp/__pycache__/utils.cpython-312.pyc differ diff --git a/skp/configs/__init__.py b/skp/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43fd69fd8beb502cef0a51352e19f8c3a8203aba --- /dev/null +++ b/skp/configs/__init__.py @@ -0,0 +1,21 @@ +from types import SimpleNamespace + + +class Config(SimpleNamespace): + + def __getattribute__(self, value): + # If attribute not specified in config, + # return None instead of raise error + try: + return super().__getattribute__(value) + except AttributeError: + return None + + def __str__(self): + # pretty print + string = ["config"] + string.append("=" * len(string[0])) + longest_param_name = max([len(k) for k in [*self.__dict__]]) + for k, v in self.__dict__.items(): + string.append(f"{k.ljust(longest_param_name)} : {v}") + return "\n".join(string) diff --git a/skp/configs/__pycache__/__init__.cpython-312.pyc b/skp/configs/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc534c9895984167fb14dd750f464ffe838fa0eb Binary files /dev/null and b/skp/configs/__pycache__/__init__.cpython-312.pyc differ diff --git a/skp/configs/__pycache__/base.cpython-312.pyc b/skp/configs/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af4d19f42f8ba303beddd630ab8c99deda402e47 Binary files /dev/null and b/skp/configs/__pycache__/base.cpython-312.pyc differ diff --git a/skp/configs/base.py b/skp/configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3f4df0ce9ad686bcfb438cf620c9791b4cdcd8 --- /dev/null +++ b/skp/configs/base.py @@ -0,0 +1,21 @@ +from types import SimpleNamespace + + +class Config(SimpleNamespace): + + def __getattribute__(self, value): + # If attribute not specified in config, + # return None instead of raise error + try: + return super().__getattribute__(value) + except AttribuateError: + return None + + def __str__(self): + # pretty print + string = ["config"] + string.append("=" * len(string[0])) + longest_param_name = max([len(k) for k in [*self.__dict__]]) + for k, v in self.__dict__.items(): + string.append(f"{k.ljust(longest_param_name)} : {v}") + return "\n".join(string) diff --git a/skp/configs/boneage/__pycache__/cfg_baseline.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_baseline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aeaa53c562441168e2bfbab00d4df7a33f5ee72 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_baseline.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_crop.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_crop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2948aa8c977b9eb1a460b6d3e02b82f55f5828f6 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_crop.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_crop_simple_resize.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_crop_simple_resize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc4166b60f0f5beef38007f5d9e5c8ac2733313 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_crop_simple_resize.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8cb2fae2f91a574e269d7bef92a20d37afefe35 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_MIL.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b41cc0aa4fbea9268c5e4c6e01e8d27c5d3b1a2 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_lstm.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_lstm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3293f39c6bc97bf4a1528360d7ae72fe22f15cf1 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_lstm.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_transformer.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2273ec5ba58d0427a3412f72dd581c209b544361 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_MIL_transformer.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d28fc4ce6c1957deeff7dd0e8fd7ebabf7bfc6 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_clip_outliers_aug.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_clip_outliers_aug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39fe6e7bf780b7e7622bd5ec79c7105cf1f886fb Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_clip_outliers_aug.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_match_hist.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_match_hist.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3baf4a0809438807ae0725c17345f3a374ae23ff Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_reg_cls_match_hist.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9c450f72e1e86f410f832749c9cbd88f8d643b1 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls.cpython-312.pyc differ diff --git a/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls_clip_outliers.cpython-312.pyc b/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls_clip_outliers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac899d38d58f793657baaf0400fe90a71cd785d9 Binary files /dev/null and b/skp/configs/boneage/__pycache__/cfg_female_channel_with_cls_clip_outliers.cpython-312.pyc differ diff --git a/skp/configs/boneage/cfg_baseline.py b/skp/configs/boneage/cfg_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..463617750ccfce938add060cdc182e499e8237a5 --- /dev/null +++ b/skp/configs/boneage/cfg_baseline.py @@ -0,0 +1,117 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_var_embed" +cfg.backbone = "tf_efficientnetv2_s" +cfg.embed_num_classes = 2 +cfg.embed_dim = 32 +cfg.pretrained = True +cfg.num_input_channels = 1 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 1 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "simple2d" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years"] +cfg.vars = "female" +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_crop.py b/skp/configs/boneage/cfg_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..87f9ad2af4c79e0827755ba083dc851b1cacb8e5 --- /dev/null +++ b/skp/configs/boneage/cfg_crop.py @@ -0,0 +1,123 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d" +cfg.backbone = "mobilenetv3_small_100" +cfg.pretrained = True +cfg.num_input_channels = 1 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 4 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False +cfg.model_activation_fn = "sigmoid" + +cfg.fold = 0 +cfg.dataset = "crop2d" +cfg.data_dir = "/mnt/stor/datasets/bone-age/train/" +cfg.annotations_file = ( + "/mnt/stor/datasets/bone-age/train_with_bounding_box_crop_coords_kfold.csv" +) +cfg.inputs = "imgfile" +cfg.targets = ["x1", "y1", "w", "h"] +cfg.normalize_crop_coords = True +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 100 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 16 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +bbox_params = A.BboxParams(format="coco") +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ], + bbox_params=bbox_params, +) + +cfg.val_transforms = A.Compose( + resize_transforms, + bbox_params=bbox_params, +) diff --git a/skp/configs/boneage/cfg_crop_simple_resize.py b/skp/configs/boneage/cfg_crop_simple_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..d09be3268fa382ebbeafab517fbdb34f73d654b1 --- /dev/null +++ b/skp/configs/boneage/cfg_crop_simple_resize.py @@ -0,0 +1,117 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d" +cfg.backbone = "mobilenetv3_small_100" +cfg.pretrained = True +cfg.num_input_channels = 1 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 4 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False +cfg.model_activation_fn = "sigmoid" + +cfg.fold = 0 +cfg.dataset = "crop2d" +cfg.data_dir = "/mnt/stor/datasets/bone-age/train/" +cfg.annotations_file = ( + "/mnt/stor/datasets/bone-age/train_with_bounding_box_crop_coords_kfold.csv" +) +cfg.inputs = "imgfile" +cfg.targets = ["x1", "y1", "w", "h"] +cfg.normalize_crop_coords = True +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 200 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 16 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +bbox_params = A.BboxParams(format="coco") +resize_transforms = [ + A.Resize(height=cfg.image_height, width=cfg.image_width, p=1) +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ], + bbox_params=bbox_params, +) + +cfg.val_transforms = A.Compose( + resize_transforms, + bbox_params=bbox_params, +) diff --git a/skp/configs/boneage/cfg_female_channel.py b/skp/configs/boneage/cfg_female_channel.py new file mode 100644 index 0000000000000000000000000000000000000000..0484323d8a4e4bbf42129885acba465ccecfbb21 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel.py @@ -0,0 +1,114 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 1 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_MIL.py b/skp/configs/boneage/cfg_female_channel_MIL.py new file mode 100644 index 0000000000000000000000000000000000000000..ba82b33a4859f6df82ed3ab89d25e87f6719a875 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_MIL.py @@ -0,0 +1,113 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "MIL.net2d_basic_attn" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 1 +cfg.attn_dropout = 0.0 +cfg.attn_version = "v1" +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_grid_patch" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.patch_size = 224 +cfg.patch_num_rows = 5 +cfg.patch_num_cols = 3 +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 16 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 560 +cfg.image_width = cfg.image_height # not used + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_MIL_lstm.py b/skp/configs/boneage/cfg_female_channel_MIL_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..383af1bb401d2ebdb312d66f2ca6186db8efe02b --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_MIL_lstm.py @@ -0,0 +1,116 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "MIL.net2d_attn" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 1 +cfg.add_lstm = True +cfg.lstm_dropout = 0.0 +cfg.lstm_num_layers = 1 +cfg.attn_dropout = 0.0 +cfg.attn_version = "v1" +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_grid_patch" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.patch_size = 224 +cfg.patch_num_rows = 5 +cfg.patch_num_cols = 3 +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 16 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 560 +cfg.image_width = cfg.image_height # not used + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_MIL_transformer.py b/skp/configs/boneage/cfg_female_channel_MIL_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ef91f3ebf8fb2a7cfcc0d50e5a9b16f7a399350d --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_MIL_transformer.py @@ -0,0 +1,117 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "MIL.net2d_attn" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = 1 +cfg.reduce_feature_dim = 256 +cfg.add_transformer = True +cfg.transformer_dropout = 0.0 +cfg.transformer_num_layers = 1 +cfg.attn_dropout = 0.0 +cfg.attn_version = "v1" +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_grid_patch" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.patch_size = 224 +cfg.patch_num_rows = 5 +cfg.patch_num_cols = 3 +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1Loss" +cfg.loss_params = {} + +cfg.batch_size = 16 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE", "classification.MSE"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 560 +cfg.image_width = cfg.image_height # not used + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_reg_cls.py b/skp/configs/boneage/cfg_female_channel_reg_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..b624b2b579f7da470940bc9074dec0e8c02e8d06 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_reg_cls.py @@ -0,0 +1,115 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 240] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.DoubleL1Loss" +cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.DoubleMAE"] +cfg.val_metric = "mae_reg" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_reg_cls_clip_outliers_aug.py b/skp/configs/boneage/cfg_female_channel_reg_cls_clip_outliers_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..58112b59cd4329d7e1cb79237392c96fc6e48172 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_reg_cls_clip_outliers_aug.py @@ -0,0 +1,119 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 240] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.clip_outlier_pixels_and_rescale = True +cfg.clip_as_data_aug = True +cfg.clip_proba = 0.5 +cfg.clip_bounds = (1, 99) +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.DoubleL1Loss" +cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.DoubleMAE"] +cfg.val_metric = "mae_reg" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_reg_cls_match_hist.py b/skp/configs/boneage/cfg_female_channel_reg_cls_match_hist.py new file mode 100644 index 0000000000000000000000000000000000000000..5d180ddf54357145e1a6455d7a7af4de3c000692 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_reg_cls_match_hist.py @@ -0,0 +1,116 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 240] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_match_hist" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.ref_image_match_hist = "/mnt/stor/datasets/bone-age/reference_cropped_image_for_histogram_matching.png" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.DoubleL1Loss" +cfg.loss_params = {"reg_weight": 1.0, "cls_weight": 0.4} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.DoubleMAE"] +cfg.val_metric = "mae_reg" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_with_cls.py b/skp/configs/boneage/cfg_female_channel_with_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..5967c95ee99c0b9cf54a05371f7273cdc9ad1365 --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_with_cls.py @@ -0,0 +1,115 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 24] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_with_cls" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years", "bone_age_categorical"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1CELoss" +cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE_Accuracy"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers.py b/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b4913dc6e2792925985f8375db19b9c9a7596d --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers.py @@ -0,0 +1,117 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 24] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_with_cls" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years", "bone_age_categorical"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.clip_outlier_pixels_and_rescale = True +cfg.clip_bounds = (1, 99) +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1CELoss" +cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE_Accuracy"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers_aug.py b/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..15b1628125863bd4d13f6b7399dd6e386c7ef06c --- /dev/null +++ b/skp/configs/boneage/cfg_female_channel_with_cls_clip_outliers_aug.py @@ -0,0 +1,119 @@ +import albumentations as A +import cv2 + +from skp.configs import Config + + +cfg = Config() +cfg.neptune_mode = "async" + +cfg.save_dir = "/home/ian/projects/SKP/experiments/boneage/" +cfg.project = "gradientecho/SKP" + +cfg.task = "classification" + +cfg.model = "classification.net2d_multihead" +cfg.backbone = "tf_efficientnetv2_s" +cfg.pretrained = True +cfg.num_input_channels = 2 +cfg.pool = "gem" +cfg.pool_params = {"p": 3} +cfg.dropout = 0.1 +cfg.num_classes = [1, 24] +cfg.num_heads = 2 +cfg.normalization = "-1_1" +cfg.normalization_params = {"min": 0, "max": 255} +cfg.backbone_img_size = False + +cfg.fold = 0 +cfg.dataset = "boneage.female_channel_with_cls" +cfg.data_dir = "/mnt/stor/datasets/bone-age/cropped_train_plus_valid/" +cfg.annotations_file = "/mnt/stor/datasets/bone-age/train_plus_valid_kfold.csv" +cfg.inputs = "imgfile0" +cfg.targets = ["bone_age_years", "bone_age_categorical"] +cfg.cv2_load_flag = cv2.IMREAD_GRAYSCALE +cfg.num_workers = 16 +cfg.clip_outlier_pixels_and_rescale = True +cfg.clip_as_data_aug = True +cfg.clip_proba = 0.5 +cfg.clip_bounds = (1, 99) +cfg.pin_memory = True +cfg.persistent_workers = True +cfg.sampler = "IterationBasedSampler" +cfg.num_iterations_per_epoch = 1000 + +cfg.loss = "classification.L1CELoss" +cfg.loss_params = {"l1_weight": 1.0, "ce_weight": 0.2} + +cfg.batch_size = 32 +cfg.num_epochs = 10 +cfg.optimizer = "AdamW" +cfg.optimizer_params = {"lr": 3e-4} + +cfg.scheduler = "LinearWarmupCosineAnnealingLR" +cfg.scheduler_params = {"pct_start": 0.1, "div_factor": 100, "final_div_factor": 1_000} +cfg.scheduler_interval = "step" + +cfg.val_batch_size = cfg.batch_size * 2 +cfg.metrics = ["classification.MAE_Accuracy"] +cfg.val_metric = "mae_mean" +cfg.val_track = "min" + +cfg.image_height = 512 +cfg.image_width = 512 + +resize_transforms = [ + A.LongestMaxSize(max_size=cfg.image_height, p=1), + A.PadIfNeeded( + min_height=cfg.image_height, + min_width=cfg.image_width, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), +] + +cfg.train_transforms = A.Compose( + resize_transforms + + [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.SomeOf( + [ + A.ShiftScaleRotate( + shift_limit=0.2, + scale_limit=0.0, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.2, + rotate_limit=0, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.ShiftScaleRotate( + shift_limit=0.0, + scale_limit=0.0, + rotate_limit=30, + border_mode=cv2.BORDER_CONSTANT, + p=1, + ), + A.GaussianBlur(p=1), + A.GaussNoise(p=1), + A.RandomBrightnessContrast( + contrast_limit=0.3, brightness_limit=0.0, p=1 + ), + A.RandomBrightnessContrast( + contrast_limit=0.0, brightness_limit=0.3, p=1 + ), + ], + n=3, + p=0.9, + replace=False, + ), + ] +) + +cfg.val_transforms = A.Compose(resize_transforms) diff --git a/skp/models/MIL/__pycache__/net2d_attn.cpython-312.pyc b/skp/models/MIL/__pycache__/net2d_attn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9cb1284a146e880e30ed01d331dd17f3c0e01ac Binary files /dev/null and b/skp/models/MIL/__pycache__/net2d_attn.cpython-312.pyc differ diff --git a/skp/models/MIL/__pycache__/net2d_basic_attn.cpython-312.pyc b/skp/models/MIL/__pycache__/net2d_basic_attn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9f706289de7df85963b10f18e8157aa362df304 Binary files /dev/null and b/skp/models/MIL/__pycache__/net2d_basic_attn.cpython-312.pyc differ diff --git a/skp/models/MIL/net2d_attn.py b/skp/models/MIL/net2d_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..07a4ca6894f91cc03323734fa503a0a3f39167b8 --- /dev/null +++ b/skp/models/MIL/net2d_attn.py @@ -0,0 +1,286 @@ +""" +2D model for multiple instance learning (MIL) +Performs attention over bag of features (i.e., attention-weighted mean of features) +Option to add LSTM or Transformer before attention aggregation +Uses timm backbones +""" + +import re +import torch +import torch.nn as nn + +from einops import rearrange +from timm import create_model +from typing import Dict, Optional, Tuple + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Attention(nn.Module): + """ + Given a batch containing bags of features (B, N, D), + generate attention scores over the features in a bag, N, + and perform an attention-weighted mean of the features (B, D) + """ + + def __init__(self, embed_dim: int, dropout: float = 0.0, version: str = "v1"): + super().__init__() + version = version.lower() + if version == "v1": + self.mlp = nn.Sequential( + nn.Tanh(), nn.Dropout(dropout), nn.Linear(embed_dim, 1) + ) + elif version == "v2": + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.Tanh(), + nn.Dropout(dropout), + nn.Linear(embed_dim, 1), + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + a = self.mlp(x) + a = a.softmax(dim=1) + x = (x * a).sum(dim=1) + return x, a + + +class BiLSTM(nn.Module): + def __init__(self, embed_dim: int, dropout: float = 0.0, num_layers: int = 1): + super().__init__() + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=embed_dim // 2, + num_layers=num_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.lstm(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, + embed_dim: int, + dropout: float = 0.0, + num_layers: int = 1, + nhead: int = 16, + activation: str = "gelu", + ): + super().__init__() + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=nhead, + dim_feedforward=embed_dim, + dropout=dropout, + activation=activation, + batch_first=True, + norm_first=False, + bias=True, + ) + self.T = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.T(x, mask=mask) + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + if self.cfg.add_lstm: + self.pre_attn = BiLSTM( + embed_dim=self.feature_dim, + dropout=self.cfg.lstm_dropout or 0.0, + num_layers=self.cfg.lstm_num_layers or 1, + ) + elif self.cfg.add_transformer: + self.pre_attn = Transformer( + embed_dim=self.feature_dim, + dropout=self.cfg.transformer_dropout or 0.0, + num_layers=self.cfg.transformer_num_layers or 1, + nhead=self.cfg.transformer_nhead or 16, + activation=self.cfg.transformer_act or "gelu", + ) + else: + self.pre_attn = nn.Identity() + + self.attn = Attention( + self.feature_dim, + dropout=self.cfg.attn_dropout, + version=self.cfg.attn_version or "v1", + ) + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes) + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, + batch: Dict, + return_loss: bool = False, + return_features: bool = False, + return_attn_scores: bool = False, + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + + if return_loss: + assert y is not None + + b, n = x.shape[:2] + x = rearrange(x, "b n c h w -> (b n) c h w") + features = self.extract_features(x, normalize=True) + features = rearrange(features, "(b n) d -> b n d", b=b, n=n) + if isinstance(self.pre_attn, Transformer): + features = self.pre_attn(features, mask=batch.get("mask", None)) + else: + features = self.pre_attn(features) + features, attn_scores = self.attn(features) + + if self.cfg.multisample_dropout: + logits = torch.stack( + [self.linear(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = self.linear(self.dropout(features)) + + if self.cfg.model_activation_fn == "sigmoid": + logits = logits.sigmoid() + elif self.cfg.model_activation_fn == "softmax": + logits = logits.softmax(dim=1) + + out = {"logits": logits} + if return_features: + out["features"] = features + if return_attn_scores: + out["attn_scores"] = attn_scores + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor: + x = self.normalize(x) if normalize else x + return self.pooling(self.backbone(x)) + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/MIL/net2d_basic_attn.py b/skp/models/MIL/net2d_basic_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5805d69f31be457a4b73557e4f21375c929c53 --- /dev/null +++ b/skp/models/MIL/net2d_basic_attn.py @@ -0,0 +1,284 @@ +""" +2D model for multiple instance learning (MIL) +Performs attention over bag of features (i.e., attention-weighted mean of features) +Uses timm backbones +""" + +import re +import torch +import torch.nn as nn + +from einops import rearrange +from timm import create_model +from typing import Dict, Optional, Tuple + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Attention(nn.Module): + """ + Given a batch containing bags of features (B, N, D), + generate attention scores over the features in a bag, N, + and perform an attention-weighted mean of the features (B, D) + """ + + def __init__(self, embed_dim: int, dropout: float = 0.0, version: str = "v1"): + super().__init__() + version = version.lower() + if version == "v1": + self.mlp = nn.Sequential( + nn.Tanh(), nn.Dropout(dropout), nn.Linear(embed_dim, 1) + ) + elif version == "v2": + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.Tanh(), + nn.Dropout(dropout), + nn.Linear(embed_dim, 1), + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + a = self.mlp(x) + a = a.softmax(dim=1) + x = (x * a).sum(dim=1) + return x, a + + +class BiLSTM(nn.Module): + def __init__(self, embed_dim: int, dropout: float = 0.0, num_layers: int = 1): + super().__init__() + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=embed_dim // 2, + num_layers=num_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.lstm(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, + embed_dim: int, + dropout: float = 0.0, + num_layers: int = 1, + nheads: int = 16, + activation: str = "gelu", + ): + super().__init__() + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + dim_feedforward=embed_dim, + dropout=dropout, + activation=activation, + batch_first=True, + norm_first=False, + bias=True, + ) + self.T = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.T(x, mask=mask) + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + if self.cfg.add_lstm: + self.pre_attn = BiLSTM( + embed_dim=self.feature_dim, + dropout=self.cfg.lstm_dropout or 0.0, + num_layers=self.cfg.lstm_num_layers or 1, + ) + elif self.cfg.add_transformer: + self.pre_attn = Transformer( + embed_dim=self.feature_dim, + dropout=self.transformer_dropout or 0.0, + num_layers=self.transformer_num_layers or 1, + nheads=self.transformer_nheads or 16, + activation=self.transformer_act or "gelu", + ) + else: + self.pre_attn = nn.Identity() + + self.attn = Attention( + self.feature_dim, + dropout=self.cfg.attn_dropout, + version=self.cfg.attn_version or "v1", + ) + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes) + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, + batch: Dict, + return_loss: bool = False, + return_features: bool = False, + return_attn_scores: bool = False, + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + + if return_loss: + assert y is not None + + b, n = x.shape[:2] + x = rearrange(x, "b n c h w -> (b n) c h w") + features = self.extract_features(x, normalize=True) + features = rearrange(features, "(b n) d -> b n d", b=b, n=n) + if isinstance(self.pre_attn, Transformer): + features = self.pre_attn(features, mask=batch.get("mask", None)) + else: + features = self.pre_attn(features) + features, attn_scores = self.attn(features) + + if self.cfg.multisample_dropout: + logits = torch.stack( + [self.linear(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = self.linear(self.dropout(features)) + + if self.cfg.model_activation_fn == "sigmoid": + logits = logits.sigmoid() + elif self.cfg.model_activation_fn == "softmax": + logits = logits.softmax(dim=1) + + out = {"logits": logits} + if return_features: + out["features"] = features + if return_attn_scores: + out["attn_scores"] = attn_scores + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor: + x = self.normalize(x) if normalize else x + return self.pooling(self.backbone(x)) + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/__pycache__/modules.cpython-312.pyc b/skp/models/__pycache__/modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a13f49863a974fe6c77eba8f0ded53c8b2b2d8d0 Binary files /dev/null and b/skp/models/__pycache__/modules.cpython-312.pyc differ diff --git a/skp/models/__pycache__/pooling.cpython-312.pyc b/skp/models/__pycache__/pooling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca8a326027a9e3a3a9a7827d5bbf1db4255e75f Binary files /dev/null and b/skp/models/__pycache__/pooling.cpython-312.pyc differ diff --git a/skp/models/classification/__pycache__/net2d.cpython-312.pyc b/skp/models/classification/__pycache__/net2d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1af440170ead776ecdb0910cb9c074f22d2e31 Binary files /dev/null and b/skp/models/classification/__pycache__/net2d.cpython-312.pyc differ diff --git a/skp/models/classification/__pycache__/net2d_multihead.cpython-312.pyc b/skp/models/classification/__pycache__/net2d_multihead.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe662cd483af93c5531a49d6e7241183e777918d Binary files /dev/null and b/skp/models/classification/__pycache__/net2d_multihead.cpython-312.pyc differ diff --git a/skp/models/classification/__pycache__/net2d_var_embed.cpython-312.pyc b/skp/models/classification/__pycache__/net2d_var_embed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c702b99a3e11c6b7f9803ded63e36b4fd0951028 Binary files /dev/null and b/skp/models/classification/__pycache__/net2d_var_embed.cpython-312.pyc differ diff --git a/skp/models/classification/net2d.py b/skp/models/classification/net2d.py new file mode 100644 index 0000000000000000000000000000000000000000..c12cc6c87abf68d9b80dd3b746959fe95b2a94fe --- /dev/null +++ b/skp/models/classification/net2d.py @@ -0,0 +1,172 @@ +""" +Simple model for 2D classification (or regression) +Uses timm for backbones +""" + +import re +import torch +import torch.nn as nn + +from timm import create_model +from typing import Dict + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes) + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, batch: Dict, return_loss: bool = False, return_features: bool = False + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + + if return_loss: + assert y is not None + + features = self.extract_features(x, normalize=True) + + if self.cfg.multisample_dropout: + logits = torch.stack( + [self.linear(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = self.linear(self.dropout(features)) + + if self.cfg.model_activation_fn == "sigmoid": + logits = logits.sigmoid() + elif self.cfg.model_activation_fn == "softmax": + logits = logits.softmax(dim=1) + + out = {"logits": logits} + if return_features: + out["features"] = features + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor: + x = self.normalize(x) if normalize else x + return self.pooling(self.backbone(x)) + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/classification/net2d_multihead.py b/skp/models/classification/net2d_multihead.py new file mode 100644 index 0000000000000000000000000000000000000000..42340be9193826d071c9bbe79e67dbdd23befafe --- /dev/null +++ b/skp/models/classification/net2d_multihead.py @@ -0,0 +1,176 @@ +""" +Simple model for 2D classification (or regression) with multiple heads +Uses timm for backbones +""" + +import re +import torch +import torch.nn as nn + +from collections.abc import Sequence +from timm import create_model +from typing import Dict + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + assert ( + isinstance(self.cfg.num_classes, Sequence) + and len(self.cfg.num_classes) == self.cfg.num_heads + ), f"cfg.num_classes should be sequence of length {self.cfg.num_heads} corresponding to each head" + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.ModuleList() + for i in range(self.cfg.num_heads): + self.linear.append(nn.Linear(self.feature_dim, self.cfg.num_classes[i])) + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, batch: Dict, return_loss: bool = False, return_features: bool = False + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + + if return_loss: + assert y is not None + + features = self.extract_features(x, normalize=True) + + out = {} + for head_idx, each_head in enumerate(self.linear): + if self.cfg.multisample_dropout: + logits = torch.stack( + [each_head(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = each_head(self.dropout(features)) + out[f"logits{head_idx}"] = logits + + if return_features: + out["features"] = features + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features(self, x: torch.Tensor, normalize: bool = True) -> torch.Tensor: + x = self.normalize(x) if normalize else x + return self.pooling(self.backbone(x)) + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/classification/net2d_multihead_var_embed.py b/skp/models/classification/net2d_multihead_var_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..c38c1da9d402b5f737b9137377fbf8802918599d --- /dev/null +++ b/skp/models/classification/net2d_multihead_var_embed.py @@ -0,0 +1,186 @@ +""" +Simple model for 2D classification (or regression) with multiple heads +Incorporates embedding of non-image features +Uses timm for backbones +""" + +import re +import torch +import torch.nn as nn + +from collections.abc import Sequence +from timm import create_model +from typing import Dict + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + assert ( + isinstance(self.cfg.num_classes, Sequence) + and len(self.cfg.num_classes) == self.cfg.num_heads + ), f"cfg.num_classes should be sequence of length {self.cfg.num_heads} corresponding to each head" + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + self.embed = nn.Embedding(self.cfg.embed_num_classes, self.cfg.embed_dim) + # allows for interaction between elements of image feature vector and embedding + self.mlp = nn.Linear(self.feature_dim + self.cfg.embed_dim, self.feature_dim) + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.ModuleList() + for i in range(self.cfg.num_heads): + self.linear.append(nn.Linear(self.feature_dim, self.cfg.num_classes[i])) + + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, batch: Dict, return_loss: bool = False, return_features: bool = False + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + var = batch["var"] + + if return_loss: + assert y is not None + + features = self.extract_features(x, var, normalize=True) + + out = {} + for head_idx, each_head in enumerate(self.linear): + if self.cfg.multisample_dropout: + logits = torch.stack( + [each_head(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = each_head(self.dropout(features)) + out[f"logits{head_idx}"] = logits + + if return_features: + out["features"] = features + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features(self, x: torch.Tensor, var: torch.Tensor, normalize: bool = True) -> torch.Tensor: + x = self.normalize(x) if normalize else x + var = self.embed(var) + feat = self.pooling(self.backbone(x)) + feat = torch.cat([feat, var], dim=1) + feat = self.mlp(feat) + return feat + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/classification/net2d_var_embed.py b/skp/models/classification/net2d_var_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..9d62a4ce60dd16a7fbefcbd76d22d6e8faf8637b --- /dev/null +++ b/skp/models/classification/net2d_var_embed.py @@ -0,0 +1,178 @@ +""" +Simple model for 2D classification (or regression) +Incorporates embedding of non-image features +Uses timm for backbones +""" + +import re +import torch +import torch.nn as nn + +from timm import create_model +from typing import Dict + +from skp.configs.base import Config +from skp.models.modules import FeatureReduction +from skp.models.pooling import get_pool_layer + + +class Net(nn.Module): + def __init__(self, cfg: Config): + super().__init__() + self.cfg = cfg + backbone_args = { + "pretrained": self.cfg.pretrained, + "num_classes": 0, + "global_pool": "", + "features_only": self.cfg.features_only, + "in_chans": self.cfg.num_input_channels, + } + if self.cfg.backbone_img_size: + # some models require specifying image size (e.g., coatnet) + if "efficientvit" in self.cfg.backbone: + backbone_args["img_size"] = self.cfg.image_height + else: + backbone_args["img_size"] = ( + self.cfg.image_height, + self.cfg.image_width, + ) + self.backbone = create_model(self.cfg.backbone, **backbone_args) + # get feature dim by passing sample through net + self.feature_dim = self.backbone( + torch.randn( + ( + 2, + self.cfg.num_input_channels, + self.cfg.image_height, + self.cfg.image_width, + ) + ) + ).size( + -1 if "xcit" in self.cfg.backbone else 1 + ) # xcit models are channels-last + + self.feature_dim = self.feature_dim * (2 if self.cfg.pool == "catavgmax" else 1) + self.pooling = get_pool_layer(self.cfg, dim=2) + + if isinstance(self.cfg.reduce_feature_dim, int): + self.backbone = nn.Sequential( + self.backbone, + FeatureReduction(self.feature_dim, self.cfg.reduce_feature_dim), + ) + self.feature_dim = self.cfg.reduce_feature_dim + + self.embed = nn.Embedding(self.cfg.embed_num_classes, self.cfg.embed_dim) + # allows for interaction between elements of image feature vector and embedding + self.mlp = nn.Linear(self.feature_dim + self.cfg.embed_dim, self.feature_dim) + self.dropout = nn.Dropout(p=self.cfg.dropout) + self.linear = nn.Linear(self.feature_dim, self.cfg.num_classes) + + if self.cfg.load_pretrained_backbone: + print( + f"Loading pretrained backbone from {self.cfg.load_pretrained_backbone} ..." + ) + weights = torch.load( + self.cfg.load_pretrained_backbone, + map_location=lambda storage, loc: storage, + )["state_dict"] + # Replace model prefix as this does not exist in Net + weights = {re.sub(r"^model.", "", k): v for k, v in weights.items()} + # Get backbone only + weights = { + re.sub(r"^backbone.", "", k): v + for k, v in weights.items() + if "backbone" in k + } + self.backbone.load_state_dict(weights) + + self.criterion = None + + self.backbone_frozen = False + if self.cfg.freeze_backbone: + self.freeze_backbone() + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + if self.cfg.normalization == "-1_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + x = x - 0.5 + x = x * 2.0 + elif self.cfg.normalization == "0_1": + mini, maxi = ( + self.cfg.normalization_params["min"], + self.cfg.normalization_params["max"], + ) + x = x - mini + x = x / (maxi - mini) + elif self.cfg.normalization == "mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + x = (x - mean) / sd + elif self.cfg.normalization == "per_channel_mean_sd": + mean, sd = ( + self.cfg.normalization_params["mean"], + self.cfg.normalization_params["sd"], + ) + assert len(mean) == len(sd) == x.size(1) + mean, sd = torch.tensor(mean).unsqueeze(0), torch.tensor(sd).unsqueeze(0) + for i in range(x.ndim - 2): + mean, sd = mean.unsqueeze(-1), sd.unsqueeze(-1) + x = (x - mean) / sd + elif self.cfg.normalization == "none": + x = x + return x + + def forward( + self, batch: Dict, return_loss: bool = False, return_features: bool = False + ) -> Dict[str, torch.Tensor]: + x = batch["x"] + y = batch.get("y", None) + var = batch["var"] + + if return_loss: + assert y is not None + + features = self.extract_features(x, var, normalize=True) + + if self.cfg.multisample_dropout: + logits = torch.stack( + [self.linear(self.dropout(features)) for _ in range(5)] + ).mean(0) + else: + logits = self.linear(self.dropout(features)) + + out = {"logits": logits} + if return_features: + out["features"] = features + if return_loss: + loss = self.criterion(out, batch) + if isinstance(loss, dict): + out.update(loss) + else: + out["loss"] = loss + + return out + + def extract_features( + self, x: torch.Tensor, var: torch.Tensor, normalize: bool = True + ) -> torch.Tensor: + x = self.normalize(x) if normalize else x + var = self.embed(var) + feat = self.pooling(self.backbone(x)) + feat = torch.cat([feat, var], dim=1) + feat = self.mlp(feat) + return feat + + def freeze_backbone(self) -> None: + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone_frozen = True + + def set_criterion(self, loss: nn.Module) -> None: + self.criterion = loss diff --git a/skp/models/modules.py b/skp/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0514a13c82e9205e83727dc7b184c2d01caef0 --- /dev/null +++ b/skp/models/modules.py @@ -0,0 +1,32 @@ +""" +Contains commonly used neural net modules. +""" + +import math +import torch +import torch.nn as nn + + +class FeatureReduction(nn.Module): + """ + Reduce feature dimensionality + Intended use is after the last layer of the neural net backbone, before pooling + Grouped convolution is used to reduce # of extra parameters + """ + + def __init__(self, feature_dim: int, reduce_feature_dim: int): + super().__init__() + groups = math.gcd(feature_dim, reduce_feature_dim) + self.reduce = nn.Conv2d( + feature_dim, + reduce_feature_dim, + groups=groups, + kernel_size=1, + stride=1, + bias=False, + ) + self.bn = nn.BatchNorm2d(reduce_feature_dim) + self.act = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(self.bn(self.reduce(x))) diff --git a/skp/models/pooling.py b/skp/models/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..5b43560b78136e6148afe1c14fe343df9adfa137 --- /dev/null +++ b/skp/models/pooling.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import SelectAdaptivePool2d + +from skp.configs.base import Config + + +class GeM(nn.Module): + def __init__( + self, p: int = 3, eps: float = 1e-6, dim: int = 2, flatten: bool = True + ): + super().__init__() + self.p = nn.Parameter(torch.ones(1) * p) + self.eps = eps + assert dim in {2, 3}, f"dim must be one of [2, 3], not {dim}" + self.dim = dim + if self.dim == 2: + self.func = F.adaptive_avg_pool2d + elif self.dim == 3: + self.func = F.adaptive_avg_pool3d + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # assumes x.shape is (n, c, [t], h, w) + x = self.func(x.clamp(min=self.eps).pow(self.p), output_size=1).pow( + 1.0 / self.p + ) + return self.flatten(x) + + +def adaptive_avgmax_pool3d(x: torch.Tensor, output_size: int = 1): + x_avg = F.adaptive_avg_pool3d(x, output_size) + x_max = F.adaptive_max_pool3d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool3d(x: torch.Tensor, output_size: int = 1): + x_avg = F.adaptive_avg_pool3d(x, output_size) + x_max = F.adaptive_max_pool3d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool3d(x: torch.Tensor, pool_type: str, output_size: int = 1) -> torch.Tensor: + """Selectable global pooling function with dynamic input kernel size""" + if pool_type == "avg": + x = F.adaptive_avg_pool3d(x, output_size) + elif pool_type == "avgmax": + x = adaptive_avgmax_pool3d(x, output_size) + elif pool_type == "catavgmax": + x = adaptive_catavgmax_pool3d(x, output_size) + elif pool_type == "max": + x = F.adaptive_max_pool3d(x, output_size) + else: + assert False, "Invalid pool type: %s" % pool_type + return x + + +class FastAdaptiveAvgPool3d(nn.Module): + def __init__(self, flatten: bool = False): + super(FastAdaptiveAvgPool3d, self).__init__() + self.flatten = flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mean((2, 3, 4), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool3d(nn.Module): + def __init__(self, output_size: int = 1): + super(AdaptiveAvgMaxPool3d, self).__init__() + self.output_size = output_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return adaptive_avgmax_pool3d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool3d(nn.Module): + def __init__(self, output_size: int = 1): + super(AdaptiveCatAvgMaxPool3d, self).__init__() + self.output_size = output_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return adaptive_catavgmax_pool3d(x, self.output_size) + + +class SelectAdaptivePool3d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size""" + + def __init__(self, output_size: int = 1, pool_type: str = "fast", flatten: bool = False): + super(SelectAdaptivePool3d, self).__init__() + self.pool_type = ( + pool_type or "" + ) # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == "": + self.pool = nn.Identity() # pass through + elif pool_type == "fast": + assert output_size == 1 + self.pool = FastAdaptiveAvgPool3d(flatten) + self.flatten = nn.Identity() + elif pool_type == "avg": + self.pool = nn.AdaptiveAvgPool3d(output_size) + elif pool_type == "avgmax": + self.pool = AdaptiveAvgMaxPool3d(output_size) + elif pool_type == "catavgmax": + self.pool = AdaptiveCatAvgMaxPool3d(output_size) + elif pool_type == "max": + self.pool = nn.AdaptiveMaxPool3d(output_size) + else: + assert False, "Invalid pool type: %s" % pool_type + + def is_identity(self) -> bool: + return not self.pool_type + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(x) + x = self.flatten(x) + return x + + def __repr__(self): + return ( + self.__class__.__name__ + + " (" + + "pool_type=" + + self.pool_type + + ", flatten=" + + str(self.flatten) + + ")" + ) + + +def get_pool_layer(cfg: Config, dim: int) -> nn.Module: + assert cfg.pool in [ + "avg", + "max", + "fast", + "avgmax", + "catavgmax", + "gem", + "" + ], f"{cfg.pool} is not a valid pooling layer" + params = cfg.pool_params or {} + if cfg.pool == "gem": + return GeM(**params, dim=dim) + else: + if dim == 2: + return SelectAdaptivePool2d(pool_type=cfg.pool, flatten=True) + elif dim == 3: + return SelectAdaptivePool3d(pool_type=cfg.pool, flatten=True) diff --git a/skp/utils.py b/skp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b94ce9621dd7c8c2ac7af22e681688247e7e5b --- /dev/null +++ b/skp/utils.py @@ -0,0 +1,49 @@ +import re +import torch + +from skp.configs import Config +from importlib import import_module +from typing import Dict, Optional, Sequence + + +def load_weights_from_path(path: str) -> Dict[str, torch.Tensor]: + w = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)[ + "state_dict" + ] + w = { + re.sub(r"^model.", "", k): v + for k, v in w.items() + if k.startswith("model.") and "criterion" not in k + } + return w + + +def load_model_from_config( + cfg: Config, + weights_path: Optional[str] = None, + device: str = "cpu", + eval_mode: bool = True, +) -> torch.nn.Module: + model = import_module(f"skp.models.{cfg.model}").Net(cfg) + if weights_path: + weights = load_weights_from_path(weights_path) + model.load_state_dict(weights) + model = model.to(device).train(mode=not eval_mode) + return model + + +def load_kfold_ensemble_as_list( + cfg: Config, + weights_paths: Sequence[str], + device: str = "cpu", + eval_mode: bool = True, +) -> torch.nn.ModuleList: + # multiple folds for the same model + # does not work for ensembling different types of models + # assumes that trained weights are available + # otherwise why would you load multiple of the same model randomly initialized + model_list = torch.nn.ModuleList() + for each_weight in weights_paths: + model = load_model_from_config(cfg, each_weight, device, eval_mode) + model_list.append(model) + return model_list