Skip to content

Commit c1adc04

Browse files
crusaderkyNeilGirdhar
authored andcommitted
ENH: New function default_dtype (#310)
* default_dtype * tweak comment in conftest
1 parent 9767271 commit c1adc04

File tree

5 files changed

+107
-15
lines changed

5 files changed

+107
-15
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
broadcast_shapes
1313
cov
1414
create_diagonal
15+
default_dtype
1516
expand_dims
1617
isclose
1718
kron

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
broadcast_shapes,
99
cov,
1010
create_diagonal,
11+
default_dtype,
1112
expand_dims,
1213
kron,
1314
nunique,
@@ -27,6 +28,7 @@
2728
"broadcast_shapes",
2829
"cov",
2930
"create_diagonal",
31+
"default_dtype",
3032
"expand_dims",
3133
"isclose",
3234
"kron",

src/array_api_extra/_lib/_funcs.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from collections.abc import Callable, Sequence
66
from types import ModuleType, NoneType
7-
from typing import cast, overload
7+
from typing import Literal, cast, overload
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
@@ -16,7 +16,7 @@
1616
meta_namespace,
1717
ndindex,
1818
)
19-
from ._utils._typing import Array
19+
from ._utils._typing import Array, Device, DType
2020

2121
__all__ = [
2222
"apply_where",
@@ -438,6 +438,44 @@ def create_diagonal(
438438
return xp.reshape(diag, (*batch_dims, n, n))
439439

440440

441+
def default_dtype(
442+
xp: ModuleType,
443+
kind: Literal[
444+
"real floating", "complex floating", "integral", "indexing"
445+
] = "real floating",
446+
*,
447+
device: Device | None = None,
448+
) -> DType:
449+
"""
450+
Return the default dtype for the given namespace and device.
451+
452+
This is a convenience shorthand for
453+
``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``.
454+
455+
Parameters
456+
----------
457+
xp : array_namespace
458+
The standard-compatible namespace for which to get the default dtype.
459+
kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional
460+
The kind of dtype to return. Default is 'real floating'.
461+
device : Device, optional
462+
The device for which to get the default dtype. Default: current device.
463+
464+
Returns
465+
-------
466+
dtype
467+
The default dtype for the given namespace, kind, and device.
468+
"""
469+
dtypes = xp.__array_namespace_info__().default_dtypes(device=device)
470+
try:
471+
return dtypes[kind]
472+
except KeyError as e:
473+
domain = ("real floating", "complex floating", "integral", "indexing")
474+
assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}"
475+
msg = f"Unknown kind '{kind}'. Expected one of {domain}."
476+
raise ValueError(msg) from e
477+
478+
441479
def expand_dims(
442480
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
443481
) -> Array:
@@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
728766
x = xp.reshape(x, (-1,))
729767
x = xp.sort(x)
730768
mask = x != xp.roll(x, -1)
731-
default_int = xp.__array_namespace_info__().default_dtypes(
732-
device=_compat.device(x)
733-
)["integral"]
769+
default_int = default_dtype(xp, "integral", device=_compat.device(x))
734770
return xp.maximum(
735771
# Special cases:
736772
# - array is size 0

tests/conftest.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,8 @@ def xp(
149149

150150
if library.like(Backend.JAX):
151151
_setup_jax(library)
152-
153-
elif library == Backend.TORCH_GPU:
154-
import torch.cuda
155-
156-
if not torch.cuda.is_available():
157-
pytest.skip("no CUDA device available")
158-
xp.set_default_device("cuda")
159-
160-
elif library == Backend.TORCH: # CPU
161-
xp.set_default_device("cpu")
152+
elif library.like(Backend.TORCH):
153+
_setup_torch(library)
162154

163155
yield xp
164156

@@ -179,6 +171,24 @@ def _setup_jax(library: Backend) -> None:
179171
jax.config.update("jax_default_device", device)
180172

181173

174+
def _setup_torch(library: Backend) -> None:
175+
import torch
176+
177+
# This is already the default, but some tests or env variables may change it.
178+
# TODO test both float32 and float64, like in scipy.
179+
torch.set_default_dtype(torch.float32)
180+
181+
if library == Backend.TORCH_GPU:
182+
import torch.cuda
183+
184+
if not torch.cuda.is_available():
185+
pytest.skip("no CUDA device available")
186+
torch.set_default_device("cuda")
187+
else:
188+
assert library == Backend.TORCH
189+
torch.set_default_device("cpu")
190+
191+
182192
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
183193
def da(
184194
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
@@ -201,6 +211,15 @@ def jnp(
201211
return xp
202212

203213

214+
@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU])
215+
def torch(request: pytest.FixtureRequest) -> ModuleType: # numpydoc ignore=PR01,RT01
216+
"""Variant of the `xp` fixture that only yields torch."""
217+
xp = pytest.importorskip("torch")
218+
xp = array_namespace(xp.empty(0))
219+
_setup_torch(request.param)
220+
return xp
221+
222+
204223
@pytest.fixture
205224
def device(
206225
library: Backend, xp: ModuleType

tests/test_funcs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
broadcast_shapes,
1818
cov,
1919
create_diagonal,
20+
default_dtype,
2021
expand_dims,
2122
isclose,
2223
kron,
@@ -517,6 +518,39 @@ def test_xp(self, xp: ModuleType):
517518
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))
518519

519520

521+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__")
522+
class TestDefaultDType:
523+
def test_basic(self, xp: ModuleType):
524+
assert default_dtype(xp) == xp.empty(0).dtype
525+
526+
def test_kind(self, xp: ModuleType):
527+
assert default_dtype(xp, "real floating") == xp.empty(0).dtype
528+
assert default_dtype(xp, "complex floating") == (xp.empty(0) * 1j).dtype
529+
assert default_dtype(xp, "integral") == xp.int64
530+
assert default_dtype(xp, "indexing") == xp.int64
531+
532+
with pytest.raises(ValueError, match="Unknown kind"):
533+
_ = default_dtype(xp, "foo") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
534+
535+
def test_device(self, xp: ModuleType, device: Device):
536+
# Note: at the moment there are no known namespaces with
537+
# device-specific default dtypes.
538+
assert default_dtype(xp, device=None) == xp.empty(0).dtype
539+
assert default_dtype(xp, device=device) == xp.empty(0).dtype
540+
541+
def test_torch(self, torch: ModuleType):
542+
xp = torch
543+
xp.set_default_dtype(xp.float64)
544+
assert default_dtype(xp) == xp.float64
545+
assert default_dtype(xp, "real floating") == xp.float64
546+
assert default_dtype(xp, "complex floating") == xp.complex128
547+
548+
xp.set_default_dtype(xp.float32)
549+
assert default_dtype(xp) == xp.float32
550+
assert default_dtype(xp, "real floating") == xp.float32
551+
assert default_dtype(xp, "complex floating") == xp.complex64
552+
553+
520554
class TestExpandDims:
521555
def test_single_axis(self, xp: ModuleType):
522556
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""

0 commit comments

Comments
 (0)