Spaces:
Sleeping
Sleeping
| import cv2 | |
| import gradio as gr | |
| import json | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from skimage.exposure import match_histograms | |
| from transformers import AutoModel | |
| # IMPORTANT: To fix AttributeError with all_tied_weights_keys in transformers >=5.0 | |
| # Add this to your requirements.txt in the Space: | |
| # transformers==4.44.2 # or 4.45.2, 4.46.0 — avoid v5.0+ until the model is updated | |
| # If you can't change requirements.txt or still get the error, the monkey-patch below helps. | |
| # Optional monkey-patch if stuck on transformers 5.0+ | |
| def patch_tied_weights(cls): | |
| if not hasattr(cls, 'all_tied_weights_keys'): | |
| cls.all_tied_weights_keys = set() # Empty set: no tied weights in this model | |
| return cls | |
| # Apply patch to the custom class after loading (safe fallback) | |
| # Comment out if you downgrade transformers successfully | |
| class ModelForGradCAM(nn.Module): | |
| def __init__(self, model, female): | |
| super().__init__() | |
| self.model = model | |
| self.female = female | |
| def forward(self, x): | |
| return self.model(x, self.female, return_logits=True) | |
| def convert_bone_age_to_string(bone_age: float): | |
| # bone_age in months | |
| years = round(bone_age // 12) | |
| months = bone_age - (years * 12) | |
| months = round(months) | |
| if months == 12: | |
| years += 1 | |
| months = 0 | |
| if years == 0: | |
| str_output = f"{months} months" if months != 1 else "1 month" | |
| else: | |
| if months == 0: | |
| str_output = f"{years} years" if years != 1 else "1 year" | |
| else: | |
| str_output = ( | |
| f"{years} years, {months} months" | |
| if months != 1 | |
| else f"{years} years, 1 month" | |
| ) | |
| return str_output | |
| def predict_bone_age(Radiograph, Sex, Heatmap): | |
| try: | |
| x = crop_model.preprocess(Radiograph) | |
| x = torch.from_numpy(x).float().to(device) | |
| x = rearrange(x, "h w -> 1 1 h w") | |
| # crop | |
| img_shape = torch.tensor([Radiograph.shape[:2]]).to(device) | |
| with torch.inference_mode(): | |
| box = crop_model(x, img_shape=img_shape).to("cpu").numpy() | |
| x, y, w, h = box[0] | |
| cropped = Radiograph[y : y + h, x : x + w] | |
| # histogram matching | |
| x = match_histograms(cropped, ref_img) | |
| x = model.preprocess(x) | |
| x = torch.from_numpy(x).float().to(device) | |
| x = rearrange(x, "h w -> 1 1 h w") | |
| female = torch.tensor([Sex]).to(device) | |
| with torch.inference_mode(): | |
| bone_age = model(x, female)[0].item() | |
| # get closest G&P ages | |
| gp_ages = greulich_and_pyle_ages["female" if Sex else "male"] | |
| diffs_gp = np.abs(bone_age - gp_ages) | |
| diffs_gp = np.argsort(diffs_gp) | |
| closest1 = gp_ages[diffs_gp[0]] | |
| closest2 = gp_ages[diffs_gp[1]] | |
| bone_age_str = convert_bone_age_to_string(bone_age) | |
| closest1 = convert_bone_age_to_string(closest1) | |
| closest2 = convert_bone_age_to_string(closest2) | |
| if Heatmap: | |
| # net1 and net2 to give good GradCAMs | |
| # net0 is bad for some reason | |
| model_grad_cam = ModelForGradCAM(model.net1, female) | |
| target_layers = [model_grad_cam.model.backbone.stages[-1]] | |
| targets = [ClassifierOutputTarget(round(bone_age))] | |
| with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam: | |
| grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True) | |
| heatmap = cv2.applyColorMap( | |
| (grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET | |
| ) | |
| image = cv2.cvtColor( | |
| x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB | |
| ) | |
| image_weight = 0.6 | |
| grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image | |
| grad_cam_image = grad_cam_image.astype("uint8") | |
| else: | |
| grad_cam_image = cv2.cvtColor( | |
| x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB | |
| ) | |
| return ( | |
| bone_age_str, | |
| f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}", | |
| grad_cam_image, | |
| ) | |
| except Exception as e: | |
| return "Error during prediction", str(e), None | |
| # Gradio UI | |
| image = gr.Image(image_mode="L") | |
| sex = gr.Radio(["Male", "Female"], type="index", label="Sex") | |
| generate_heatmap = gr.Radio(["No", "Yes"], type="index", label="Generate Heatmap?") | |
| label = gr.Label(show_label=False) | |
| textbox = gr.Textbox(show_label=False) | |
| grad_cam_image = gr.Image(image_mode="RGB", label="Heatmap / Image") | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Deep Learning Model for Pediatric Bone Age | |
| This model predicts the bone age from a single frontal view hand radiograph. | |
| There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on | |
| to make its prediction. However, this takes extra computation and will increase the runtime. | |
| This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes | |
| any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected | |
| health information, as this demonstration is not compliant with patient privacy laws. | |
| **System by Simon** | |
| Contact: ithacks254@gmail.com | |
| Forensic age estimation support tool – AI-assisted only. Always verify with qualified expert. | |
| """ | |
| ) | |
| gr.Interface( | |
| fn=predict_bone_age, | |
| inputs=[image, sex, generate_heatmap], | |
| outputs=[label, textbox, grad_cam_image], | |
| examples=[ | |
| ["examples/2639.png", "Female", "Yes"], | |
| ["examples/10043.png", "Female", "No"], | |
| ["examples/8888.png", "Female", "Yes"], | |
| ], | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device `{device}` ...") | |
| try: | |
| crop_model = AutoModel.from_pretrained( | |
| "ianpan/bone-age-crop", trust_remote_code=True | |
| ) | |
| # Apply monkey-patch if needed (fallback) | |
| crop_model.__class__ = patch_tied_weights(crop_model.__class__) | |
| except Exception as e: | |
| print(f"Error loading crop model: {e}") | |
| crop_model = None # Optional: continue without cropping if fails | |
| model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True) | |
| crop_model, model = crop_model.eval().to(device) if crop_model else None, model.eval().to(device) | |
| ref_img_path = "ref_img.png" | |
| try: | |
| ref_img = cv2.imread(ref_img_path, 0) | |
| if ref_img is None: | |
| raise FileNotFoundError(f"ref_img.png not found at {ref_img_path}") | |
| except Exception as e: | |
| print(f"Error loading ref_img: {e}") | |
| ref_img = None # Handle gracefully in predict if needed | |
| with open("greulich_and_pyle_ages.json", "r") as f: | |
| greulich_and_pyle_ages = json.load(f)["bone_ages"] | |
| greulich_and_pyle_ages = { | |
| k: np.asarray(v) for k, v in greulich_and_pyle_ages.items() | |
| } | |
| demo.launch(share=True) |