Classifying handwritten digits¶
This notebook shows how giotto-tda
can be used to generate
topological features for image classification. We’ll be using the famous
MNIST dataset, which contains images of handwritten digits and is a
standard benchmark for testing new classification algorithms.
Figure 1: A few digits from the MNIST dataset. Figure reference: en.wikipedia.org/wiki/MNIST_database.
If you are looking at a static version of this notebook and would like to run its contents, head over to GitHub.
Useful references¶
A Topological “Reading” Lesson: Classification of MNIST using TDA by Adélie Garin and Guillaume Tauzin
The MNIST Database of Handwritten Digits by Yann LeCun, Corinna Cortes, and Christopher J.C. Burges
License: AGPLv3
Load the MNIST dataset¶
To get started, let’s fetch the MNIST dataset using one of
scikit-learn
’s helper functions:
from sklearn.datasets import fetch_openml
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
By looking at the shapes of these arrays
print(f"X shape: {X.shape}, y shape: {y.shape}")
X shape: (70000, 784), y shape: (70000,)
we see that there are 70,000 images, where each image has 784 features
that represent pixel intensity. Let’s reshape the feature vector to a
28x28 array and visualise one of the “8” digits using giotto-tda
’s
plotting API:
import numpy as np
from gtda.plotting import plot_heatmap
im8_idx = np.flatnonzero(y == "8")[0]
img8 = X[im8_idx].reshape(28, 28)
plot_heatmap(img8)