Skip to content

Improvements to device support #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
asmeurer opened this issue Oct 18, 2024 · 7 comments
Open

Improvements to device support #70

asmeurer opened this issue Oct 18, 2024 · 7 comments

Comments

@asmeurer
Copy link
Member

#59 added support for basic devices. Some improvements that could be made:

  • Factor out the device checking logic (into a decorator?)
  • Add devices that don't support certain dtypes (non-support for float64/complex128 being the most common)
  • Allow the user to manually create custom "devices" with certain properties (probably via the flags API).
  • Better testing for device support. This ideally should go in array-api-tests, but currently it doesn't test devices at all, so it would be useful to have some basic tests here.
@ogrisel
Copy link

ogrisel commented Oct 21, 2024

I tried to run the scikit-learn tests with array-api-strict on the non-default device and I got some failures that do not happen with PyTorch and non-default device. So I suspect there is something fishy happening with the new devices in array-api-strict but I have not investigated the root cause yet.

scikit-learn/scikit-learn#30090

@betatim
Copy link
Member

betatim commented Oct 21, 2024

I'll take a look/investigate that failure further.

@lucyleeow
Copy link
Contributor

lucyleeow commented Mar 28, 2025

Commenting here because it partially relates OP but directly related to @ogrisel 's comment #70 (comment)

Looking at: https://data-apis.org/array-api/latest/design_topics/device_support.html#device-support , the spec is very light in this area. It only specifies support for local control over data allocation target device, and syntax for device assignment. I agree that it makes sense that strict does not allow mixing of devices, but there are 2 different options here:

  1. no mixing of devices
  2. only allow the default device (note NumPy itself only allows a default device)

Current behaviour seems a bit conflicting (?), this works:

from array_api_strict import asarray, Device

asarray([1,2,3], device=Device('device1'))

this fails:

asarray([asarray([1,2], device=Device('device1')), asarray([1,2], device=Device('device1'))])

due to this check (in __array__):

if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")

this seems to be the only place where such a check is happening. And it's probably the cause of the behaviour noticed by @ogrisel where some tests on non-default devices are failing on strict but not with Torch and non-default device.

(I think we should agree on a behaviour and make it consistent)

@ev-br

(ref: noticed while working on #134 )

@ev-br
Copy link
Member

ev-br commented Mar 28, 2025

First and foremost, working on device support is very welcome @lucyleeow !

Two quick notes:

  1. Neither dunder array nor nested asarrays are standard-compliant. asarray is not strict enough #118 tracks the latter in array-api-compat and I just sent a PR to potentially clarify the spec, docs: clarify that behavior when provided a nested sequence to asarray is unspecified array-api#917

  2. My gut feeling is that allowing multiple dummy devices in array-api-strict is more useful than only allowing a single CPU_DEVICE. If only to simplify testing device support in array-api-tests (the last bullet in the OP).

@lucascolley
Copy link
Member

Agreed, having different devices is useful for testing.

asarray([asarray([1,2], device=Device('device1')), asarray([1,2], device=Device('device1'))])

is more idiomatically written as

xp.stack((xp.asarray([1, 2], device=Device('device1')), xp.asarray([1, 2], device=Device('device1')))

which should not fail

@betatim
Copy link
Member

betatim commented Mar 31, 2025

The idea behind the "multi device support" in array-api-strict is to help projects test the behaviour of using more than one device, without needing actual special hardware (like a CUDA GPU).

The decision I made was to allow arrays on the default device (CPU_DEVICE) to be converted back to Numpy arrays. The reasoning being that it was already possible (don't introduce breaking changes) and that there are libraries that allow this (so it would be useful to test this). The "non default devices" don't allow conversion to Numpy arrays because there are libraries that don't allow you to do that and there was no existing behaviour so we could do what we wanted.

I think the error you get when you write asarray([asarray([1,2], device=Device('device1')), asarray([1,2], device=Device('device1'))]) is a bit weird/hard to figure out. At least I was puzzled by it. Maybe we can improve that to help users and point them to what Lucas suggested as the way to write this? Though I was also not sure what the intended outcome of asarray([asarray([1,2], device=Device('device1')), asarray([1,2], device=Device('device1'))]) is, an array on the default device or on device1?

@lucyleeow
Copy link
Contributor

lucyleeow commented Mar 31, 2025

Thanks for the background @betatim, I think your decisions all sound reasonable.

I'm afraid the context is very boring, while working on #134, I wanted to amend test_indexing_arrays to also test device support. That test unfortunately has:

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])

which I think we now agree should have been written xp.stack([a[idx[i]] for i in range(idx.shape[0])])

I agree that the error is confusing. Indeed, this was also the error returned by bug #134 , though that one is an actual bug to be fixed.

I tried to run the scikit-learn tests with array-api-strict on the non-default device and I got some failures that do not happen with PyTorch and non-default device.

I am now interested in what test failed with array-api-strict but passed with PyTorch? @ogrisel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants