ELI5: Gradient Descent

Kantajit Shaw
4 min readMay 10, 2020

--

Now, in the age of data science when everyone is crazy about machine learning we tend to forget the basics. Now we have got loads of data, fancy new GPUs, ML/DL libraries at our disposal. Who cares about the underlying linear algebra and calculus?

Actually we should care. Deep learning will not be a black box forever, where you just put some data and train it modifying some values until it gets the answer right.

Image reference: https://imgs.xkcd.com/comics/machine_learning.png
Image reference: xkcd

Backpropagation is the first step to understand how we teach the machine to learn.

This is too DARK. Bleh!

Let’s start with the interesting stuff already.

Image reference: Wikipedia

Imagine we are playing a game of darts. You throw a dart and it lands far from the center or bullseye. You take a deep breath, get a measure of how you want to throw the dart next time and you throw it. This part is called backpropagation (ignore taking the deep breath part).

You look at how far you have shot, you want to aim better for next time. Similarly, in the machine learning algorithms we compute how far our prediction is from the original value or ground truth. Now, how do we correct this garbage predictor? Now let’s talk about gradient descent.

Change of Scene: We are at a hill station going down the hill.

We are looking for a place to settle. We measure the slope to see which is going downwards, then we take a step in that direction. We keep taking steps until we reach a plain land. This long repetitive part is Gradient descent.

Now that we have cleared our birds-eye view, let’s get technical (a tad bit mathematical).

Loss function:

This is how we compute how far our predicted value is from the original value. Now, I know what you are thinking. Simple difference between predicted and original value should do just fine, right?

Actually no. The loss function should satisfy two properties (two fancy mathematical terms).

  1. It should be differentiable.
  2. It should convex.

Differentiable: In simple terms, the function should have a derivative (slope) at each point. We want to avoid reaching a point where we can’t compute the difference between the original and predicted value.

Image reference: Wikipedia

Convex: Going back to the hill analogy, where we want to go to the lowest plain land. If the function is non-convex, then we might get stuck somewhere in the middle thinking we already reached the lowest point.

Image Reference: ResearchGate

There are plenty of loss functions that satisfy the above-mentioned 2 properties. E.g.: Mean Square Error, Mean Absolute Error etc.

Gradient Descent:

Suppose, we plot our loss function which satisfies 2 properties, with respect to the original and predicted value. We will get a nice graph like this.

Image Reference: Wikipedia

The lowest point denotes we are closest to the actual value. At first, when the model is not trained, we start somewhere in the curve. Then we calculate the derivative(slope) at that point, take a step towards the lowest point where the slope is zero. We keep doing this until we reach the lowest point where the loss function is minimum ( where our model now knows how to predict with minimum error).

Few more points:

  1. Learning rate: It determines how long steps we will take. A smaller learning rate will take a longer time to reach the lowest point. By taking a larger learning rate we might overstep and will reach somewhere near but not the lowest point.
  2. Cost Function: During the learning phase, we deal with multiple examples so that the algorithm is generalized. The cost function is the average of all the losses.

I hope this helps.

Reference:

https://www.coursera.org/learn/neural-networks-deep-learning

--

--

Kantajit Shaw
Kantajit Shaw

Written by Kantajit Shaw

Deep learning enthusiast, interested in Computer Vision and Natural Language Processing problems.

No responses yet