Lesson 5: Image Classification Project
Train a CNN to classify real-world images with high accuracy.
Let's put everything we have learned about PyTorch and CNNs into practice! In this project, you will build and train a deep learning classifier on the famous **CIFAR-10** dataset.
The CIFAR-10 Dataset
The CIFAR-10 dataset contains 60,000 32x32 color images divided across 10 classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. It is a classic dataset for testing new computer vision architectures.
Loading Data with transforms.Compose
Before feeding images into our neural network, we must normalize pixel values and convert them into PyTorch tensors. We chain these preprocessing steps using transforms.Compose:
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Batching Data with DataLoader
Feeding all 60,000 images into a network at once is memory-prohibitive. We use a PyTorch DataLoaderto shuffle the dataset and serve images in small, manageable batches (e.g. batch size of 64).
The Training Loop and optimizer.step
For each batch, we perform a forward pass to calculate predictions, compute loss, zero existing gradients via optimizer.zero_grad(), perform a backward pass via loss.backward(), and update weights via optimizer.step().
Project Tasks
To complete this project workshop, implement the following components inside the sandbox code panel:
- [ ]Task 1: Load the CIFAR-10 dataset using torchvision and apply the ToTensor normalizer.
- [ ]Task 2: Define a CNN class containing 2 Convolutional layers, MaxPool, and 2 Fully Connected layers.
- [ ]Task 3: Instantiate a CrossEntropyLoss criteria and Adam optimizer.
- [ ]Task 4: Run a training loop over 2 epochs, printing batch loss, and run optimizer.step() to adjust weights.