Skip to content

Add device kwarg support to can_cast and result_type #691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/array_api_stubs/_draft/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ def astype(
"""


def can_cast(from_: Union[dtype, array], to: dtype, /) -> bool:
def can_cast(
from_: Union[dtype, array],
to: dtype,
/,
*,
device: Optional[device] = None,
) -> bool:
"""
Determines if one data type can be cast to another data type according :ref:`type-promotion` rules.

Expand All @@ -70,11 +76,16 @@ def can_cast(from_: Union[dtype, array], to: dtype, /) -> bool:
input data type or array from which to cast.
to: dtype
desired data type.
device: Optional[device]
device on which to perform a cast. If ``device`` is ``None``, the function must apply :ref:`type-promotion` rules irrespective of device capabilities. If ``device`` is a device object, the function must determine whether a cast can be performed on the specified device. Default: ``None``.

.. note::
While specification-conforming array libraries are expected to support all data types included in this specification, array libraries may support devices which do not have full data type support. Accordingly, the inclusion of a ``device`` keyword argument allows downstream array API consumers to test casting support for particular devices.

Returns
-------
out: bool
``True`` if the cast can occur according to :ref:`type-promotion` rules; otherwise, ``False``.
``True`` if the cast can occur; otherwise, ``False``.
"""


Expand Down Expand Up @@ -206,7 +217,10 @@ def isdtype(
"""


def result_type(*arrays_and_dtypes: Union[array, dtype]) -> dtype:
def result_type(
*arrays_and_dtypes: Union[array, dtype],
device: Optional[device] = None,
) -> dtype:
"""
Returns the dtype that results from applying the type promotion rules (see :ref:`type-promotion`) to the arguments.

Expand All @@ -217,6 +231,11 @@ def result_type(*arrays_and_dtypes: Union[array, dtype]) -> dtype:
----------
arrays_and_dtypes: Union[array, dtype]
an arbitrary number of input arrays and/or dtypes.
device: Optional[device]
device on which to apply type promotion rules. If ``device`` is ``None``, the function must apply :ref:`type-promotion` rules irrespective of device capabilities. If ``device`` is a device object, the function must apply type promotion rules with respect to the specified device. Default: ``None``.

.. note::
While specification-conforming array libraries are expected to support all data types included in this specification, array libraries may support devices which do not have full data type support. Accordingly, the inclusion of a ``device`` keyword argument allows downstream array API consumers to apply type promotions with respect to particular devices.

Returns
-------
Expand Down