Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are any of these functions worth adding here? #139

Closed
NeilGirdhar opened this issue Feb 8, 2025 · 9 comments
Closed

Are any of these functions worth adding here? #139

NeilGirdhar opened this issue Feb 8, 2025 · 9 comments

Comments

@NeilGirdhar
Copy link
Contributor

I'm just wondering if any of these functions are worth adding to this project:

https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/math_tools.py

https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/leaky_integral.py

The leaky integral would need a lax.scan, which was deferred by the Array API people.

@lucascolley lucascolley added enhancement New feature or request new function labels Feb 8, 2025
@crusaderky
Copy link
Contributor

The functionality of divide_where is delivered by #14.

leaky_integrate sounds like something backed by some paper? At the very least, as an engineer with no specific domain knowledge I have no idea what it is meant to do. As such, it feels like it would fit naturally in scipy.

Other functions also look quite obscure to me as to their intent.

@NeilGirdhar
Copy link
Contributor Author

The functionality of divide_where is delivered by #14.

I don't think so. While both functions can produce the same primals, divide_where prevents NaN on the cotangents.

leaky_integrate sounds like something backed by some paper?

Leaky integration is extremely common in machine learning. It's a special case of lfilter. But to implement lfilter in the Array API, you need lax.scan or something like it.

@crusaderky
Copy link
Contributor

crusaderky commented Feb 10, 2025

I don't think so. While both functions can produce the same primals, divide_where prevents NaN on the cotangents.

Could you produce an example where the snippet below doesn't replicate the functionality of divide_where?

def divide_where(dividend, divisor, where, otherwise):
    return _lazywhere(
        where, 
        dividend, divisor, otherwise, 
        lambda n, d, o: n / d, 
        f2=lambda n, d, o: o,
    )

It's a special case of lfilter.

To me lax.scan sounds like a useful primitive with a place in array-api-extra; leaky_integrate definitely doesn't.
I'm unsure however on how an xpx implementation of lax.scan would look like for backends other than JAX. cython for numpy (which would be a major problem for vendoring) and iterate cell-by-cell in pure python for all other backends? That could be very slow depending on your array's size along the target axis.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 10, 2025

Could you produce an example where the snippet below doesn't replicate the functionality of divide_where?

As I said above: "While both functions can produce the same primals, divide_where prevents NaN on the cotangents." For example,

from collections.abc import Callable
from typing import Any, cast

import jax.numpy as jnp
from array_api_compat import array_namespace
from array_api_compat.common._helpers import is_array_api_obj, is_jax_namespace
from jax import grad
from scipy._lib._util import _lazywhere
from tjax import divide_where

# Paste in apply_where from PR and make minor edits to get it working.

def divide_where_b(dividend: Any, divisor: Any, where: Any, otherwise: Any) -> Any:
    return apply_where(  # Had to use apply_where since _lazywhere doesn't support Jax arrays.
        where,
        lambda n, d, o: n / d,
        lambda n, d, o: o,
        dividend, divisor, otherwise,
    )

x = jnp.zeros(())

def f(x: Any) -> Any:
    return divide_where(x, x, where=x != 0.0, otherwise=jnp.ones(()))

def g(x: Any) -> Any:
    return divide_where_b(x, x, where=x != 0.0, otherwise=jnp.ones(()))

a = grad(f)(x)
b = grad(g)(x)
print(a, b)  # 0.0 nan

To me lax.scan sounds like a useful primitive with a place in array-api-extra;

I agree. Ideally that scan could be used to efficiently implement SciPy's filters for the Array API.

@crusaderky
Copy link
Contributor

Right. So there is no difference unless you run inside an auto-differentiation tool, which AFAIK is exclusively offered by JAX, at least among array API compliant libraries. ndonnx could potentially offer one too in the future I guess. Dask could as a graph-based engine but such a feature would be firmly out of scope for the project.

Here's the problem: let's assume that array-api-extra offers both apply_where and divide_where. The only way to test that a downstream library (read: scipy) is using one or the other is with jax.grad. Would you expect all scipy functions to eventually gain a jax-only test that verifies their output when they're wrapped by jax.grad?

@rgommers
Copy link
Member

an auto-differentiation tool, which AFAIK is exclusively offered by JAX,

And PyTorch, and MLX.

That said, I agree - auto-differentiation isn't in scope, that's whole new level of complexity. Seems like a bridge too far to do in a framework-agnostic way at the moment.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 11, 2025

Right. So there is no difference unless you run inside an auto-differentiation tool, which AFAIK is exclusively offered by JAX, at least among array API compliant libraries.

Pytorch and Tensorflow are automatic differentiation libraries.

I also disagree that the effect on cotangents doesn't matter. The effect on primals and cotangents are part of the API. If you implement something in a way that destroys or damages the cotangents, then that is a different API. Care should be taken to always produce the right ones since they can have dramatic effects on algorithms that use the functions.

Would you expect all scipy functions to eventually gain a jax-only test that verifies their output when they're wrapped by jax.grad?

It's not Jax-only. Most users of Pytorch, Tensorflow, and Jax are using these libraries as automatic differentiators to do machine learning.

And while I recognize that it's a bit more work, I think this library should be testing that the cotangents aren't damaged because it will break normal uses of this library if the cotangents are damaged.

That means that people will spend hours finding bugs in their code, they will report bugs here, and if they provide fixes, they will add tests. It would be better from a person-hours standpoint to add the tests in the first place when it's likely that something may break.

Of course, you don't have to do that, but then we'll just be in the alternative world where users run into problems and submit PRs for you. I think that world is a net negative on person-hours spent.

Seems like a bridge too far to do in a framework-agnostic way at the moment.

I don't see why that would be the case. The tests may not be framework agnostic, but they're not always agnostic now anyway.

Protecting cotangents is not framework specific. It's just about being a little bit careful when you're writing algorithms. For example, in the above case, you need to mentally trace the cotangents (just like you mentally trace primals). Then you'll see that the NaN cotangent is being mixed in the unit cotangent, which you can fix by not producing the NaN cotangent in the first place.

@lucascolley
Copy link
Member

thanks for thinking about this @NeilGirdhar, I'll close this as 'not planned' for now with the autodiff label. It sounds like we should return to this in the future if there is wider work on autodiff within the standard ecosystem, but this issue is maybe not the best place for that to start.

@lucascolley lucascolley closed this as not planned Won't fix, can't repro, duplicate, stale Feb 19, 2025
@NeilGirdhar
Copy link
Contributor Author

Yeah, no worries, I think I got a bit carried away thinking of all the bugs I've worked through tracing cotangents. But I'll just keep any related tools in tjax for now.

However, as more and more algorithms get ported to the Array API, I imagine the stop_gradient will be requested pretty early on:

def stop_gradient(x: U, *, xp: ModuleType | None = None) -> U:
    if xp is None:
        xp = get_namespace(x)
    if is_jax_array(xp):
        from jax.lax import stop_gradient as sg  # noqa: PLC0415
        return sg(x)
    if is_torch_array(xp):
        from torch import Tensor  # noqa: PLC0415
        assert isinstance(x, Tensor)
        return x.detach()  # pyright: ignore
    return x

No rush, of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants