Classifying handwritten digits
==============================
This notebook shows how ``giotto-tda`` can be used to generate
topological features for image classification. We’ll be using the famous
MNIST dataset, which contains images of handwritten digits and is a
standard benchmark for testing new classification algorithms.
.. container::
.. raw:: html
Figure 1: A few digits from the MNIST dataset. Figure reference:
en.wikipedia.org/wiki/MNIST_database.
.. raw:: html

If you are looking at a static version of this notebook and would like
to run its contents, head over to
`GitHub `__.
Useful references
-----------------
- `A Topological “Reading” Lesson: Classification of MNIST using
TDA `__ by Adélie Garin and
Guillaume Tauzin
- `The MNIST Database of Handwritten
Digits `__ by Yann LeCun, Corinna
Cortes, and Christopher J.C. Burges
**License: AGPLv3**
Load the MNIST dataset
----------------------
To get started, let’s fetch the MNIST dataset using one of
``scikit-learn``\ ’s helper functions:
.. code:: ipython3
from sklearn.datasets import fetch_openml
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
By looking at the shapes of these arrays
.. code:: ipython3
print(f"X shape: {X.shape}, y shape: {y.shape}")
.. parsed-literal::
X shape: (70000, 784), y shape: (70000,)
we see that there are 70,000 images, where each image has 784 features
that represent pixel intensity. Let’s reshape the feature vector to a
28x28 array and visualise one of the “8” digits using ``giotto-tda``\ ’s
plotting API:
.. code:: ipython3
import numpy as np
from gtda.plotting import plot_heatmap
im8_idx = np.flatnonzero(y == "8")[0]
img8 = X[im8_idx].reshape(28, 28)
plot_heatmap(img8)
.. raw:: html