Hiren Nandigama's blog

Regularization: Intuition behind the math

Regularization in Machine Learning is talked about a lot, often indicating that it solves the high variance problem whilst also helping with feature selection. A model is said to have overfit some training data when the error on it is very low but the error on unseen data is very high. This is explained by the formula for variance.

σ2=1Ni=1N(xiμ)2

This tells us that, if for some dataset the variance is high, then on average the individual data points are further away from the mean value.

If we were to extend similar thinking to model fitting, a model having high variance means, on average, the individual predictions are further away from the mean prediction.

Variance(x)=𝔼[(f^(x)𝔼[f^(x)])2]

How can we have multiple predictions? We take our model and fit it to datasets randomly sampled from the original dataset. Intuitively, when models overfit their datasets, the spread of predictions is higher.

This is a problem because the model has generalized poorly and when presented with unknown datasets, predictions are way off.

So, how do we solve this?

Regularization to the rescue

When a model overfits a dataset, the cost is very low i.e., the parameters of our model are too tuned to fit the dataset closely. Can we somehow nudge the cost function, so that the parameters shift from this minima? We can & it's called regularization. Let us take an example function

f(x)=3.2x

i.e., w = 3.2

desmos-graph

Let the hypothesis function be

f^(x)=wx

When we train the model, the hypothesis/model parameter w would be equal to 3.2. Let's plot the cost function J(w)=12((1w3.2)2+(2w6.4)2), we can observe that without regularization, the minima lies at 3.2 at which the cost function is 0. We have overfit our model on the dataset!

desmos-graph-2

L1 Regularization

Keeping our hypothesis function the same, let's see what happens to the cost function when we add the L1 regularization term λ|w|

J(w)= 12((1w3.2)2+(2w6.4)2)+λ|w|

Let's set λ = 1

The cost function shifts ever so slightly & the value of w which minimizes the cost function is 3. This little nudge prevents the model from overfitting on training data.

desmos-graph-7

The hypothesis function now shifts from the true function, preventing overfitting.

desmos-graph-4

L2 Regularization

In L2, only the regularization term differs from the L1 regularization term and is now λw2

J(w)= 12((1w3.2)2+(2w6.4)2)+λw2

Let's set λ = 1

Here too, the cost function shifts slightly, such that the value of w which minimizes the cost function is 2.28571. This nudge prevents the model parameters to be optimized.

desmos-graph-6

The hypothesis function now shifts from the true function, preventing overfitting.

desmos-graph-3

Lambda and "feature selection"

The λ we saw in L1, L2 is the regularization strength. Increasing it makes the model simpler i.e., moves some parameters towards 0 aka feature selection. But there's a simple distinction here

Why is that?

Let's take the same cost function from above. At some value of λ, the slope of the cost function for w > 0 should be > 0 and w < 0 should be < 0. This has to hold true for the slope/derivative at w = 0 to be 0.

desmos-graph-8

Derivative of L1 regularized cost function

dL1dw=5w16+λ·sign(w)

For sign(w)>0

dL1dw=5w16+λ

And for sign(w)<0

dL1dw=5w16λ

Since this is an absolute value function, the function at w=0 is non differentiable. The slope could be in the range

16+λ016λ

For a λ16, the slope range will include 0. Which means for λ16 at w=0 the slope is 0 i.e., the minimum of cost function is found at w=0. This is how for large lambda values, feature selection happens in L1!

desmos-graph-5

Derivative of L2 regularized cost function

In L2 regularization, the derivative of the cost function is

dL2dw=5w16+2λw

Let's set derivative to 0 (to find minimum):

5w16+2λw=0

Solve for w:

w=165+2λ

In this case, w = 0 iff λ tends to infinity. Hence, in L2 regularization, parameters are rarely pushed to exactly 0.

With Gratitude
To my family and friends for their support

Sources