Matrix Inverse, Deriving the Gradient for the Backward Pass
Obtaining the gradient of the matrix inverse
Here, I will derive the gradients of a matrix inverse used for backpropagation in deep learning models. I will use tensor calculus and index notation - see my article The Tensor Calculus You Need for Deep Learning for more information.
Given , we know from the definition of an inverse matrix that (assuming is square and the inverse of exists). We convert this to index notation (the indices , and must be the same size):
First, we find the derivative with respect to and use the product rule:
We then multiply by the inverse matrix , we must use a new free index for the first axis and contract on the second axis :
So the gradient of the inverse matrix with respect to itself is an order-4 tensor whereby every combination of elements of the inverse are multiplied together. This is similar to a Kronecker product but we don’t “flatten” the result into 2 dimensions.
Next, to obtain the gradient for backpropagation we assume is an input of a scalar function and we are provided with the gradients of with respect to . Then to derive the gradients for backpropagation, we apply the chain rule:
And we can convert it back to matrix notation:
PyTorch Implementation
We can check the result by comparing the equation above with PyTorch’s built-in autograd output:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
# Create random input
torch.manual_seed(42)
K = 256
x = torch.rand((K, K), dtype=torch.float32, requires_grad=True)
dldy = torch.rand((K, K), dtype=torch.float32)
# Run forward and backward pass using built-in function
y = x.inverse()
y.backward(dldy)
# Calculate gradients using above equations
# Note: I use brackets to specify the order of the matmuls
# to be consistent with how PyTorch calculate it
dldx = -y.T @ (dldy @ y.T)
# Compare with PyTorch
torch.testing.assert_close(dldx, x.grad)
Next
Further examples of calculating gradients using tensor calculus and index notation can be found on the intro page.