-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Vectorize light curve functions (#141)
* starting to implement vectorization for light curves * Adding tests and getting things working * Updating docstrings * remove manual vmaps from tutorial
- Loading branch information
Showing
14 changed files
with
105 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from jaxoplanet.light_curves import exposure_time as exposure_time | ||
from jaxoplanet.light_curves.limb_dark import ( | ||
limb_dark_light_curve as limb_dark_light_curve, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from collections.abc import Callable | ||
|
||
LightCurveFunc = Callable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from functools import wraps | ||
from typing import Any, Union | ||
|
||
import jax | ||
from jpu.core import is_quantity | ||
|
||
from jaxoplanet.light_curves.types import LightCurveFunc | ||
from jaxoplanet.types import Array, Quantity | ||
|
||
|
||
def vectorize(func: LightCurveFunc) -> LightCurveFunc: | ||
@wraps(func) | ||
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]: | ||
if is_quantity(time): | ||
time_magnitude = time.magnitude | ||
time_units = time.units | ||
else: | ||
time_magnitude = time | ||
time_units = None | ||
|
||
def inner(time_magnitude: Array) -> Union[Array, Quantity]: | ||
if time_units is None: | ||
return func(time_magnitude, *args, **kwargs) | ||
else: | ||
return func(time_magnitude * time_units, *args, **kwargs) | ||
|
||
for _ in time.shape: | ||
inner = jax.vmap(inner) | ||
|
||
return inner(time_magnitude) | ||
|
||
return wrapped |
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from jaxoplanet.light_curves.utils import vectorize | ||
from jaxoplanet.units import quantity_input, unit_registry as ureg | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(), (10,), (10, 3)]) | ||
def test_vectorize_scalar(shape): | ||
@quantity_input(time=ureg.day) | ||
@vectorize | ||
def lc(time): | ||
assert time.shape == () | ||
return time.magnitude**2 | ||
|
||
time = jnp.ones(shape) | ||
assert lc(time).shape == time.shape | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(), (10,), (10, 3)]) | ||
def test_vectorize_array(shape): | ||
@vectorize | ||
def lc(time): | ||
assert time.shape == () | ||
return jnp.array([time, time**2, 1.0]) | ||
|
||
time = jnp.ones(shape) | ||
assert lc(time).shape == time.shape + (3,) | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(), (10,), (10, 3)]) | ||
@pytest.mark.parametrize("out_units", [None, ureg.dimensionless, ureg.m]) | ||
def test_vectorize_quantity(shape, out_units): | ||
@quantity_input(time=ureg.day) | ||
@vectorize | ||
def lc(time, units): | ||
assert time.shape == () | ||
y = jnp.stack([time.magnitude, time.magnitude**2, 1.0]) | ||
if units is None: | ||
return y | ||
else: | ||
return y * units | ||
|
||
time = jnp.ones(shape) | ||
assert lc(time, out_units).shape == time.shape + (3,) | ||
assert lc(time, units=out_units).shape == time.shape + (3,) |
Empty file.