import os import uuid import pickle from typing import List try: from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware except ImportError as exc: raise ImportError( "FastAPI is required to run this application. Install it with 'pip install fastapi'." ) from exc import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split # Added for accuracy scoring import io # ── App Initialization ─────────────────────────────────────────────────────── app = FastAPI(title="Teachable Machine Backend") # ── CORS Configuration ─────────────────────────────────────────────────────── # Enables file uploads and API calls from frontend (running on different origin/port) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow requests from any origin allow_credentials=True, allow_methods=["*"], # Allow all HTTP methods allow_headers=["*"], # Allow all headers ) DATASET_DIR = os.path.join(os.path.dirname(__file__), "dataset") MODEL_PATH = os.path.join(os.path.dirname(__file__), "model.pkl") # ── Shared ML Setup (runs once at startup) ─────────────────────────────────── # Loading the model once here means every request reuses the same object in # memory instead of reloading it from disk each time — much faster. device = torch.device("cpu") backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) # Remove the final classifier layer — we only want feature extraction. # The 960 numbers it outputs describe the image content without predicting a category. backbone.classifier = torch.nn.Identity() backbone.eval() # Disables dropout — we are inferring, not training the backbone # These values MUST be identical during training and prediction. # 224x224 = the size MobileNetV3 was designed for. # mean/std = ImageNet dataset statistics the model was pre-trained on. transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ── Helper: Extract features from a PIL image ──────────────────────────────── def extract_features(pil_image: Image.Image) -> np.ndarray: """ Passes an image through MobileNetV3 and returns a 960-number feature vector. Used by both /train and /predict to guarantee identical preprocessing. """ image = pil_image.convert("RGB") # Handles RGBA/grayscale images safely tensor = transform(image) tensor = tensor.unsqueeze(0) # (3,224,224) → (1,3,224,224) — adds batch dim with torch.no_grad(): # No gradients needed — saves memory & time features = backbone(tensor) return features.squeeze().numpy() # (1,960) → (960,) numpy array for sklearn # ── Health Check ───────────────────────────────────────────────────────────── @app.get("/") def health_check(): return {"status": "Backend is running!"} # ── Milestone 1: Upload images ─────────────────────────────────────────────── @app.post("/upload-sample") async def upload_sample( class_name: str = Form(...), files: List[UploadFile] = File(...) ): """ Accepts a class label + a batch of images. Saves each image into dataset// with a random UUID filename. """ class_name = class_name.strip().replace(" ", "_") if not class_name: raise HTTPException(status_code=400, detail="class_name cannot be empty.") class_folder = os.path.join(DATASET_DIR, class_name) os.makedirs(class_folder, exist_ok=True) if not files: raise HTTPException(status_code=400, detail="At least one image file is required.") saved_files = [] for file in files: if not file.content_type.startswith("image/"): raise HTTPException( status_code=400, detail=f"File '{file.filename}' is not an image. Only image files are accepted." ) extension = os.path.splitext(file.filename)[1] or ".jpg" random_filename = f"{uuid.uuid4()}{extension}" save_path = os.path.join(class_folder, random_filename) contents = await file.read() with open(save_path, "wb") as f: f.write(contents) saved_files.append(random_filename) return { "message": f"Uploaded {len(saved_files)} image(s) to class '{class_name}'", "class": class_name, "saved_files": saved_files } # ── Milestone 1 Bonus: Dataset info ───────────────────────────────────────── @app.get("/dataset-info") def dataset_info(): if not os.path.exists(DATASET_DIR): return {"classes": {}, "total_images": 0} summary = {} for class_name in os.listdir(DATASET_DIR): class_path = os.path.join(DATASET_DIR, class_name) if os.path.isdir(class_path): summary[class_name] = len(os.listdir(class_path)) return { "classes": summary, "total_images": sum(summary.values()) } # ── Milestone 2: Train model ───────────────────────────────────────────────── @app.post("/train") def train_model(): """ Scans dataset/, extracts MobileNetV3 features from every image, trains a LogisticRegression classifier, and saves it to model.pkl. """ # ── Step 1: Validate dataset exists ────────────────────────────────────── if not os.path.exists(DATASET_DIR): raise HTTPException( status_code=400, detail="No dataset found. Please upload images first." ) classes = [ d for d in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, d)) ] # Classifier needs at least 2 classes — it learns to DISTINGUISH between them. # With only 1 class there is nothing to distinguish. if len(classes) < 2: raise HTTPException( status_code=400, detail=f"Need at least 2 classes to train. You currently have: {classes}" ) X = [] # Feature vectors — one row per image y = [] # Labels — one entry per image, matched by index to X # ── Step 2: Extract features from every image ──────────────────────────── for class_name in classes: class_folder = os.path.join(DATASET_DIR, class_name) image_files = os.listdir(class_folder) if len(image_files) == 0: continue # Skip empty class folders silently for filename in image_files: image_path = os.path.join(class_folder, filename) try: img = Image.open(image_path) features = extract_features(img) X.append(features) y.append(class_name) except Exception as e: # One corrupted image should not kill the whole training run print(f"Skipping {filename}: {e}") continue if len(X) == 0: raise HTTPException( status_code=400, detail="No valid images found in dataset." ) X = np.array(X) # Shape: (num_images, 960) y = np.array(y) # Shape: (num_images,) # ── Step 3: Train the classifier ───────────────────────────────────────── # NEW: Split the data to calculate a real accuracy metric. # We added a safety net: if there are fewer than 5 images total, we test # on the training data so it doesn't crash during a live presentation. if len(X) >= 5: X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) else: X_train, X_test, y_train, y_test = X, X, y, y # Why LogisticRegression? # MobileNetV3 already converted images into meaningful 960-number vectors. # LogisticRegression just finds the boundary between those vectors. # It trains in under a second, works with very few images, and needs no GPU. # max_iter=1000 prevents ConvergenceWarning on small datasets. classifier = LogisticRegression(max_iter=1000) classifier.fit(X_train, y_train) # Calculate overall accuracy accuracy = classifier.score(X_test, y_test) # ── Step 4: Save classifier + class list to disk ───────────────────────── # We save classes explicitly so the /predict endpoint can map # numeric outputs back to human-readable label names. model_data = { "classifier": classifier, "classes": classes } with open(MODEL_PATH, "wb") as f: pickle.dump(model_data, f) return { "message": "Training complete!", "classes": classes, "total_images": len(X), "accuracy": round(accuracy * 100, 2), # Returned safely to the frontend! "model_saved_at": MODEL_PATH } # ── Milestone 3: Predict endpoint ──────────────────────────────────────────── @app.post("/predict") async def predict(file: UploadFile = File(...)): """ Accepts a single image, runs it through MobileNetV3 + the trained LogisticRegression classifier, and returns the predicted class with a confidence score for every class. """ # ── Step 1: Check model exists ──────────────────────────────────────────── # If the user hits /predict before ever running /train, model.pkl won't # exist yet. We catch this early with a clear message instead of a crash. if not os.path.exists(MODEL_PATH): raise HTTPException( status_code=400, detail="No trained model found. Please call /train first." ) # ── Step 2: Validate the uploaded file is an image ──────────────────────── if not file.content_type.startswith("image/"): raise HTTPException( status_code=400, detail=f"File '{file.filename}' is not an image. Only image files are accepted." ) # ── Step 3: Load the saved model from disk ──────────────────────────────── # We reload model.pkl on every prediction request. # Why not load it once at startup like the backbone? # Because model.pkl gets replaced every time /train is called. # If we cached it at startup, predictions would use the OLD model # even after the user retrains — a subtle but serious bug. with open(MODEL_PATH, "rb") as f: model_data = pickle.load(f) classifier = model_data["classifier"] classes = model_data["classes"] # ── Step 4: Read and decode the uploaded image ──────────────────────────── # file.read() gives us raw bytes. We wrap them in BytesIO so PIL # can treat the bytes like a file on disk — no temp file needed. contents = await file.read() image = Image.open(io.BytesIO(contents)) # ── Step 5: Extract features using the SAME function used during training ─ # This is the most important consistency rule in the whole project. # If training used 224x224 + ImageNet normalization, prediction MUST too. # extract_features() guarantees this since both phases call the same code. features = extract_features(image) # ── Step 6: Run prediction ──────────────────────────────────────────────── # features is shape (960,) — we reshape to (1, 960) because sklearn # expects a 2D array: (number_of_samples, number_of_features) features_2d = features.reshape(1, -1) # predict() returns the winning class label e.g. ["cat"] predicted_class = classifier.predict(features_2d)[0] # predict_proba() returns confidence scores for ALL classes e.g. [0.82, 0.18] # Each number = how confident the model is that this image belongs to that class. # They always sum to 1.0 (100%). probabilities = classifier.predict_proba(features_2d)[0] # ── Step 7: Build a clean confidence scores dict ────────────────────────── # zip(classes, probabilities) pairs each class name with its score: # e.g. {"cat": 0.82, "dog": 0.18} # round(..., 4) keeps it readable: 0.8173 instead of 0.81734521938... # float() converts numpy float32 → Python float so JSON can serialize it confidence_scores = { cls: round(float(prob), 4) for cls, prob in zip(classifier.classes_, probabilities) } return { "predicted_class": predicted_class, "confidence": round(float(max(probabilities)), 4), "all_scores": confidence_scores }