Understanding ResNet50: A Deep Dive with PyTorch

3 minute read

Published:

Introduction

In the realm of deep learning and computer vision, convolutional neural networks (CNNs) play a pivotal role in tasks such as image classification, object detection, and segmentation. Among these architectures, ResNet, short for Residual Network, has stood out for its remarkable performance and ability to train very deep networks.

In this blog post, we’ll delve into the details of ResNet50, a specific variant of the ResNet architecture, and implement it from scratch using PyTorch. By the end, you’ll have a solid understanding of ResNet50 and the practical skills to implement it on your own.

Understanding ResNet50

ResNet50 is a variant of ResNet that specifically contains 50 layers. The key innovation introduced by ResNet is the concept of residual learning, where each layer learns the residual with respect to the input, making it easier to train very deep networks. This is achieved through the use of skip connections or shortcuts, allowing the gradient to flow directly through the network.

The architecture of ResNet50 consists of a series of convolutional layers, followed by batch normalization and rectified linear unit (ReLU) activation functions. The residual blocks, which contain the skip connections, enable the network to efficiently learn the identity mapping. The last layers typically include a global average pooling layer and a fully connected layer for classification.

Implementing ResNet50 in PyTorch

Let’s now implement ResNet50 from scratch using PyTorch. We’ll define the architecture in a modular way, making it easier to understand and modify.

import torch
import torch.nn as nn

# Define the basic block with skip connection
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()

        # First convolutional layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Second convolutional layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(identity)
        out = self.relu(out)

        return out

# Define the ResNet50 model
class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet50, self).__init__()

        self.in_channels = 64

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual blocks
        self.layer1 = self._make_layer(BasicBlock, 64, 3, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 4, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 6, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 3, stride=2)

        # Global average pooling and fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

# Create ResNet50 model
resnet50 = ResNet50()

# Display the model architecture
print(resnet50)

This implementation follows the structure of ResNet50, with the BasicBlock serving as the fundamental building block. The ResNet50 class defines the overall architecture, including the initial convolutional layer, residual blocks, and the final classification layer.

Conclusion

In this blog post, we’ve explored the ResNet50 architecture and implemented it from scratch using PyTorch. ResNet50’s deep structure and skip connections have made it a go-to choice for various computer vision tasks. Understanding the inner workings of such architectures is crucial for building and customizing deep learning models.

Feel free to experiment with the code, modify hyperparameters, or integrate it into your projects to leverage the power of ResNet50 for image classification tasks. Happy coding!