Classifying handwritten digits

This notebook shows how giotto-tda can be used to generate topological features for image classification. We’ll be using the famous MNIST dataset, which contains images of handwritten digits and is a standard benchmark for testing new classification algorithms.

Figure 1: A few digits from the MNIST dataset. Figure reference:

If you are looking at a static version of this notebook and would like to run its contents, head over to GitHub.

Useful references

License: AGPLv3

Load the MNIST dataset

To get started, let’s fetch the MNIST dataset using one of scikit-learn’s helper functions:

from sklearn.datasets import fetch_openml

X, y = fetch_openml("mnist_784", version=1, return_X_y=True)

By looking at the shapes of these arrays

print(f"X shape: {X.shape}, y shape: {y.shape}")
X shape: (70000, 784), y shape: (70000,)

we see that there are 70,000 images, where each image has 784 features that represent pixel intensity. Let’s reshape the feature vector to a 28x28 array and visualise one of the “8” digits using giotto-tda’s plotting API:

import numpy as np
from gtda.plotting import plot_heatmap

im8_idx = np.flatnonzero(y == "8")[0]
img8 = X[im8_idx].reshape(28, 28)