Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a test for count_nonzero #347

Merged
merged 2 commits into from
Mar 3, 2025
Merged

Conversation

ev-br
Copy link
Member

@ev-br ev-br commented Mar 3, 2025

Parrot the test from test_{argmin, argmax}

Parrot the test from test_{argmin, argmax}
@ev-br ev-br marked this pull request as draft March 3, 2025 16:36
On torch, work around count_nonzero not implemented for uints
On jax, there are problems with integers > iinfo(jnp.int32)
@ev-br
Copy link
Member Author

ev-br commented Mar 3, 2025

The test is useful for data-apis/array-api-compat#267.

There are several issues with the test itself:

  • on pytorch, count_nonzero is not implemented for unsigned integer dtypes. Not much we can do about here.
  • on jax, the strategy runs into JAX problems with the lack of int64:
In [2]: jnp.asarray([2147483648], dtype=jnp.int64)
<ipython-input-2-df19af8d03da>:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in asarray is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  jnp.asarray([2147483648], dtype=jnp.int64)
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
<ipython-input-2-df19af8d03da> in ?()
----> 1 jnp.asarray([2147483648], dtype=jnp.int64)

~/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(a, dtype, order, copy, device)
   5816                       "Consider using copy=None or copy=True instead.")
   5817   dtypes.check_user_dtype_supported(dtype, "asarray")
   5818   if dtype is not None:
   5819     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 5820   return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)

~/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(object, dtype, copy, order, ndmin, device)
   5634     # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   5635     # containing large integers; see discussion in
   5636     # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
   5637     # coerce_to_array on each leaf, but this may have performance implications.
-> 5638     out = np.asarray(object, dtype=dtype)
   5639   elif isinstance(object, Array):
   5640     assert object.aval is not None
   5641     out = _array_copy(object) if copy else object

OverflowError: Python integer 2147483648 out of bounds for int32

Explicitly limiting the range of elements runs into problems with floats: hypothesis howls about not exactly representable numbers and errors out.

So an ideal fix is to cook up a strategy which both limits the range of ints and does the right thing for floats.

For now, I think, I'm going to merge this test soon to unblock data-apis/array-api-compat#267 (which probably warrants a quick bugfix release of array-api-compat, so is better done sooner than later).

The flakiness shows up locally with --max-examples 10_000 or so, so let's see how annoying the problem is on the CI

@ev-br ev-br marked this pull request as ready for review March 3, 2025 21:58
@ev-br ev-br merged commit 0b89c52 into data-apis:master Mar 3, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant