Backpropagation is usually proved by using the chain rule on partial derivatives. However, it often gets hairy (see an example here). I am not blaming the author here, I just find it hard to follow and confusing, because there are indices everywhere, and we need to remember the dimensions of the tensors, and that partial derivatives store information in the “transpose of the dimensions”, etc.
Instead we can use differential calculus. Tom Minka’s notes on differential calculus have been circulating for a long time; but looking at them felt daunting, there seems to be so many weird formulas to remember. Only recently, I went through Magnus and Neudecker, especially the free Chapter 18 which is a great intro and summary. It started to feel less like rote learning, and more usable to me. There is a steeper learning curve but it is surely more powerful and less error-prone, once one is used to it. Here I give a very small example to show that you need very little knowledge to work with small neural networks. Please read Chapter 18 for all the details; this is not self-contained.
M&D motivate the differential calculus approach:
Now we can see the advantage of working with differentials rather than with derivatives. When we have an m × 1 vector y, which is a function of an n × 1 vector of variables x, say y = f (x), then the derivative is an m× n matrix, but the differential dy or df remains an m × 1 vector. […] The practical advantage of working with differentials is therefore that the order does not increase but always stays the same.
In our deep learning case, we have a loss (scalar) and parameters which are vectors or matrices. We only care about parameter updates given by the partial derivatives of that loss wrt those vectors or matrices, which are either vectors or matrices. But the intermediary computations with the naive approach involve computing partial derivatives of matrices wrt matrices :( Here those intermediary computations will be written only as “differentials”.
Suppose we have a one-hidden layer MLP, and want to compute gradients of the MSE loss with regards to parameters \(A\) and \(C\):
In terms of notations,
Without justification, here are a few basic rules for computing differentials:
We want to compute gradient updates to use gradient descent. It doesn’t really matter where we start. We can plug-in expressions as we go. Let’s first compute \(dX_2\):
\[\begin{aligned} dX_2 &= φ'(X_1A+b) \odot (d(X_1A+b)) \\ &= φ'(X_1A+b) \odot d(X_1A) + db \\ &= φ'(X_1A+b) \odot (dX_1A + X_1dA + db) \\ &= φ' \odot (dX_1A + X_1dA + db) && \text{Notation}\\ \end{aligned}\]We do not try to compute the derivatives of each component of \(X_2\) with regards to \(A\) here! We would need to introduce indices, which is what we’re trying to avoid. So we just keep our differentials, and we will plug-them in / compose them later on, using the differential calculus’ version of chain-rule.
Similarly, then, we get \(dX_3 = dX_2 C + X_2 dC + dd\).
\[\begin{aligned} l &= ||X_3 - y||^2 \\ \end{aligned}\]Let’s write \(Z=X_3-y\), a n×1 vector. Then \(l = ||Z||^2 = ZᵀZ\). So \(dl = d(ZᵀZ) = d(Zᵀ)Z + ZᵀdZ = dZᵀZ + ZᵀdZ = 2ZᵀdZ\). Furthermore, we consider \(y\) as fixed (targets) so \(dZ=dX_3\).
M&D prove “identification theorems”. They tell us how to extract partial derivatives from differentials.
We have: \(dl = 2ZᵀdZ = 2(X_3-y)ᵀdX_3\). From this expression, the identification theorem tells us that the partial derivatives of \(l\) wrt \(X_3\) is \(2(X_3-y)ᵀ\). It is quite easy to remember: algebraically, “divide” by \(dX_3\) on both sides, to get the partial derivative notation… Don’t do this mindlessly! You don’t want tensors higher order than 2, otherwise you need to use more operators like vec.
If \(X_3\) were a (free) parameter, this would tell us how to update it. Here, though, it is a function of \(X_2\) and \(C\). We can consider \(X_2\) fixed and compute the partial derivatives of \(l\) wrt \(C\). Thus, setting \(dX_2 = 0\), we get \(dX_3 = dX_2C + X_2 dC = X_2 dC\). Thus \(dl = 2(X_3 - y)ᵀX_2 dC\). Again we can identify \(\frac{\partial l}{\partial C} = 2(X_3-y)ᵀX_2\).
Now, to backpropagate further in the MLP, we consider \(C\) and \(d\) as fixed, that is, \(dX_3 = dX_2 C\). We had \(dX_2 = φ' \odot (dX_1A + X_1dA + db)\). We want to compute the partial derivative of \(l\) wrt \(A\). Clearly, \(A\) is never a function of \(X_1\) and \(b\) so \(dX_1 = 0\) and \(db = 0\) for the purpose of that computation. So \(dX_2 = φ' \odot (X_1dA)\).
Substituting into \(dl\), we have
\[\begin{aligned} dl &= 2(X_3 - y)ᵀ dX_2 C \\ &= 2(X_3 - y)ᵀ (φ' \odot (X_1dA)) C \\ \end{aligned}\]At this point we need to use the trace and its nice properties to extract the \(dA\) term off the Hadamard’s claws. Note that for any vector \(x\), \(xᵀx = \sum_i x_i^2 = tr(xᵀx)\). But it turns out \((X_3-y)\) and \((φ' \odot (X_1dA)) C\) are both vectors of size \(n×1\).
\[\begin{aligned} dl &= 2(X_3 - y)ᵀ (φ' \odot (X_1dA)) C \\ &= tr(2(X_3 - y)ᵀ (φ' \odot (X_1dA))C) \\ \end{aligned}\]Then we can use the fact that \(tr(AB) = tr(BA)\) when \(A\) and \(Bᵀ\) have the same dimensions.
\[\begin{aligned} dl &= tr(2(X_3 - y)ᵀ (φ' \odot (X_1dA))C) \\ &= tr(2C(X_3 - y)ᵀ (φ' \odot (X_1dA))) \\ \end{aligned}\]But we haven’t yet solved the problem of \(dA\) being inside the Hadamard product.
Look at these matrices:
What do they have in common? They all have the same dimension, \(n × d_2\). Let’s try to simplify our expression by going back to the definitions:
\[\begin{aligned} tr(Uᵀ(V\odot W)) &= \sum_i (Uᵀ(V\odot W))_{ii} \\ &= \sum_i \sum_k (Uᵀ)_{ik} (V\odot W)_{ki} \\ &= \sum_i \sum_k (U)_{ki} (V)_{ki} (W)_{ki} \\ &= \sum_i \sum_k (V)_{ki} (U)_{ki} (W)_{ki} \\ &= \sum_i \sum_k (W)_{ki} (U)_{ki} (V)_{ki} \\ &= \ldots \\ \end{aligned}\]So we can permute those matrices however we want!
We want to put \(dA\) (which is in \(W\)) at the end of the trace to identify the partial derivatives. So, we write
\[\begin{aligned} dl &= tr(Uᵀ(V \odot W)) \\ &= tr(Wᵀ(V \odot U)) \\ &= tr((V \odot U)ᵀW) && \text{$tr(AᵀB) = tr(BᵀA)$} \\ &= tr((Vᵀ \odot Uᵀ)W) \\ &= tr(2 φ'ᵀ \odot (C(X_3-y)ᵀ) X_1 dA)) \\ \end{aligned}\]And by identifying, we get \(\frac{\partial l}{\partial A} = 2φ'ᵀ \odot (C(X_3-y)ᵀ) X_1\) which has the correct dimension, \(d_2×d_1\).