| import gradio as gr |
| import timm |
| import torch |
| import torch.nn as nn |
|
|
|
|
| def change_num_input_channels(model, in_channels=1): |
| """ |
| Assumes number of input channels in model is 3. |
| """ |
| for i, m in enumerate(model.modules()): |
| if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: |
| m.in_channels = in_channels |
| |
| W = m.weight.sum(1, keepdim=True) |
| |
| W = W / in_channels |
| |
| size = [1] * W.ndim |
| size[1] = in_channels |
| W = W.repeat(size) |
| m.weight = nn.Parameter(W) |
| break |
| return model |
|
|
|
|
| class Net2D(nn.Module): |
|
|
| def __init__(self, weights): |
| super().__init__() |
| self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0) |
| self.backbone = change_num_input_channels(self.backbone, 2) |
| self.pool_layer = nn.AdaptiveAvgPool2d(1) |
| self.dropout = nn.Dropout(0.2) |
| self.classifier = nn.Linear(1280, 1) |
| self.load_state_dict(weights) |
|
|
| def forward(self, x): |
| x = self.backbone(x) |
| x = self.pool_layer(x).view(x.size(0), -1) |
| x = self.dropout(x) |
| x = self.classifier(x) |
| return x[:, 0] if x.size(1) == 1 else x |
|
|
|
|
| class Ensemble(nn.Module): |
|
|
| def __init__(self, model_list): |
| super().__init__() |
| self.model_list = nn.ModuleList(model_list) |
| |
| def forward(self, x): |
| return torch.stack([model(x) for model in self.model_list]).mean(0) |
|
|
|
|
| checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"] |
| weights = [torch.load(ckpt, map_location=torch.device("cpu"))["state_dict"] for ckpt in checkpoints] |
| weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights] |
| models = [Net2D(wt) for wt in weights] |
| ensemble = Ensemble(models).eval() |
|
|
| def predict_bone_age(Radiograph, Sex): |
| img = torch.from_numpy(Radiograph) |
| img = img.unsqueeze(0).unsqueeze(0) |
| img = img / img.max() |
| img = img - 0.5 |
| img = img * 2.0 |
| if Sex == 1: |
| img = torch.cat([img, torch.zeros_like(img) + 1], dim=1) |
| else: |
| img = torch.cat([img, torch.zeros_like(img) - 1], dim=1) |
| with torch.no_grad(): |
| bone_age = ensemble(img.float())[0].item() |
| total_months = bone_age * 12 |
| years = total_months // 12 |
| months = total_months - years * 12 |
| months = round(months) |
| if months == 12: |
| years += 1 |
| months = 0 |
| if years == 0: |
| str_output = f"{months} months" if months != 1 else "1 month" |
| else: |
| months = round(months) |
| if months == 0: |
| str_output = f"{years} years" if years != 1 else "1 year" |
| else: |
| str_output = f"{years} years, {months} months" if months != 1 else f"{years} years, 1 month" |
| return f"Estimated Bone Age: {str_output}" |
|
|
|
|
| image = gr.Image(shape=(512, 512), image_mode="L") |
| sex = gr.Radio(["Male", "Female"], type="index") |
| label = gr.Label(show_label=True, label="Result") |
|
|
| demo = gr.Interface( |
| fn=predict_bone_age, |
| inputs=[image, sex], |
| outputs=label, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|
|
|