Post

Backpropagation and Multivariable Calculus

A quick intro on backpropagation and multivariable calculus for deep learning

Backpropagation and Multivariable Calculus

This post offers a concise overview of multivariable calculus and backpropagation to help you derive the gradients necessary for backpropigation in deep learning. This is a whistle-stop tour. For a more in-depth description, refer to the excellent article The Matrix Calculus You Need For Deep Learning. The next post demonstrates how to combine these techniques with tensor calculus to derive gradients for any tensor function.

Backpropagation

In deep learning, when training using batches of data, the model’s weights are adjusted based on the calculated loss ll. To be able to update the weights we first need to calculate the gradient of the loss ll with respect to all weights of the model, which tells us how to adjust the weights to reduce the loss for the current batch of data. Auto-differentiation or backpropagation is the most popular algorithm for calculating such gradients.

Taking PyTorch as an example, one would execute the backward() function on the loss to determine the gradients relative to variables. PyTorch accomplishes this by tracking each operation (forward function) that contributes to the loss calculation. Every forward function has a corresponding backward function, and these backward functions are run in reverse order to the forward functions to compute the gradients.

Frameworks like PyTorch have predefined backward functions that typically suffice for most users. However, for those seeking a deeper understanding or needing to implement a custom operation, it’s necessary to understand how to define and manipulate these backward functions.

In general, for any function ff that takes MM inputs (xx’s) and produces NN outputs (yy’s):

f:x1,x2,,xMy1,y2,,yNf: x_{1}, x_{2}, \ldots, x_{M} \mapsto y_{1}, y_{2}, \ldots, y_{N}

Then there is an associated “backward” function gg, which takes M+NM+N inputs and produces MM outputs:

g:x1,x2,,xM,ly1,ly2,,lyNlx1,lx2,,lxMg: x_{1}, x_{2}, \ldots, x_{M}, \frac{\partial l}{\partial y_{1}}, \frac{\partial l}{\partial y_{2}}, \ldots, \frac{\partial l}{\partial y_{N}} \mapsto \frac{\partial l}{\partial x_{1}}, \frac{\partial l}{\partial x_{2}}, \ldots, \frac{\partial l}{\partial x_{M}}

The inputs are the input of ff and the gradient of the loss ll with respect to each output of ff. And the outputs are the gradient of the loss ll with respect to each input of ff.

The backward function gg, also known as the vector-Jocobian product (VJP), calculates the backpropagated gradients. More on this later. For example, in PyTorch, you can define an auto-differentiable operation by providing a forward and backward function pair like so (by implementing torch.autograd.Function):

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
class MatrixMultiplication(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(x, w)
        y= x @w
        return y
    @staticmethod
    def backward(ctx, dldy: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        x, w = ctx.saved_tensors
        dldx = dldy @ w.T
        dldw = x.T @ dldy
        return dldx, dldw

Don’t worry about how the gradients are calculated - we will get onto that later. Importantly, each input of the forward function has a corresponding gradient output in the backward function, and each output of the forward function has a corresponding input of the backward function. The shapes of the tensors should also correspond with each other, e.g. x.shape == dldx.shape, w.shape == dldw.shape and y.shape == dldy.shape. The forward inputs are also captured using the ctx.save_for_backward function and obtained in the backwards function using the ctx.saved_tensors class attribute.

Multivariable Calculus

Most functions in deep learning have multiple inputs and outputs. A function that has multiple scalar inputs f(u,v)f(u, v) is equivalent to a function that has a vector input f(x^)f(\hat{x}), where uu and vv are the components of x^\hat{x}. This is known as a multivariable function, and it is said that “ff takes a vector” to indicate it has a vector input.

If a function has multiple scalar outputs, these can also be collected into a vector, e.g. f:xy1,y2f: x \mapsto y_{1}, y_{2} is equivalent to a function with a vector output f:xy^f: x \mapsto \hat{y} and is said to be multivalued or vector-valued. Another way of looking at it is that a multivalued function is equivalent to multiple functions stacked into a vector y^=(y0,y1)T\hat{y}=\left(y_{0}, y_{1}\right)^{T}.

This article is concerned with multivariable multivalued functions, but as that’s such a mouthful, we will call them multivariable functions.

If we want to calculate the derivative of a multivariable function, we need to consider the gradient of each output with respect to each input. If a function ff has MM inputs (xx) and NN outputs (yy), there are MNM*N gradients associated with the function, which can be denoted as yi/xj\partial y_{i} / \partial x_{j}. Importantly ii and jj should enumerate over all scalar inputs and outputs, respectively. It doesn’t matter if the scalars are arranged into a vector, matrix or tensor - the process remains the same.

A useful tool in calculus is the chain rule. Let yy be the output of a function which takes MM scalar inputs u1,,uMu_{1}, \ldots, u_{M} which all depend on xx, the derivative of yy with respect to xx can be shown to be:

y(u1,,uM)x=i=1Myuiuix\frac{\partial y\left(u_{1}, \ldots, u_{M}\right)}{\partial x}=\sum_{i=1}^{M} \frac{\partial y}{\partial u_{i}} \frac{\partial u_{i}}{\partial x}

In other words, the derivative of yy with respect to xx is the weighted sum of all xx contributions to the change in yy. It is assumed that yy is not directly a function of xx, it’s only a function of xx through the intermediate functions. For example, y=u1(x)u2(x)+x2y=u_{1}(x) * u_{2}(x)+x^{2} isn’t valid, you need to substitute the last term so it’s y=u1(x)u2(x)+u3(x);u3=x2y=u_{1}(x) * u_{2}(x)+u_{3}(x) ; u_{3}=x^{2}.

The uu variables could be part of data structures such as a vector, matrix or tensor. The above chain rule considers the constituent scalars and the data structures do not affect the rule. For example, given the function y(u^)y(\hat{u}) and u^=(u0,,uM)T\hat{u}=\left(u_{0}, \ldots, u_{M}\right)^{T}, the chain rule above is still valid.

If the function has multiple data structures as inputs, we must consider all inputs in the chain rule. For example, given the function y(a^,b^)y(\hat{a}, \hat{b}), the chain rule would be (ii iterates over the full length of both vectors):

y(a^,b^)x=iyaiaix+jybjbjx\frac{\partial y(\hat{a}, \hat{b})}{\partial x}=\sum_{i} \frac{\partial y}{\partial a_{i}} \frac{\partial a_{i}}{\partial x}+\sum_{j} \frac{\partial y}{\partial b_{j}} \frac{\partial b_{j}}{\partial x}

Throughout this article, we use the partial derivative notation y/x\partial y / \partial x instead of the “normal” derivative notation dy/dx\mathrm{d} y / \mathrm{d} x. This is because the functions are always assumed to be multivariable, and we don’t know the relationship between those variables without further context. Partial derivatives can be reinterpreted as normal derivatives with further context.

If the chain rule above can be applied to functions that use any data structure, why do we need vector, matrix or tensor calculus? Because using matrices and tensors can greatly simplify the algebra.

In vector calculus, when a function takes a vector xx of length MM and is vector-valued with an output yy of length NN, the derivatives yi/xj\partial y_{i} / \partial x_{j} can be arranged into a MM by NN matrix known as the Jacobian matrix or just Jacobian:

y^x^=(y1x1y1xNyMx1yMxN)\frac{\partial \hat{y}}{\partial \hat{x}}=\left(\begin{array}{ccc} \frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{1}}{\partial x_{N}} \\ \vdots & & \vdots \\ \frac{\partial y_{M}}{\partial x_{1}} & \cdots & \frac{\partial y_{M}}{\partial x_{N}} \end{array}\right)

Note that above, we are using what’s known as a numerator layout whereby the inputs are enumerated horizontally, and the outputs are vertically. Confusingly, other authors might use denominator layout or mixed-layout conventions, which enumerate them differently. Some authors also add a transpose to the vector in the denominator e.g. y^x^T\frac{\partial \hat{y}}{\partial \hat{x}^{T}}. This is purely notational and indicates that the xx inputs go horizontal in the Jacobian. Essentially, if you are using a consistent layout, then y^x^=y^x^T\frac{\partial \hat{y}}{\partial \hat{x}}=\frac{\partial \hat{y}}{\partial \hat{x}^{T}}.

The Jacobian provides a convenient data structure to group the gradients. It simplifies the application of the chain rule by utilizing matrix multiplication, which aggregates all contributions of xx to each output of yy through a weighted sum:

y^(u^)x^=y^u^u^x^\frac{\partial \hat{y}(\hat{u})}{\partial \hat{x}}=\frac{\partial \hat{y}}{\partial \hat{u}} \frac{\partial \hat{u}}{\partial \hat{x}}

If the chain rule we first introduced can be applied to functions that use any data structure, why do we need vector, matrix or tensor calculus? As we have seen, using matrices can significantly streamline the algebra involved.

It seems logical to discuss matrix functions next; however, since the derivative of one matrix with respect to another is a tensor, we must first introduce tensors to proceed further. Following the examples, the next section will explore the application of multivariable calculus to backpropagation, and then Part 2 introduce tensors and tensor calculus.

Example: Sum

Let’s consider a “sum” operator which adds together two variables:

y(x1,x2)=x1+x2y\left(x_{1}, x_{2}\right)=x_{1}+x_{2}

This is the same as:

x^=(x1x2)y(x^)=x1+x2\begin{aligned} \hat{x} & =\binom{x_{1}}{x_{2}} \\ y(\hat{x}) & =x_{1}+x_{2} \end{aligned}

To calculate the gradient of yy with respect to x^\hat{x}:

yx^=(yx1yx2)=(11)\frac{\partial y}{\partial \hat{x}}=\left(\begin{array}{ll} \frac{\partial y}{\partial x_{1}} & \frac{\partial y}{\partial x_{2}} \end{array}\right)=\left(\begin{array}{ll} 1 & 1 \end{array}\right)

Example: Broadcast

For another example, let’s consider two different functions that copy a value xx:

y1(x)=xy2(x)=x\begin{aligned} & y_{1}(x)=x \\ & y_{2}(x)=x \end{aligned}

We can then collect these functions into an array:

y^(x)=(xx)\hat{y}(x)=\binom{x}{x}

To calculate the gradient of y^\hat{y} with respect to xx:

y^x=(y1xy2x)=(11)\frac{\partial \hat{y}}{\partial x}=\binom{\frac{\partial y_{1}}{\partial x}}{\frac{\partial y_{2}}{\partial x}}=\binom{1}{1}

Application to Backpropagation

Here, we apply multivariable calculus to derive the backward graph. For example, provided with a layer in a model which executes:

u(a^)=a^a^v(a^,b)=ba^y(u)=u+vz(u,v)=2u\begin{array}{ll} u(\hat{a}) & =\hat{a} \cdot \hat{a} \\ v(\hat{a}, b) & =b|\hat{a}| \\ y(u) & =u+v \\ z(u, v) & =2 u \end{array}

We can visually represent the dependencies, known as the forward graph:

Forward graph

The forward graph represents many sub-graphs for each input and output combination. For example, if we consider the dependency of yy on a^\hat{a}, we would obtain the sub-graph:

sub-graph

Different input and output combinations use the same forward functions and allow us to reuse computation, e.g. we only compute uu once as it can be reused when computing yy and zz.

To obtain the backward graph, first, we assume that a downstream scalar function ll consumes all outputs, so in our case l(y,z)l(y, z). We don’t need to know the function, just that the outputs are consumed by ll. So, let’s redraw the forward graph with the single output ll.

forward-graph-with-loss

We also assume we are provided with backpropagated gradients for those outputs, l/y\partial l / \partial y and l/z\partial l / \partial z. Then, we need to apply the chain rule for each function to find the dependencies of the gradients. For example, if we focus on uu, it is consumed by two functions yy and zz, both consumed by l. So:

l(y(u),z(u))u=lyyu+lzzu\frac{\partial l(y(u), z(u))}{\partial u}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial u}+\frac{\partial l}{\partial z} \frac{\partial z}{\partial u}

This means that l/u\partial l / \partial u has a dependency on l/y\partial l / \partial y and l/z\partial l / \partial z, which is the reverse dependency in the forward graph. This pattern holds for all the functions in the graph, and so to obtain the backward graph, we reverse all the dependencies. For example:

backward-graph

In the diagram, we are using a dash to represent the backpropagated gradients, e.g. y=l/yy^{\prime}=\partial l / \partial y. There is also a dashed line between ll and yy^{\prime}, as we don’t know the function that maps from ll to yy^{\prime}; all we need is the values of the backpropagated gradients yy^{\prime} and zz^{\prime}.

Again, the backward graph is many sub-graphs which map the inputs to every output. Each sub-graph can overlap, leading to computation reuse, making it efficient.

Next

We have done a quick tour of multivarible calculus and backpropigation. Next Part 2 introducing tensors and tensor calculus.

This post is copyrighted by Josh Levy-kramer.