Gravitational wave detection¶
Chrisopher Bresten and Jae-Hun Jung propose to include topological features in a CNN classifier for gravitational waves detection. Adapted from their article, this notebook showcases an application of ideas from the Topology of time series example.
If you are looking at a static version of this notebook and would like to run its contents, head over to GitHub and download the source.
Useful references¶
Topology of time series, in which the Takens embedding technique used here is explained in detail and illustrated via simple examples.
Detection of gravitational waves using topological data analysis and convolutional neural network: An improved approach by Christopher Bresten and Jae-Hun Jung. We thank Christopher Bresten for sharing the code and data used in the article.
See also¶
Topology in time series forecasting, in which the Takens embedding technique is used in time series forecasting tasks by using sliding windows.
Topological feature extraction using VietorisRipsPersistence and PersistenceEntropy for a quick introduction to general topological feature extraction in
giotto-tda
.
License: AGPLv3
Motivation¶
The videos below show different representations of the gravitational waves that we aim to detect. We will aim to pick out the “chirp” signal of two colliding black holes from a very noisy backgound.
from IPython.display import YouTubeVideo
YouTubeVideo("Y3eR49ogsF0", width=600, height=400)
YouTubeVideo("QyDcTbR-kEA", width=600, height=400)
Generate the data¶
In the article, the authors create a synthetic training set as follows:
Generate gravitational wave signals that correspond to non-spinning binary black hole mergers
Generate a noisy time series and embed a gravitational wave signal with probability 0.5 at a random time.
The result is a set of time series of the form
where \(g\) is a gravitational wave signal from the reference set, \(\xi\) is Gaussian noise, \(\epsilon=10^{-19}\) scales the noise amplitude to the signal, and \(R \in (0.075, 0.65)\) is a parameter that controls the signal-to-noise-ratio (SNR).
Constant signal-to-noise ratio¶
As a warmup, let’s generate some noisy signals with a constant SNR of 17.98. As shown in Table 1 of the article, this corresponds to an \(R\) value of 0.65. By picking the upper end of the interval, we place ourselves in a favorable scenario and, thus, can gain a sense for what the best possible performance is for our time series classifier. We pick a small number of samples to make the computations run fast, but in practice would scale this by 1-2 orders of magnitude as done in the original article.
from data.generate_datasets import make_gravitational_waves
from pathlib import Path
R = 0.65
n_signals = 100
DATA = Path("./data")
noisy_signals, gw_signals, labels = make_gravitational_waves(
path_to_data=DATA, n_signals=n_signals, r_min=R, r_max=R, n_snr_values=1
)
print(f"Number of noisy signals: {len(noisy_signals)}")
print(f"Number of timesteps per series: {len(noisy_signals[0])}")
Number of noisy signals: 100
Number of timesteps per series: 8692
Next let’s visualise the two different types of time series that we wish to classify: one that is pure noise vs. one that is composed of noise plus an embedded gravitational wave signal:
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
# get the index corresponding to the first pure noise time series
background_idx = np.argmin(labels)
# get the index corresponding to the first noise + gravitational wave time series
signal_idx = np.argmax(labels)
ts_noise = noisy_signals[background_idx]
ts_background = noisy_signals[signal_idx]
ts_signal = gw_signals[signal_idx]
fig = make_subplots(rows=1, cols=2)
fig.add_trace(
go.Scatter(x=list(range(len(ts_noise))), y=ts_noise, mode="lines", name="noise"),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=list(range(len(ts_background))),
y=ts_background,
mode="lines",
name="background",
),
row=1,
col=2,
)
fig.add_trace(
go.Scatter(x=list(range(len(ts_signal))), y=ts_signal, mode="lines", name="signal"),
row=1,
col=2,
)
fig.show()