diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index c3c8462..e318724 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -100,7 +100,9 @@ def can_cast(from_: DType | Array, to: DType, /) -> bool: # These are internal objects for the return types of finfo and iinfo, since # the NumPy versions contain extra data that isn't part of the spec. -@dataclass +# There should be no expectation for them to be comparable, hashable, or writeable. + +@dataclass(frozen=True, eq=False) class finfo_object: bits: int # Note: The types of the float data here are float, whereas in NumPy they @@ -111,14 +113,18 @@ class finfo_object: smallest_normal: float dtype: DType + __hash__ = NotImplemented + -@dataclass +@dataclass(frozen=True, eq=False) class iinfo_object: bits: int max: int min: int dtype: DType + __hash__ = NotImplemented + def finfo(type: DType | Array, /) -> finfo_object: """ @@ -126,7 +132,13 @@ def finfo(type: DType | Array, /) -> finfo_object: See its docstring for more information. """ - np_type = type._array if isinstance(type, Array) else type._np_dtype + if isinstance(type, Array): + np_type = type._dtype._np_dtype + elif isinstance(type, DType): + np_type = type._np_dtype + else: + raise TypeError(f"'type' must be a dtype or array, not {type!r}") + fi = np.finfo(np_type) # Note: The types of the float data here are float, whereas in NumPy they # are scalars of the corresponding float dtype. @@ -146,7 +158,13 @@ def iinfo(type: DType | Array, /) -> iinfo_object: See its docstring for more information. """ - np_type = type._array if isinstance(type, Array) else type._np_dtype + if isinstance(type, Array): + np_type = type._dtype._np_dtype + elif isinstance(type, DType): + np_type = type._np_dtype + else: + raise TypeError(f"'type' must be a dtype or array, not {type!r}") + ii = np.iinfo(np_type) return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype)) diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 7bed828..564db5a 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -35,7 +35,7 @@ def __eq__(self, other: object) -> builtins.bool: stacklevel=2, ) if not isinstance(other, DType): - return NotImplemented + return False return self._np_dtype == other._np_dtype def __hash__(self) -> int: diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 919c0b4..7f24920 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -1,15 +1,12 @@ import warnings +import numpy as np import pytest - from numpy.testing import assert_raises -import numpy as np from .._creation_functions import asarray -from .._data_type_functions import astype, can_cast, isdtype, result_type -from .._dtypes import ( - bool, int8, int16, uint8, float64, int64 -) +from .._data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type +from .._dtypes import bool, float64, int8, int16, int64, uint8 from .._flags import set_array_api_strict_flags @@ -88,3 +85,40 @@ def test_result_type_py_scalars(api_version): with pytest.raises(TypeError): result_type(int64, True) + + +def test_finfo_iinfo_dtypelike(): + """np.finfo() and np.iinfo() accept any DTypeLike. + Array API only accepts Array | DType. + """ + match = "must be a dtype or array" + with pytest.raises(TypeError, match=match): + finfo("float64") + with pytest.raises(TypeError, match=match): + finfo(float) + with pytest.raises(TypeError, match=match): + iinfo("int8") + with pytest.raises(TypeError, match=match): + iinfo(int) + + +def test_finfo_iinfo_wrap_output(): + """Test that the finfo(...).dtype and iinfo(...).dtype + are array-api-strict.DType objects; not numpy.dtype. + """ + # Note: array_api_strict.DType objects are not singletons + assert finfo(float64).dtype == float64 + assert iinfo(int8).dtype == int8 + + +@pytest.mark.parametrize("func,arg", [(finfo, float64), (iinfo, int8)]) +def test_finfo_iinfo_output_assumptions(func, arg): + """There should be no expectation for the output of finfo()/iinfo() + to be comparable, hashable, or writeable. + """ + obj = func(arg) + assert obj != func(arg) # Defaut behaviour for custom classes + with pytest.raises(TypeError): + hash(obj) + with pytest.raises(Exception, match="cannot assign"): + obj.min = 0