Linear Regression: A Tale of a Transform

What is linear regression?

In machine learning, we learn(fit) a model to some data using only a small sample of the data called the training set. The model is then used to make predictions about similar but unseen data. Linear regression is a model that represents data using a stright line. Linear regression got its name for historical reasons. In simple language it means line fitting.

But why do we need learn or fit a model, in our case a line? Because a model lets us approximate the data using only a few quantities(parameters) and draw conclusions that might not be evident looking at the bulk of the data.

In the rest of the post I will show you why linear regression counts as machine learning, when it is OK to use linear regression and how it actually works.

Most typical introductions to linear regression use matrix algebra. In this post, however, I will use simple algebra and calculus to derive all equations. Interestingly, Python has a very powerful symbolic math library called SymPy that will let us do all calculation like we do them with pen and paper.

The data we will use here comes from CMU's Body Fat dataset. This data is stored in a csv file. And in order to work with this data I need to access it inside the notebook. For this purpose I am using the Pandas library.

First I import the Pandas library and then use the read_csv() method to load our data. After that I print the first few rows from the data.

In [1]:
import pandas as pn
In [104]:
# Read data and print some rows
dataframe = pn.read_csv('bodyfat.csv', sep="\s+")
dataframe.head()
Out[104]:
Density Body fat (%) Age (years) Weight (lbs) Height (inches) Neck (cm) Chest (cm) Abdomen (cm) Hip (cm) Thigh (cm) Knee (cm) Ankle (cm) Biceps (cm) Forearm (cm) Wrist (cm)
0 1.0708 12.3 23 154.25 67.75 36.2 93.1 85.2 94.5 59.0 37.3 21.9 32.0 27.4 17.1
1 1.0853 6.1 22 173.25 72.25 38.5 93.6 83.0 98.7 58.7 37.3 23.4 30.5 28.9 18.2
2 1.0414 25.3 22 154.00 66.25 34.0 95.8 87.9 99.2 59.6 38.9 24.0 28.8 25.2 16.6
3 1.0751 10.4 26 184.75 72.25 37.4 101.8 86.4 101.2 60.1 37.3 22.8 32.4 29.4 18.2
4 1.0340 28.7 24 184.25 71.25 34.4 97.3 100.0 101.9 63.2 42.2 24.0 32.2 27.7 17.7

As we can see from the data for the first 5 persons, there are a lot of columns. It'd be interesting to see how the column values are related to each other. A good approach is to use a scatter plot matrix to visualize dependencies between different columns. Data visualization is considered a crucial step while dealing with machine learning algorithms. The code below extracts a subset of the data, columns 3 to 8 and visualizes the relationship between every pair of columns from this subset using the Plotly library.

In [3]:
import numpy as np
df = pn.DataFrame(dataframe, columns=list(dataframe.columns.values)[3:8])
In [4]:
import plotly as py
import plotly.figure_factory as ff
import plotly.graph_objs as go
py.offline.init_notebook_mode()

fig = ff.create_scatterplotmatrix(df,diag='histogram',height=800, width=800)
py.offline.iplot(fig)