-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Are any of these functions worth adding here? #139
Comments
The functionality of
Other functions also look quite obscure to me as to their intent. |
I don't think so. While both functions can produce the same primals,
Leaky integration is extremely common in machine learning. It's a special case of |
Could you produce an example where the snippet below doesn't replicate the functionality of def divide_where(dividend, divisor, where, otherwise):
return _lazywhere(
where,
dividend, divisor, otherwise,
lambda n, d, o: n / d,
f2=lambda n, d, o: o,
)
To me |
As I said above: "While both functions can produce the same primals, divide_where prevents NaN on the cotangents." For example, from collections.abc import Callable
from typing import Any, cast
import jax.numpy as jnp
from array_api_compat import array_namespace
from array_api_compat.common._helpers import is_array_api_obj, is_jax_namespace
from jax import grad
from scipy._lib._util import _lazywhere
from tjax import divide_where
# Paste in apply_where from PR and make minor edits to get it working.
def divide_where_b(dividend: Any, divisor: Any, where: Any, otherwise: Any) -> Any:
return apply_where( # Had to use apply_where since _lazywhere doesn't support Jax arrays.
where,
lambda n, d, o: n / d,
lambda n, d, o: o,
dividend, divisor, otherwise,
)
x = jnp.zeros(())
def f(x: Any) -> Any:
return divide_where(x, x, where=x != 0.0, otherwise=jnp.ones(()))
def g(x: Any) -> Any:
return divide_where_b(x, x, where=x != 0.0, otherwise=jnp.ones(()))
a = grad(f)(x)
b = grad(g)(x)
print(a, b) # 0.0 nan
I agree. Ideally that |
Right. So there is no difference unless you run inside an auto-differentiation tool, which AFAIK is exclusively offered by JAX, at least among array API compliant libraries. ndonnx could potentially offer one too in the future I guess. Dask could as a graph-based engine but such a feature would be firmly out of scope for the project. Here's the problem: let's assume that array-api-extra offers both |
And PyTorch, and MLX. That said, I agree - auto-differentiation isn't in scope, that's whole new level of complexity. Seems like a bridge too far to do in a framework-agnostic way at the moment. |
Pytorch and Tensorflow are automatic differentiation libraries. I also disagree that the effect on cotangents doesn't matter. The effect on primals and cotangents are part of the API. If you implement something in a way that destroys or damages the cotangents, then that is a different API. Care should be taken to always produce the right ones since they can have dramatic effects on algorithms that use the functions.
It's not Jax-only. Most users of Pytorch, Tensorflow, and Jax are using these libraries as automatic differentiators to do machine learning. And while I recognize that it's a bit more work, I think this library should be testing that the cotangents aren't damaged because it will break normal uses of this library if the cotangents are damaged. That means that people will spend hours finding bugs in their code, they will report bugs here, and if they provide fixes, they will add tests. It would be better from a person-hours standpoint to add the tests in the first place when it's likely that something may break. Of course, you don't have to do that, but then we'll just be in the alternative world where users run into problems and submit PRs for you. I think that world is a net negative on person-hours spent.
I don't see why that would be the case. The tests may not be framework agnostic, but they're not always agnostic now anyway. Protecting cotangents is not framework specific. It's just about being a little bit careful when you're writing algorithms. For example, in the above case, you need to mentally trace the cotangents (just like you mentally trace primals). Then you'll see that the NaN cotangent is being mixed in the unit cotangent, which you can fix by not producing the NaN cotangent in the first place. |
thanks for thinking about this @NeilGirdhar, I'll close this as 'not planned' for now with the autodiff label. It sounds like we should return to this in the future if there is wider work on autodiff within the standard ecosystem, but this issue is maybe not the best place for that to start. |
Yeah, no worries, I think I got a bit carried away thinking of all the bugs I've worked through tracing cotangents. But I'll just keep any related tools in tjax for now. However, as more and more algorithms get ported to the Array API, I imagine the def stop_gradient(x: U, *, xp: ModuleType | None = None) -> U:
if xp is None:
xp = get_namespace(x)
if is_jax_array(xp):
from jax.lax import stop_gradient as sg # noqa: PLC0415
return sg(x)
if is_torch_array(xp):
from torch import Tensor # noqa: PLC0415
assert isinstance(x, Tensor)
return x.detach() # pyright: ignore
return x No rush, of course. |
I'm just wondering if any of these functions are worth adding to this project:
https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/math_tools.py
https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/leaky_integral.py
The leaky integral would need a lax.scan, which was deferred by the Array API people.
The text was updated successfully, but these errors were encountered: