Skip to content

Commit ed23bfa

Browse files
committed
Use key-value tuples for EqualityMapping as opposed to dicts
Drops internal use of hashing via dicts for dtype helpers
1 parent 4d1891b commit ed23bfa

File tree

2 files changed

+86
-83
lines changed

2 files changed

+86
-83
lines changed

Diff for: array_api_tests/dtype_helpers.py

+75-80
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Mapping
22
from functools import lru_cache
3-
from typing import NamedTuple, Tuple, Union
3+
from typing import Any, NamedTuple, Sequence, Tuple, Union
44
from warnings import warn
55

66
from . import _array_module as xp
@@ -48,8 +48,8 @@ class EqualityMapping(Mapping):
4848
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
4949
"""
5050

51-
def __init__(self, mapping: Mapping):
52-
keys = list(mapping.keys())
51+
def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]):
52+
keys = [k for k, _ in key_value_pairs]
5353
for i, key in enumerate(keys):
5454
if not (key == key): # specifically checking __eq__, not __neq__
5555
raise ValueError("Key {key!r} does not have equality with itself")
@@ -58,23 +58,26 @@ def __init__(self, mapping: Mapping):
5858
for other_key in other_keys:
5959
if key == other_key:
6060
raise ValueError("Key {key!r} has equality with key {other_key!r}")
61-
self._mapping = mapping
61+
self._key_value_pairs = key_value_pairs
6262

6363
def __getitem__(self, key):
64-
for k, v in self._mapping.items():
64+
for k, v in self._key_value_pairs:
6565
if key == k:
6666
return v
6767
else:
6868
raise KeyError(f"{key!r} not found")
6969

7070
def __iter__(self):
71-
return iter(self._mapping)
71+
return (k for k, _ in self._key_value_pairs)
7272

7373
def __len__(self):
74-
return len(self._mapping)
74+
return len(self._key_value_pairs)
75+
76+
def __str__(self):
77+
return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}"
7578

7679
def __repr__(self):
77-
return f"EqualityMapping({self._mapping!r})"
80+
return f"EqualityMapping({self})"
7881

7982

8083
_uint_names = ("uint8", "uint16", "uint32", "uint64")
@@ -92,15 +95,15 @@ def __repr__(self):
9295
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
9396

9497

95-
dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names})
98+
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])
9699

97100

98101
dtype_to_scalars = EqualityMapping(
99-
{
100-
xp.bool: [bool],
101-
**{d: [int] for d in all_int_dtypes},
102-
**{d: [int, float] for d in float_dtypes},
103-
}
102+
[
103+
(xp.bool, [bool]),
104+
*[(d, [int]) for d in all_int_dtypes],
105+
*[(d, [int, float]) for d in float_dtypes],
106+
]
104107
)
105108

106109

@@ -134,35 +137,30 @@ class MinMax(NamedTuple):
134137

135138

136139
dtype_ranges = EqualityMapping(
137-
{
138-
xp.int8: MinMax(-128, +127),
139-
xp.int16: MinMax(-32_768, +32_767),
140-
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
141-
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
142-
xp.uint8: MinMax(0, +255),
143-
xp.uint16: MinMax(0, +65_535),
144-
xp.uint32: MinMax(0, +4_294_967_295),
145-
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
146-
xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
147-
xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
148-
}
140+
[
141+
(xp.int8, MinMax(-128, +127)),
142+
(xp.int16, MinMax(-32_768, +32_767)),
143+
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
144+
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
145+
(xp.uint8, MinMax(0, +255)),
146+
(xp.uint16, MinMax(0, +65_535)),
147+
(xp.uint32, MinMax(0, +4_294_967_295)),
148+
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
149+
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
150+
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
151+
]
149152
)
150153

151154
dtype_nbits = EqualityMapping(
152-
{
153-
**{d: 8 for d in [xp.int8, xp.uint8]},
154-
**{d: 16 for d in [xp.int16, xp.uint16]},
155-
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
156-
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
157-
}
155+
[(d, 8) for d in [xp.int8, xp.uint8]]
156+
+ [(d, 16) for d in [xp.int16, xp.uint16]]
157+
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
158+
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
158159
)
159160

160161

161162
dtype_signed = EqualityMapping(
162-
{
163-
**{d: True for d in int_dtypes},
164-
**{d: False for d in uint_dtypes},
165-
}
163+
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
166164
)
167165

168166

@@ -186,54 +184,51 @@ class MinMax(NamedTuple):
186184
default_uint = xp.uint64
187185

188186

189-
_numeric_promotions = {
187+
_numeric_promotions = [
190188
# ints
191-
(xp.int8, xp.int8): xp.int8,
192-
(xp.int8, xp.int16): xp.int16,
193-
(xp.int8, xp.int32): xp.int32,
194-
(xp.int8, xp.int64): xp.int64,
195-
(xp.int16, xp.int16): xp.int16,
196-
(xp.int16, xp.int32): xp.int32,
197-
(xp.int16, xp.int64): xp.int64,
198-
(xp.int32, xp.int32): xp.int32,
199-
(xp.int32, xp.int64): xp.int64,
200-
(xp.int64, xp.int64): xp.int64,
189+
((xp.int8, xp.int8), xp.int8),
190+
((xp.int8, xp.int16), xp.int16),
191+
((xp.int8, xp.int32), xp.int32),
192+
((xp.int8, xp.int64), xp.int64),
193+
((xp.int16, xp.int16), xp.int16),
194+
((xp.int16, xp.int32), xp.int32),
195+
((xp.int16, xp.int64), xp.int64),
196+
((xp.int32, xp.int32), xp.int32),
197+
((xp.int32, xp.int64), xp.int64),
198+
((xp.int64, xp.int64), xp.int64),
201199
# uints
202-
(xp.uint8, xp.uint8): xp.uint8,
203-
(xp.uint8, xp.uint16): xp.uint16,
204-
(xp.uint8, xp.uint32): xp.uint32,
205-
(xp.uint8, xp.uint64): xp.uint64,
206-
(xp.uint16, xp.uint16): xp.uint16,
207-
(xp.uint16, xp.uint32): xp.uint32,
208-
(xp.uint16, xp.uint64): xp.uint64,
209-
(xp.uint32, xp.uint32): xp.uint32,
210-
(xp.uint32, xp.uint64): xp.uint64,
211-
(xp.uint64, xp.uint64): xp.uint64,
200+
((xp.uint8, xp.uint8), xp.uint8),
201+
((xp.uint8, xp.uint16), xp.uint16),
202+
((xp.uint8, xp.uint32), xp.uint32),
203+
((xp.uint8, xp.uint64), xp.uint64),
204+
((xp.uint16, xp.uint16), xp.uint16),
205+
((xp.uint16, xp.uint32), xp.uint32),
206+
((xp.uint16, xp.uint64), xp.uint64),
207+
((xp.uint32, xp.uint32), xp.uint32),
208+
((xp.uint32, xp.uint64), xp.uint64),
209+
((xp.uint64, xp.uint64), xp.uint64),
212210
# ints and uints (mixed sign)
213-
(xp.int8, xp.uint8): xp.int16,
214-
(xp.int8, xp.uint16): xp.int32,
215-
(xp.int8, xp.uint32): xp.int64,
216-
(xp.int16, xp.uint8): xp.int16,
217-
(xp.int16, xp.uint16): xp.int32,
218-
(xp.int16, xp.uint32): xp.int64,
219-
(xp.int32, xp.uint8): xp.int32,
220-
(xp.int32, xp.uint16): xp.int32,
221-
(xp.int32, xp.uint32): xp.int64,
222-
(xp.int64, xp.uint8): xp.int64,
223-
(xp.int64, xp.uint16): xp.int64,
224-
(xp.int64, xp.uint32): xp.int64,
211+
((xp.int8, xp.uint8), xp.int16),
212+
((xp.int8, xp.uint16), xp.int32),
213+
((xp.int8, xp.uint32), xp.int64),
214+
((xp.int16, xp.uint8), xp.int16),
215+
((xp.int16, xp.uint16), xp.int32),
216+
((xp.int16, xp.uint32), xp.int64),
217+
((xp.int32, xp.uint8), xp.int32),
218+
((xp.int32, xp.uint16), xp.int32),
219+
((xp.int32, xp.uint32), xp.int64),
220+
((xp.int64, xp.uint8), xp.int64),
221+
((xp.int64, xp.uint16), xp.int64),
222+
((xp.int64, xp.uint32), xp.int64),
225223
# floats
226-
(xp.float32, xp.float32): xp.float32,
227-
(xp.float32, xp.float64): xp.float64,
228-
(xp.float64, xp.float64): xp.float64,
229-
}
230-
promotion_table = EqualityMapping(
231-
{
232-
(xp.bool, xp.bool): xp.bool,
233-
**_numeric_promotions,
234-
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
235-
}
236-
)
224+
((xp.float32, xp.float32), xp.float32),
225+
((xp.float32, xp.float64), xp.float64),
226+
((xp.float64, xp.float64), xp.float64),
227+
]
228+
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
229+
_promotion_table = list(set(_numeric_promotions))
230+
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
231+
promotion_table = EqualityMapping(_promotion_table)
237232

238233

239234
def result_type(*dtypes: DataType):

Diff for: array_api_tests/meta/test_equality_mapping.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def test_raises_on_distinct_eq_key():
77
with pytest.raises(ValueError):
8-
EqualityMapping({float("nan"): "foo"})
8+
EqualityMapping([(float("nan"), "value")])
99

1010

1111
def test_raises_on_indistinct_eq_keys():
@@ -20,10 +20,18 @@ def __hash__(self):
2020
return self._hash
2121

2222
with pytest.raises(ValueError):
23-
EqualityMapping({AlwaysEq(0): "foo", AlwaysEq(1): "bar"})
23+
EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])
2424

2525

2626
def test_key_error():
27-
mapping = EqualityMapping({"foo": "bar"})
27+
mapping = EqualityMapping([("key", "value")])
2828
with pytest.raises(KeyError):
2929
mapping["nonexistent key"]
30+
31+
32+
def test_iter():
33+
mapping = EqualityMapping([("key", "value")])
34+
it = iter(mapping)
35+
assert next(it) == "key"
36+
with pytest.raises(StopIteration):
37+
next(it)

0 commit comments

Comments
 (0)