Shedding some light on the causes behind CUDA out of memory ERROR, and an example on how to reduce by 80% your memory footprint with a few lines of code in Pytorch
In this first part, I will explain how a deep learning models that use a few hundred MB for its parameters can crash a GPU with more than 10GB of memory during their training !
So where does this need for memory comes from? Below I present the two main high-level reasons why a deep learning training need to store information:
information necessary to backpropagate the error (gradients of the activation w.r.t. the loss)
information necessary to compute the gradient of the model parameters
If there is one thing you should take out from this article, it is this:
As a rule of thumb, each layer with learnable parameters will need to store its input until the backward pass.
This means that every batchnorm, convolution, dense layer will store its input until it was able to compute the gradient of its parameters.
Now even some layer without any learnable parameters need to store some data! This is because we need to backpropagate the error back to the input and we do this thanks to the chain rule:
Chain rule:(a_i being the activations of the layer i)
The culprit in this equation is the derivative of the input w.r.t the output. Depending on the layer, it will
be dependent on the parameters of the layer (dense, convolution…)
be dependent on nothing (sigmoid activation)
be dependent on the values of the inputs:
eg MaxPool, ReLU …
For example, if we take a ReLU activation layer, the minimum information we need is the sign of the input.
Different implementations can look like:
We store the whole input layer
We store a binary mask of the signs (that takes less memory)
We check if the output is stored by the next layer. If so, we get the sign info from there and we don’t need to store additional data
Maybe some other smart optimization I haven’t thought of…
Now let’s take a closer look at a concrete example: The ResNet18!
We are going to look at the memory allocated on the GPU at specific times of the training iteration:
At the beginning of the forward pass of each module
At the end of the forward pass of each module
At the end of the backward pass of each module
(Full code and Github repo available here)
The logger looks like this:
Then we can look at the memory consumption for the resnet18 (from the torchvision.models) with the following code:
Memory consumption during one training iteration of a ResNet18
A few things to observe:
The memory keeps increasing during the forward pass and then starts decreasing during the backward pass
The slope is pretty steep at the beginning and then flattens:
→ The activations become lighter and lighter when we go deeper into the network
We have a maximum memory of about 2500 MB
Optional: the next section digs deeper into the shape of the plot
Let’s try to understand why memory usage is more important in the first layers.
For this, I display the memory impact in MB of each layer and analyse it.
Some reading key:
The indentation levels represent the parent/submodules relationship (e.g. the ResNet, is the root torch.nn.Module)
On one line we see:
→The name of the Module
→The hook concerned. (pre: before the forward pass, fwd: at the end of the forward pass, bwd: at the end of the backward pass)
→The GPU memory difference with the previous line, if there is any (in MegaBytes)
→Some comments made by me :)
ResNet pre # <- shape of the input (128, 3, 224, 224) Conv2d pre Conv2d fwd 392.0 # <- shape of the output (128, 64, 112, 112) BatchNorm2d pre BatchNorm2d fwd 392.0 ReLU pre ReLU fwd MaxPool2d pre MaxPool2d fwd 294.0 # <- shape of the output (128, 64, 56, 56) Sequential pre BasicBlock pre Conv2d pre Conv2d fwd 98.0 # <-- (128, 64, 56, 56) BatchNorm2d pre BatchNorm2d fwd 98.0 ReLU pre ReLU fwd Conv2d pre Conv2d fwd 98.0 BatchNorm2d pre BatchNorm2d fwd 98.0 ReLU pre ReLU fwd BasicBlock fwd ... ... ResNet fwd # <-- End of the forward pass Linear bwd 2.0 # <-- Beginning of the backward pass ... ... BatchNorm2d bwd -98.0 Conv2d bwd -98.0 MaxPool2d bwd 98.0 ReLU bwd 98.0 BatchNorm2d bwd -392.0 Conv2d bwd -784.0 # <-- End of the backward pass
The input shape of the layer is :
image dimensions: 224 x 224
The layer is : Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
The output shape of the layer is :
image dimensions: 112 x 112
The additional allocation size for the output is:
(128 x 64 x 112 x 112 x 4) / 2**20 = 392 MB
(NB: the factor 4 comes from the storage of each number in 4 bytes as FP32, the division comes from the fact that 1 MB = 2**20 B)
Note also that this additional memory will not be freed once we moved on to the next layers
Here we went through a max-pooling which divided the height and the width of the activations by 2.
The conv layer conserves the dimensions: Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
The additional memory allocated is:
(128 x 64 x 56 x 56 x 4) / 2**20 = 98 MB (=392/4)
Next, I will first present two ideas and their implementation in Pytorch to divide by 5 the footprint of the resnet in 4 lines of code :)
The idea behind gradient checkpointing is pretty simple:
If I need some data that I have computed once, I don’t need to store it: I can compute it again
So basically instead of storing all the layers’ inputs, I will store a few checkpoints along the way during the forward pass, and when I need some input that I haven’t stored I’ll just recompute it from the last checkpoint.
Plus it’s really easy to implement in Pytorch, especially if you have a nn.Sequential module. To apply it , I changed the line 9 of the log function as below:
And since it takes an instance of nn.Sequential, I created it as such
The idea behind mixed-precision training is the following:
If we store every number on 2 bytes instead of 4: we’ll use half the memory
→But then the training doesn’t converge…
To fix this, different techniques are combined (loss scaling, master weight copy, casting to FP32 for some layers…).
The implementation of mixed-precision training can be subtle, and if you want to know more, I encourage you to go to visit the resources at the end of the article.
Thankfully everything has been beautifully automatized in the Pytorch module!
So we can with only a couple of changes get some nice memory optimization (check lines 6, 7, 14, 15)
Then we can combine both into the following :
Memory consumption comparison of the optimizations method with the baseline
Here are the main facts to observe:
AMP: The overall shape is the same, but we use less memory
Checkpointing : We can see that the model does not accumulate memory during the forward pass
Below are the maximum memory footprint of each iteration, and we can see how we divided the overall footprint of the baseline by 5.
Maximum memory consumption for each training iteration
(Full code and Github repo available here)
Some notes on the results:
We only looked at the memory savings
To have a better comparison, we need to look at two additional metrics: training speed
and training accuracy
My intuition on this would be:
Checkpointing is slower than the baseline and achieves the same accuracy
AMP is faster than the baseline and achieves a lower accuracy
To be confirmed in the next episode … References: Why is so much memory needed for deep neural network ? Fitting larger networks into memory. Mixed-Precision Training of Deep Neural Networks Explore Gradient-Checkpointing in PyTorch Are you looking for Image Recognition Experts? Don’t hesitate to contact us!
Learn to Test Your Pyspark Project with Pytest — example-based Tutorial
In this tutorial, I will explain how to get started with test writing for your Spark project.
Python: How to Train your Own Model with NLTK and Stanford NER Tagger?
This guide shows how to use NER tagging with NLTK and Standford NER tagger (Python).
Set up TensorFlow with Docker + GPU in Minutes
Why Docker is the best platform to use Tensorflow with a GPU.