Implement helper @as_jax_op
to wrap JAX functions in PyTensor
#537
Labels
@as_jax_op
to wrap JAX functions in PyTensor
#537
Description
This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same:
Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in
@as_op
:4. Input and outputs types
5. infer_shape
The text was updated successfully, but these errors were encountered: