Skip to content

Commit

Permalink
Updating some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Feb 28, 2024
1 parent 10bdc61 commit b23ada6
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
]

autoapi_dirs = ["../src"]
autoapi_ignore = ["*/experimental/*", "*_version*", "*/types*"]
autoapi_options = [
"members",
"undoc-members",
Expand All @@ -23,6 +24,7 @@
"special-members",
# "imported-members",
]
suppress_warnings = ["autoapi.python_import_resolution"]

myst_enable_extensions = ["dollarmath", "colon_fence"]
source_suffix = {
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ directly.
```{toctree}
:maxdepth: 1
tutorials/getting-started.ipynb
tutorials/autodiff.ipynb
tutorials/orbits.ipynb
tutorials/transit.ipynb
tutorials/rv.ipynb
tutorials/starry.ipynb
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/autodiff.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "exo4jax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -243,5 +243,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions docs/tutorials/core-from-scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"id": "0",
"metadata": {},
"source": [
"(core-from-scratch)=\n",
"\n",
"# Jaxoplanet core from scratch\n",
"\n",
"Inspired by the [autodidax tutorial](https://jax.readthedocs.io/en/latest/autodidax.html) from the JAX documentation, in this tutorial we work through implementing some of the core `jaxoplanet` functionality from scratch, to demonstrate and discuss the choices made within the depths of the codebase.\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"id": "0",
"metadata": {},
"source": [
"(getting-started)=\n",
"\n",
"# Getting Started"
]
},
Expand Down
8 changes: 5 additions & 3 deletions docs/tutorials/rv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(rv)=\n",
"\n",
"# Radial Velocities Fitting"
]
},
Expand Down Expand Up @@ -253,7 +255,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "jaxoplanet_docs",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -267,9 +269,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
6 changes: 4 additions & 2 deletions docs/tutorials/starry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(starry)=\n",
"\n",
"# Starry light curve\n",
"\n",
"```{warning}\n",
"Experimental features!\n",
"```\n",
Expand All @@ -12,8 +16,6 @@
"Notebook under construction!\n",
"```\n",
"\n",
"# Starry light curve\n",
"\n",
"*jaxoplanet* aims to match the features of starry, a framework to compute the light curves of systems made of of non-uniform spherical bodies. In this small tutorial we demonstrate some of these features.\n",
"\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions src/jaxoplanet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["core", "light_curves", "orbits", "units"]

from jaxoplanet import (
core as core,
light_curves as light_curves,
Expand Down
7 changes: 7 additions & 0 deletions src/jaxoplanet/core/kepler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""This module provides the core functionality to solve Kepler's equation in JAX. For
more details, see the :ref:`core-from-scratch` tutorial.
"""

__all__ = ["kepler"]

import jax
import jax.numpy as jnp
from jax.interpreters import ad

from jaxoplanet.types import Array


@jax.jit
def kepler(M: Array, ecc: Array) -> tuple[Array, Array]:
"""Solve Kepler's equation to compute the true anomaly
Expand Down
24 changes: 24 additions & 0 deletions src/jaxoplanet/core/limb_dark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""This module provides the functions needed to compute a limb darkened light curve as
described by `Agol et al. (2020) <https://arxiv.org/abs/1908.03222>`_.
"""

__all__ = ["light_curve"]

from functools import partial
from typing import Callable

Expand All @@ -12,6 +18,24 @@

@partial(jax.jit, static_argnames=("order",))
def light_curve(u: Array, b: Array, r: Array, *, order: int = 10):
"""Compute the light curve for arbitrary polynomial limb darkening
See `Agol et al. (2020) <https://arxiv.org/abs/1908.03222>`_ for more technical
details. Unlike in that paper, here we don't evaluate all the solution vector
integrals in closed form. Instead, for all but the lowest order terms, we numerically
integrate the relevant 1D line integral using Gauss-Legendre quadrature at fixed
order ``order``.
Args:
u (Array): The coefficients of the polynomial limb darkening model
b (Array): The center-to-center distance between the occultor and the occulted
body
r (Array): The radius ratio between the occultor and the occulted body
order (int): The quadrature order to use when numerically computing the the 1D
line integral
"""

u = jnp.asarray(u)
assert u.ndim == 1
if u.shape[0] == 0:
Expand Down
2 changes: 2 additions & 0 deletions src/jaxoplanet/light_curves/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This module contains models for computing and transforming light curve models"""

from jaxoplanet.light_curves import exposure_time as exposure_time
from jaxoplanet.light_curves.limb_dark import (
limb_dark_light_curve as limb_dark_light_curve,
Expand Down
15 changes: 15 additions & 0 deletions src/jaxoplanet/light_curves/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["vectorize"]

from functools import wraps
from typing import Any, Union

Expand All @@ -9,6 +11,19 @@


def vectorize(func: LightCurveFunc) -> LightCurveFunc:
"""Vectorize a scalar light curve function to work with array inputs
Like ``jax.numpy.vectorize``, this automatically wraps a function which operates on a
scalar to handle array inputs. Unlike that function, this handles ``Quantity`` inputs
and outputs, but it only broadcasts the first input (``time``).
Args:
func: A function which takes a scalar ``Quantity`` time as the first input
Returns:
An updated function which can operate on ``Quantity`` times of any shape
"""

@wraps(func)
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
if is_quantity(time):
Expand Down
17 changes: 1 addition & 16 deletions src/jaxoplanet/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
from functools import wraps
from typing import Optional, Union
__all__ = ["get_dtype_eps", "zero_safe_sqrt"]

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike


@wraps(jnp.where)
def where(
condition: ArrayLike,
x: Optional[ArrayLike],
y: Optional[ArrayLike],
*,
size: Optional[int] = None,
fill_value: Optional[Union[jax.Array, tuple[ArrayLike]]] = None,
) -> jax.Array:
"""A properly typed version of jnp.where"""
return jnp.where(condition, x, y, size=size, fill_value=fill_value) # type: ignore


def get_dtype_eps(x):
Expand Down

0 comments on commit b23ada6

Please sign in to comment.