Ground-based search#
Show code cell source
# in order to run on all CPUs
import os
import jax
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
The dataset#
For this tutorial, we will generate a simulated dataset that consists in 5 nights of observation from a ground-based telescope.
import tinygp
import numpy as np
from nuance import core
N = 5 # number of nights
n = 500 # points per night
true_params = {
"t0": 0.3,
"D": 38 / 60 / 24,
"depth": 0.005,
"P": 0.65,
}
idxs = np.arange(1, N) * n
times = [np.linspace(0 + i, 0.5 + i, n) for i in range(N)]
time = np.concatenate(times)
error = 1e-3
# variability
kernel = tinygp.kernels.quasisep.SHO(
np.pi / true_params["D"] / 3, 1, true_params["depth"] / 2
)
variability_gp = tinygp.GaussianProcess(kernel, np.concatenate(times), diag=0)
jax_key = jax.random.PRNGKey(0)
variability = variability_gp.sample(jax_key)
variabilities = [np.random.normal(0, error) + v for v in np.split(variability, idxs)]
# explanatory variables
airmasses = [0.2 * (t - t.min()) ** 2 for t in times]
bkgs = [
tinygp.GaussianProcess(
tinygp.kernels.quasisep.SHO(20, 1, 0.005), t, diag=(1e-4) ** 2
).sample(jax.random.PRNGKey(i))
for i, t in enumerate(times)
]
fwhms = [
tinygp.GaussianProcess(
tinygp.kernels.quasisep.SHO(45, 1, 0.005), t, diag=(5e-4) ** 2
).sample(jax.random.PRNGKey(i))
for i, t in enumerate(times)
]
# transits
transits = [
true_params["depth"] * core.transit(time, 0.3, true_params["D"], true_params["P"])
for time in times
]
# systematics
systematics = [
0.2 * np.random.normal(0, 0.9, 3) @ np.vstack([airmasses[i], bkgs[i], fwhms[i]])
for i in range(N)
]
# fluxes
fluxes = [systematics[i] + variabilities[i] + transits[i] + 1.0 for i in range(N)]
fluxes = [f - np.median(f) + 1.0 for f in fluxes]
Let’s plot these observations
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 2.5))
plt.plot(np.concatenate(times), np.concatenate(fluxes), ".", c="0.8")
_ = plt.ylim(0.98, 1.02)

Fitting the GP#
As usual, we start by defining a Gaussian Process (GP) that can model the covariance of the data, in this case mostly associated with stellar variability. We choose an SHO kernel that we fit on the global light curve
import tinygp
from nuance.core import gp_model
from nuance.utils import minimize
time = np.concatenate(times)
flux = np.concatenate(fluxes)
def build_gp(params, time=time):
kernel = tinygp.kernels.quasisep.SHO(**params)
return tinygp.GaussianProcess(kernel, time, diag=error**2, mean=0.0)
initial_params = {"omega": 5.0, "quality": 5.0, "sigma": 0.02}
mu, nll = gp_model(time, flux, build_gp)
gp_params = minimize(nll, initial_params)
Let see the result
plt.figure(figsize=(8, 2.5))
plt.plot(time, flux, ".", c="0.8")
_ = plt.plot(time, mu(gp_params), c="k")

We note that we completely ignored the presence of systematic signals, using a 0
mean GP that will naturally be non-optimal (but we only know that because we simulated the dataset).
Linear search(es)#
We can now use the kernel optimized in the last section to define the Nuance
object of each observation, and run each linear search separately. We can also build each observation design matrix X
based on the airmass, bkg, fwhm contemporaneous measurements.
from nuance.linear_search import linear_search
Xs = [
np.vstack([np.ones_like(time), airmass, bkg, fwhm])
for time, airmass, bkg, fwhm in zip(times, airmasses, bkgs, fwhms)
]
gps = [build_gp(gp_params, time) for time in times]
linear_searches = []
for i in range(len(fluxes)):
epochs = times[i].copy()
durations = np.linspace(0.01, 0.1, 5)
linear_searches.append(
linear_search(times[i], fluxes[i], gp=gps[i])(epochs, durations)
)
100%|██████████| 510/510 [00:00<00:00, 1942.36it/s]
100%|██████████| 510/510 [00:00<00:00, 1874.86it/s]
100%|██████████| 510/510 [00:00<00:00, 1987.33it/s]
100%|██████████| 510/510 [00:00<00:00, 2015.66it/s]
100%|██████████| 510/510 [00:00<00:00, 2029.50it/s]
Periodic search#
Let’s combine these linear searches and perform the periodic search
from nuance.linear_search import combine_linear_searches
from nuance.periodic_search import periodic_search
ls = combine_linear_searches(*linear_searches)
periods = np.linspace(0.2, 2.5, 5000)
snr_function = jax.jit(core.snr(times, fluxes, gp=gps, X=Xs))
ps_function = periodic_search(time, durations, ls, snr_function)
snr, params = ps_function(periods)
/opt/homebrew/Caskroom/miniforge/base/envs/nuance/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
100%|██████████| 5000/5000 [00:02<00:00, 2308.22it/s]
Results#
Let’s look at the results
t0, D, P = params[np.argmax(snr)]
found = {"t0": t0, "D": D, "P": P}
print(f"\ttrue\tfound")
for param in ["t0", "D", "P"]:
print(f"{param}\t{found[param]:.3f}\t{true_params[param]:.3f} days")
plt.axvline(true_params["P"], ls="--", c="k", alpha=0.2, label="true period")
plt.axvline(found["P"], ls="--", c="C0", alpha=0.2, label="found period")
plt.plot(periods, snr)
plt.xlabel("Period (days)")
plt.ylabel("SNR")
_ = plt.legend()
true found
t0 0.299 0.300 days
D 0.033 0.026 days
P 0.650 0.650 days

Looking at this periodogram, it seems we could also have detected an alias of the true transit period (see the second peak). However, if we plot the found signal against the true transit, we notice that both periods are actually very representative of the data we have, given most transits were actually not observed.
Show code cell source
best_params = params[np.argmax(snr)]
plt.figure(figsize=(10, 4.5))
plt.plot(time, flux, ".", c="0.8")
for i in range(len(fluxes)):
systematics, transit, variability = core.separate_models(
times[i], fluxes[i], gp=gps[i], X=Xs[i]
)(*best_params)
plt.plot(times[i], systematics + variability + transit, c="C4", lw=1)
plt.plot(
times[i],
systematics + variability,
c="k",
lw=1,
label="reconstructed" if i == 0 else None,
)
plt.plot(
times[i],
transit + 1 - 0.02,
c="C4",
label="found transit" if i == 0 else None,
)
plt.plot(
times[i],
transits[i] + 1 - 0.02,
c="C0",
label="true transit" if i == 0 else None,
)
plt.ylim(0.97, 1.015)
plt.legend()
<matplotlib.legend.Legend at 0x17582e770>

We note that the transit we found has a shallower depth than the simulated one, but the fact that a detection was possible in such levels of correlated noise is satisfying.