Chapter 14 Machine Learning Workflow
We assume you have loaded the following packages:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Below we load more as we introduce more.
14.1 Boston Housing Data
Here we demonstrate small task to select the best linear regression model using validation. We use Boston Housing Dataset.
= pd.read_csv("../data/boston.csv.bz2", sep="\t")
boston 3) boston.head(
## crim zn indus chas nox rm age dis rad tax ptratio black lstat medv
## 0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0
## 1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6
## 2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
Our task is to predict the average house value medv as well as we can using all other features. We pick a subset of features
= boston[["age", "rm", "zn", "medv"]] boston
And we add a few features, namely \(\mathit{age}\times\mathit{rm}\) and
"ageXrm"] = boston.age * boston.rm boston[
Now we have dataset that looks like
3) boston.sample(
## age rm zn medv ageXrm
## 357 91.0 6.395 0.0 21.7 581.9450
## 337 59.6 5.895 0.0 18.5 351.3420
## 327 43.7 6.083 0.0 22.2 265.8271
First we demonstate the workflow using training-validation approach. We split data into training and validation parts:
from sklearn.model_selection import train_test_split
## /usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.1
## warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
= boston.medv
y = boston[["age", "rm", "zn", "ageXrm"]]
X = train_test_split(y, X) yt, yv, Xt, Xv
Now we can test different models in terms of \(R^2\):
from sklearn.linear_model import LinearRegression
= LinearRegression()
m = m.fit(Xt, yt) _
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
m.score(Xv, yv)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
This was the model with all variables. We can try other combinations of variables:
= boston[["age", "rm", "zn"]] # leave out age x rm X
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= train_test_split(y, X) yt, yv, Xt, Xv
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= m.fit(Xt, yt) _
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
m.score(Xv, yv)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
and
= boston[["rm"]] # use only rm X
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= train_test_split(y, X) yt, yv, Xt, Xv
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= m.fit(Xt, yt) _
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
m.score(Xv, yv)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
14.2 Categorization: image recognition
Here we analyze MNIST digits. This is a dataset of handwritten digits, widely used for computer vision tasks. sklearn contains a low-resolution sample of this dataset:
from sklearn.datasets import load_digits
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= load_digits() mnist
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= mnist.data X
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
= mnist.target y
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
This loads the dataset and extracts the design matrix and labels y from there. We can take a look how does the data look with
X.shape
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
25] y[:
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
X tells us that we have 1797 different digits, each of which contains 64 features. These features are pixels–the image consists of \(8\times8\) pixels, in the design matrix X the images are flattened into 64 consecutive pixels as features. A sample of data looks like
7:9] X[
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
You can see many “0”-s (background) and numbers between “1” and “15”, denoting various intensity of the pen. We can easily plot these images, just we have to reshape those back into \(8\times8\) matrices. This is what is leads to
7].reshape((8,8)) X[
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
If you look at the matrix closely, you can see that it depicts a number “7”. This is much easier to see if we plot the result:
7].reshape((8,8)), cmap="gray_r") plt.imshow(X[
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
This is a low-resolution image of number “7”. Here is a larger example of the first 30 digits:
for i in range(30):
= plt.subplot(3, 10, i+1)
ax = ax.imshow(X[i].reshape((8, 8)), cmap='gray_r')
_ = ax.axis("off")
_ = ax.set_title(f"A: {y[i]}")
_ = plt.show() _
As you can see, the numbers are of low quality and a bit hard to recognize for us. Computer will do it very well though–our brains are trained with high-quality images, not with low-quality images like here.
As a first step, let’s take an easy task and separate “0”-s and “1”-s. We’ll test a few different models in terms of how well do those perform. Extract all “0”-s and “1”-s:
## Find all 0,1 -s
= np.isin(y, [0, 1])
i = X[i]
Xd01 = y[i] yd01
Next, we do training-validation split to validate our predictions:
= train_test_split(Xd01, yd01)
Xt, Xv, yt, yv Xt.shape, Xv.shape
## ((270, 64), (90, 64))
We can use logistic regression for this binary classification task, and after we fit the model, we compute accuracy:
from sklearn.linear_model import LogisticRegression
= LogisticRegression()
m = m.fit(Xt, yt) # fit on 270 training cases _
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
# validate on 90 validation cases m.score(Xv, yv)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
We got a perfect score–1.0. This means that the model is able to perfectly distinguish between these two digits. Indeed, these digits are in fact easy to distinguish, the pixel patterns look substantially different.
Let us try some more challenging tasks–to distinguish between “4”-s and “9”-s:
= np.isin(y, [4, 9]) i
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= X[i] Xd01
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= y[i] yd01
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= train_test_split(Xd01, yd01) Xt, Xv, yt, yv
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= m.fit(Xt, yt) _
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
m.score(Xv, yv)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
The result is still ridiculously good with only a single wrong prediction as shown by the confusion matrix:
from sklearn.metrics import confusion_matrix
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= m.predict(Xv) yhat
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
confusion_matrix(yv, yhat)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
The mis-categorized image is
= yhat != yv # which predictions are wrong iw
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
8, 8)), cmap='gray_r') plt.imshow(Xv[iw].reshape((
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
"off") plt.axis(
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
plt.show()
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
Indeed, even human eyes cannot tell what is the digit.
Logistic regression allows to use more than two categories–this is called multinomial logit. So instead of distinguishing between just two types of digits, we can categorize all 10 different categories:
from warnings import simplefilter
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
from sklearn.exceptions import ConvergenceWarning
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
"ignore", category=ConvergenceWarning) simplefilter(
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= train_test_split(X, y) Xt, Xv, yt, yv
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= LogisticRegression() m
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= m.fit(Xt, yt) _
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
m.score(Xv, yv)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
The results are still very-very good although we got more than a single case wrong. The confusion matrix is
= m.predict(Xv) yhat
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
confusion_matrix(yv, yhat)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
We can see that by far the most cases are on the main diagonal–the model gets most of the cases right. The most problematic cases are mispredicting “8” as “1”.
14.3 Training-validation-testing approach
We start by separating testing, or hold-out data:
= train_test_split(X, y) Xw, Xh, yw, yh
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
Xw.shape
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
Now we do not touch the hold-out data until the very end. Instead, we split the work-data into training and validation parts:
= train_test_split(Xw, yw) Xt, Xv, yt, yv
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
Xt.shape
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
Now we can test different models on training-validation data:
= LogisticRegression(C=1e9) mLogistic
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
mLogistic.fit(Xt, yt)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
mLogistic.score(Xv, yv)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
The beauty of sklearn is that is easy to try different models. Let’s try a single nearest-neighbor classifier:
from sklearn.neighbors import KNeighborsClassifier
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= KNeighborsClassifier(1) m1NN
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= m1NN.fit(Xt, yt) _
KNeighborsClassifier(n_neighbors=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=1)
m1NN.score(Xv, yv)
KNeighborsClassifier(n_neighbors=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=1)
This one achieved very good score on validation data. What about 5-nearest neighbors?
from sklearn.neighbors import KNeighborsClassifier
KNeighborsClassifier(n_neighbors=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=1)
= KNeighborsClassifier(5) m5NN
KNeighborsClassifier(n_neighbors=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=1)
= m5NN.fit(Xt, yt) _
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
m5NN.score(Xv, yv)
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
The accuracy is almost as good as in single-nearest-neighbor case. We can also try decision trees:
from sklearn.tree import DecisionTreeClassifier
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
= DecisionTreeClassifier() mTree
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
= mTree.fit(Xt, yt) _
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
mTree.score(Xv, yv)
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
Trees were clearly inferior here.
Instead of training-validation split, we can use cross-validation:
from sklearn.model_selection import cross_val_score
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
=5).mean() cross_val_score(mLogistic, Xw, yw, cv
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
=5).mean() cross_val_score(m1NN, Xw, yw, cv
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
=5).mean() cross_val_score(m5NN, Xw, yw, cv
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
=5).mean() cross_val_score(mTree, Xw, yw, cv
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
Cross-validation basically replicated training-validation split results and the best model again appears to be 1-NN. But the lead in front of 5-NN is just tiny. But we can pick 1-NN as our preferred model.
Fianlly, the hold-out data gives us the final performance measure:
m1NN.score(Xh, yh)
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
Now we have computed the final model accuracy, we should not change the model any more.