Skip to main content
Ctrl+K
Logo image

Get started

  • Installation
  • Motivation
  • Examples
    • Single transit search
    • Periodic transit search
    • Multi-planetary search
    • Combined search
  • Citation

Tutorials

  • GP optimization
  • Fast injection-recovery
  • Ground-based search
  • TESS search
  • Exocomet search

Reference

  • Stellar parameters
  • Model templates
  • Hardware acceleration
  • API
    • core
    • linear search
    • periodic search
    • Star
  • Repository
  • Suggest edit
  • Open issue
  • .ipynb

Multi-planetary search

Contents

  • Generating the data
  • The linear search
  • The periodic search
    • First planet
    • Second search

Multi-planetary search#

Show code cell source Hide code cell source
# in order to run on all CPUs
import os
import jax

jax.config.update("jax_enable_x64", True)
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"

In this notebook, we use nuance to search for the transiting exoplanets with the following parameters:

truth = [
    {"epoch": 0.5, "duration": 0.03, "period": 1.8, "depth": 3e-3},
    {"epoch": 0.2, "duration": 0.03, "period": 0.8, "depth": 1e-3},
]

Generating the data#

Let generate this transit signals together with correlated noise

import tinygp
import numpy as np
import matplotlib.pyplot as plt
from nuance import core

# measurements
time = np.linspace(0, 6, 1000)
error = 0.0005
transits = [
    p["depth"] * core.transit(time, p["epoch"], p["duration"], p["period"])
    for p in truth
]
diff_transit_flux = np.sum(transits, 0) + 1.0

# GP
kernel = tinygp.kernels.quasisep.SHO(
    np.pi / (truth[1]["duration"] * 5), 8.0, truth[1]["depth"] / 2
)
gp = tinygp.GaussianProcess(kernel, time, diag=error**2, mean=0.0)
flux = diff_transit_flux + gp.sample(jax.random.PRNGKey(40))

plt.figure(None, (8.5, 3))
plt.plot(time, flux, ".", c="0.8")
plt.ylabel("flux")
plt.xlabel("time")

plt.tight_layout()
../_images/2913b1cdc38572bb4c8c5a08e4abf7a074c72608aa15a8bb21ecd4acfca169a9.png

The linear search#

We start by performing the linear_search over all times (considered as potential transit epochs) and on a wide range of durations.

from nuance.linear_search import linear_search

epochs = time.copy()
durations = np.linspace(0.01, 0.2, 15)

ls = linear_search(time, flux, gp=gp)(epochs, durations)

Note

Notice how we reused the same Gaussian Process used to generate the data. In practice, a kernel need to be chosen and its hyper-parameters optimized. See the GP optimization tutorial

Whatever the number of planets we will look for, this step needs to be done only once.

The periodic search#

First planet#

We can now perform the periodic_search

import matplotlib.pyplot as plt
from nuance.periodic_search import periodic_search

periods = np.linspace(0.3, 4.0, 4000)

snr_function = jax.jit(core.snr(time, flux, gp=gp))
ps_function = periodic_search(epochs, durations, ls, snr_function)
snr, params = ps_function(periods)

t0, D, P = params[np.argmax(snr)]

print(
    "\n".join(
        [f"{n}: {v:.3f}" for n, v in zip(["epoch", "duration", "period"], [t0, D, P])]
    )
)
/home/docs/checkouts/readthedocs.org/user_builds/nuance/envs/stable/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
epoch: 0.503
duration: 0.024
period: 1.798

and plot the SNR periodogram

Show code cell source Hide code cell source
from nuance import utils

fig = plt.figure(figsize=(8.5, 4))

linear, found, noise = core.separate_models(time, flux, gp=gp)(t0, D, P)

transit_prams = truth[0]

ax = plt.subplot(121, xlabel="periods", ylabel="SNR")
ax.axvline(transit_prams["period"], c="0.8", ls="-", label="true")
ax.plot(periods, snr)
ax.legend()

ax = plt.subplot(222, xlabel="time", ylabel="flux")
ax.plot(time, flux, ".", c="0.8")
ax.plot(time, found + 1, c="k", label="found")
ax.plot(time, transits[0] + 1, c="C0", label="true")
ax.legend()

ax = plt.subplot(224, xlabel="time", ylabel="flux", xlim=(-0.1, 0.1))
phi = utils.phase(time, t0, P)
detrended = flux - noise - linear
plt.plot(phi, detrended, ".", c=".8")
bx, by, be = utils.binn_time(phi, detrended, bins=10 / 60 / 24)
plt.errorbar(bx, by, yerr=be, fmt=".", c="k")

plt.tight_layout()
/home/docs/checkouts/readthedocs.org/user_builds/nuance/envs/stable/lib/python3.10/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
../_images/dada14f488c7b3d1e6b7abd4b72013981123400c07142126805f2dc844cb6651.png

This corresponds to the first transit signal injected.

Second search#

To look for a second planetary candidate, we can mask the current candidate with

from tinygp import GaussianProcess

phase = utils.phase(epochs, t0, P)
mask = np.abs(phase) > 2 * D
masked_gp = GaussianProcess(gp.kernel, time[mask], diag=gp.variance[mask])
masked_ls = np.array(ls)[:, mask]

and perform the second periodic search

Note

Again, we only have to perform another periodic search, reusing the products of the linear search without loss of sensitivity

snr_function = jax.jit(core.snr(time[mask], flux[mask], gp=masked_gp))
ps_function = periodic_search(epochs[mask], durations, masked_ls, snr_function)
snr, params = ps_function(periods)

t0, D, P = params[np.argmax(snr)]

print(
    "\n".join(
        [f"{n}: {v:.3f}" for n, v in zip(["epoch", "duration", "period"], [t0, D, P])]
    )
)
/home/docs/checkouts/readthedocs.org/user_builds/nuance/envs/stable/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
/home/docs/checkouts/readthedocs.org/user_builds/nuance/envs/stable/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
epoch: 0.204
duration: 0.024
period: 0.800
Show code cell source Hide code cell source
from nuance import utils

fig = plt.figure(figsize=(8.5, 4))

linear, found, noise = core.separate_models(time[mask], flux[mask], gp=masked_gp)(
    t0, D, P
)

transit_prams = truth[1]

ax = plt.subplot(121, xlabel="periods", ylabel="SNR")
ax.axvline(transit_prams["period"], c="0.8", ls="-", label="true")
ax.plot(periods, snr)
ax.legend()

ax = plt.subplot(222, xlabel="time", ylabel="flux")
ax.plot(time[mask], flux[mask], ".", c="0.8")
ax.plot(time[mask], found + 1, c="k", label="found")
ax.plot(time, transits[1] + 1, c="C0", label="true")
ax.legend()

ax = plt.subplot(224, xlabel="time", ylabel="flux", xlim=(-0.1, 0.1))
phi = utils.phase(time[mask], t0, P)
detrended = flux[mask] - noise - linear
plt.plot(phi, detrended, ".", c=".8")
bx, by, be = utils.binn_time(phi, detrended, bins=10 / 60 / 24)
plt.errorbar(bx, by, yerr=be, fmt=".", c="k")

plt.tight_layout()
/home/docs/checkouts/readthedocs.org/user_builds/nuance/envs/stable/lib/python3.10/site-packages/jax/_src/core.py:691: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
../_images/81716c217d3b6421816b5106caf5269bf06715e70711c2d4923e11a8b575afdf.png

which is the second planetary transit injected.

previous

Periodic transit search

next

Combined search

Contents
  • Generating the data
  • The linear search
  • The periodic search
    • First planet
    • Second search

By Lionel Garcia

© Copyright 2023, Lionel Garcia.