{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluate the Model #" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import torch.backends.cudnn as cudnn\n", "from collections import OrderedDict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "transform_train = transforms.Compose([\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", "])\n", "\n", "transform_test = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", "])\n", "\n", "trainset = torchvision.datasets.CIFAR10(\n", " root='./data', train=True, download=True, transform=transform_train)\n", "trainloader = torch.utils.data.DataLoader(\n", " trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)\n", "\n", "testset = torchvision.datasets.CIFAR10(\n", " root='./data', train=False, download=True, transform=transform_test)\n", "testloader = torch.utils.data.DataLoader(\n", " testset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True)\n", "\n", "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate the Model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Test Loss: 0.2659624905884266, Accuracy: 94.24%\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import torch.backends.cudnn as cudnn\n", "from collections import OrderedDict\n", "from resnet import *\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "net, name = ResNet18_with_name()\n", "\n", "# load checkpoint\n", "checkpoint = torch.load('./checkpoint/' + 'ResNet18.pth', weights_only=True)\n", "\n", "# remove 'module.' prefix\n", "new_state_dict = OrderedDict()\n", "for k, v in checkpoint['net'].items():\n", " name = k[7:] if k.startswith('module.') else k # remove 'module.' prefix\n", " new_state_dict[name] = v\n", "\n", "# load best state_dict\n", "net.load_state_dict(new_state_dict)\n", "\n", "# move to GPU if supported\n", "net = net.to(device)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "# test function\n", "def test(epoch):\n", " net.eval()\n", " test_loss = 0\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for batch_idx, (inputs, targets) in enumerate(testloader):\n", " inputs, targets = inputs.to(device), targets.to(device)\n", " outputs = net(inputs)\n", " loss = criterion(outputs, targets)\n", "\n", " test_loss += loss.item()\n", " _, predicted = outputs.max(1)\n", " total += targets.size(0)\n", " correct += predicted.eq(targets).sum().item()\n", "\n", " print(f'Epoch: {epoch}, Test Loss: {test_loss / (batch_idx + 1)}, Accuracy: {100. * correct / total}%')\n", "\n", "\n", "\n", "\n", "test(0)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }