Hiren Nandigama's blog

Feature Scaling: Intuition behind the math

I posted this article originally on Notion, sharing here to unify

We’ve all been asked to scale our features to speed up training & steadily decrease our cost. Ever wondered why scaling plays such an important role during model training? Let’s find out!

💭 This blog assumes that you are familiar with the concepts of cost function, gradient descent algorithm

In multivariate linear regression we deal with n>=2 features. These features represent independent variables that have an effect on the dependent variable. Depending on the dataset, each of the features could have a different range of values. For example:

area in sq.ft. x1 no. of bedrooms x2 price in $ y
1800 3 5000
1500 2 4500
2500 5 5600

The cost function for this data with two parameters would look something like this, assuming we use batch gradient descent with batch size of 3

J(θ0,θ1,θ2)=16[(θ0+θ1·1800+θ2·35000)2+(θ0+θ1·1500+θ2·24500)2+(θ0+θ1·2500+θ2·55600)2]

If we simplify the above equation, all we have is a quadratic equation of the form

J(θ0,θ1,θ2)=(θ0+θ1·x1+θ2·x2y)2

In our dataset, x1 is a relatively large value when compared to x2. This makes the shape of the cost function J such that θ1 will correspond to the narrow axis and θ2 will correspond to the long axis. The function plot will look something like this where x axis represents θ1 and y axis represents θ2

feature-scaling-1 https://www.desmos.com/3d/wthv10wjki

It’s clear that small changes in θ1 have a big impact on J & relatively large changes in θ2 have a small impact. The rate of change of J along the direction of each of parameter i.e., gradient components is given by

Jθ0=2·(θ0+θ1·x1+θ2·x2)Jθ1=2·(θ0+θ1·x1+θ2·x2)·x1Jθ2=2·(θ0+θ1·x1+θ2·x2)·x2

🧠 If f is a multivariate function in x1, x2, a partial derivative of a function w.r.t. x1 is just a way of asking how much does f change w.r.t. x1 while keeping x2 constant.

Visually Jθ1 looks like this, keeping θ2 constant, where the plane cuts the 3-D curve

feature-scaling-2 https://www.desmos.com/3d/yojzjugcw5

Visually Jθ2 looks like this, keeping θ1 constant, where the plane cuts the 3-D curve

feature-scaling-3 https://www.desmos.com/3d/kwp9c9zzrm

Jθ1 is larger as rate of change of J is higher w.r.t. small changes in θ1. Conversly, Jθ2 is smaller as rate of change of J is lower w.r.t. small changes in θ2. Visually, by imagining on a 2D axis, we can see that along the narrow axis θ1 (red), a small change has a huge impact on J (vertical axis). Along the long axis θ2 (black), to change J by the same amount, we need a relatively large change in θ2.

feature-scaling-4 https://www.desmos.com/calculator/i6y8fwmd3m

Let’s plug this information into our parameter update step using gradient descent

θ1=θ1α·Jθ1θ2=θ2α·Jθ2

Faster convergence means having a relatively larger learning rate α. But we run into an issue with his, if α is too large, θ1 will overshoot as Jθ1 is also large. To counteract this, if we choose a relatively smaller learning rate α, θ2 will update very slowly leading to a very slow convergence overall.

That is why we need to scale our features. Scaling helps us choose a good learning rate without having to worry about it’s effect on each of the parameters. This leads to faster & more stable convergence.

feature-scaling-5 https://www.desmos.com/calculator/axjjlwtgyy

With Gratitude
To my family and friends for their support

Sources
Partial derivatives, introduction by Khan Academy