Skip to content

Commit

Permalink
Vectorize light curve functions (#141)
Browse files Browse the repository at this point in the history
* starting to implement vectorization for light curves

* Adding tests and getting things working

* Updating docstrings

* remove manual vmaps from tutorial
  • Loading branch information
dfm authored Feb 27, 2024
1 parent 63c0a8a commit 10bdc61
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 22 deletions.
8 changes: 4 additions & 4 deletions docs/tutorials/transit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"# Compute a limb-darkened light curve for this orbit\n",
"t = np.linspace(-0.1, 0.1, 1000)\n",
"u = [0.1, 0.06] # Quadratic limb-darkening coefficients\n",
"light_curve = jax.vmap(limb_dark_light_curve(orbit, u))(t)\n",
"light_curve = limb_dark_light_curve(orbit, u)(t)\n",
"\n",
"# Plot the light curve\n",
"plt.figure(dpi=150)\n",
Expand Down Expand Up @@ -129,7 +129,7 @@
"orbit = TransitOrbit(\n",
" period=PERIOD, duration=DURATION, time_transit=T0, impact_param=B, radius=ROR\n",
")\n",
"y_true = jax.vmap(limb_dark_light_curve(orbit, U))(t)\n",
"y_true = limb_dark_light_curve(orbit, U)(t)\n",
"y = y_true + yerr * random.normal(size=len(t))\n",
"\n",
"# Let's see what the light curve looks like\n",
Expand Down Expand Up @@ -190,7 +190,7 @@
" orbit = TransitOrbit(\n",
" period=period, duration=duration, time_transit=t0, impact_param=b, radius=r\n",
" )\n",
" y_pred = jax.vmap(limb_dark_light_curve(orbit, u))(t)\n",
" y_pred = limb_dark_light_curve(orbit, u)(t)\n",
"\n",
" # Let's track the light curve\n",
" numpyro.deterministic(\"light_curve\", y_pred)\n",
Expand Down Expand Up @@ -505,7 +505,7 @@
" impact_param=inferred_b,\n",
" radius=inferred_r,\n",
")\n",
"y_model = jax.vmap(limb_dark_light_curve(orbit, inferred_u))(t)\n",
"y_model = limb_dark_light_curve(orbit, inferred_u)(t)\n",
"\n",
"fig, ax = plt.subplots(dpi=150)\n",
"\n",
Expand Down
4 changes: 4 additions & 0 deletions src/jaxoplanet/light_curves/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from jaxoplanet.light_curves import exposure_time as exposure_time
from jaxoplanet.light_curves.limb_dark import (
limb_dark_light_curve as limb_dark_light_curve,
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from functools import wraps
from typing import Any, Optional, Protocol, TypeVar
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
import jpu.numpy as jnpu

from jaxoplanet import units
from jaxoplanet.light_curves.types import LightCurveFunc
from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.types import Array, Quantity
from jaxoplanet.units import unit_registry as ureg

Expand All @@ -14,27 +16,21 @@
except ImportError:
from jax import linear_util as lu # type: ignore

T = TypeVar("T", Array, Quantity, covariant=True)


class _LightCurveFunc(Protocol[T]):
def __call__(self, time: Quantity, *args: Any, **kwargs: Any) -> T: ...


@units.quantity_input(exposure_time=ureg.d)
def integrate(
func: _LightCurveFunc[T],
func: LightCurveFunc,
exposure_time: Optional[Quantity] = None,
order: int = 0,
num_samples: int = 7,
) -> _LightCurveFunc[T]:
) -> LightCurveFunc:
if exposure_time is None:
return func

if jnpu.ndim(exposure_time) != 0:
raise ValueError(
"The exposure time passed to 'integrate_exposure_time' has shape "
f"{jnpu.shape(exposure_time)}, but a scalar was expected; " # type: ignore
f"{jnpu.shape(exposure_time)}, but a scalar was expected; "
"To use exposure time integration with different exposures at different "
"times, manually 'vmap' or 'vectorize' the function"
)
Expand Down Expand Up @@ -65,13 +61,14 @@ def integrate(

@wraps(func)
@units.quantity_input(time=ureg.d)
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> T:
@vectorize
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
if jnpu.ndim(time) != 0:
raise ValueError(
"The time passed to 'integrate_exposure_time' has shape "
f"{jnpu.shape(time)}, but a scalar was expected; " # type: ignore
"To use exposure time integration for an array of times, "
"manually 'vmap' or 'vectorize' the function"
f"{jnpu.shape(time)}, but a scalar was expected; "
"this shouldn't typically happen so please open an issue "
"on GitHub demonstrating the problem"
)

f = lu.wrap_init(jax.vmap(func, in_axes=(0,) + (None,) * len(args)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jaxoplanet import units
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.proto import LightCurveOrbit
from jaxoplanet.types import Array, Quantity
from jaxoplanet.units import unit_registry as ureg
Expand All @@ -20,13 +21,14 @@ def limb_dark_light_curve(
ld_u = jnp.array([])

@units.quantity_input(time=ureg.d)
@vectorize
def light_curve_impl(time: Quantity) -> Array:
if jnpu.ndim(time) != 0:
raise ValueError(
"The time passed to 'light_curve' has shape "
f"{jnpu.shape(time)}, but a scalar was expected; "
"To use exposure time integration for an array of times, "
"manually 'vmap' or 'vectorize' the function"
"this shouldn't typically happen so please open an issue "
"on GitHub demonstrating the problem"
)

# Evaluate the coordinates of the transiting body
Expand Down
3 changes: 3 additions & 0 deletions src/jaxoplanet/light_curves/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from collections.abc import Callable

LightCurveFunc = Callable
32 changes: 32 additions & 0 deletions src/jaxoplanet/light_curves/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from functools import wraps
from typing import Any, Union

import jax
from jpu.core import is_quantity

from jaxoplanet.light_curves.types import LightCurveFunc
from jaxoplanet.types import Array, Quantity


def vectorize(func: LightCurveFunc) -> LightCurveFunc:
@wraps(func)
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
if is_quantity(time):
time_magnitude = time.magnitude
time_units = time.units
else:
time_magnitude = time
time_units = None

def inner(time_magnitude: Array) -> Union[Array, Quantity]:
if time_units is None:
return func(time_magnitude, *args, **kwargs)
else:
return func(time_magnitude * time_units, *args, **kwargs)

for _ in time.shape:
inner = jax.vmap(inner)

return inner(time_magnitude)

return wrapped
Empty file added tests/__init__.py
Empty file.
Empty file added tests/core/__init__.py
Empty file.
Empty file added tests/experimental/__init__.py
Empty file.
Empty file.
Empty file added tests/light_curves/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
import jax.numpy as jnp

from jaxoplanet.light_curves import limb_dark_light_curve
Expand All @@ -22,5 +21,5 @@ def test_light_curve():

# Compute a limb-darkened light curve using jaxoplanet
t = jnp.linspace(-0.3, 0.3, 1000)
lc = jax.vmap(limb_dark_light_curve(orbit, params["u"]))(t)
lc = limb_dark_light_curve(orbit, params["u"])(t)
assert lc.shape == t.shape + (1,)
46 changes: 46 additions & 0 deletions tests/light_curves/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import jax.numpy as jnp
import pytest

from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.units import quantity_input, unit_registry as ureg


@pytest.mark.parametrize("shape", [(), (10,), (10, 3)])
def test_vectorize_scalar(shape):
@quantity_input(time=ureg.day)
@vectorize
def lc(time):
assert time.shape == ()
return time.magnitude**2

time = jnp.ones(shape)
assert lc(time).shape == time.shape


@pytest.mark.parametrize("shape", [(), (10,), (10, 3)])
def test_vectorize_array(shape):
@vectorize
def lc(time):
assert time.shape == ()
return jnp.array([time, time**2, 1.0])

time = jnp.ones(shape)
assert lc(time).shape == time.shape + (3,)


@pytest.mark.parametrize("shape", [(), (10,), (10, 3)])
@pytest.mark.parametrize("out_units", [None, ureg.dimensionless, ureg.m])
def test_vectorize_quantity(shape, out_units):
@quantity_input(time=ureg.day)
@vectorize
def lc(time, units):
assert time.shape == ()
y = jnp.stack([time.magnitude, time.magnitude**2, 1.0])
if units is None:
return y
else:
return y * units

time = jnp.ones(shape)
assert lc(time, out_units).shape == time.shape + (3,)
assert lc(time, units=out_units).shape == time.shape + (3,)
Empty file added tests/orbits/__init__.py
Empty file.

0 comments on commit 10bdc61

Please sign in to comment.