Skip to content

Commit 578e437

Browse files
committed
MAINT: finfo() / iinfo() input/output review
1 parent ea5deb1 commit 578e437

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

array_api_strict/_data_type_functions.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def can_cast(from_: DType | Array, to: DType, /) -> bool:
100100

101101
# These are internal objects for the return types of finfo and iinfo, since
102102
# the NumPy versions contain extra data that isn't part of the spec.
103-
@dataclass
103+
@dataclass(frozen=True, eq=False)
104104
class finfo_object:
105105
bits: int
106106
# Note: The types of the float data here are float, whereas in NumPy they
@@ -111,22 +111,32 @@ class finfo_object:
111111
smallest_normal: float
112112
dtype: DType
113113

114+
__hash__ = NotImplemented
114115

115-
@dataclass
116+
117+
@dataclass(frozen=True, eq=False)
116118
class iinfo_object:
117119
bits: int
118120
max: int
119121
min: int
120122
dtype: DType
121123

124+
__hash__ = NotImplemented
125+
122126

123127
def finfo(type: DType | Array, /) -> finfo_object:
124128
"""
125129
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
126130
127131
See its docstring for more information.
128132
"""
129-
np_type = type._array if isinstance(type, Array) else type._np_dtype
133+
if isinstance(type, Array):
134+
np_type = type._dtype._np_dtype
135+
elif isinstance(type, DType):
136+
np_type = type._np_dtype
137+
else:
138+
raise TypeError(f"'type' must be a dtype or array, not {type!r}")
139+
130140
fi = np.finfo(np_type)
131141
# Note: The types of the float data here are float, whereas in NumPy they
132142
# are scalars of the corresponding float dtype.
@@ -146,7 +156,13 @@ def iinfo(type: DType | Array, /) -> iinfo_object:
146156
147157
See its docstring for more information.
148158
"""
149-
np_type = type._array if isinstance(type, Array) else type._np_dtype
159+
if isinstance(type, Array):
160+
np_type = type._dtype._np_dtype
161+
elif isinstance(type, DType):
162+
np_type = type._np_dtype
163+
else:
164+
raise TypeError(f"'type' must be a dtype or array, not {type!r}")
165+
150166
ii = np.iinfo(np_type)
151167
return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype))
152168

array_api_strict/_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __eq__(self, other: object) -> builtins.bool:
3535
stacklevel=2,
3636
)
3737
if not isinstance(other, DType):
38-
return NotImplemented
38+
return False
3939
return self._np_dtype == other._np_dtype
4040

4141
def __hash__(self) -> int:

array_api_strict/tests/test_data_type_functions.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import warnings
22

3+
import numpy as np
34
import pytest
4-
55
from numpy.testing import assert_raises
6-
import numpy as np
76

87
from .._creation_functions import asarray
9-
from .._data_type_functions import astype, can_cast, isdtype, result_type
10-
from .._dtypes import (
11-
bool, int8, int16, uint8, float64, int64
12-
)
8+
from .._data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type
9+
from .._dtypes import DType, bool, float64, int8, int16, int64, uint8
1310
from .._flags import set_array_api_strict_flags
1411

1512

@@ -88,3 +85,40 @@ def test_result_type_py_scalars(api_version):
8885

8986
with pytest.raises(TypeError):
9087
result_type(int64, True)
88+
89+
90+
def test_finfo_iinfo_dtypelike():
91+
"""np.finfo() and np.iinfo() accept any DTypeLike.
92+
Array API only accepts Array | DType.
93+
"""
94+
match = "must be a dtype or array"
95+
with pytest.raises(TypeError, match=match):
96+
finfo("float64")
97+
with pytest.raises(TypeError, match=match):
98+
finfo(float)
99+
with pytest.raises(TypeError, match=match):
100+
iinfo("int8")
101+
with pytest.raises(TypeError, match=match):
102+
iinfo(int)
103+
104+
105+
def test_finfo_iinfo_wrap_output():
106+
"""Test that the finfo(...).dtype and iinfo(...).dtype
107+
are array-api-strict.DType objects; not numpy.dtype.
108+
"""
109+
# Note: array_api_strict.DType objects are not singletons
110+
assert finfo(float64).dtype == float64
111+
assert iinfo(int8).dtype == int8
112+
113+
114+
@pytest.mark.parametrize("func,arg", [(finfo, float64), (iinfo, int8)])
115+
def test_finfo_iinfo_output_assumptions(func, arg):
116+
"""There should be no expectation for the output of finfo()/iinfo()
117+
to be comparable, hashable, or writeable.
118+
"""
119+
obj = func(arg)
120+
assert obj != func(arg) # Defaut behaviour for custom classes
121+
with pytest.raises(TypeError):
122+
hash(obj)
123+
with pytest.raises(Exception, match="cannot assign"):
124+
obj.min = 0

0 commit comments

Comments
 (0)