Cross-Entropy Loss (Softmax) Gradient Used In Deep Learning
Obtaining the gradient of the Cross-entropy loss (softmax and negative log-likelihood loss function
Cross-entropy is a common loss used for classification tasks in deep learning - including transformers. It is defined as the softmax function followed by the negative log-likelihood loss. Here, I will walk through how to derive the gradient of the cross-entropy loss used for the backward pass when training a model. I will use tensor calculus and index notation - see my article The Tensor Calculus You Need for Deep Learning for more information.
Say we have an input vector of logits and vector of target classes , we can define the cross-entropy loss using index notation:
The first equation calculates the softmax: for every value in we divide the exponent by the sum of the exponent of all values. The second equation calculates the negative log-likelihood loss of the softmax: is the index corresponding to the correct target label. Note that is a constant and not a free or dummy index.
First, we need to derive the Jacobian tensor of the function. Let’s start with the denominator in the softmax:
Notice that we can drop as is a free index, and always equals 1.
To differentiate the softmax tensor with respect to , we use the quotient rule and simplify the expression by reusing the definition of the softmax:
And we move on to the negative log-likelihood loss:
Note that as is a constant, not a dummy index, the expression is when and zero otherwise (it is not a summation).
Putting the two expressions together to get the complete gradient:
It’s interesting to note that because of the influence of normalising all values by , all logits have a non-zero gradient even if they do not correspond to the true label.
Then, deriving the backpropagated gradient is trivial:
It might be the case that we start backpropagation using the cross-entropy loss, in that case and .
PyTorch Implementation
We can check the result by comparing the equation above with PyTorch’s built-in autograd output. I have generalised the above equation for a batch of results of size (N):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch import nn
import torch.nn.functional as F
# Create random inputs
torch.manual_seed(42)
N = 128
num_classes = 256
x = torch.rand((N, num_classes), dtype=torch.float32, requires_grad=True)
target = torch.randint(high=num_classes-1, size=(N,))
dldc = torch.rand((N,), dtype=torch.float32)
# Run forward and backward pass using built-in function
loss_layer = nn.CrossEntropyLoss(reduction='none')
c = loss_layer(x, target)
c.backward(dldc)
# Calculate gradient using above equation
s = torch.softmax(x, dim=1)
dldx = dldc.unsqueeze(1)*(s - F.one_hot(target, num_classes=num_classes))
# Compare with PyTorch
torch.testing.assert_close(dldx, x.grad)
Next
Further examples of calculating gradients using tensor calculus and index notation can be found on the intro page.