Post

Layer Normalization, Deriving the Gradient for the Backward Pass

Obtaining the gradient of the layer normalization layer

Layer Normalization, Deriving the Gradient for the Backward Pass

This post explains how to calculate the gradients of layer normalisation used for backpropagation using tensor calculus and index notation. It is part of a series on differentiating and calculating gradients in deep learning. This example is quite long and involved but combines the different concepts presented in the article series. If you have not done so, be sure to become familiar with the previous examples first.

PyTorch defines the layer normalization operation for an input matrix XX, with shape batch size (B)(B) by hidden size (H)(H), as:

y=xE[x]Var[x]+ϵγ+βy=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta

Where the mean E[x]\mathrm{E}[x] and variance Var[x]\operatorname{Var}[x] are calculated for each sample in a batch, and γ\gamma and β\beta are learnable vector weights with lengths equal to the hidden size. ϵ\epsilon is a constant usually equal to 1e051 \mathrm{e}-05.

As shown previously, we can represent this using index notation:

mb=1H1hxbhvb=1H1h(xbh1hmb)2ybh=xbh1hmbvb+ϵγh+1bβh\begin{aligned} m_{b} & =\frac{1}{H} \mathbf{1}_{h} x_{b h} \\ v_{b} & =\frac{1}{H} \mathbf{1}_{h}\left(x_{b h}-\mathbf{1}_{h} m_{b}\right)^{2} \\ y_{b h} & =\frac{x_{b h}-\mathbf{1}_{h} m_{b}}{\sqrt{v_{b}+\epsilon}} \gamma_{h}+\mathbf{1}_{b} \beta_{h} \end{aligned}

To make the problem more manageable, we are going to define additional intermediate tensor functions μbh\mu_{b h} and σb\sigma_{b}:

mb=1H1hxbhμbh=xbh1hmbvb=1H1hμbh2σb=vb+ϵybh=μbhσbγh+1bβh\begin{aligned} m_{b} & =\frac{1}{H} \mathbf{1}_{h} x_{b h} \\ \mu_{b h} & =x_{b h}-\mathbf{1}_{h} m_{b} \\ v_{b} & =\frac{1}{H} \mathbf{1}_{h} \mu_{b h}^{2} \\ \sigma_{b} & =\sqrt{v_{b}+\epsilon} \\ y_{b h} & =\frac{\mu_{b h}}{\sigma_{b}} \gamma_{h}+\mathbf{1}_{b} \beta_{h} \end{aligned}

The tensor functions above have the following dependency graph:

dependency graph

Gradient of Weights

Let’s start with the easier gradients γ\gamma and β\beta:

ybhγq=μbhσbδhqybhβq=1bδhq\begin{aligned} & \frac{\partial y_{b h}}{\partial \gamma_{q}}=\frac{\mu_{b h}}{\sigma_{b}} \delta_{h q} \\ & \frac{\partial y_{b h}}{\partial \beta_{q}}=\mathbf{1}_{b} \delta_{h q} \end{aligned}

Secondly, we find the backpropagated gradient:

lγq=lybhybhγq=lybqμbqσblβq=lybhybhβq=lybq1b\begin{aligned} \frac{\partial l}{\partial \gamma_{q}} & =\frac{\partial l}{\partial y_{b h}} \frac{\partial y_{b h}}{\partial \gamma_{q}}=\frac{\partial l}{\partial y_{b q}} \frac{\mu_{b q}}{\sigma_{b}} \\ \frac{\partial l}{\partial \beta_{q}} & =\frac{\partial l}{\partial y_{b h}} \frac{\partial y_{b h}}{\partial \beta_{q}}=\frac{\partial l}{\partial y_{b q}} \mathbf{1}_{b} \end{aligned}

Gradient of Input XX

Directly calculating the derivative of ybhy_{b h} with respect to xpqx_{p q} is quite complex and is an order-4 tensor. However, we don’t need to construct this tensor fully since we can backpropagate the loss after each intermediate tensor function, simplifying the process. The backpropagated gradient is simpler because the loss is a scalar, meaning the gradient is, at most, an order-2 tensor.

To accomplish this, we’ll start at the end of the dependency graph and calculate the Jacobian tensor at each intermediate stage, followed by calculating the backpropagated gradient. The goal is to obtain an expression of l/xpq\partial l / \partial x_{p q} in terms of l/ypq\partial l / \partial y_{p q}.

Gradient of σ\sigma

The derivative of ybhy_{b h} with respect to σp\sigma_{p}:

ybhσp=σp(μbhσb1γh+1bβh)=μbhσb2γhδbp\begin{aligned} \frac{\partial y_{b h}}{\partial \sigma_{p}} & =\frac{\partial}{\partial \sigma_{p}}\left(\mu_{b h} \sigma_{b}^{-1} \gamma_{h}+\mathbf{1}_{b} \beta_{h}\right) \\ & =-\mu_{b h} \sigma_{b}^{-2} \gamma_{h} \delta_{b p} \end{aligned}

And the backpropagated gradient:

lσp=lybhybhσp=lybh(μbhσb2γhδbp)=lyphμphσp2γh\begin{aligned} \frac{\partial l}{\partial \sigma_{p}} & =\frac{\partial l}{\partial y_{b h}} \frac{\partial y_{b h}}{\partial \sigma_{p}} \\ & =\frac{\partial l}{\partial y_{b h}}\left(-\mu_{b h} \sigma_{b}^{-2} \gamma_{h} \delta_{b p}\right) \\ & =-\frac{\partial l}{\partial y_{p h}} \mu_{p h} \sigma_{p}^{-2} \gamma_{h} \end{aligned}

Gradient of vv

The derivative of σb\sigma_{b} with respect to vpv_{p}:

σbvp=vp[(vb+ϵ)0.5]=12(vb+ϵ)0.5δbp=δbp2σb\begin{aligned} \frac{\partial \sigma_{b}}{\partial v_{p}} & =\frac{\partial}{\partial v_{p}}\left[\left(v_{b}+\epsilon\right)^{0.5}\right] \\ & =\frac{1}{2}\left(v_{b}+\epsilon\right)^{-0.5} \delta_{b p} \\ & =\frac{\delta_{b p}}{2 \sigma_{b}} \end{aligned}

The backpropagated gradient:

lvp=lσbσbvp=lσbδbp2σb=lσp12σp\begin{aligned} \frac{\partial l}{\partial v_{p}} & =\frac{\partial l}{\partial \sigma_{b}} \frac{\partial \sigma_{b}}{\partial v_{p}} \\ & =\frac{\partial l}{\partial \sigma_{b}} \frac{\delta_{b p}}{2 \sigma_{b}} \\ & =\frac{\partial l}{\partial \sigma_{p}} \frac{1}{2 \sigma_{p}} \end{aligned}

Substituting in l/σp\partial l / \partial \sigma_{p} from the previous step:

lvp=lσp12σp=(lyphμphσp2γh)12σp=lyphμphγh2σp3\begin{aligned} \frac{\partial l}{\partial v_{p}} & =\frac{\partial l}{\partial \sigma_{p}} \frac{1}{2 \sigma_{p}} \\ & =\left(-\frac{\partial l}{\partial y_{p h}} \mu_{p h} \sigma_{p}^{-2} \gamma_{h}\right) \frac{1}{2 \sigma_{p}} \\ & =-\frac{\partial l}{\partial y_{p h}} \frac{\mu_{p h} \gamma_{h}}{2 \sigma_{p}^{3}} \end{aligned}

Gradient of μ\mu

The function μbh\mu_{b h} is consumed by two functions, vbv_{b} and ybhy_{b h}, therefore we need to differentiate both functions by μbh\mu_{b h}. First, the derivative of vbv_{b} with respect to μpq\mu_{p q}:

vbμpq=μpq(1H1hμbh2)=2H1hμbhδbpδhq=2Hμbqδbp\begin{aligned} \frac{\partial v_{b}}{\partial \mu_{p q}} & =\frac{\partial}{\partial \mu_{p q}}\left(\frac{1}{H} \mathbf{1}_{h} \mu_{b h}^{2}\right) \\ & =\frac{2}{H} \mathbf{1}_{h} \mu_{b h} \delta_{b p} \delta_{h q} \\ & =\frac{2}{H} \mu_{b q} \delta_{b p} \end{aligned}

Then derivative of ybhy_{b h} with respect to μpq\mu_{p q}:

ybhμpq=μpq(μbhσbγh+1bβh)=γhσbδbpδhq\begin{aligned} \frac{\partial y_{b h}}{\partial \mu_{p q}} & =\frac{\partial}{\partial \mu_{p q}}\left(\frac{\mu_{b h}}{\sigma_{b}} \gamma_{h}+\mathbf{1}_{b} \beta_{h}\right) \\ & =\frac{\gamma_{h}}{\sigma_{b}} \delta_{b p} \delta_{h q} \end{aligned}

When applying the chain rule to obtain the backpropagated gradient, we need to include contributions from both functions:

lμpq=lvbvbμpq+lybhybhμpq=lvb(2Hμbqδbp)+lybh(γhσbδbpδhq)=lvp2μpqH+lypqγqσp\begin{aligned} \frac{\partial l}{\partial \mu_{p q}} & =\frac{\partial l}{\partial v_{b}} \frac{\partial v_{b}}{\partial \mu_{p q}}+\frac{\partial l}{\partial y_{b h}} \frac{\partial y_{b h}}{\partial \mu_{p q}} \\ & =\frac{\partial l}{\partial v_{b}}\left(\frac{2}{H} \mu_{b q} \delta_{b p}\right)+\frac{\partial l}{\partial y_{b h}}\left(\frac{\gamma_{h}}{\sigma_{b}} \delta_{b p} \delta_{h q}\right) \\ & =\frac{\partial l}{\partial v_{p}} \frac{2 \mu_{p q}}{H}+\frac{\partial l}{\partial y_{p q}} \frac{\gamma_{q}}{\sigma_{p}} \end{aligned}

Gradient of mm

The derivative with respect to mpm_{p}:

μbhmp=mp(xbh1hmb)=1hδbp\begin{aligned} \frac{\partial \mu_{b h}}{\partial m_{p}} & =\frac{\partial}{\partial m_{p}}\left(x_{b h}-\mathbf{1}_{h} m_{b}\right) \\ & =-\mathbf{1}_{h} \delta_{b p} \end{aligned}

And the backpropagated gradient:

lmp=lμbhμbhmp=lμbh(1hδbp)=lμph1h\begin{aligned} \frac{\partial l}{\partial m_{p}} & =\frac{\partial l}{\partial \mu_{b h}} \frac{\partial \mu_{b h}}{\partial m_{p}} \\ & =\frac{\partial l}{\partial \mu_{b h}}\left(-\mathbf{1}_{h} \delta_{b p}\right) \\ & =-\frac{\partial l}{\partial \mu_{p h}} \mathbf{1}_{h} \end{aligned}

And substituting l/μph\partial l / \partial \mu_{p h} derived from the previous step:

lmp=1h(lvp2μphH+lyphγhσp)=lvp2H(1hμph)lyphγhσp\begin{aligned} \frac{\partial l}{\partial m_{p}} & =-\mathbf{1}_{h}\left(\frac{\partial l}{\partial v_{p}} \frac{2 \mu_{p h}}{H}+\frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{\sigma_{p}}\right) \\ & =-\frac{\partial l}{\partial v_{p}} \frac{2}{H}\left(\mathbf{1}_{h} \mu_{p h}\right)-\frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{\sigma_{p}} \end{aligned}

In the first term, we have the sum 1hμph\mathbf{1}_{h} \mu_{p h} which can be shown to equal zero:

1hμph=1h(xph1hmp)=1hxph1h1hmp=HmpHmp=0\begin{aligned} \mathbf{1}_{h} \mu_{p h} & =\mathbf{1}_{h}\left(x_{p h}-\mathbf{1}_{h} m_{p}\right) \\ & =\mathbf{1}_{h} x_{p h}-\mathbf{1}_{h} \mathbf{1}_{h} m_{p} \\ & =H m_{p}-H m_{p} \\ & =0 \end{aligned}

And so, we can simplify the above expression:

lmp=lyphγhσp\frac{\partial l}{\partial m_{p}}=-\frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{\sigma_{p}}

Gradient of xx

And finally, we move onto xbhx_{b h}. Two functions, mbm_{b} and μbh\mu_{b h}, consume xbhx_{b h} and so we need to consider both. First, the derivative of mbm_{b} with respect to xpqx_{p q}:

mbxpq=xpq(1H1hxbh)=1qHδbp\begin{aligned} \frac{\partial m_{b}}{\partial x_{p q}} & =\frac{\partial}{\partial x_{p q}}\left(\frac{1}{H} \mathbf{1}_{h} x_{b h}\right) \\ & =\frac{\mathbf{1}_{q}}{H} \delta_{b p} \end{aligned}

And the derivative of μbh\mu_{b h} with respect to xpqx_{p q}:

μbhxpq=xpq(xbh1hmb)=δbpδhq\begin{aligned} \frac{\partial \mu_{b h}}{\partial x_{p q}} & =\frac{\partial}{\partial x_{p q}}\left(x_{b h}-\mathbf{1}_{h} m_{b}\right) \\ & =\delta_{b p} \delta_{h q} \end{aligned}

Finally, we use the chain rule to obtain the backpropagated gradient and combine the contribution from both functions:

lxpq=lmbmbxpq+lμbhμbhxpq=(lybhγhσb)(1qHδbp)+(lvb2μbhH+lybhγhσb)(δbpδhq)=1qHlyphγhσp+2Hlvpμpq+lypqγqσp\begin{aligned} \frac{\partial l}{\partial x_{p q}} & =\frac{\partial l}{\partial m_{b}} \frac{\partial m_{b}}{\partial x_{p q}}+\frac{\partial l}{\partial \mu_{b h}} \frac{\partial \mu_{b h}}{\partial x_{p q}} \\ & =\left(-\frac{\partial l}{\partial y_{b h}} \frac{\gamma_{h}}{\sigma_{b}}\right)\left(\frac{\mathbf{1}_{q}}{H} \delta_{b p}\right)+\left(\frac{\partial l}{\partial v_{b}} \frac{2 \mu_{b h}}{H}+\frac{\partial l}{\partial y_{b h}} \frac{\gamma_{h}}{\sigma_{b}}\right)\left(\delta_{b p} \delta_{h q}\right) \\ & =-\frac{\mathbf{1}_{q}}{H} \frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{\sigma_{p}}+\frac{2}{H} \frac{\partial l}{\partial v_{p}} \mu_{p q}+\frac{\partial l}{\partial y_{p q}} \frac{\gamma_{q}}{\sigma_{p}} \end{aligned}

The goal is to obtain an expression of l/xpq\partial l / \partial x_{p q} in terms of l/ypq\partial l / \partial y_{p q}, and so we substituent l/vp\partial l / \partial v_{p} using the previously derived expression and rearranging the terms to obtain the final result:

lxpq=1qHlyphγhσp+2H(lyphμphγh2σp3)μpq+lypqγqσp=lypqγqσplyphγhH(1qσp+μphμpqσp3)\begin{aligned} \frac{\partial l}{\partial x_{p q}} & =-\frac{\mathbf{1}_{q}}{H} \frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{\sigma_{p}}+\frac{2}{H}\left(-\frac{\partial l}{\partial y_{p h}} \frac{\mu_{p h} \gamma_{h}}{2 \sigma_{p}^{3}}\right) \mu_{p q}+\frac{\partial l}{\partial y_{p q}} \frac{\gamma_{q}}{\sigma_{p}} \\ & =\frac{\partial l}{\partial y_{p q}} \frac{\gamma_{q}}{\sigma_{p}}-\frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{H}\left(\frac{\mathbf{1}_{q}}{\sigma_{p}}+\frac{\mu_{p h} \mu_{p q}}{\sigma_{p}^{3}}\right) \end{aligned}

Conclusion

Bringing the results together:

lγq=lybqμbqσblβq=lybq1blxpq=lypqγqσplyphγhH(1qσp+μphμpqσp3)\begin{gathered} \frac{\partial l}{\partial \gamma_{q}}=\frac{\partial l}{\partial y_{b q}} \frac{\mu_{b q}}{\sigma_{b}} \\ \frac{\partial l}{\partial \beta_{q}}=\frac{\partial l}{\partial y_{b q}} \mathbf{1}_{b} \\ \frac{\partial l}{\partial x_{p q}}=\frac{\partial l}{\partial y_{p q}} \frac{\gamma_{q}}{\sigma_{p}}-\frac{\partial l}{\partial y_{p h}} \frac{\gamma_{h}}{H}\left(\frac{\mathbf{1}_{q}}{\sigma_{p}}+\frac{\mu_{p h} \mu_{p q}}{\sigma_{p}^{3}}\right) \end{gathered}

PyTorch Implementation

We can numerically check the above result by implementing the equations in PyTorch and comparing the result to the built-in PyTorch function:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch

# Create random inputs
torch.manual_seed(42)
B, H = 128, 256
eps = 1e-05
x = torch.rand((B, H), dtype=torch.float32, requires_grad=True)
gamma = torch.rand(H, dtype=torch.float32, requires_grad=True)
beta = torch.rand(H, dtype=torch.float32, requires_grad=True)
dldy = torch.rand((B, H), dtype=torch.float32)

# Run forward and backward pass using built-in function
y = torch.nn.functional.layer_norm(x, [H], gamma, beta)
y.backward(dldy)

# Calculate gradients using above equations
m = x.mean(axis=1)
mu = x - m.unsqueeze(1)
v = torch.mean(mu**2, axis=1)
sigma = torch.sqrt(v + eps)

dldgamma = torch.einsum('bq,bq,b->q', [dldy, mu, 1/sigma])
dldbeta = dldy.sum(axis=0)

dldx = (
    dldy*gamma.unsqueeze(0) / sigma.unsqueeze(1)
    - 1/H * torch.einsum('ph,h,p->p', [dldy, gamma, 1/sigma]).unsqueeze(1)
    - 1/H * mu * torch.einsum('ph,h,ph,p->p', [dldy, gamma, mu, sigma**(-3)]).unsqueeze(1)
)

# Compare against PyTorch
torch.testing.assert_close(dldgamma, gamma.grad)
torch.testing.assert_close(dldbeta, beta.grad)
torch.testing.assert_close(dldx, x.grad)

We can also implement our own custom PyTorch layer as well:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class LayerNormManual(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor):
        eps = 1e-05
        assert x.dim() == 2
        B, H = x.shape
        m = x.mean(axis=1)
        mu = x - m.unsqueeze(1)
        v = torch.mean(mu**2, axis=1)
        sigma = torch.sqrt(v + eps)
        
        y = (mu/sigma.unsqueeze(1))*gamma.unsqueeze(0) + beta.unsqueeze(0)
        
        ctx.save_for_backward(x, m, mu, v, sigma, y)
        
        return y
        
    @staticmethod
    def backward(ctx, dldy):
        x, m, mu, v, sigma, y = ctx.saved_tensors
        B, H = x.shape
        
        dldgamma = torch.einsum('bq,bq,b->q', [dldy, mu, 1/sigma])
        dldbeta = dldy.sum(axis=0)
        
        dldx = (
            dldy*gamma.unsqueeze(0) / sigma.unsqueeze(1)
            - 1/H * torch.einsum('ph,h,p->p', [dldy, gamma, 1/sigma]).unsqueeze(1)
            - 1/H * mu * torch.einsum('ph,h,ph,p->p', [dldy, gamma, mu, sigma**(-3)]).unsqueeze(1)
        )
        
        return dldx, dldgamma, dldbeta

PyTorch also provides a function called gradcheck to calculate the gradient of a layer using finite-differences and check to see if the backward function matches. So, we can also use that to assert the layer is correct:

1
2
3
4
5
6
torch.manual_seed(42)
B, H = 32, 64
x = torch.rand((B, H), dtype=torch.float64, requires_grad=True)
gamma = torch.rand(H, dtype=torch.float64, requires_grad=True)
beta = torch.rand(H, dtype=torch.float64, requires_grad=True)
torch.autograd.gradcheck(LayerNormManual.apply, (x,gamma,beta), eps=1e-6, atol=0.1, rtol=0.1)

Notice that the input tensor dtypes have been increased to float64. The grad check fails when using float32, likely due to the numerical instability of the check.

Next

Further examples of calculating gradients using tensor calculus and index notation can be found on the intro page.

This post is copyrighted by Josh Levy-kramer.