Skip to content

Commit dc71844

Browse files
authored
Merge pull request #143 from crusaderky/finfo_iinfo
MAINT: finfo() / iinfo() input/output review
2 parents d13ab1b + f5778f6 commit dc71844

File tree

3 files changed

+63
-11
lines changed

3 files changed

+63
-11
lines changed

array_api_strict/_data_type_functions.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def can_cast(from_: DType | Array, to: DType, /) -> bool:
100100

101101
# These are internal objects for the return types of finfo and iinfo, since
102102
# the NumPy versions contain extra data that isn't part of the spec.
103-
@dataclass
103+
# There should be no expectation for them to be comparable, hashable, or writeable.
104+
105+
@dataclass(frozen=True, eq=False)
104106
class finfo_object:
105107
bits: int
106108
# Note: The types of the float data here are float, whereas in NumPy they
@@ -111,22 +113,32 @@ class finfo_object:
111113
smallest_normal: float
112114
dtype: DType
113115

116+
__hash__ = NotImplemented
117+
114118

115-
@dataclass
119+
@dataclass(frozen=True, eq=False)
116120
class iinfo_object:
117121
bits: int
118122
max: int
119123
min: int
120124
dtype: DType
121125

126+
__hash__ = NotImplemented
127+
122128

123129
def finfo(type: DType | Array, /) -> finfo_object:
124130
"""
125131
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
126132
127133
See its docstring for more information.
128134
"""
129-
np_type = type._array if isinstance(type, Array) else type._np_dtype
135+
if isinstance(type, Array):
136+
np_type = type._dtype._np_dtype
137+
elif isinstance(type, DType):
138+
np_type = type._np_dtype
139+
else:
140+
raise TypeError(f"'type' must be a dtype or array, not {type!r}")
141+
130142
fi = np.finfo(np_type)
131143
# Note: The types of the float data here are float, whereas in NumPy they
132144
# are scalars of the corresponding float dtype.
@@ -146,7 +158,13 @@ def iinfo(type: DType | Array, /) -> iinfo_object:
146158
147159
See its docstring for more information.
148160
"""
149-
np_type = type._array if isinstance(type, Array) else type._np_dtype
161+
if isinstance(type, Array):
162+
np_type = type._dtype._np_dtype
163+
elif isinstance(type, DType):
164+
np_type = type._np_dtype
165+
else:
166+
raise TypeError(f"'type' must be a dtype or array, not {type!r}")
167+
150168
ii = np.iinfo(np_type)
151169
return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype))
152170

array_api_strict/_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __eq__(self, other: object) -> builtins.bool:
3535
stacklevel=2,
3636
)
3737
if not isinstance(other, DType):
38-
return NotImplemented
38+
return False
3939
return self._np_dtype == other._np_dtype
4040

4141
def __hash__(self) -> int:

array_api_strict/tests/test_data_type_functions.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import warnings
22

3+
import numpy as np
34
import pytest
4-
55
from numpy.testing import assert_raises
6-
import numpy as np
76

87
from .._creation_functions import asarray
9-
from .._data_type_functions import astype, can_cast, isdtype, result_type
10-
from .._dtypes import (
11-
bool, int8, int16, uint8, float64, int64
12-
)
8+
from .._data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type
9+
from .._dtypes import bool, float64, int8, int16, int64, uint8
1310
from .._flags import set_array_api_strict_flags
1411

1512

@@ -88,3 +85,40 @@ def test_result_type_py_scalars(api_version):
8885

8986
with pytest.raises(TypeError):
9087
result_type(int64, True)
88+
89+
90+
def test_finfo_iinfo_dtypelike():
91+
"""np.finfo() and np.iinfo() accept any DTypeLike.
92+
Array API only accepts Array | DType.
93+
"""
94+
match = "must be a dtype or array"
95+
with pytest.raises(TypeError, match=match):
96+
finfo("float64")
97+
with pytest.raises(TypeError, match=match):
98+
finfo(float)
99+
with pytest.raises(TypeError, match=match):
100+
iinfo("int8")
101+
with pytest.raises(TypeError, match=match):
102+
iinfo(int)
103+
104+
105+
def test_finfo_iinfo_wrap_output():
106+
"""Test that the finfo(...).dtype and iinfo(...).dtype
107+
are array-api-strict.DType objects; not numpy.dtype.
108+
"""
109+
# Note: array_api_strict.DType objects are not singletons
110+
assert finfo(float64).dtype == float64
111+
assert iinfo(int8).dtype == int8
112+
113+
114+
@pytest.mark.parametrize("func,arg", [(finfo, float64), (iinfo, int8)])
115+
def test_finfo_iinfo_output_assumptions(func, arg):
116+
"""There should be no expectation for the output of finfo()/iinfo()
117+
to be comparable, hashable, or writeable.
118+
"""
119+
obj = func(arg)
120+
assert obj != func(arg) # Defaut behaviour for custom classes
121+
with pytest.raises(TypeError):
122+
hash(obj)
123+
with pytest.raises(Exception, match="cannot assign"):
124+
obj.min = 0

0 commit comments

Comments
 (0)