Skip to content

Commit 10bdc61

Browse files
authored
Vectorize light curve functions (#141)
* starting to implement vectorization for light curves * Adding tests and getting things working * Updating docstrings * remove manual vmaps from tutorial
1 parent 63c0a8a commit 10bdc61

File tree

14 files changed

+105
-22
lines changed

14 files changed

+105
-22
lines changed

docs/tutorials/transit.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"# Compute a limb-darkened light curve for this orbit\n",
8585
"t = np.linspace(-0.1, 0.1, 1000)\n",
8686
"u = [0.1, 0.06] # Quadratic limb-darkening coefficients\n",
87-
"light_curve = jax.vmap(limb_dark_light_curve(orbit, u))(t)\n",
87+
"light_curve = limb_dark_light_curve(orbit, u)(t)\n",
8888
"\n",
8989
"# Plot the light curve\n",
9090
"plt.figure(dpi=150)\n",
@@ -129,7 +129,7 @@
129129
"orbit = TransitOrbit(\n",
130130
" period=PERIOD, duration=DURATION, time_transit=T0, impact_param=B, radius=ROR\n",
131131
")\n",
132-
"y_true = jax.vmap(limb_dark_light_curve(orbit, U))(t)\n",
132+
"y_true = limb_dark_light_curve(orbit, U)(t)\n",
133133
"y = y_true + yerr * random.normal(size=len(t))\n",
134134
"\n",
135135
"# Let's see what the light curve looks like\n",
@@ -190,7 +190,7 @@
190190
" orbit = TransitOrbit(\n",
191191
" period=period, duration=duration, time_transit=t0, impact_param=b, radius=r\n",
192192
" )\n",
193-
" y_pred = jax.vmap(limb_dark_light_curve(orbit, u))(t)\n",
193+
" y_pred = limb_dark_light_curve(orbit, u)(t)\n",
194194
"\n",
195195
" # Let's track the light curve\n",
196196
" numpyro.deterministic(\"light_curve\", y_pred)\n",
@@ -505,7 +505,7 @@
505505
" impact_param=inferred_b,\n",
506506
" radius=inferred_r,\n",
507507
")\n",
508-
"y_model = jax.vmap(limb_dark_light_curve(orbit, inferred_u))(t)\n",
508+
"y_model = limb_dark_light_curve(orbit, inferred_u)(t)\n",
509509
"\n",
510510
"fig, ax = plt.subplots(dpi=150)\n",
511511
"\n",
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from jaxoplanet.light_curves import exposure_time as exposure_time
2+
from jaxoplanet.light_curves.limb_dark import (
3+
limb_dark_light_curve as limb_dark_light_curve,
4+
)

src/jaxoplanet/exposure_time.py renamed to src/jaxoplanet/light_curves/exposure_time.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from functools import wraps
2-
from typing import Any, Optional, Protocol, TypeVar
2+
from typing import Any, Optional, Union
33

44
import jax
55
import jax.numpy as jnp
66
import jpu.numpy as jnpu
77

88
from jaxoplanet import units
9+
from jaxoplanet.light_curves.types import LightCurveFunc
10+
from jaxoplanet.light_curves.utils import vectorize
911
from jaxoplanet.types import Array, Quantity
1012
from jaxoplanet.units import unit_registry as ureg
1113

@@ -14,27 +16,21 @@
1416
except ImportError:
1517
from jax import linear_util as lu # type: ignore
1618

17-
T = TypeVar("T", Array, Quantity, covariant=True)
18-
19-
20-
class _LightCurveFunc(Protocol[T]):
21-
def __call__(self, time: Quantity, *args: Any, **kwargs: Any) -> T: ...
22-
2319

2420
@units.quantity_input(exposure_time=ureg.d)
2521
def integrate(
26-
func: _LightCurveFunc[T],
22+
func: LightCurveFunc,
2723
exposure_time: Optional[Quantity] = None,
2824
order: int = 0,
2925
num_samples: int = 7,
30-
) -> _LightCurveFunc[T]:
26+
) -> LightCurveFunc:
3127
if exposure_time is None:
3228
return func
3329

3430
if jnpu.ndim(exposure_time) != 0:
3531
raise ValueError(
3632
"The exposure time passed to 'integrate_exposure_time' has shape "
37-
f"{jnpu.shape(exposure_time)}, but a scalar was expected; " # type: ignore
33+
f"{jnpu.shape(exposure_time)}, but a scalar was expected; "
3834
"To use exposure time integration with different exposures at different "
3935
"times, manually 'vmap' or 'vectorize' the function"
4036
)
@@ -65,13 +61,14 @@ def integrate(
6561

6662
@wraps(func)
6763
@units.quantity_input(time=ureg.d)
68-
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> T:
64+
@vectorize
65+
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
6966
if jnpu.ndim(time) != 0:
7067
raise ValueError(
7168
"The time passed to 'integrate_exposure_time' has shape "
72-
f"{jnpu.shape(time)}, but a scalar was expected; " # type: ignore
73-
"To use exposure time integration for an array of times, "
74-
"manually 'vmap' or 'vectorize' the function"
69+
f"{jnpu.shape(time)}, but a scalar was expected; "
70+
"this shouldn't typically happen so please open an issue "
71+
"on GitHub demonstrating the problem"
7572
)
7673

7774
f = lu.wrap_init(jax.vmap(func, in_axes=(0,) + (None,) * len(args)))

src/jaxoplanet/light_curves.py renamed to src/jaxoplanet/light_curves/limb_dark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from jaxoplanet import units
88
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
9+
from jaxoplanet.light_curves.utils import vectorize
910
from jaxoplanet.proto import LightCurveOrbit
1011
from jaxoplanet.types import Array, Quantity
1112
from jaxoplanet.units import unit_registry as ureg
@@ -20,13 +21,14 @@ def limb_dark_light_curve(
2021
ld_u = jnp.array([])
2122

2223
@units.quantity_input(time=ureg.d)
24+
@vectorize
2325
def light_curve_impl(time: Quantity) -> Array:
2426
if jnpu.ndim(time) != 0:
2527
raise ValueError(
2628
"The time passed to 'light_curve' has shape "
2729
f"{jnpu.shape(time)}, but a scalar was expected; "
28-
"To use exposure time integration for an array of times, "
29-
"manually 'vmap' or 'vectorize' the function"
30+
"this shouldn't typically happen so please open an issue "
31+
"on GitHub demonstrating the problem"
3032
)
3133

3234
# Evaluate the coordinates of the transiting body

src/jaxoplanet/light_curves/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from collections.abc import Callable
2+
3+
LightCurveFunc = Callable

src/jaxoplanet/light_curves/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from functools import wraps
2+
from typing import Any, Union
3+
4+
import jax
5+
from jpu.core import is_quantity
6+
7+
from jaxoplanet.light_curves.types import LightCurveFunc
8+
from jaxoplanet.types import Array, Quantity
9+
10+
11+
def vectorize(func: LightCurveFunc) -> LightCurveFunc:
12+
@wraps(func)
13+
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
14+
if is_quantity(time):
15+
time_magnitude = time.magnitude
16+
time_units = time.units
17+
else:
18+
time_magnitude = time
19+
time_units = None
20+
21+
def inner(time_magnitude: Array) -> Union[Array, Quantity]:
22+
if time_units is None:
23+
return func(time_magnitude, *args, **kwargs)
24+
else:
25+
return func(time_magnitude * time_units, *args, **kwargs)
26+
27+
for _ in time.shape:
28+
inner = jax.vmap(inner)
29+
30+
return inner(time_magnitude)
31+
32+
return wrapped

tests/__init__.py

Whitespace-only changes.

tests/core/__init__.py

Whitespace-only changes.

tests/experimental/__init__.py

Whitespace-only changes.

tests/experimental/starry/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)