Classifying handwritten digits ============================== This notebook shows how giotto-tda can be used to generate topological features for image classification. We’ll be using the famous MNIST dataset, which contains images of handwritten digits and is a standard benchmark for testing new classification algorithms. .. container:: .. raw:: html

Figure 1: A few digits from the MNIST dataset. Figure reference: en.wikipedia.org/wiki/MNIST_database. .. raw:: html

If you are looking at a static version of this notebook and would like to run its contents, head over to GitHub __. Useful references ----------------- - A Topological “Reading” Lesson: Classification of MNIST using TDA __ by Adélie Garin and Guillaume Tauzin - The MNIST Database of Handwritten Digits __ by Yann LeCun, Corinna Cortes, and Christopher J.C. Burges **License: AGPLv3** Load the MNIST dataset ---------------------- To get started, let’s fetch the MNIST dataset using one of scikit-learn\ ’s helper functions: .. code:: ipython3 from sklearn.datasets import fetch_openml X, y = fetch_openml("mnist_784", version=1, return_X_y=True) By looking at the shapes of these arrays .. code:: ipython3 print(f"X shape: {X.shape}, y shape: {y.shape}") .. parsed-literal:: X shape: (70000, 784), y shape: (70000,) we see that there are 70,000 images, where each image has 784 features that represent pixel intensity. Let’s reshape the feature vector to a 28x28 array and visualise one of the “8” digits using giotto-tda\ ’s plotting API: .. code:: ipython3 import numpy as np from gtda.plotting import plot_heatmap im8_idx = np.flatnonzero(y == "8") img8 = X[im8_idx].reshape(28, 28) plot_heatmap(img8) .. raw:: html