Skip to content

where not following Python scalar promotion rules? #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mdhaber opened this issue Mar 5, 2025 · 6 comments · Fixed by #132
Closed

where not following Python scalar promotion rules? #131

mdhaber opened this issue Mar 5, 2025 · 6 comments · Fixed by #132
Milestone

Comments

@mdhaber
Copy link

mdhaber commented Mar 5, 2025

According to the standard for where,

Image

array_api_strict.where does not seem to implement this rule.

import array_api_strict as xp
x = xp.asarray([1., 2.])
xp.where(x > 1.5, x, 0)
# TypeError: array_api_strict.float64 and array_api_strict.int64 cannot be type promoted together
@ev-br
Copy link
Member

ev-br commented Mar 5, 2025

So the relevant wording from the spec is:

Using Python scalars (i.e., instances of bool, int, float, complex) together with arrays must be supported for:

array <op> scalar


... and scalar has a type and value compatible with the array data type:


a Python int or float for real-valued floating-point array data types.

so indeed where is too strict here. PR welcome!

@ev-br
Copy link
Member

ev-br commented Mar 5, 2025

What I do wonder though is that array <op> scalar is more permissive than array <op> array:

In [9]: x = xp.asarray(1.)

In [10]: x + 0
Out[10]: Array(1., dtype=array_api_strict.float64)

In [11]: x + xp.asarray(0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
...
TypeError: array_api_strict.float64 and array_api_strict.int64 cannot be type promoted together

This was definitely discussed and rejected (at a guess, it's granularity of floating point numbers gets larger than 1 at the end for values >> 1). Anyhow, it's covered by "A conforming implementation of the array API standard may support additional type promotion rules beyond those described in this specification."

@ev-br
Copy link
Member

ev-br commented Mar 6, 2025

PR welcome!

Well it's easiest to just fix it, now that I very recently worked on a similar support for other binary functions, and still remember the details.

gh-132 seems to do the trick.

@ev-br ev-br added this to the 2.4 milestone Mar 6, 2025
@mdhaber
Copy link
Author

mdhaber commented Mar 14, 2025

I also ran into:

import array_api_strict as xp
xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)
Array([1., 2.], dtype=array_api_strict.float64)  # should be float32

Will gh-132 also fix this?
Is there reason to believe the problem is isolated to where? I tested clip (very lightly) and it seems OK, and I notice that where was a different PR from the rest.

@ev-br
Copy link
Member

ev-br commented Mar 14, 2025

Will #132 also fix this?

Yes. It basically makes where reuse scalar handling from binary functions.

In [1]: import array_api_strict as xp

In [2]: xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32))
Out[2]: Array([1., 2.], dtype=array_api_strict.float32)

In [3]: xp.__version__
Out[3]: '2.4.dev3+g33b4bf6.d20250302'

@ev-br ev-br modified the milestones: 2.4, 2.3.1 Mar 20, 2025
@mdhaber
Copy link
Author

mdhaber commented Mar 26, 2025

Thanks. I think I am able to remove all the test skips related to this issue now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants