Most modern convolutional neural networks (CNNs) use the same principles like convolution and max-pooling layers followed by a few fully connected layers. A convolutional layer can simply replace max-pooling with increased stride without loss in accuracy on several image recognition benchmarks. 

Since dimensionality reduction is performed via strided convolution rather than max-pooling we can replace the pooling layers, which are present in all modern CNNs used for object recognition, with standard convolutional layers with stride two

What is Downsampling?

Why do we downsample feature maps using pooling layers?  Why not remove the max-pooling layers and keep fairly large feature maps all the way up? Let’s look convolutional base of the model:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt




It isn’t conducive to learning a spatial hierarchy of features. The 3 × 3 windows in the third layer will only contain information coming from 7 × 7 windows in the initial input.

The high-level patterns learned by the convnet will still be very small with regard to the initial input, which may not be enough to learn to classify digits. We need the features from the last convolution layer to contain information about the totality of the input.

Pooling vs strid convolution

The final feature map has 22 × 22 × 64 = 30,976 total coefficients per sample. This is huge. If you were to flatten it to stick a Dense layer of size 512 on top, that layer would have 15.8 million parameters. This is far too large for such a small model and would result in intense overfitting.

The reason to use downsampling is to reduce the number of feature-map coefficients to process, as well as to induce spatial filter hierarchies by making successive convolution layers look at increasingly large windows in terms of the fraction of the original input they cover.

That max pooling isn’t the only way you can achieve downsampling. You can also use strides in the prior convolution layer or you can use average pooling instead of max pooling, where each local input patch is transformed by taking the average value of each channel over the patch, rather than the max. 

Strided Convolution

The center tiles of the convolution windows are all contiguous. But the distance between two successive windows is a parameter of the convolution, called its stride, which defaults to 1. It’s possible to have strided convolutions with a stride higher than 1.

Strided Convolution

The patches are extracted by a 3 × 3 convolution with stride 2 over a 5 × 5 input (without padding).

Using stride 2 means the width and height of the feature map are downsampled by a factor of 2 in addition to any changes induced by border effects. Strided convolutions are rarely used in practice, although they can come in handy for some types of models.

The advantage of the Strided convolution layer is that it can learn certain properties that might not be with the pooling layer. Pooling is a fixed operation and convolution can be learned. 

On the other hand, pooling is a cheaper operation than convolution, both in terms of the amount of computation that you need to do and a number of parameters that you need to store (no parameters for the pooling layer). There are examples when one of them is a better choice than the other.

When the Strided Convolution is better than Pooling?

The first layer in the ResNet uses convolution with strides. This layer by itself significantly reduces the amount of computation that has to be done by the network in the subsequent layers.

def ResNet50(input_shape = (64, 64, 3), classes = 2):
    # Define the input as a tensor with shape input_shape
    X_input = Input(input_shape)

    # Zero-Padding
    X = ZeroPadding2D((3, 3))(X_input)

    # Stage 1
    X = Conv2D(64, (7, 7), strides = (2, 2), name = 'conv1', kernel_initializer = glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis = 3, name = 'bn_conv1')(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((3, 3), strides=(2, 2))(X)

It compresses multiple 3×3 convolutions (3 to be exact) into one 7×7 convolution, to make sure that it has exactly the same receptive field as 3 convolution layers (even though it is less powerful in terms of what it can learn).

At the same time, this layer applies stride=2 that downsamples the image. Because this first layer in ResNet does convolution and downsampling at the same time, the operation becomes significantly cheaper computationally. 

If you use stride=1 and pooling for downsampling, then you will end up with convolution that does 4 times more computation + extra computation for the next pooling layer. The same trick was used in SqueezeNet and some other neural network architectures.

Where Pooling is Better than Strided Convolution?

In the ResNet, in a few places, they put 1×1 convolution in the skip connection when downsampling was applied to the image. This convolution layer makes gradient propagation harder

One of the major changes in FishNet is that they get rid of the convolutions in the residual connections and replace them with pooling and simple upscales/identities/concatenations. This solution fixes a problem with gradient propagation in very deep networks.

In essence, max-pooling (or any kind of pooling) is a fixed operation, and replacing it with a strided convolution can also be seen as learning the pooling operation, which increases the model’s expressiveness ability. The downside is that it also increases the number of trainable parameters, but this is not a real problem in our days.


Max pooling tends to work better than Convolution stride solutions. In a nutshell, the reason is that features tend to encode the spatial presence of some pattern or concept over the different tiles of the feature map, and it’s more informative to look at the maximal presence of different features than at their average presence. 

So the most reasonable subsampling strategy is to first produce dense maps of features (via unstrided convolutions) and then look at the maximal activation of the features over small patches, rather than looking at sparser windows of the inputs (via strided convolutions) or averaging input patches, which could cause you to miss or dilute feature-presence information.

Related Post