Liqian Huang
Upload 4 files (#1)
862cbf6 verified
Raw
History Blame
8.55 kB
import gradio as gr
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np
import requests
import os
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet18
# --------------------------------------------------------------------------
# 1. Model Definitions (We need to put all model architecture definitions here)
# --------------------------------------------------------------------------
# --- U-Net Model Definition ---
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels: mid_channels = out_channels
self.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
def forward(self, x): return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels): super().__init__(); self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
def forward(self, x): return self.maxpool_conv(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True); self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2); self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2): x1 = self.up(x1); x = torch.cat([x2, x1], dim=1); return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=1, bilinear=True):
super(UNet, self).__init__(); factor = 2 if bilinear else 1; self.inc=DoubleConv(n_channels,64); self.down1=Down(64,128); self.down2=Down(128,256); self.down3=Down(256,512); self.down4=Down(512, 1024 // factor); self.up1=Up(1024,512 // factor,bilinear); self.up2=Up(512,256 // factor,bilinear); self.up3=Up(256,128 // factor,bilinear); self.up4=Up(128,64,bilinear); self.outc=nn.Conv2d(64,n_classes,1)
def forward(self, x): x1=self.inc(x); x2=self.down1(x1); x3=self.down2(x2); x4=self.down3(x3); x5=self.down4(x4); x=self.up1(x5,x4); x=self.up2(x,x3); x=self.up3(x,x2); x=self.up4(x,x1); return self.outc(x)
# --------------------------------------------------------------------------
# 2. Global Variables and Loading Functions
# --------------------------------------------------------------------------
# Use CPU, as free GPU resources on Hugging Face Spaces are limited and unstable
device = torch.device('cpu')
# Define preprocessing
from torchvision import transforms
val_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
classification_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# [Important] Your class names
CLASS_NAMES = [
'work_dress', 'sling_dress', 'ethnic_dress', 'gown', 'casual_dress',
'party_dress', 'formal_dress', 'sports_dress', 'shirt_dress', 'resort_dress'
]
# Load all models at once
def load_models():
# Load the object detection model
detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device)
detection_model.eval()
# Load the segmentation model
segmentation_model = UNet(n_channels=3, n_classes=1).to(device)
segmentation_model.load_state_dict(torch.load("unet_dress_segmentation.pth", map_location=device))
segmentation_model.eval()
# Load your classification model
classification_model = resnet18(weights=None).to(device)
classification_model.fc = nn.Linear(classification_model.fc.in_features, 10)
classification_model.load_state_dict(torch.load("best_cnn_model_LR_1e_4.pt", map_location=device))
classification_model.eval()
return detection_model, segmentation_model, classification_model
detection_model, segmentation_model, classification_model = load_models()
print("All models have been successfully loaded to the CPU.")
# --------------------------------------------------------------------------
# 3. Core Inference Function (Modified to return images and text)
# --------------------------------------------------------------------------
def process_image(input_image):
"""
Receives a PIL Image object and returns the processing results.
"""
original_pil_img = input_image.convert("RGB")
img_tensor = F.to_tensor(original_pil_img).unsqueeze(0).to(device)
# 1. Detection
with torch.no_grad():
predictions = detection_model(img_tensor)
boxes = []
for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']):
if label.item() == 1 and score.item() > 0.8:
boxes.append(box.cpu().numpy())
if not boxes:
return None, None, "No person detected.", None
box = boxes[0]
x1, y1, x2, y2 = map(int, box)
# 2. Segmentation
person_crop_pil = original_pil_img.crop((x1, y1, x2, y2))
person_crop_np = np.array(person_crop_pil)
seg_input_tensor = val_transforms(person_crop_pil).unsqueeze(0).to(device)
with torch.no_grad():
mask_logits = segmentation_model(seg_input_tensor)
mask_pred = torch.sigmoid(mask_logits) > 0.5
mask_np = mask_pred.squeeze().cpu().numpy().astype(np.uint8)
mask_resized = cv2.resize(mask_np, (person_crop_pil.width, person_crop_pil.height))
# 3. Classification
mask_3_channel = np.stack([mask_resized]*3, axis=-1)
extracted_dress_np = person_crop_np * mask_3_channel
extracted_dress_pil = Image.fromarray(extracted_dress_np)
class_input_tensor = classification_transforms(extracted_dress_pil).unsqueeze(0).to(device)
with torch.no_grad():
output_logits = classification_model(class_input_tensor)
probabilities = torch.softmax(output_logits, dim=1)[0]
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(10)}
predicted_label = CLASS_NAMES[probabilities.argmax()]
# 4. Grad-CAM
target_layer = [classification_model.layer4[-1]]
cam = GradCAM(model=classification_model, target_layers=target_layer)
targets = [ClassifierOutputTarget(probabilities.argmax())]
rgb_img_for_cam = np.array(extracted_dress_pil) / 255.0
rgb_img_for_cam = rgb_img_for_cam.astype(np.float32)
grayscale_cam = cam(input_tensor=class_input_tensor, targets=targets)[0, :]
visualization = show_cam_on_image(rgb_img_for_cam, grayscale_cam, use_rgb=True)
# Return results
return extracted_dress_pil, visualization, confidences, Image.fromarray((mask_resized * 255).astype(np.uint8))
# --------------------------------------------------------------------------
# 4. Create Gradio Interface
# --------------------------------------------------------------------------
title = "👗✨ FashionAI: Dress Analysis Pipeline ✨👗"
description = """
**An end-to-end Computer Vision Pipeline.**
Upload an image of a person wearing a dress. The AI will first detect the person, then segment the dress, classify its style, and finally show which part of the dress was most important for its decision.
\n*Built with PyTorch, torchvision, Gradio, and ❤️ by a DS405B student.*
"""
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Your Image"),
outputs=[
gr.Image(type="pil", label="Extracted Dress"),
gr.Image(type="pil", label="Grad-CAM Explanation"),
gr.Label(num_top_classes=3, label="Classification Probabilities"),
gr.Image(type="pil", label="Segmentation Mask")
],
title=title,
description=description,
examples=[
["https://images.pexels.com/photos/1036627/pexels-photo-1036627.jpeg"],
["https://images.pexels.com/photos/1126993/pexels-photo-1126993.jpeg"],
["https://images.pexels.com/photos/985635/pexels-photo-985635.jpeg"]
]
)
# Launch the application
iface.launch()