1
1
from __future__ import annotations
2
2
3
- from functools import wraps as _wraps
3
+ from functools import reduce as _reduce , wraps as _wraps
4
4
from builtins import all as _builtin_all , any as _builtin_any
5
5
6
6
from ..common import _aliases
@@ -124,43 +124,35 @@ def _fix_promotion(x1, x2, only_scalar=True):
124
124
125
125
126
126
def result_type (* arrays_and_dtypes : Union [array , Dtype , bool , int , float , complex ]) -> Dtype :
127
- if len (arrays_and_dtypes ) == 0 :
128
- raise TypeError ("At least one array or dtype must be provided" )
129
- if len (arrays_and_dtypes ) == 1 :
127
+ num = len (arrays_and_dtypes )
128
+
129
+ if num == 0 :
130
+ raise ValueError ("At least one array or dtype must be provided" )
131
+
132
+ elif num == 1 :
130
133
x = arrays_and_dtypes [0 ]
131
134
if isinstance (x , torch .dtype ):
132
135
return x
133
136
return x .dtype
134
137
135
- if len (arrays_and_dtypes ) > 2 :
136
- # sort the scalars to the left so that they are treated last
137
- scalars , others = [], []
138
- for x in arrays_and_dtypes :
139
- if isinstance (x , _py_scalars ):
140
- scalars .append (x )
141
- else :
142
- others .append (x )
143
- if len (scalars ) == len (arrays_and_dtypes ):
144
- raise ValueError ("At least one array or dtype is required." )
138
+ if num == 2 :
139
+ x , y = arrays_and_dtypes
140
+ return _result_type (x , y )
145
141
146
- arrays_and_dtypes = scalars + others
147
- return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
142
+ else :
143
+ if _builtin_all (isinstance (x , _py_scalars ) for x in arrays_and_dtypes ):
144
+ raise ValueError ("At least one array or dtype must be provided" )
148
145
149
- # the binary case
150
- x , y = arrays_and_dtypes
146
+ return _reduce (_result_type , arrays_and_dtypes )
151
147
152
- if isinstance (x , _py_scalars ):
153
- if isinstance (y , _py_scalars ):
154
- raise ValueError ("At least one array or dtype is required." )
155
- return y
156
- elif isinstance (y , _py_scalars ):
157
- return x
158
148
159
- xdt = x .dtype if not isinstance (x , torch .dtype ) else x
160
- ydt = y .dtype if not isinstance (y , torch .dtype ) else y
149
+ def _result_type (x , y ):
150
+ if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
151
+ xdt = x .dtype if not isinstance (x , torch .dtype ) else x
152
+ ydt = y .dtype if not isinstance (y , torch .dtype ) else y
161
153
162
- if (xdt , ydt ) in _promotion_table :
163
- return _promotion_table [xdt , ydt ]
154
+ if (xdt , ydt ) in _promotion_table :
155
+ return _promotion_table [xdt , ydt ]
164
156
165
157
# This doesn't result_type(dtype, dtype) for non-array API dtypes
166
158
# because torch.result_type only accepts tensors. This does however, allow
@@ -169,6 +161,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
169
161
y = torch .tensor ([], dtype = y ) if isinstance (y , torch .dtype ) else y
170
162
return torch .result_type (x , y )
171
163
164
+
172
165
def can_cast (from_ : Union [Dtype , array ], to : Dtype , / ) -> bool :
173
166
if not isinstance (from_ , torch .dtype ):
174
167
from_ = from_ .dtype
0 commit comments