Linear Layer, Deriving the Gradient for the Backward Pass
Deriving the gradient for the backward pass for the linear layer using tensor calculus
The linear layer, a.k.a. dense layer or fully-connected layer, is everywhere in deep learning and forms the foundation to most neural networks. PyTorch defines the linear layer as:
Whereby the tensors and their shapes are:
- Input : (∗, in_features)
- Weights : (out_features,in_features)
- Bias : (out_features)
- Output : (∗, out_features)
The “∗” means any number of dimensions. PyTorch is a bit unusual in that it takes the transpose of the weights matrix before multiplying it with the input . The hat over indicates it’s a vector.
In this article, we will derive the gradients used for backpropagation for the linear layer, the function used when calling Y.backwards() on the output tensor Y. We will use index notation to express the function more precisely and tensor calculus to calculate the gradients.
To derive the gradient for each input, we proceed as follows:
- Translate the function into index notation.
- Calculate the derivative with respect to the output.
- Use the chain rule to determine the gradients of with respect to each of the inputs of the function. To do this we first must assume the output is a dependency of a downstream scalar function and we are provided with the gradients of with respect to the output.
Using Index Notation
To keep it simple, we are going to assume there is only one “∗” dimension, and it’s easy enough to extend the reasoning to more. Let’s introduce a new tensor so we can ignore the transpose for a bit and can express the linear layer using index notation like so:
In the first expression, the index is repeated and not used on the left-hand side of the equation, meaning it’s a dummy index and we sum over the multiplied components of and . In the second expression, the one-tensor is used to broadcast the vector into a matrix so it has the correct shape to be added to the first expression.
Gradients of all the inputs
Next, we want to calculate the gradient of the inputs , and .
The gradient of
Let’s first obtain the derivative with respect to . We need to remember to use new free indices ( and ) for the derivative operator:
The second term with the bias is independent of and so that’s zero. In the first term is just a factor and so can be moved outside the derivative operator:
From the rules of tensor calculus, we know that the derivative of a variable with itself equals a product of Kronecker deltas:
We can then contract index to obtain:
This is an order-4 tensor, i.e. a tensor with 4 dimensions, and so can’t be expressed using matrix notation. However, the tensor is only non-zero when due to the Kronecker delta. Fortunately, the gradient used for backpropagation is a lower-order tensor, and we can use matrix notation. To do that lets first assume is an input of a scalar function and we are provided with the gradients of with respect to . Then to derive the gradients for backpropagation, we apply the chain rule:
We now want to convert this to matrix notation, however, we are summing the second axis of both components so this can’t be represented as a matrix multiplication. In index notation, because each term is just an element of a tensor. So we can get:
We can then convert it to matrix notation like so (the square brackets indicate taking the elements of the matrix):
Therefore the gradient is calculated by multiplying the gradient of the loss with respect to the output with the weights .
Gradient of
We use the same procedure as above. First, obtain the derivative with respect to (we use new free indices for the derivative operator):
As , we know that . So, if we continue with the above:
Then, we obtain the backpropagated gradient by assuming a downstream loss consumes the output :
We now want to convert this to matrix notation, however, we are summing the first axis of both components so this can’t be represented as a matrix multiplication. We use the same trick we used in the previous section and take the transpose of to swap the index order and then convert this to matrix notation:
Therefore, the gradient is calculated by taking the transpose of the gradient of the loss with respect to the output and multiplying it with the input .
Gradient of
First we calculate the gradient of the output with respect to the bias:
Again, using the rules of tensor calculus, we know that the derivative of a variable with itself equals a product of Kronecker deltas:
Finally, we derive the gradient used for backpropagation by assuming a downstream loss consumes the output :
Simply put, this means we sum the gradient in the dimension to obtain the backpropagated gradient of the bias:
Or in matrix notation, you can express this using a vector filled with ones:
Comparison with PyTorch
We can numerically show that the above results are correct by comparing them to the output of PyTorch using this code:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from torch import nn
M, K, N = 128, 256, 512
torch.manual_seed(42)
# Create tensors
linear = nn.Linear(in_features = K, out_features=N, bias=True)
x = torch.rand((M, K), requires_grad=True)
y_grad = torch.rand((M, N))
# Run forward and backward pass
y = linear(x)
y.backward(y_grad)
# Calculate gradients using above equations
x_grad = y_grad @ linear.weight
a_grad = y_grad.T @ x
b_grad = y_grad.sum(axis=0)
# Check against PyTorch
torch.testing.assert_close(x_grad, x.grad)
torch.testing.assert_close(a_grad, linear.weight.grad)
torch.testing.assert_close(b_grad, linear.bias.grad)
Next
If you would like to read more about calculating gradients using tensor calculus and index notation, please have a look at the series introduction or The Tensor Calculus You Need for Deep Learning.