Understanding the gradients for the backward pass of Batch Normalization

I recently did an assignment of CS231n stanford course where I was required to derive the gradients for backward pass of Batch Normalization. In this blog post, I will try to explain how I derived the desired gradients from scratch. Why Batch Normalization is used? Well, it reduces the amount by what the hidden unit values shift around (covariance shift) and thus it allows us to train a deep network.


Forward Pass

In forward pass, we just calculate the mean and variance of the batch and the Normalization is performed to achieve unit Gaussian Distribution. Scaling and shifting is performed with the learnable parameters : \(\gamma\) and \(\beta\) which gets updated during the backward pass using the calculated gradients. The forward pass is relatively simple since it only requires standardizing the input features (zero mean and unit standard deviation). Here, ‘m’ is the size of mini-batch for the forward pass.

\[\begin{align} \mu_B &= \frac{1}{m}\sum_{i=1}^{m} x_i \\ \sigma_B^2 &= \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_B)^2 \\ \hat{x_i} &= \frac{x_i - \mu_B}{\sqrt{ \sigma_B^2 + \epsilon }} \\ y_i &= \gamma \hat{x_i} + \beta \end{align}\]

Implementation in Python :

def batchnorm_forward(x, gamma, beta, eps):

    N, D = x.shape
    out, cache = None, None
    xmu = (1./N)*np.sum(x,axis=0)
    xvar = (1./N)*np.sum((x-xmu)**2,axis=0)
    inv_var = 1./np.sqrt(xvar+eps)
    xhat = (x-xmu)*inv_var
    out = gamma*xhat+beta
    cache = (xhat,xmu,inv_var,gamma)
    
    return out,cache

We have stored the necessary intermediate components which will be required in backward pass.


Backward pass

In backward pass, we will calculate the gradients of the learnable parameters through which we can deduce that by how much factor those parameters are ultimately effecting the Loss and thus the gradients calculated will help in updating those parameters to improve the performance of the network. So, we will be calculating the following gradients: \(\frac{\partial C}{\partial x_i}\), \(\frac{\partial C}{\partial \gamma}\) and \(\frac{\partial C}{\partial \beta}\).

Implementation in Python :

def batchnorm_backward(dout, cache):

    #dout is the upstream derivative
    dx, dgamma, dbeta = None, None, None
    xhat,xmu,inv_var,gamma = cache
    (N,D) = dout.shape
    dxhat = dout * gamma
    
    dbeta = np.sum(dout, axis=0)
    dgamma = np.sum(dout*xhat, axis=0)
    dx = (1./N)*(inv_var)*((N*dxhat)-(np.sum(dxhat, axis=0))-xhat*(np.sum(dxhat*xhat, axis=0)))
    
    return dx, dgamma, dbeta

Now, let’s have some quality and intuitive calculus time to get the core understanding of derivation of the gradients calculated in the implementation above. So, take your pen and paper and solve with me!!

\[\begin{align} \frac{\partial C}{\partial \gamma \hat{x_i}} &= \frac{\partial C}{\partial y_i} \times \frac{\partial y_i}{\partial \gamma \hat{x_i}} \\ &= \frac{\partial C}{\partial y_i} \times \frac{\partial (\gamma x_i + \beta)}{\partial \gamma \hat{x_i}} \\ &= \frac{\partial C}{\partial y_i} \end{align}\] \[\begin{align} \frac{\partial C}{\partial \beta} &= \frac{\partial C}{\partial y_i} \times \frac{\partial y_i}{\partial \beta} \\ &= \frac{\partial C}{\partial y_i} \times \frac{\partial(\gamma x_i + \beta )}{\partial \beta} \\ &= \sum_{i=1}^m \frac{\partial C}{\partial y_i} \end{align}\] \[\begin{align} \frac{\partial C}{\partial \gamma} &= \frac{\partial C}{\partial \gamma \hat{x_i}} \times \frac{\partial \gamma \hat{x_i}}{\partial \gamma} \\ &= \sum_{i=1}^m \frac{\partial C }{\partial y_i} \times \hat{x_i} \end{align}\]

We have calculated the gradients for the learnable parameters and now it’s time to calculate the input gradient. I had a very rough time to calculate this part, but anyways I will try to explain it in an easy fashion -

\[\begin{align} \frac{\partial C}{\partial \hat{x_i}} &= \frac{\partial C}{\partial y_i} \times \frac{\partial y_i}{\partial \hat{x_i}} \\ &= \frac{\partial C}{\partial y_i} \times \gamma \end{align}\]

So, for calculating the gradient of the input gradient, the gradient of \(\hat{x_i}\) is required in the chain rule. We need to calculate \(\frac{\partial C}{\partial x_i}\) which can be calculated using the chain rule, I hope everyone is familiar with it. So, the equation looks like this :

\[\begin{align} \frac{\partial C}{\partial x_i} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i}}{\partial x_i} + \frac{\partial C}{\partial \mu_B} \times \frac{\partial \mu_B}{\partial x_i} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{\partial \sigma_B^2}{\partial x_i} \end{align}\]

Now, let’s calculate each component indvidually except \(\frac{\partial C}{\partial \hat{x_i}}\) since this has been already calculated above. Take reference from the above equations for calculating normalized inputs.

\[\begin{align} \frac{\partial \hat{x_i}}{\partial x_i} &= \frac{1}{\sqrt{ \sigma_B^2 + \epsilon }}\\ \frac{\partial \mu_B}{\partial x_i} &= \frac{1}{m}\\ \frac{\partial \sigma_B^2}{\partial x_i} &= 2(x_i - \mu_B)/m\\ \end{align}\]

Note,that the summation symbol is not present since we are calculating the gradients with respect to each training sample in the mini-batch of size ‘m’. Now, let’s calculate the rest of the components for calculating \(\frac{\partial C}{\partial x_i}\).

\[\begin{align} \frac{\partial C}{\partial \mu_B} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i}}{\partial \mu_B} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{\partial \sigma_B^2}{\partial \mu_B}\\ \frac{\partial \hat{x_i}}{\partial \mu_B} &= \frac{-1}{\sqrt{ \sigma_B^2 + \epsilon }}\\ \frac{\partial \sigma_B^2}{\partial \mu_B} &= \frac{1}{m}\sum_{i=1}^{m} -2(x_i - \mu_B)\\ \frac{\partial C}{\partial \sigma_B^2} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i}}{\partial \sigma_B^2}\\ \end{align}\]

Now, let’s calculate the value of \(\frac{\partial \hat{x_i}}{\partial \sigma_B^2}\) which can be calculated easily using Quotient rule. \(\begin{align} \hat{x_i} &= \frac{x_i - \mu_B}{\sqrt{ \sigma_B^2 + \epsilon }} \\ \frac{\partial \hat{x_i}}{\partial \sigma_B^2} &= (-0.5) \times (\hat{x_i} - \mu_B) \times (\sigma_B^2 + \epsilon)^{-1.5}\\ \end{align}\)

Now, using the values of \(\frac{\partial C}{\partial \sigma_B^2}\) and \(\frac{\partial C}{\partial \mu_B}\), we can calculate the input gradient \(\frac{\partial C}{\partial x_i}\). I hope you can substitue the values in the above equations. Just, try out the math, it’s real fun! After vigorious calculation and several substitution we will get the value of \(\frac{\partial C}{\partial x_i}\) as \(\begin{align} \frac{\partial C}{\partial x_i} &= \frac{m\frac{\partial C}{\partial \hat{x_i}} - \sum_{j=1}^{m}\frac{\partial C}{\partial \hat{x_j}} - \hat{x_i}\sum_{j=1}^{m}\frac{\partial C}{\partial \hat{x_j}}\hat{x_j}}{m\sqrt{ \sigma_B^2 + \epsilon }}\\ \end{align}\)

So, we have calculated the required gradients for the backward pass : \(\frac{\partial C}{\partial x_i}\), \(\frac{\partial C}{\partial \gamma}\) and \(\frac{\partial C}{\partial \beta}\).

This was my first attempt in presenting my knowledge in the form of blog. I hope y’all liked it!!

Written on February 7, 2019