import gradio as gr import torch import torch.nn as nn import torchvision from torchvision.transforms import functional as F from PIL import Image import numpy as np import requests import os import cv2 from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from torchvision.models import resnet18 # -------------------------------------------------------------------------- # 1. Model Definitions (We need to put all model architecture definitions here) # -------------------------------------------------------------------------- # --- U-Net Model Definition --- class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self, x): return self.double_conv(x) class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__(); self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True); self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2); self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1); x = torch.cat([x2, x1], dim=1); return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels=3, n_classes=1, bilinear=True): super(UNet, self).__init__(); factor = 2 if bilinear else 1; self.inc=DoubleConv(n_channels,64); self.down1=Down(64,128); self.down2=Down(128,256); self.down3=Down(256,512); self.down4=Down(512, 1024 // factor); self.up1=Up(1024,512 // factor,bilinear); self.up2=Up(512,256 // factor,bilinear); self.up3=Up(256,128 // factor,bilinear); self.up4=Up(128,64,bilinear); self.outc=nn.Conv2d(64,n_classes,1) def forward(self, x): x1=self.inc(x); x2=self.down1(x1); x3=self.down2(x2); x4=self.down3(x3); x5=self.down4(x4); x=self.up1(x5,x4); x=self.up2(x,x3); x=self.up3(x,x2); x=self.up4(x,x1); return self.outc(x) # -------------------------------------------------------------------------- # 2. Global Variables and Loading Functions # -------------------------------------------------------------------------- # Use CPU, as free GPU resources on Hugging Face Spaces are limited and unstable device = torch.device('cpu') # Define preprocessing from torchvision import transforms val_transforms = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) classification_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # [Important] Your class names CLASS_NAMES = [ 'work_dress', 'sling_dress', 'ethnic_dress', 'gown', 'casual_dress', 'party_dress', 'formal_dress', 'sports_dress', 'shirt_dress', 'resort_dress' ] # Load all models at once def load_models(): # Load the object detection model detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device) detection_model.eval() # Load the segmentation model segmentation_model = UNet(n_channels=3, n_classes=1).to(device) segmentation_model.load_state_dict(torch.load("unet_dress_segmentation.pth", map_location=device)) segmentation_model.eval() # Load your classification model classification_model = resnet18(weights=None).to(device) classification_model.fc = nn.Linear(classification_model.fc.in_features, 10) classification_model.load_state_dict(torch.load("best_cnn_model_LR_1e_4.pt", map_location=device)) classification_model.eval() return detection_model, segmentation_model, classification_model detection_model, segmentation_model, classification_model = load_models() print("All models have been successfully loaded to the CPU.") # -------------------------------------------------------------------------- # 3. Core Inference Function (Modified to return images and text) # -------------------------------------------------------------------------- def process_image(input_image): """ Receives a PIL Image object and returns the processing results. """ original_pil_img = input_image.convert("RGB") img_tensor = F.to_tensor(original_pil_img).unsqueeze(0).to(device) # 1. Detection with torch.no_grad(): predictions = detection_model(img_tensor) boxes = [] for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']): if label.item() == 1 and score.item() > 0.8: boxes.append(box.cpu().numpy()) if not boxes: return None, None, "No person detected.", None box = boxes[0] x1, y1, x2, y2 = map(int, box) # 2. Segmentation person_crop_pil = original_pil_img.crop((x1, y1, x2, y2)) person_crop_np = np.array(person_crop_pil) seg_input_tensor = val_transforms(person_crop_pil).unsqueeze(0).to(device) with torch.no_grad(): mask_logits = segmentation_model(seg_input_tensor) mask_pred = torch.sigmoid(mask_logits) > 0.5 mask_np = mask_pred.squeeze().cpu().numpy().astype(np.uint8) mask_resized = cv2.resize(mask_np, (person_crop_pil.width, person_crop_pil.height)) # 3. Classification mask_3_channel = np.stack([mask_resized]*3, axis=-1) extracted_dress_np = person_crop_np * mask_3_channel extracted_dress_pil = Image.fromarray(extracted_dress_np) class_input_tensor = classification_transforms(extracted_dress_pil).unsqueeze(0).to(device) with torch.no_grad(): output_logits = classification_model(class_input_tensor) probabilities = torch.softmax(output_logits, dim=1)[0] confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(10)} predicted_label = CLASS_NAMES[probabilities.argmax()] # 4. Grad-CAM target_layer = [classification_model.layer4[-1]] cam = GradCAM(model=classification_model, target_layers=target_layer) targets = [ClassifierOutputTarget(probabilities.argmax())] rgb_img_for_cam = np.array(extracted_dress_pil) / 255.0 rgb_img_for_cam = rgb_img_for_cam.astype(np.float32) grayscale_cam = cam(input_tensor=class_input_tensor, targets=targets)[0, :] visualization = show_cam_on_image(rgb_img_for_cam, grayscale_cam, use_rgb=True) # Return results return extracted_dress_pil, visualization, confidences, Image.fromarray((mask_resized * 255).astype(np.uint8)) # -------------------------------------------------------------------------- # 4. Create Gradio Interface # -------------------------------------------------------------------------- title = "👗✨ FashionAI: Dress Analysis Pipeline ✨👗" description = """ **An end-to-end Computer Vision Pipeline.** Upload an image of a person wearing a dress. The AI will first detect the person, then segment the dress, classify its style, and finally show which part of the dress was most important for its decision. \n*Built with PyTorch, torchvision, Gradio, and ❤️ by a DS405B student.* """ iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Upload Your Image"), outputs=[ gr.Image(type="pil", label="Extracted Dress"), gr.Image(type="pil", label="Grad-CAM Explanation"), gr.Label(num_top_classes=3, label="Classification Probabilities"), gr.Image(type="pil", label="Segmentation Mask") ], title=title, description=description, examples=[ ["https://images.pexels.com/photos/1036627/pexels-photo-1036627.jpeg"], ["https://images.pexels.com/photos/1126993/pexels-photo-1126993.jpeg"], ["https://images.pexels.com/photos/985635/pexels-photo-985635.jpeg"] ] ) # Launch the application iface.launch()