1
1
from collections .abc import Mapping
2
2
from functools import lru_cache
3
- from typing import NamedTuple , Tuple , Union
3
+ from typing import Any , NamedTuple , Sequence , Tuple , Union
4
4
from warnings import warn
5
5
6
6
from . import _array_module as xp
@@ -48,8 +48,8 @@ class EqualityMapping(Mapping):
48
48
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
49
49
"""
50
50
51
- def __init__ (self , mapping : Mapping ):
52
- keys = list ( mapping . keys ())
51
+ def __init__ (self , key_value_pairs : Sequence [ Tuple [ Any , Any ]] ):
52
+ keys = [ k for k , _ in key_value_pairs ]
53
53
for i , key in enumerate (keys ):
54
54
if not (key == key ): # specifically checking __eq__, not __neq__
55
55
raise ValueError ("Key {key!r} does not have equality with itself" )
@@ -58,23 +58,26 @@ def __init__(self, mapping: Mapping):
58
58
for other_key in other_keys :
59
59
if key == other_key :
60
60
raise ValueError ("Key {key!r} has equality with key {other_key!r}" )
61
- self ._mapping = mapping
61
+ self ._key_value_pairs = key_value_pairs
62
62
63
63
def __getitem__ (self , key ):
64
- for k , v in self ._mapping . items () :
64
+ for k , v in self ._key_value_pairs :
65
65
if key == k :
66
66
return v
67
67
else :
68
68
raise KeyError (f"{ key !r} not found" )
69
69
70
70
def __iter__ (self ):
71
- return iter ( self ._mapping )
71
+ return ( k for k , _ in self ._key_value_pairs )
72
72
73
73
def __len__ (self ):
74
- return len (self ._mapping )
74
+ return len (self ._key_value_pairs )
75
+
76
+ def __str__ (self ):
77
+ return "{" + ", " .join (f"{ k !r} : { v !r} " for k , v in self ._key_value_pairs ) + "}"
75
78
76
79
def __repr__ (self ):
77
- return f"EqualityMapping({ self . _mapping !r } )"
80
+ return f"EqualityMapping({ self } )"
78
81
79
82
80
83
_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
@@ -92,15 +95,15 @@ def __repr__(self):
92
95
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
93
96
94
97
95
- dtype_to_name = EqualityMapping ({ getattr (xp , name ): name for name in _dtype_names } )
98
+ dtype_to_name = EqualityMapping ([( getattr (xp , name ), name ) for name in _dtype_names ] )
96
99
97
100
98
101
dtype_to_scalars = EqualityMapping (
99
- {
100
- xp .bool : [bool ],
101
- ** { d : [int ] for d in all_int_dtypes } ,
102
- ** { d : [int , float ] for d in float_dtypes } ,
103
- }
102
+ [
103
+ ( xp .bool , [bool ]) ,
104
+ * [( d , [int ]) for d in all_int_dtypes ] ,
105
+ * [( d , [int , float ]) for d in float_dtypes ] ,
106
+ ]
104
107
)
105
108
106
109
@@ -134,35 +137,30 @@ class MinMax(NamedTuple):
134
137
135
138
136
139
dtype_ranges = EqualityMapping (
137
- {
138
- xp .int8 : MinMax (- 128 , + 127 ),
139
- xp .int16 : MinMax (- 32_768 , + 32_767 ),
140
- xp .int32 : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
141
- xp .int64 : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
142
- xp .uint8 : MinMax (0 , + 255 ),
143
- xp .uint16 : MinMax (0 , + 65_535 ),
144
- xp .uint32 : MinMax (0 , + 4_294_967_295 ),
145
- xp .uint64 : MinMax (0 , + 18_446_744_073_709_551_615 ),
146
- xp .float32 : MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ),
147
- xp .float64 : MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ),
148
- }
140
+ [
141
+ ( xp .int8 , MinMax (- 128 , + 127 ) ),
142
+ ( xp .int16 , MinMax (- 32_768 , + 32_767 ) ),
143
+ ( xp .int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 ) ),
144
+ ( xp .int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ) ),
145
+ ( xp .uint8 , MinMax (0 , + 255 ) ),
146
+ ( xp .uint16 , MinMax (0 , + 65_535 ) ),
147
+ ( xp .uint32 , MinMax (0 , + 4_294_967_295 ) ),
148
+ ( xp .uint64 , MinMax (0 , + 18_446_744_073_709_551_615 ) ),
149
+ ( xp .float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ) ),
150
+ ( xp .float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ) ),
151
+ ]
149
152
)
150
153
151
154
dtype_nbits = EqualityMapping (
152
- {
153
- ** {d : 8 for d in [xp .int8 , xp .uint8 ]},
154
- ** {d : 16 for d in [xp .int16 , xp .uint16 ]},
155
- ** {d : 32 for d in [xp .int32 , xp .uint32 , xp .float32 ]},
156
- ** {d : 64 for d in [xp .int64 , xp .uint64 , xp .float64 ]},
157
- }
155
+ [(d , 8 ) for d in [xp .int8 , xp .uint8 ]]
156
+ + [(d , 16 ) for d in [xp .int16 , xp .uint16 ]]
157
+ + [(d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
158
+ + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
158
159
)
159
160
160
161
161
162
dtype_signed = EqualityMapping (
162
- {
163
- ** {d : True for d in int_dtypes },
164
- ** {d : False for d in uint_dtypes },
165
- }
163
+ [(d , True ) for d in int_dtypes ] + [(d , False ) for d in uint_dtypes ]
166
164
)
167
165
168
166
@@ -186,54 +184,51 @@ class MinMax(NamedTuple):
186
184
default_uint = xp .uint64
187
185
188
186
189
- _numeric_promotions = {
187
+ _numeric_promotions = [
190
188
# ints
191
- (xp .int8 , xp .int8 ): xp .int8 ,
192
- (xp .int8 , xp .int16 ): xp .int16 ,
193
- (xp .int8 , xp .int32 ): xp .int32 ,
194
- (xp .int8 , xp .int64 ): xp .int64 ,
195
- (xp .int16 , xp .int16 ): xp .int16 ,
196
- (xp .int16 , xp .int32 ): xp .int32 ,
197
- (xp .int16 , xp .int64 ): xp .int64 ,
198
- (xp .int32 , xp .int32 ): xp .int32 ,
199
- (xp .int32 , xp .int64 ): xp .int64 ,
200
- (xp .int64 , xp .int64 ): xp .int64 ,
189
+ (( xp .int8 , xp .int8 ), xp .int8 ) ,
190
+ (( xp .int8 , xp .int16 ), xp .int16 ) ,
191
+ (( xp .int8 , xp .int32 ), xp .int32 ) ,
192
+ (( xp .int8 , xp .int64 ), xp .int64 ) ,
193
+ (( xp .int16 , xp .int16 ), xp .int16 ) ,
194
+ (( xp .int16 , xp .int32 ), xp .int32 ) ,
195
+ (( xp .int16 , xp .int64 ), xp .int64 ) ,
196
+ (( xp .int32 , xp .int32 ), xp .int32 ) ,
197
+ (( xp .int32 , xp .int64 ), xp .int64 ) ,
198
+ (( xp .int64 , xp .int64 ), xp .int64 ) ,
201
199
# uints
202
- (xp .uint8 , xp .uint8 ): xp .uint8 ,
203
- (xp .uint8 , xp .uint16 ): xp .uint16 ,
204
- (xp .uint8 , xp .uint32 ): xp .uint32 ,
205
- (xp .uint8 , xp .uint64 ): xp .uint64 ,
206
- (xp .uint16 , xp .uint16 ): xp .uint16 ,
207
- (xp .uint16 , xp .uint32 ): xp .uint32 ,
208
- (xp .uint16 , xp .uint64 ): xp .uint64 ,
209
- (xp .uint32 , xp .uint32 ): xp .uint32 ,
210
- (xp .uint32 , xp .uint64 ): xp .uint64 ,
211
- (xp .uint64 , xp .uint64 ): xp .uint64 ,
200
+ (( xp .uint8 , xp .uint8 ), xp .uint8 ) ,
201
+ (( xp .uint8 , xp .uint16 ), xp .uint16 ) ,
202
+ (( xp .uint8 , xp .uint32 ), xp .uint32 ) ,
203
+ (( xp .uint8 , xp .uint64 ), xp .uint64 ) ,
204
+ (( xp .uint16 , xp .uint16 ), xp .uint16 ) ,
205
+ (( xp .uint16 , xp .uint32 ), xp .uint32 ) ,
206
+ (( xp .uint16 , xp .uint64 ), xp .uint64 ) ,
207
+ (( xp .uint32 , xp .uint32 ), xp .uint32 ) ,
208
+ (( xp .uint32 , xp .uint64 ), xp .uint64 ) ,
209
+ (( xp .uint64 , xp .uint64 ), xp .uint64 ) ,
212
210
# ints and uints (mixed sign)
213
- (xp .int8 , xp .uint8 ): xp .int16 ,
214
- (xp .int8 , xp .uint16 ): xp .int32 ,
215
- (xp .int8 , xp .uint32 ): xp .int64 ,
216
- (xp .int16 , xp .uint8 ): xp .int16 ,
217
- (xp .int16 , xp .uint16 ): xp .int32 ,
218
- (xp .int16 , xp .uint32 ): xp .int64 ,
219
- (xp .int32 , xp .uint8 ): xp .int32 ,
220
- (xp .int32 , xp .uint16 ): xp .int32 ,
221
- (xp .int32 , xp .uint32 ): xp .int64 ,
222
- (xp .int64 , xp .uint8 ): xp .int64 ,
223
- (xp .int64 , xp .uint16 ): xp .int64 ,
224
- (xp .int64 , xp .uint32 ): xp .int64 ,
211
+ (( xp .int8 , xp .uint8 ), xp .int16 ) ,
212
+ (( xp .int8 , xp .uint16 ), xp .int32 ) ,
213
+ (( xp .int8 , xp .uint32 ), xp .int64 ) ,
214
+ (( xp .int16 , xp .uint8 ), xp .int16 ) ,
215
+ (( xp .int16 , xp .uint16 ), xp .int32 ) ,
216
+ (( xp .int16 , xp .uint32 ), xp .int64 ) ,
217
+ (( xp .int32 , xp .uint8 ), xp .int32 ) ,
218
+ (( xp .int32 , xp .uint16 ), xp .int32 ) ,
219
+ (( xp .int32 , xp .uint32 ), xp .int64 ) ,
220
+ (( xp .int64 , xp .uint8 ), xp .int64 ) ,
221
+ (( xp .int64 , xp .uint16 ), xp .int64 ) ,
222
+ (( xp .int64 , xp .uint32 ), xp .int64 ) ,
225
223
# floats
226
- (xp .float32 , xp .float32 ): xp .float32 ,
227
- (xp .float32 , xp .float64 ): xp .float64 ,
228
- (xp .float64 , xp .float64 ): xp .float64 ,
229
- }
230
- promotion_table = EqualityMapping (
231
- {
232
- (xp .bool , xp .bool ): xp .bool ,
233
- ** _numeric_promotions ,
234
- ** {(d2 , d1 ): res for (d1 , d2 ), res in _numeric_promotions .items ()},
235
- }
236
- )
224
+ ((xp .float32 , xp .float32 ), xp .float32 ),
225
+ ((xp .float32 , xp .float64 ), xp .float64 ),
226
+ ((xp .float64 , xp .float64 ), xp .float64 ),
227
+ ]
228
+ _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
229
+ _promotion_table = list (set (_numeric_promotions ))
230
+ _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
231
+ promotion_table = EqualityMapping (_promotion_table )
237
232
238
233
239
234
def result_type (* dtypes : DataType ):
0 commit comments