Computer Vision Demo

Practical introduction to Convolutional Neural Networks using PyTorch

Computer Vision Demo

Practical Implementation of Image Classification with PyTorch using Fashion MNIST Dataset

Created by Murilo Gustineli

Run the Computer Vision Demo in a colab notebook

Open in Colab


Welcome to this practical implementation, where we delve into the world of computer vision using PyTorch.

Our goal is to develop a practical understanding of Convolutional Neural Networks (CNNs) by implementing a model that can classify different types of clothing. This demo is designed for learners at all levels, so don’t worry if some concepts seem new. We’ll walk through each step of the process, explaining the key ideas and code details along the way.

# PyTorch
import torch

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

# Check if cuda is available
cuda_availability = torch.cuda.is_available()
print(f"PyTorch version: {torch.__version__}")
print(f"Cuda availability: {cuda_availability}")

# Auto-reloading external modules
# See
%load_ext autoreload
%autoreload 2
PyTorch version: 2.1.2
Cuda availability: False

1. The Fashion MNIST dataset

In this demo, we use the Fashion MNIST dataset, a modern alternative to the classic MNIST dataset traditionally used for handwriting recognition.

Fashion MNIST comprises 70,000 grayscale images, each 28x28 pixels, distributed across 10 different clothing categories. These images are small, detailed, and varied enough to challenge our model while being simple enough for straightforward processing and quick training times.

Fashion MNIST sprite
Figure 1. Fashion-MNIST samples (by Zalando, MIT License).

1.1 Dataset Structure

The dataset is split into two parts:

  • Training Set: This includes train_images and train_labels. These are the arrays that our model will learn from. The model sees these images and their corresponding labels, adjusting its weights and biases to reduce classification error.

  • Test Set: This comprises test_images and test_labels. These are used to evaluate how well our model performs on data it has never seen before. This is crucial for understanding the model’s generalization capability.

1.2 Understanding the Data

Each image in the dataset is a 28x28 NumPy array. The pixel values range from 0 to 255, with 0 being black, 255 being white, and the various shades of gray in between. The labels are integers from 0 to 9, each representing a specific category of clothing.

Class Names:

The dataset doesn’t include the names of the clothing classes, so we will manually define them for clarity when visualizing our results. Here’s the mapping:

9Ankle boot

In the following sections, we will load and preprocess this data, design and train a CNN, and finally evaluate its performance on the test set. Let’s get started!

# Defining class names
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot',

2. Loading the dataset

Let’s load the Fashion MNIST dataset directly from PyTorch:

import torch
from torchvision import datasets, transforms
from import DataLoader
from IPython import get_ipython

# Check if running on Google Colab
if 'google.colab' in str(get_ipython()):
    data_dir = '/content/FashionMNIST_data/'
    data_dir = './FashionMNIST_data/'  # Relative path for local execution

# Define transformation
transform = transforms.Compose([
    transforms.Normalize((0.5,), (0.5,))

# Seed for reproducibility

# Create training and validation datasets
trainset = datasets.FashionMNIST(data_dir, download=True, train=True, transform=transform)
validset = datasets.FashionMNIST(data_dir, download=True, train=False, transform=transform)

# Create data loaders for our datasets
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
validloader = DataLoader(validset, batch_size=32, shuffle=True)

# Report split sizes
print('Training set has {} instances'.format(len(trainset)))
print('Validation set has {} instances'.format(len(validset)))
Training set has 60000 instances
Validation set has 10000 instances

Defining Transformations:

  • transforms.ToTensor(): Converts the images into PyTorch tensors and scales the pixel values to the range [0, 1].

  • transforms.Normalize((0.5,), (0.5,)): Normalizes the tensor images so that each pixel value is centered around 0 and falls within the range [-1, 1]. This normalization helps in stabilizing the learning process and often leads to faster convergence in deep learning models.

2.1 Using DataLoader

The DataLoader in PyTorch provides batches of data, and you can iterate through these batches to collect all images and labels. This method is memory-efficient and is typically used when dealing with large datasets.

def extract_images_labels(loader):
    images = []
    labels = []
    for batch in loader:
        b_images, b_labels = batch
    return, dim=0),, dim=0)

# Extracting train and test images and labels
train_images, train_labels = extract_images_labels(trainloader)
valid_images, valid_labels = extract_images_labels(validloader)

Important Notes:

  • This method will load the entire dataset into memeory, which is fine for small datasets like the Fashion MNIST, but might not be feasible for significantly larger datasets.

  • train_images and valid_images will be tensors containing the image data, and train_labels and valid_labels will be tensors containing the corresponding labels.

  • These tensors are primarily for visualization purposes. For actual model training, especially with larger datasets, it’s recommended to use trainloader and validloader directly within the training loop to leverage their efficiency and memory-friendly approach.

3. Exploring the data

Before training our model, it’s essential to understand the Fashion MNIST dataset’s format and structure. This understanding helps in effectively tailoring our model and preprocessing steps.

3.1 Understanding the Training Set

Size and Shape of Training Images: The shape of the train_images tensor provides insight into the number of images and their dimensions.

The shape is represented as (N, C, H, W), where:

  • N = Number of images
  • C = Number of color channels per image (For grayscale images, C = 1)
  • H = Height of each image in pixels
  • W = Width of each image in pixels

In our dataset, each image is a 28 x 28 pixel grayscale image, so the shape will be (60000, 1, 28, 28), indicating 60,000 images with 1 color channel and 28x28 pixel resolution.

train_images.shape  # Output expected: (60000, 1, 28, 28)
torch.Size([60000, 1, 28, 28])

The “rank” of a tensor refers to the number of dimensions it has. For images represented as tensors, you can find the rank using the .dim() method, which returns the number of dimensions in the tensor.

train_images.dim() # Output expected: 4
  • Training Labels: The total number of labels in the training set should match the number of images. Each label corresponds to a category of fashion item.
len(train_labels)   # Output expected: 60000
  • Label Range: Each label is an integer from 0 to 9, where each number corresponds to a specific category (like T-shirts, trousers, etc.).
train_labels        # Output example: tensor([9, 8, 3, ..., 7, 7, 6])
tensor([9, 8, 3,  ..., 7, 7, 6])

3.2 Understanding the Test Set

  • Size and Shape of Test Images: The test set should have a similar structure but with fewer images, typically used for evaluating the model’s performance.
valid_images.shape   # Output expected: (10000, 1, 28, 28)
torch.Size([10000, 1, 28, 28])
  • Test Labels: The test set contains labels corresponding to each image, used to verify the model’s predictions.
len(valid_labels)    # Output expected: 10000

4. Data Preprocessing

Before training our neural network, it’s crucial to preprocess the data. This involves scaling the pixel values of the images to a standard range, which helps the network learn more efficiently.

4.1 Understanding Pixel Values:

  • Each image in the Fashion MNIST dataset is represented in grayscale with pixel values ranging from 0 to 255.

  • We applied Scaling and Normalization methods so that each pixel value is centered around 0 and falls within the range [-1, 1].

  • This normalization is often used in deep learning models as it centers the data around 0, which can lead to faster convergence during training. It can also help mitigate issues caused by different lighting and contrast in images.

  • The value -1 represents black, 1 represents white, and the values in between represent various shades of gray.

Let’s inspect the first image in the training set displaying these pixel values:

# Plotting firt image
plt.imshow(train_images[0][0], cmap='gray')


def get_text_color(value:float) -> str:
    """Returns 'white' for dark pixels and 'black' for light pixels."""
    return 'white' if value < 0.5 else 'black'

image_numpy = train_images[0][0].squeeze().numpy()
label = train_labels[0]

# Plotting the image
plt.imshow(image_numpy, cmap='gray')
plt.title(class_names[label], fontsize=16,)

# Overlaying the pixel values
for i in range(image_numpy.shape[0]):
    for j in range(image_numpy.shape[1]):
        plt.text(j, i, '{:.1f}'.format(image_numpy[i,j]), ha='center', va='center', color=get_text_color(image_numpy[i,j]))


4.2 Verifying the Data Format:

  • Before building the model, it’s a good practice to visualize the data to ensure it’s in the correct format. Displaying the first 25 images from the training set can help us confirm that the data is ready for model training.

  • Additionally, displaying the class name below each image ensures that the labels correspond correctly to the images:

for i in range(25):
    plt.imshow(train_images[i][0], cmap='gray') 


5. Convolutional Neural Networks

In deep learning, especially for tasks like image classification, Convolutional Neural Networks (CNNs) are often the architecture of choice. CNNs are designed to automatically and adaptively learn spatial hierarchies of features from input images.

CNN Architecture
Figure 2. Understanding Convolutional Neural Network (CNN): A Complete Guide (by LearnOpenCV).

5.1 Hierarchical Compositionality in CNNs

Complex features are built from simpler ones.

In Convolutional Neural Networks (CNNs), the concept of hierarchical compositionality plays a pivotal role. This idea is based on how CNNs learn to recognize and interpret images through layers that understand increasingly complex features:

  • Low-Level Features: The initial layers of a CNN focus on simple, low-level features such as edges, colors, and basic textures.

  • Mid-Level Features: As the data progresses through the network, these basic features are combined to form mid-level features, like shapes and specific patterns.

  • High-Level Features: In the deeper layers, these combinations further evolve into high-level features that represent more complex aspects of the image, such as entire objects or significant parts of them.

Hierarchical Compositionality
Figure 3. CS 7643: Deep Learning (by Zsolt Kira, Georgia Institute of Technology).

This hierarchical approach allows CNNs to build a deep understanding of images from simple to complex, making them highly effective for tasks like image classification.

Extra Resources:

5.2 CNN Architecture

Let’s construct a CNN suitable for the Fashion MNIST dataset using PyTorch’s torch.nn module. This CNN will consist of convolutional layers for feature extraction followed by fully connected layers for classification.

Here’s how we can structure our CNN:

import torch.nn as nn
import torch.nn.functional as F

class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        # Convolutional layers using Sequential
        self.conv_layers = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),   # 32 filters, 3x3 kernel, stride 1, padding 1
            nn.ReLU(inplace=True),                                  # ReLU activation
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 2x2 Max pooling with stride 2

            # Second convolutional layer
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 64 filters, 3x3 kernel, stride 1, padding 1
            nn.ReLU(inplace=True),                                  # ReLU activation
            nn.MaxPool2d(kernel_size=2, stride=2),                  # 2x2 Max pooling with stride 2

        # Fully connected layers using Sequential
        self.fc_layers = nn.Sequential(
            nn.Flatten(),                                           # Flatten the output of conv layers
            nn.Linear(64 * 7 * 7, 128),                             # Fully connected layer with 128 neurons
            nn.ReLU(inplace=True),                                  # ReLU activation
            nn.Linear(128, 10)                                      # Output layer with 10 neurons for 10 classes

    def forward(self, x):
        x = self.conv_layers(x)                                     # Pass through conv layers
        x = self.fc_layers(x)                                       # Pass through fully connected layers
        return x

# Create the model instance
model = CNNClassifier()

Architecture explanation:

  • Convolutional Layers: The model starts with two sets of convolutional layers, each followed by a ReLU activation and a max pooling layer. The first Conv2d layer takes a single-channel (grayscale) image and applies 32 filters. The second Conv2d layer increases the depth to 64 filters.

  • ReLU Activation: After each convolutional layer, a ReLU activation function is used. It introduces non-linearity, allowing the model to learn more complex patterns.

  • Max Pooling: Each max pooling layer (MaxPool2d) reduces the spatial dimensions of the feature map by half, helping in reducing the computation and controlling overfitting.

  • Fully Connected Layers: The output from the convolutional layers is flattened into a 1D vector and then passed through two fully connected layers. The first linear layer reduces the dimension to 128, and the second linear layer produces the final output corresponding to the 10 classes.

This architecture, structured using the nn.Sequential container, offers a clear and compact way to define a CNN in PyTorch. The model is now ready to be trained with the Fashion MNIST dataset.

5.3 Defining the Loss Function and Optimizer

Before training, the model requires a few additional settings, including an optimizer and a loss function:

  1. Optimizer: The optimizer is responsible for updating the model parameters based on the computed gradients. It’s crucial for the convergence of the training process.

  2. Loss Function: The loss function measures the discrepancy between the model’s predictions and the actual labels. During training, we aim to minimize this loss.

Here’s how you set up these components in PyTorch:

import torch.optim as optim

# Set the loss function
loss_fn = nn.CrossEntropyLoss()

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)


  • Cross-Entropy Loss: We use nn.CrossEntropyLoss for multi-class classification. This loss function combines nn.LogSoftmax and nn.NLLLoss in one single class. It is suitable for classification tasks with C classes.

  • Adam Optimizer: optim.Adam is used as the optimizer. It’s a popular choice due to its effectiveness in handling sparse gradients and adapting the learning rate during training.

  • Learning Rate (lr): This is a hyperparameter that controls how much to change the model in response to the estimated error each time the model weights are updated. Here, it’s set to 0.001.

6. Training the Model

Having defined the model architecture and set up the loss function and optimizer, we now move to one of the most crucial stages in building a machine learning model – training.

Training a model in deep learning involves feeding it data, letting it make predictions, and then adjusting the model parameters (weights) based on the error in its predictions. This process is repeated iteratively and is essential for the model to learn from the data.

6.1 The Training Loop

The core of model training in PyTorch is the training loop. During each iteration (or epoch) of the loop, the model makes predictions (a forward pass), calculates the error (loss), and updates its parameters (a backward pass). Here’s a basic outline of what this process involves:

  1. Forward Pass: The model processes the input data and makes predictions.

  2. Compute Loss: The discrepancy between the model’s predictions and the actual labels is calculated using the loss function.

  3. Backward Pass: Backpropagation is performed to calculate the gradients of the loss with respect to each model parameter.

  4. Update Model Parameters: The optimizer updates the model parameters using the computed gradients.

Training Function

Unlike TensorFlow’s Keras API, which provides high-level functions like and model.evaluate, PyTorch requires you to explicitly define the training loop and the evaluation process. This approach gives you more control and flexibility but also requires more code.

First, we define a function to encapsulate the training logic for one epoch. This function will handle the forward and backward passes, loss computation, and parameter updates:

def train_one_epoch(model, trainloader, optimizer, loss_fn, epoch_index, tb_writer):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    total_batches = len(trainloader)
    log_interval = max(1, total_batches // 10)  # Log 10 times per epoch or at least once

    for i, data in enumerate(trainloader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero the gradients to ensure they aren't accumulated

        # Forward pass
        outputs = model(inputs)          # Make predictions for this batch
        loss = loss_fn(outputs, labels)  # Compute the loss

        # Backward pass and optimization
        loss.backward()     # Compute the gradients
        optimizer.step()    # Adjust learning weigths

        # Accumulate loss
        running_loss += loss.item()
        if i % log_interval == log_interval - 1:        # Adjusted logging condition
            average_loss = running_loss / log_interval  # Average loss per batch in this set
            print(f"  batch {i+1} loss: {average_loss}")
            tb_x = epoch_index * total_batches + i + 1
            tb_writer.add_scalar('Loss/train', average_loss, tb_x)
            running_loss = 0.0

    return average_loss

Validation Function

Next, we add a function to perform validation after each training epoch. This function will evaluate the model on a separate validation dataset:

import os

# Model validation
def validate(model, validloader, loss_fn):
    model.eval()  # Set the model to evaluation mode
    running_vloss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in validloader:
            outputs = model(inputs)
            vloss = loss_fn(outputs, labels)
            running_vloss += vloss.item()

    average_vloss = running_vloss / len(validloader)
    return average_vloss

Early Stopping

Early Stopping is a technique that helps prevent overfitting by terminating the training when the model starts to learn noise or irrelevant patterns in the training data.

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        Early stops the training if validation loss doesn't improve after a given patience.
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            self.best_loss = val_loss
            self.counter = 0

# Initialize the early stopping object
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

Training and Validation Loop

With these functions in place, we can structure the overall training and validation loop:

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def train_model(model, trainloader, validloader, optimizer, loss_fn, num_epochs=10, save_model=True):
    Runs the training and validation loop for the given model.
        model (nn.Module): The neural network model to be trained.
        trainloader (DataLoader): DataLoader for the training dataset.
        validloader (DataLoader): DataLoader for the validation dataset.
        optimizer (torch.optim.Optimizer): Optimizer for the model.
        loss_fn: Loss function to use for training.
        num_epochs (int): Number of epochs for training. Default is 10.
        save_model (bool): Flag to save the model if validation loss improves. Default is True.
    timestamp ='%Y%m%d_%H%M%S')
    writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
    train_losses = []
    val_losses = []
    EPOCHS = num_epochs
    best_vloss = 1_000_000.

    for epoch in range(EPOCHS):
        print(f'EPOCH {epoch + 1}/{EPOCHS}')

        # Training
        avg_loss = train_one_epoch(model, trainloader, optimizer, loss_fn, epoch, writer)
        # print(f'Training Loss: {avg_loss:.4f}')

        # Validation
        avg_vloss = validate(model, validloader, loss_fn)
        print(f'LOSS train: {avg_loss:.4f} -- valid: {avg_vloss:.4f}')
        # print(f'Validation Loss: {avg_vloss:.4f}')
        # Log the running loss averaged per batch for both training and validation
            'Training vs. Validation Loss',
            {'Training': avg_loss, 'Validation': avg_vloss },

        # Append average losses after each epoch

        # Early Stopping
        if early_stopping.early_stop:
            print("Early stopping triggered")

        # Check for improvement and save model
        if save_model and avg_vloss < best_vloss:
            # Ensure the directory exists
            saved_models_dir = './saved_models'
            if not os.path.exists(saved_models_dir):
            # Save best model
            best_vloss = avg_vloss
            model_path = f"./saved_models/model_{epoch}_{timestamp}.pth"
  , model_path)
            print(f'Model saved to {model_path}')

    return train_losses, val_losses
# Train the model
train_losses, val_losses = train_model(model, trainloader, validloader, optimizer, loss_fn)
EPOCH 1/10
  batch 187 loss: 0.7815310352626331
  batch 374 loss: 0.48613907572101145
  batch 561 loss: 0.3999798647700784
  batch 748 loss: 0.36871147040217955
  batch 935 loss: 0.35516737118602437
  batch 1122 loss: 0.3617793916859091
  batch 1309 loss: 0.34854616754673384
  batch 1496 loss: 0.3176474018410884
  batch 1683 loss: 0.31534920716508824
  batch 1870 loss: 0.3031003690339665
LOSS train: 0.3031 -- valid: 0.3254
Model saved to ./saved_models/model_0_20240122_074537.pth
EPOCH 2/10
  batch 187 loss: 0.2633895138726834
  batch 374 loss: 0.25477837984335616
  batch 561 loss: 0.26777994118750414
  batch 748 loss: 0.2780310674944026
  batch 935 loss: 0.2514276590975211
  batch 1122 loss: 0.2670865686620302
  batch 1309 loss: 0.2500107698839775
  batch 1496 loss: 0.2612737203863534
  batch 1683 loss: 0.2562679691827871
  batch 1870 loss: 0.2597585876078848
LOSS train: 0.2598 -- valid: 0.2774
Model saved to ./saved_models/model_1_20240122_074537.pth
EPOCH 3/10
  batch 187 loss: 0.2182582863033774
  batch 374 loss: 0.21429858047933503
  batch 561 loss: 0.2213542611004516
  batch 748 loss: 0.20720045682899454
  batch 935 loss: 0.21772785568460423
  batch 1122 loss: 0.2048042686307494
  batch 1309 loss: 0.22424145604678017
  batch 1496 loss: 0.2157760307710757
  batch 1683 loss: 0.22659921588346282
  batch 1870 loss: 0.20506563242823683
LOSS train: 0.2051 -- valid: 0.2518
Model saved to ./saved_models/model_2_20240122_074537.pth
EPOCH 4/10
  batch 187 loss: 0.17509131173399042
  batch 374 loss: 0.18894203200458207
  batch 561 loss: 0.1780999759759973
  batch 748 loss: 0.18176765804662104
  batch 935 loss: 0.1851623724488651
  batch 1122 loss: 0.18986983400057345
  batch 1309 loss: 0.18469491857975562
  batch 1496 loss: 0.18138664159027332
  batch 1683 loss: 0.18155848991902754
  batch 1870 loss: 0.1737585065358463
LOSS train: 0.1738 -- valid: 0.2676
EPOCH 5/10
  batch 187 loss: 0.14732835041906903
  batch 374 loss: 0.15089740436947482
  batch 561 loss: 0.15200401314678677
  batch 748 loss: 0.1536160892742203
  batch 935 loss: 0.1623252102965738
  batch 1122 loss: 0.15252764867668483
  batch 1309 loss: 0.14172770584470287
  batch 1496 loss: 0.15746805861213786
  batch 1683 loss: 0.1480582723702817
  batch 1870 loss: 0.1564132237089748
LOSS train: 0.1564 -- valid: 0.2436
Model saved to ./saved_models/model_4_20240122_074537.pth
EPOCH 6/10
  batch 187 loss: 0.12108336666220411
  batch 374 loss: 0.12462331514984848
  batch 561 loss: 0.12107207644033559
  batch 748 loss: 0.12825290896596117
  batch 935 loss: 0.12207487767392938
  batch 1122 loss: 0.13691178718731206
  batch 1309 loss: 0.12671447862039276
  batch 1496 loss: 0.13837192698137804
  batch 1683 loss: 0.13607694819211003
  batch 1870 loss: 0.13199008475580715
LOSS train: 0.1320 -- valid: 0.2568
Early stopping triggered
# Plotting
def plot_loss_curves(train_losses, val_losses):
    fig, ax = plt.subplots(figsize=(6.4, 4.8), dpi=120)
    fig.suptitle("Training and Validation Loss Over Epochs", fontsize=14, weight='bold')
    ax.plot(train_losses, linewidth=2, label='Training Loss')
    ax.plot(val_losses, linewidth=2, label='Validation Loss')
    ax.set_xticks(np.arange(0, len(train_losses), 1))
    ax.grid(color='blue', linestyle='--', linewidth=1, alpha=0.2)
    ax.legend(loc="upper right")
    spines = ['top', 'right', 'bottom', 'left']
    for s in spines:

plot_loss_curves(train_losses, val_losses)


7. Evaluate Accuracy in PyTorch

Writing a function to calculate the accuracy. We’ll use this function with our validation dataset.

Accuracy Calculation Function

def evaluate_accuracy(model, testloader):
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient computation during inference
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

Using the function to evaluate Validation Accuracy:

model.eval()  # Set the model to evaluation mode
valid_accuracy = evaluate_accuracy(model, validloader)
print(f'\nValidation Accuracy: {valid_accuracy:.2f}%')
Validation Accuracy: 91.77%

Explanation of Accuracy in Image Classification

Accuracy is a key metric in image classification that quantifies how often a model correctly predicts the label of an image. It’s expressed as the percentage of test images correctly classified by the model.

Mathematically, accuracy is defined as the ratio of correct predictions to total predictions: $$ \text{Accuracy}=\frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $$

Alternatively, using terms from confusion matrix: $$ \text{Accuracy}=\frac{TP+TN}{TP+TN+FP+FN} $$


  • TP (True Positives): Images correctly identified as belonging to a class.
  • TN (True Negatives): Images correctly identified as not belonging to a class.
  • FP (False Positives): Images incorrectly identified as belonging to a class.
  • FN (False Negatives): Images incorrectly identified as not belonging to a class.

In multi-class settings, such as the Fashion MNIST dataset with 10 classes, accuracy is typically calculated as the ratio of correct predictions to total predictions. The use of TP, TN, FP, FN becomes more relevant in binary classification.

8. Making Predictions

After training your model in PyTorch, you can use it to make predictions on new data. PyTorch models output the logits directly, and you can apply a softmax function to convert these logits into probabilities for easier interpretation.

# Function to make predictions
def make_predictions(model, test_images):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient tracking
        logits = model(test_images)
        probabilities = F.softmax(logits, dim=1)
        return probabilities

# Make predictions
predictions = make_predictions(model, valid_images)

8.1 Interpreting Predictions

Each prediction is an array of 10 elements when working with the Fashion MNIST dataset, corresponding to the model’s confidence for each of the 10 classes.

# Look at the first prediction
first_prediction = predictions[0]

# Find the class with the highest confidence for the first prediction
predicted_class = torch.argmax(first_prediction)
print(f"Predicted class: {predicted_class.item()}")

# Compare with the actual label
actual_label = valid_labels[0]
print(f"Actual label: {actual_label}")
tensor([6.5018e-08, 1.0000e+00, 1.1627e-10, 3.9829e-08, 5.4289e-10, 1.7141e-13,
        4.4465e-09, 7.4519e-16, 3.9247e-12, 4.0176e-14])
Predicted class: 1
Actual label: 1

Key Points:

  • Softmax Transformation: Applying softmax to the logits converts them into probabilities, which sum up to 1. This makes the model’s outputs more interpretable as confidence scores for each class.

  • Disabling Gradient Tracking: Since we’re only making predictions and not training the model, we disable gradient tracking with torch.no_grad(), which reduces memory usage and speeds up computations.

  • Model Evaluation Mode: It’s crucial to set the model to evaluation mode (model.eval()) before making predictions to ensure layers like dropout and batch normalization work in inference mode.

  • Predictions: The predictions are tensors containing the probability of each class. The class with the highest probability is considered the model’s prediction.

8.2 Verify Predictions

With the model trained, you can use it to make predictions about some images.

Let’s look at the 0th image, predictions, and prediction array. Correct prediction labels are blue and incorrect prediction labels are red. The number gives the percentage (out of 100) for the predicted label.

import matplotlib.pyplot as plt

def plot_image(i, predictions_array, true_label, img, class_names):
    true_label, img = true_label[i], img[i]
    plt.imshow(img.squeeze(),  # Assuming img is a PyTorch Tensor

    predicted_label = torch.argmax(predictions_array).item()
    if predicted_label == true_label:
        color = 'blue'
        color = 'red'

    plt.xlabel("{} {:2.0f}% ({})".format(
        100 * torch.max(predictions_array).item(),

def plot_value_array(i, predictions_array, true_label, class_names):
    true_label = true_label[i]
    predictions_array = predictions_array.numpy()
    thisplot =, predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)

Now, let’s use these functions to visualize the model’s predictions. We’ll plot both the image and its prediction probability distribution:

# Example: Visualizing the first image and its prediction
def plot_single_image_prediction(img_number):
    plot_image(img_number, predictions[img_number], valid_labels, valid_images, class_names)
    plot_value_array(img_number, predictions[img_number], valid_labels, class_names)





Let’s plot several images with their predictions. Note that the model can be wrong even when very confident.

num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
    plot_image(i, predictions[i], valid_labels, valid_images, class_names)
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
    plot_value_array(i, predictions[i], valid_labels, class_names)


9. Loading a Trained Model

After training your model and saving its state, you can load the model for further inference, evaluation, or continued training. To load a saved model, you’ll need to:

  • Recreate the Model Architecture: Instantiate a new object of your model class. It’s important that this new model has the same architecture as the one you trained.

  • Load the Saved State Dictionary: Use torch.load() to load the saved state dictionary, and then load this state into your newly created model using model.load_state_dict().

  • Documentation: Saving and Loading models using PyTorch

Here’s an example:

import os

# Step 1: Instantiate the model
loaded_model = CNNClassifier()

# Step 2: Find the latest model file
model_dir = './saved_models'
model_files = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')]
latest_model_file = max(model_files, key=os.path.getctime)

# Step 3: Load the saved state dictionary

# Step 4: Switch the model to evaluation mode for inference

print(f"Loaded model from {latest_model_file}")
Loaded model from ./saved_models/model_4_20240122_074537.pth

9.1 Test Loaded Model

We can use the loaded model for inference:

valid_accuracy = evaluate_accuracy(loaded_model, validloader)
print(f'\nValidation Accuracy: {valid_accuracy:.2f}%')
Validation Accuracy: 91.71%

10. Challenge: Can You Build a Better CNN?

As we conclude this demo, I want to leave you with a challenge that will not only test what you’ve learned but also push the boundaries of your skills in computer vision and deep learning.

Your Mission: Outperform My Model

I’ve walked you through building and training a Convolutional Neural Network (CNN) that achieved a 91% accuracy on the Fashion MNIST dataset. Now, it’s your turn to take the reins. Can you construct and train a CNN that surpasses this benchmark?

Tips for Improvement:

  • Experiment with Architecture: Try adding more convolutional layers, or adjust the number of neurons in the dense layers. Introduce layers like Dropout or Batch Normalization to see if they help in achieving better generalization.

  • Tune Hyperparameters: Play around with different learning rates, batch sizes, or optimization algorithms. Sometimes, small changes in these parameters can lead to significant improvements.

  • Data Augmentation: Use techniques like rotation, scaling, or horizontal flipping to artificially expand your training dataset. This can often help improve the robustness of your model.

  • Advanced Techniques: If you’re feeling adventurous, explore more advanced architectures like ResNets or Capsule Networks, or delve into newer regularization techniques.

Share Your Results!

Once you’ve built your model, evaluate it on the validation set to see if you’ve managed to outdo the 91% accuracy.

Share your results, along with your unique approach and insights. This is a fantastic opportunity to engage with peers, exchange knowledge, and showcase your skills.

Murilo Gustineli
Murilo Gustineli
Computer Science at Georgia Tech

My research interests include deep learning, computer vision, and NLP