|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torchvision |
| from torchvision.transforms import functional as F |
| from PIL import Image |
| import numpy as np |
| import requests |
| import os |
| import cv2 |
| from pytorch_grad_cam import GradCAM |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| from torchvision.models import resnet18 |
|
|
| |
| |
| |
|
|
| |
| class DoubleConv(nn.Module): |
| def __init__(self, in_channels, out_channels, mid_channels=None): |
| super().__init__() |
| if not mid_channels: mid_channels = out_channels |
| self.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) |
| def forward(self, x): return self.double_conv(x) |
|
|
| class Down(nn.Module): |
| def __init__(self, in_channels, out_channels): super().__init__(); self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) |
| def forward(self, x): return self.maxpool_conv(x) |
|
|
| class Up(nn.Module): |
| def __init__(self, in_channels, out_channels, bilinear=True): |
| super().__init__() |
| if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True); self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) |
| else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2); self.conv = DoubleConv(in_channels, out_channels) |
| def forward(self, x1, x2): x1 = self.up(x1); x = torch.cat([x2, x1], dim=1); return self.conv(x) |
|
|
| class UNet(nn.Module): |
| def __init__(self, n_channels=3, n_classes=1, bilinear=True): |
| super(UNet, self).__init__(); factor = 2 if bilinear else 1; self.inc=DoubleConv(n_channels,64); self.down1=Down(64,128); self.down2=Down(128,256); self.down3=Down(256,512); self.down4=Down(512, 1024 // factor); self.up1=Up(1024,512 // factor,bilinear); self.up2=Up(512,256 // factor,bilinear); self.up3=Up(256,128 // factor,bilinear); self.up4=Up(128,64,bilinear); self.outc=nn.Conv2d(64,n_classes,1) |
| def forward(self, x): x1=self.inc(x); x2=self.down1(x1); x3=self.down2(x2); x4=self.down3(x3); x5=self.down4(x4); x=self.up1(x5,x4); x=self.up2(x,x3); x=self.up3(x,x2); x=self.up4(x,x1); return self.outc(x) |
|
|
| |
| |
| |
|
|
| |
| device = torch.device('cpu') |
|
|
| |
| from torchvision import transforms |
| val_transforms = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| classification_transforms = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| CLASS_NAMES = [ |
| 'work_dress', 'sling_dress', 'ethnic_dress', 'gown', 'casual_dress', |
| 'party_dress', 'formal_dress', 'sports_dress', 'shirt_dress', 'resort_dress' |
| ] |
|
|
| |
| def load_models(): |
| |
| detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device) |
| detection_model.eval() |
|
|
| |
| segmentation_model = UNet(n_channels=3, n_classes=1).to(device) |
| segmentation_model.load_state_dict(torch.load("unet_dress_segmentation.pth", map_location=device)) |
| segmentation_model.eval() |
|
|
| |
| classification_model = resnet18(weights=None).to(device) |
| classification_model.fc = nn.Linear(classification_model.fc.in_features, 10) |
| classification_model.load_state_dict(torch.load("best_cnn_model_LR_1e_4.pt", map_location=device)) |
| classification_model.eval() |
| |
| return detection_model, segmentation_model, classification_model |
|
|
| detection_model, segmentation_model, classification_model = load_models() |
| print("All models have been successfully loaded to the CPU.") |
|
|
| |
| |
| |
|
|
| def process_image(input_image): |
| """ |
| Receives a PIL Image object and returns the processing results. |
| """ |
| original_pil_img = input_image.convert("RGB") |
| img_tensor = F.to_tensor(original_pil_img).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| predictions = detection_model(img_tensor) |
| |
| boxes = [] |
| for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']): |
| if label.item() == 1 and score.item() > 0.8: |
| boxes.append(box.cpu().numpy()) |
| |
| if not boxes: |
| return None, None, "No person detected.", None |
|
|
| box = boxes[0] |
| x1, y1, x2, y2 = map(int, box) |
|
|
| |
| person_crop_pil = original_pil_img.crop((x1, y1, x2, y2)) |
| person_crop_np = np.array(person_crop_pil) |
| seg_input_tensor = val_transforms(person_crop_pil).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| mask_logits = segmentation_model(seg_input_tensor) |
| mask_pred = torch.sigmoid(mask_logits) > 0.5 |
| mask_np = mask_pred.squeeze().cpu().numpy().astype(np.uint8) |
| mask_resized = cv2.resize(mask_np, (person_crop_pil.width, person_crop_pil.height)) |
|
|
| |
| mask_3_channel = np.stack([mask_resized]*3, axis=-1) |
| extracted_dress_np = person_crop_np * mask_3_channel |
| extracted_dress_pil = Image.fromarray(extracted_dress_np) |
| class_input_tensor = classification_transforms(extracted_dress_pil).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| output_logits = classification_model(class_input_tensor) |
| probabilities = torch.softmax(output_logits, dim=1)[0] |
| confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(10)} |
| predicted_label = CLASS_NAMES[probabilities.argmax()] |
|
|
| |
| target_layer = [classification_model.layer4[-1]] |
| cam = GradCAM(model=classification_model, target_layers=target_layer) |
| targets = [ClassifierOutputTarget(probabilities.argmax())] |
| rgb_img_for_cam = np.array(extracted_dress_pil) / 255.0 |
| rgb_img_for_cam = rgb_img_for_cam.astype(np.float32) |
| grayscale_cam = cam(input_tensor=class_input_tensor, targets=targets)[0, :] |
| visualization = show_cam_on_image(rgb_img_for_cam, grayscale_cam, use_rgb=True) |
|
|
| |
| return extracted_dress_pil, visualization, confidences, Image.fromarray((mask_resized * 255).astype(np.uint8)) |
|
|
|
|
| |
| |
| |
| title = "👗✨ FashionAI: Dress Analysis Pipeline ✨👗" |
| description = """ |
| **An end-to-end Computer Vision Pipeline.** |
| Upload an image of a person wearing a dress. The AI will first detect the person, then segment the dress, classify its style, and finally show which part of the dress was most important for its decision. |
| \n*Built with PyTorch, torchvision, Gradio, and ❤️ by a DS405B student.* |
| """ |
|
|
| iface = gr.Interface( |
| fn=process_image, |
| inputs=gr.Image(type="pil", label="Upload Your Image"), |
| outputs=[ |
| gr.Image(type="pil", label="Extracted Dress"), |
| gr.Image(type="pil", label="Grad-CAM Explanation"), |
| gr.Label(num_top_classes=3, label="Classification Probabilities"), |
| gr.Image(type="pil", label="Segmentation Mask") |
| ], |
| title=title, |
| description=description, |
| examples=[ |
| ["https://images.pexels.com/photos/1036627/pexels-photo-1036627.jpeg"], |
| ["https://images.pexels.com/photos/1126993/pexels-photo-1126993.jpeg"], |
| ["https://images.pexels.com/photos/985635/pexels-photo-985635.jpeg"] |
| ] |
| ) |
|
|
| |
| iface.launch() |
|
|