@@ -59,95 +59,92 @@ def test_xp(self, xp: ModuleType):
59
59
xp_assert_equal (actual , expected )
60
60
61
61
62
- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
63
- @pytest .mark .parametrize (
64
- ("dtype" , "b" , "defined" ),
65
- [
66
- # Well-defined cases of dtype promotion from Python scalar to Array
67
- # bool vs. bool
68
- ("bool" , True , True ),
69
- # int vs. xp.*int*, xp.float*, xp.complex*
70
- ("int16" , 1 , True ),
71
- ("uint8" , 1 , True ),
72
- ("float32" , 1 , True ),
73
- ("float64" , 1 , True ),
74
- ("complex64" , 1 , True ),
75
- ("complex128" , 1 , True ),
76
- # float vs. xp.float, xp.complex
77
- ("float32" , 1.0 , True ),
78
- ("float64" , 1.0 , True ),
79
- ("complex64" , 1.0 , True ),
80
- ("complex128" , 1.0 , True ),
81
- # complex vs. xp.complex
82
- ("complex64" , 1.0j , True ),
83
- ("complex128" , 1.0j , True ),
84
- # Undefined cases
85
- ("bool" , 1 , False ),
86
- ("int64" , 1.0 , False ),
87
- ("float64" , 1.0j , False ),
88
- ],
89
- )
90
- def test_asarrays_array_vs_scalar (
91
- dtype : str , b : int | float | complex , defined : bool , xp : ModuleType
92
- ):
93
- a = xp .asarray (1 , dtype = getattr (xp , dtype ))
94
-
95
- xa , xb = asarrays (a , b , xp )
96
- assert xa .dtype == a .dtype
97
- if defined :
98
- assert xb .dtype == a .dtype
99
- else :
100
- assert xb .dtype == xp .asarray (b ).dtype
101
-
102
- xbr , xar = asarrays (b , a , xp )
103
- assert xar .dtype == xa .dtype
104
- assert xbr .dtype == xb .dtype
105
-
106
-
107
- def test_asarrays_scalar_vs_scalar (xp : ModuleType ):
108
- a , b = asarrays (1 , 2.2 , xp = xp )
109
- assert a .dtype == xp .asarray (1 ).dtype # Default dtype
110
- assert b .dtype == xp .asarray (2.2 ).dtype # Default dtype; not broadcasted
111
-
112
-
113
- ALL_TYPES = (
114
- "int8" ,
115
- "int16" ,
116
- "int32" ,
117
- "int64" ,
118
- "uint8" ,
119
- "uint16" ,
120
- "uint32" ,
121
- "uint64" ,
122
- "float32" ,
123
- "float64" ,
124
- "complex64" ,
125
- "complex128" ,
126
- "bool" ,
127
- )
128
-
129
-
130
- @pytest .mark .parametrize ("a_type" , ALL_TYPES )
131
- @pytest .mark .parametrize ("b_type" , ALL_TYPES )
132
- def test_asarrays_array_vs_array (a_type : str , b_type : str , xp : ModuleType ):
133
- """
134
- Test that when both inputs of asarray are already Array API objects,
135
- they are returned unchanged.
136
- """
137
- a = xp .asarray (1 , dtype = getattr (xp , a_type ))
138
- b = xp .asarray (1 , dtype = getattr (xp , b_type ))
139
- xa , xb = asarrays (a , b , xp )
140
- assert xa .dtype == a .dtype
141
- assert xb .dtype == b .dtype
142
-
143
-
144
- @pytest .mark .parametrize ("dtype" , [np .float64 , np .complex128 ])
145
- def test_asarrays_numpy_generics (dtype : type ):
146
- """
147
- Test special case of np.float64 and np.complex128,
148
- which are subclasses of float and complex.
149
- """
150
- a = dtype (0 )
151
- xa , xb = asarrays (a , 0 , xp = np )
152
- assert xa .dtype == dtype
153
- assert xb .dtype == dtype
62
+ class TestAsArrays :
63
+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
64
+ @pytest .mark .parametrize (
65
+ ("dtype" , "b" , "defined" ),
66
+ [
67
+ # Well-defined cases of dtype promotion from Python scalar to Array
68
+ # bool vs. bool
69
+ ("bool" , True , True ),
70
+ # int vs. xp.*int*, xp.float*, xp.complex*
71
+ ("int16" , 1 , True ),
72
+ ("uint8" , 1 , True ),
73
+ ("float32" , 1 , True ),
74
+ ("float64" , 1 , True ),
75
+ ("complex64" , 1 , True ),
76
+ ("complex128" , 1 , True ),
77
+ # float vs. xp.float, xp.complex
78
+ ("float32" , 1.0 , True ),
79
+ ("float64" , 1.0 , True ),
80
+ ("complex64" , 1.0 , True ),
81
+ ("complex128" , 1.0 , True ),
82
+ # complex vs. xp.complex
83
+ ("complex64" , 1.0j , True ),
84
+ ("complex128" , 1.0j , True ),
85
+ # Undefined cases
86
+ ("bool" , 1 , False ),
87
+ ("int64" , 1.0 , False ),
88
+ ("float64" , 1.0j , False ),
89
+ ],
90
+ )
91
+ def test_array_vs_scalar (
92
+ self , dtype : str , b : int | float | complex , defined : bool , xp : ModuleType
93
+ ):
94
+ a = xp .asarray (1 , dtype = getattr (xp , dtype ))
95
+
96
+ xa , xb = asarrays (a , b , xp )
97
+ assert xa .dtype == a .dtype
98
+ if defined :
99
+ assert xb .dtype == a .dtype
100
+ else :
101
+ assert xb .dtype == xp .asarray (b ).dtype
102
+
103
+ xbr , xar = asarrays (b , a , xp )
104
+ assert xar .dtype == xa .dtype
105
+ assert xbr .dtype == xb .dtype
106
+
107
+ def test_scalar_vs_scalar (self , xp : ModuleType ):
108
+ a , b = asarrays (1 , 2.2 , xp = xp )
109
+ assert a .dtype == xp .asarray (1 ).dtype # Default dtype
110
+ assert b .dtype == xp .asarray (2.2 ).dtype # Default dtype; not broadcasted
111
+
112
+ ALL_TYPES : tuple [str , ...] = (
113
+ "int8" ,
114
+ "int16" ,
115
+ "int32" ,
116
+ "int64" ,
117
+ "uint8" ,
118
+ "uint16" ,
119
+ "uint32" ,
120
+ "uint64" ,
121
+ "float32" ,
122
+ "float64" ,
123
+ "complex64" ,
124
+ "complex128" ,
125
+ "bool" ,
126
+ )
127
+
128
+ @pytest .mark .parametrize ("a_type" , ALL_TYPES )
129
+ @pytest .mark .parametrize ("b_type" , ALL_TYPES )
130
+ def test_array_vs_array (self , a_type : str , b_type : str , xp : ModuleType ):
131
+ """
132
+ Test that when both inputs of asarray are already Array API objects,
133
+ they are returned unchanged.
134
+ """
135
+ a = xp .asarray (1 , dtype = getattr (xp , a_type ))
136
+ b = xp .asarray (1 , dtype = getattr (xp , b_type ))
137
+ xa , xb = asarrays (a , b , xp )
138
+ assert xa .dtype == a .dtype
139
+ assert xb .dtype == b .dtype
140
+
141
+ @pytest .mark .parametrize ("dtype" , [np .float64 , np .complex128 ])
142
+ def test_numpy_generics (self , dtype : type ):
143
+ """
144
+ Test special case of np.float64 and np.complex128,
145
+ which are subclasses of float and complex.
146
+ """
147
+ a = dtype (0 )
148
+ xa , xb = asarrays (a , 0 , xp = np )
149
+ assert xa .dtype == dtype
150
+ assert xb .dtype == dtype
0 commit comments