Hey! If you love Python and building Python apps as much as I do, let's connect on Twitter or LinkedIn. I talk about this stuff all the time!

How to Train a Neural Network in PyTorch

Master the art of building and training neural networks using PyTorch, a popular Python library for deep learning. This comprehensive guide takes you through the process of creating and training a ne …


Updated July 25, 2023

|Master the art of building and training neural networks using PyTorch, a popular Python library for deep learning. This comprehensive guide takes you through the process of creating and training a neural network from scratch, covering concepts such as data loading, model definition, optimization, and evaluation.|

Definition of the Concept

Training a neural network in PyTorch involves several steps that work together to build and refine a deep learning model. At its core, training a neural network is about optimizing the model’s parameters to minimize the error between predicted outputs and actual labels.

PyTorch is a dynamic computation graph library that allows for rapid prototyping and development of machine learning models. Its strengths lie in its ability to easily manage complex computations and provide flexibility when working with deep learning architectures.

Step-by-Step Explanation

Here are the steps involved in training a neural network using PyTorch:

1. Importing Libraries and Loading Data

import torch
import torchvision
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
  • We start by importing the necessary libraries: torch and torchvision.
  • Next, we define a transformation (transform) that will be applied to our data. In this case, we’re converting our images to tensors.
  • Then, we load our dataset using MNIST. The download=True argument tells PyTorch to download the dataset if it’s not already present on your system.

2. Defining the Model Architecture

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 128)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=0.2)

        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out
  • We define a custom neural network class (Net) that inherits from torch.nn.Module.
  • Inside the class’s constructor (__init__ method), we initialize three fully connected layers (fc1, relu, and dropout) followed by another fully connected layer (fc2).
  • The forward method defines how our data flows through the network.

3. Defining Loss Function, Optimizer, and Evaluation Metrics

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train for a specified number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    
    # Iterate through the training dataset
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        # Zero gradients before backpropagation
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss and optimize weights
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Print the average loss per epoch
    print('Epoch %d: Average Loss %.3f' % (epoch + 1, running_loss / len(train_loader)))
  • We define a loss function (CrossEntropyLoss) and an optimizer (Adam).
  • Then, we loop through each epoch of training data.
  • Inside the loop, we perform a forward pass to calculate our network’s outputs, then compute the cross-entropy loss between predictions and labels.
  • Next, we backpropagate this error using backward() to adjust model weights and optimize them using our specified optimizer (Adam).
  • Finally, we print out the average loss for each epoch.

4. Evaluating Model Performance

# Evaluate on test set
model.eval()
test_loss = 0.0
correct = 0

with torch.no_grad():
    for i, data in enumerate(test_loader):
        inputs, labels = data
        
        # Forward pass (no gradients)
        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset)

print('\nTest set: Average Loss %.3f | Accuracy %d/%d (%.2f%%)\n' % (
    test_loss,
    correct,
    len(test_loader.dataset),
    100. * accuracy))
  • After training our model, we evaluate its performance on a separate test dataset.
  • We loop through the test_loader, compute predictions for each input image using our pre-trained network (model), and calculate cross-entropy loss between these predictions and ground truth labels.
  • Then, we print out average test set loss and model accuracy.

Code Explanation

The code provided is a basic example of how to train a neural network with PyTorch. It starts by loading the MNIST dataset using torchvision. The training loop iterates over each epoch, where for every mini-batch in the training data:

  1. Zero Gradients: Before performing any forward passes or backpropagation, we reset all gradients to zero.
  2. Forward Pass: We pass our input images through the network (model) and obtain output predictions.
  3. Backward Pass: Using backward() on the loss calculated from these predictions and true labels, we propagate this error backwards to adjust model weights.
  4. Optimization: Finally, using an optimizer like Adam, we update model parameters to minimize loss.

After training is complete, we use our trained network to make predictions on a separate test dataset, calculating mean cross-entropy loss between predicted outputs and actual labels for each input image in the test set, along with overall accuracy by comparing these predictions against ground truth labels.

Stay up to date on the latest in Python, AI, and Data Science

Intuit Mailchimp