asadullahdogarr commited on
Commit
46fba1d
Β·
verified Β·
1 Parent(s): 2a15a86

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +24 -0
  2. main.py +327 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime
2
+ FROM python:3.10-slim
3
+
4
+ # Set up a new user named "user" with user ID 1000
5
+ # (Required for Hugging Face Spaces to prevent permission errors)
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+
9
+ # Set environmental variables
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+ WORKDIR /home/user/app
12
+
13
+ # Copy requirements and install
14
+ COPY --chown=user requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy your backend code
18
+ COPY --chown=user . .
19
+
20
+ # Hugging Face Spaces route web traffic to port 7860
21
+ EXPOSE 7860
22
+
23
+ # Start the FastAPI app on port 7860
24
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import pickle
4
+ from typing import List
5
+
6
+ try:
7
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ except ImportError as exc:
10
+ raise ImportError(
11
+ "FastAPI is required to run this application. Install it with 'pip install fastapi'."
12
+ ) from exc
13
+
14
+ import torch
15
+ import torchvision.models as models
16
+ import torchvision.transforms as transforms
17
+ from PIL import Image
18
+ import numpy as np
19
+ from sklearn.linear_model import LogisticRegression
20
+ from sklearn.model_selection import train_test_split # Added for accuracy scoring
21
+ import io
22
+
23
+ # ── App Initialization ───────────────────────────────────────────────────────
24
+ app = FastAPI(title="Teachable Machine Backend")
25
+
26
+ # ── CORS Configuration ───────────────────────────────────────────────────────
27
+ # Enables file uploads and API calls from frontend (running on different origin/port)
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # Allow requests from any origin
31
+ allow_credentials=True,
32
+ allow_methods=["*"], # Allow all HTTP methods
33
+ allow_headers=["*"], # Allow all headers
34
+ )
35
+
36
+ DATASET_DIR = os.path.join(os.path.dirname(__file__), "dataset")
37
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "model.pkl")
38
+
39
+
40
+ # ── Shared ML Setup (runs once at startup) ───────────────────────────────────
41
+ # Loading the model once here means every request reuses the same object in
42
+ # memory instead of reloading it from disk each time β€” much faster.
43
+ device = torch.device("cpu")
44
+
45
+ backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
46
+ # Remove the final classifier layer β€” we only want feature extraction.
47
+ # The 960 numbers it outputs describe the image content without predicting a category.
48
+ backbone.classifier = torch.nn.Identity()
49
+ backbone.eval() # Disables dropout β€” we are inferring, not training the backbone
50
+
51
+ # These values MUST be identical during training and prediction.
52
+ # 224x224 = the size MobileNetV3 was designed for.
53
+ # mean/std = ImageNet dataset statistics the model was pre-trained on.
54
+ transform = transforms.Compose([
55
+ transforms.Resize((224, 224)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(
58
+ mean=[0.485, 0.456, 0.406],
59
+ std=[0.229, 0.224, 0.225]
60
+ )
61
+ ])
62
+
63
+
64
+ # ── Helper: Extract features from a PIL image ────────────────────────────────
65
+ def extract_features(pil_image: Image.Image) -> np.ndarray:
66
+ """
67
+ Passes an image through MobileNetV3 and returns a 960-number feature vector.
68
+ Used by both /train and /predict to guarantee identical preprocessing.
69
+ """
70
+ image = pil_image.convert("RGB") # Handles RGBA/grayscale images safely
71
+ tensor = transform(image)
72
+ tensor = tensor.unsqueeze(0) # (3,224,224) β†’ (1,3,224,224) β€” adds batch dim
73
+
74
+ with torch.no_grad(): # No gradients needed β€” saves memory & time
75
+ features = backbone(tensor)
76
+
77
+ return features.squeeze().numpy() # (1,960) β†’ (960,) numpy array for sklearn
78
+
79
+
80
+ # ── Health Check ─────────────────────────────────────────────────────────────
81
+ @app.get("/")
82
+ def health_check():
83
+ return {"status": "Backend is running!"}
84
+
85
+
86
+ # ── Milestone 1: Upload images ───────────────────────────────────────────────
87
+ @app.post("/upload-sample")
88
+ async def upload_sample(
89
+ class_name: str = Form(...),
90
+ files: List[UploadFile] = File(...)
91
+ ):
92
+ """
93
+ Accepts a class label + a batch of images.
94
+ Saves each image into dataset/<class_name>/ with a random UUID filename.
95
+ """
96
+ class_name = class_name.strip().replace(" ", "_")
97
+
98
+ if not class_name:
99
+ raise HTTPException(status_code=400, detail="class_name cannot be empty.")
100
+
101
+ class_folder = os.path.join(DATASET_DIR, class_name)
102
+ os.makedirs(class_folder, exist_ok=True)
103
+
104
+ if not files:
105
+ raise HTTPException(status_code=400, detail="At least one image file is required.")
106
+
107
+ saved_files = []
108
+
109
+ for file in files:
110
+ if not file.content_type.startswith("image/"):
111
+ raise HTTPException(
112
+ status_code=400,
113
+ detail=f"File '{file.filename}' is not an image. Only image files are accepted."
114
+ )
115
+
116
+ extension = os.path.splitext(file.filename)[1] or ".jpg"
117
+ random_filename = f"{uuid.uuid4()}{extension}"
118
+ save_path = os.path.join(class_folder, random_filename)
119
+
120
+ contents = await file.read()
121
+ with open(save_path, "wb") as f:
122
+ f.write(contents)
123
+
124
+ saved_files.append(random_filename)
125
+
126
+ return {
127
+ "message": f"Uploaded {len(saved_files)} image(s) to class '{class_name}'",
128
+ "class": class_name,
129
+ "saved_files": saved_files
130
+ }
131
+
132
+
133
+ # ── Milestone 1 Bonus: Dataset info ─────────────────────────────────────────
134
+ @app.get("/dataset-info")
135
+ def dataset_info():
136
+ if not os.path.exists(DATASET_DIR):
137
+ return {"classes": {}, "total_images": 0}
138
+
139
+ summary = {}
140
+ for class_name in os.listdir(DATASET_DIR):
141
+ class_path = os.path.join(DATASET_DIR, class_name)
142
+ if os.path.isdir(class_path):
143
+ summary[class_name] = len(os.listdir(class_path))
144
+
145
+ return {
146
+ "classes": summary,
147
+ "total_images": sum(summary.values())
148
+ }
149
+
150
+
151
+ # ── Milestone 2: Train model ─────────────────────────────────────────────────
152
+ @app.post("/train")
153
+ def train_model():
154
+ """
155
+ Scans dataset/, extracts MobileNetV3 features from every image,
156
+ trains a LogisticRegression classifier, and saves it to model.pkl.
157
+ """
158
+
159
+ # ── Step 1: Validate dataset exists ──────────────────────────────────────
160
+ if not os.path.exists(DATASET_DIR):
161
+ raise HTTPException(
162
+ status_code=400,
163
+ detail="No dataset found. Please upload images first."
164
+ )
165
+
166
+ classes = [
167
+ d for d in os.listdir(DATASET_DIR)
168
+ if os.path.isdir(os.path.join(DATASET_DIR, d))
169
+ ]
170
+
171
+ # Classifier needs at least 2 classes β€” it learns to DISTINGUISH between them.
172
+ # With only 1 class there is nothing to distinguish.
173
+ if len(classes) < 2:
174
+ raise HTTPException(
175
+ status_code=400,
176
+ detail=f"Need at least 2 classes to train. You currently have: {classes}"
177
+ )
178
+
179
+ X = [] # Feature vectors β€” one row per image
180
+ y = [] # Labels β€” one entry per image, matched by index to X
181
+
182
+ # ── Step 2: Extract features from every image ────────────────────────────
183
+ for class_name in classes:
184
+ class_folder = os.path.join(DATASET_DIR, class_name)
185
+ image_files = os.listdir(class_folder)
186
+
187
+ if len(image_files) == 0:
188
+ continue # Skip empty class folders silently
189
+
190
+ for filename in image_files:
191
+ image_path = os.path.join(class_folder, filename)
192
+ try:
193
+ img = Image.open(image_path)
194
+ features = extract_features(img)
195
+ X.append(features)
196
+ y.append(class_name)
197
+ except Exception as e:
198
+ # One corrupted image should not kill the whole training run
199
+ print(f"Skipping {filename}: {e}")
200
+ continue
201
+
202
+ if len(X) == 0:
203
+ raise HTTPException(
204
+ status_code=400,
205
+ detail="No valid images found in dataset."
206
+ )
207
+
208
+ X = np.array(X) # Shape: (num_images, 960)
209
+ y = np.array(y) # Shape: (num_images,)
210
+
211
+ # ── Step 3: Train the classifier ─────────────────────────────────────────
212
+
213
+ # NEW: Split the data to calculate a real accuracy metric.
214
+ # We added a safety net: if there are fewer than 5 images total, we test
215
+ # on the training data so it doesn't crash during a live presentation.
216
+ if len(X) >= 5:
217
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
218
+ else:
219
+ X_train, X_test, y_train, y_test = X, X, y, y
220
+
221
+ # Why LogisticRegression?
222
+ # MobileNetV3 already converted images into meaningful 960-number vectors.
223
+ # LogisticRegression just finds the boundary between those vectors.
224
+ # It trains in under a second, works with very few images, and needs no GPU.
225
+ # max_iter=1000 prevents ConvergenceWarning on small datasets.
226
+ classifier = LogisticRegression(max_iter=1000)
227
+ classifier.fit(X_train, y_train)
228
+
229
+ # Calculate overall accuracy
230
+ accuracy = classifier.score(X_test, y_test)
231
+
232
+ # ── Step 4: Save classifier + class list to disk ─────────────────────────
233
+ # We save classes explicitly so the /predict endpoint can map
234
+ # numeric outputs back to human-readable label names.
235
+ model_data = {
236
+ "classifier": classifier,
237
+ "classes": classes
238
+ }
239
+ with open(MODEL_PATH, "wb") as f:
240
+ pickle.dump(model_data, f)
241
+
242
+ return {
243
+ "message": "Training complete!",
244
+ "classes": classes,
245
+ "total_images": len(X),
246
+ "accuracy": round(accuracy * 100, 2), # Returned safely to the frontend!
247
+ "model_saved_at": MODEL_PATH
248
+ }
249
+
250
+
251
+ # ── Milestone 3: Predict endpoint ────────────────────────────────────────────
252
+ @app.post("/predict")
253
+ async def predict(file: UploadFile = File(...)):
254
+ """
255
+ Accepts a single image, runs it through MobileNetV3 + the trained
256
+ LogisticRegression classifier, and returns the predicted class
257
+ with a confidence score for every class.
258
+ """
259
+
260
+ # ── Step 1: Check model exists ────────────────────────────────────────────
261
+ # If the user hits /predict before ever running /train, model.pkl won't
262
+ # exist yet. We catch this early with a clear message instead of a crash.
263
+ if not os.path.exists(MODEL_PATH):
264
+ raise HTTPException(
265
+ status_code=400,
266
+ detail="No trained model found. Please call /train first."
267
+ )
268
+
269
+ # ── Step 2: Validate the uploaded file is an image ────────────────────────
270
+ if not file.content_type.startswith("image/"):
271
+ raise HTTPException(
272
+ status_code=400,
273
+ detail=f"File '{file.filename}' is not an image. Only image files are accepted."
274
+ )
275
+
276
+ # ── Step 3: Load the saved model from disk ────────────────────────────────
277
+ # We reload model.pkl on every prediction request.
278
+ # Why not load it once at startup like the backbone?
279
+ # Because model.pkl gets replaced every time /train is called.
280
+ # If we cached it at startup, predictions would use the OLD model
281
+ # even after the user retrains β€” a subtle but serious bug.
282
+ with open(MODEL_PATH, "rb") as f:
283
+ model_data = pickle.load(f)
284
+
285
+ classifier = model_data["classifier"]
286
+ classes = model_data["classes"]
287
+
288
+ # ── Step 4: Read and decode the uploaded image ────────────────────────────
289
+ # file.read() gives us raw bytes. We wrap them in BytesIO so PIL
290
+ # can treat the bytes like a file on disk β€” no temp file needed.
291
+ contents = await file.read()
292
+ image = Image.open(io.BytesIO(contents))
293
+
294
+ # ── Step 5: Extract features using the SAME function used during training ─
295
+ # This is the most important consistency rule in the whole project.
296
+ # If training used 224x224 + ImageNet normalization, prediction MUST too.
297
+ # extract_features() guarantees this since both phases call the same code.
298
+ features = extract_features(image)
299
+
300
+ # ── Step 6: Run prediction ────────────────────────────────────────────────
301
+ # features is shape (960,) β€” we reshape to (1, 960) because sklearn
302
+ # expects a 2D array: (number_of_samples, number_of_features)
303
+ features_2d = features.reshape(1, -1)
304
+
305
+ # predict() returns the winning class label e.g. ["cat"]
306
+ predicted_class = classifier.predict(features_2d)[0]
307
+
308
+ # predict_proba() returns confidence scores for ALL classes e.g. [0.82, 0.18]
309
+ # Each number = how confident the model is that this image belongs to that class.
310
+ # They always sum to 1.0 (100%).
311
+ probabilities = classifier.predict_proba(features_2d)[0]
312
+
313
+ # ── Step 7: Build a clean confidence scores dict ──────────────────────────
314
+ # zip(classes, probabilities) pairs each class name with its score:
315
+ # e.g. {"cat": 0.82, "dog": 0.18}
316
+ # round(..., 4) keeps it readable: 0.8173 instead of 0.81734521938...
317
+ # float() converts numpy float32 β†’ Python float so JSON can serialize it
318
+ confidence_scores = {
319
+ cls: round(float(prob), 4)
320
+ for cls, prob in zip(classifier.classes_, probabilities)
321
+ }
322
+
323
+ return {
324
+ "predicted_class": predicted_class,
325
+ "confidence": round(float(max(probabilities)), 4),
326
+ "all_scores": confidence_scores
327
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ python-multipart
4
+ torch
5
+ torchvision
6
+ scikit-learn
7
+ Pillow
8
+ streamlit
9
+ requests