Liqian Huang commited on
Commit
862cbf6
·
verified ·
1 Parent(s): 46c60bc

Upload 4 files (#1)

Browse files

- Upload 4 files (a1feacb6bfccef21b49dd523f7fb9721ea5ae71f)

app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision
6
+ from torchvision.transforms import functional as F
7
+ from PIL import Image
8
+ import numpy as np
9
+ import requests
10
+ import os
11
+ import cv2
12
+ from pytorch_grad_cam import GradCAM
13
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
14
+ from pytorch_grad_cam.utils.image import show_cam_on_image
15
+ from torchvision.models import resnet18
16
+
17
+ # --------------------------------------------------------------------------
18
+ # 1. Model Definitions (We need to put all model architecture definitions here)
19
+ # --------------------------------------------------------------------------
20
+
21
+ # --- U-Net Model Definition ---
22
+ class DoubleConv(nn.Module):
23
+ def __init__(self, in_channels, out_channels, mid_channels=None):
24
+ super().__init__()
25
+ if not mid_channels: mid_channels = out_channels
26
+ self.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))
27
+ def forward(self, x): return self.double_conv(x)
28
+
29
+ class Down(nn.Module):
30
+ def __init__(self, in_channels, out_channels): super().__init__(); self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
31
+ def forward(self, x): return self.maxpool_conv(x)
32
+
33
+ class Up(nn.Module):
34
+ def __init__(self, in_channels, out_channels, bilinear=True):
35
+ super().__init__()
36
+ if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True); self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
37
+ else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2); self.conv = DoubleConv(in_channels, out_channels)
38
+ def forward(self, x1, x2): x1 = self.up(x1); x = torch.cat([x2, x1], dim=1); return self.conv(x)
39
+
40
+ class UNet(nn.Module):
41
+ def __init__(self, n_channels=3, n_classes=1, bilinear=True):
42
+ super(UNet, self).__init__(); factor = 2 if bilinear else 1; self.inc=DoubleConv(n_channels,64); self.down1=Down(64,128); self.down2=Down(128,256); self.down3=Down(256,512); self.down4=Down(512, 1024 // factor); self.up1=Up(1024,512 // factor,bilinear); self.up2=Up(512,256 // factor,bilinear); self.up3=Up(256,128 // factor,bilinear); self.up4=Up(128,64,bilinear); self.outc=nn.Conv2d(64,n_classes,1)
43
+ def forward(self, x): x1=self.inc(x); x2=self.down1(x1); x3=self.down2(x2); x4=self.down3(x3); x5=self.down4(x4); x=self.up1(x5,x4); x=self.up2(x,x3); x=self.up3(x,x2); x=self.up4(x,x1); return self.outc(x)
44
+
45
+ # --------------------------------------------------------------------------
46
+ # 2. Global Variables and Loading Functions
47
+ # --------------------------------------------------------------------------
48
+
49
+ # Use CPU, as free GPU resources on Hugging Face Spaces are limited and unstable
50
+ device = torch.device('cpu')
51
+
52
+ # Define preprocessing
53
+ from torchvision import transforms
54
+ val_transforms = transforms.Compose([
55
+ transforms.Resize((256, 256)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
+ ])
59
+ classification_transforms = transforms.Compose([
60
+ transforms.Resize((224, 224)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
+ ])
64
+
65
+ # [Important] Your class names
66
+ CLASS_NAMES = [
67
+ 'work_dress', 'sling_dress', 'ethnic_dress', 'gown', 'casual_dress',
68
+ 'party_dress', 'formal_dress', 'sports_dress', 'shirt_dress', 'resort_dress'
69
+ ]
70
+
71
+ # Load all models at once
72
+ def load_models():
73
+ # Load the object detection model
74
+ detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(device)
75
+ detection_model.eval()
76
+
77
+ # Load the segmentation model
78
+ segmentation_model = UNet(n_channels=3, n_classes=1).to(device)
79
+ segmentation_model.load_state_dict(torch.load("unet_dress_segmentation.pth", map_location=device))
80
+ segmentation_model.eval()
81
+
82
+ # Load your classification model
83
+ classification_model = resnet18(weights=None).to(device)
84
+ classification_model.fc = nn.Linear(classification_model.fc.in_features, 10)
85
+ classification_model.load_state_dict(torch.load("best_cnn_model_LR_1e_4.pt", map_location=device))
86
+ classification_model.eval()
87
+
88
+ return detection_model, segmentation_model, classification_model
89
+
90
+ detection_model, segmentation_model, classification_model = load_models()
91
+ print("All models have been successfully loaded to the CPU.")
92
+
93
+ # --------------------------------------------------------------------------
94
+ # 3. Core Inference Function (Modified to return images and text)
95
+ # --------------------------------------------------------------------------
96
+
97
+ def process_image(input_image):
98
+ """
99
+ Receives a PIL Image object and returns the processing results.
100
+ """
101
+ original_pil_img = input_image.convert("RGB")
102
+ img_tensor = F.to_tensor(original_pil_img).unsqueeze(0).to(device)
103
+
104
+ # 1. Detection
105
+ with torch.no_grad():
106
+ predictions = detection_model(img_tensor)
107
+
108
+ boxes = []
109
+ for box, label, score in zip(predictions[0]['boxes'], predictions[0]['labels'], predictions[0]['scores']):
110
+ if label.item() == 1 and score.item() > 0.8:
111
+ boxes.append(box.cpu().numpy())
112
+
113
+ if not boxes:
114
+ return None, None, "No person detected.", None
115
+
116
+ box = boxes[0]
117
+ x1, y1, x2, y2 = map(int, box)
118
+
119
+ # 2. Segmentation
120
+ person_crop_pil = original_pil_img.crop((x1, y1, x2, y2))
121
+ person_crop_np = np.array(person_crop_pil)
122
+ seg_input_tensor = val_transforms(person_crop_pil).unsqueeze(0).to(device)
123
+ with torch.no_grad():
124
+ mask_logits = segmentation_model(seg_input_tensor)
125
+ mask_pred = torch.sigmoid(mask_logits) > 0.5
126
+ mask_np = mask_pred.squeeze().cpu().numpy().astype(np.uint8)
127
+ mask_resized = cv2.resize(mask_np, (person_crop_pil.width, person_crop_pil.height))
128
+
129
+ # 3. Classification
130
+ mask_3_channel = np.stack([mask_resized]*3, axis=-1)
131
+ extracted_dress_np = person_crop_np * mask_3_channel
132
+ extracted_dress_pil = Image.fromarray(extracted_dress_np)
133
+ class_input_tensor = classification_transforms(extracted_dress_pil).unsqueeze(0).to(device)
134
+ with torch.no_grad():
135
+ output_logits = classification_model(class_input_tensor)
136
+ probabilities = torch.softmax(output_logits, dim=1)[0]
137
+ confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(10)}
138
+ predicted_label = CLASS_NAMES[probabilities.argmax()]
139
+
140
+ # 4. Grad-CAM
141
+ target_layer = [classification_model.layer4[-1]]
142
+ cam = GradCAM(model=classification_model, target_layers=target_layer)
143
+ targets = [ClassifierOutputTarget(probabilities.argmax())]
144
+ rgb_img_for_cam = np.array(extracted_dress_pil) / 255.0
145
+ rgb_img_for_cam = rgb_img_for_cam.astype(np.float32)
146
+ grayscale_cam = cam(input_tensor=class_input_tensor, targets=targets)[0, :]
147
+ visualization = show_cam_on_image(rgb_img_for_cam, grayscale_cam, use_rgb=True)
148
+
149
+ # Return results
150
+ return extracted_dress_pil, visualization, confidences, Image.fromarray((mask_resized * 255).astype(np.uint8))
151
+
152
+
153
+ # --------------------------------------------------------------------------
154
+ # 4. Create Gradio Interface
155
+ # --------------------------------------------------------------------------
156
+ title = "👗✨ FashionAI: Dress Analysis Pipeline ✨👗"
157
+ description = """
158
+ **An end-to-end Computer Vision Pipeline.**
159
+ Upload an image of a person wearing a dress. The AI will first detect the person, then segment the dress, classify its style, and finally show which part of the dress was most important for its decision.
160
+ \n*Built with PyTorch, torchvision, Gradio, and ❤️ by a DS405B student.*
161
+ """
162
+
163
+ iface = gr.Interface(
164
+ fn=process_image,
165
+ inputs=gr.Image(type="pil", label="Upload Your Image"),
166
+ outputs=[
167
+ gr.Image(type="pil", label="Extracted Dress"),
168
+ gr.Image(type="pil", label="Grad-CAM Explanation"),
169
+ gr.Label(num_top_classes=3, label="Classification Probabilities"),
170
+ gr.Image(type="pil", label="Segmentation Mask")
171
+ ],
172
+ title=title,
173
+ description=description,
174
+ examples=[
175
+ ["https://images.pexels.com/photos/1036627/pexels-photo-1036627.jpeg"],
176
+ ["https://images.pexels.com/photos/1126993/pexels-photo-1126993.jpeg"],
177
+ ["https://images.pexels.com/photos/985635/pexels-photo-985635.jpeg"]
178
+ ]
179
+ )
180
+
181
+ # Launch the application
182
+ iface.launch()
best_cnn_model_LR_1e_4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83e862c222c0e306a9840a5838087f64e88f14d44d2fecaf7c19265b49aa34cd
3
+ size 786432
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python-headless
5
+ matplotlib
6
+ Pillow
7
+ requests
8
+ pytorch-grad-cam
9
+ albumentations
10
+ scikit-learn
unet_dress_segmentation.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:496b8a110af8ee89d70bee273af16745798d316f48849c53663cabc84febffa7
3
+ size 1048576