@@ -91,6 +91,64 @@ def test_cumulative_sum(x, data):
91
91
idx = out_idx .raw , out = out_val ,
92
92
expected = expected )
93
93
94
+
95
+
96
+ @pytest .mark .min_version ("2024.12" )
97
+ @pytest .mark .unvectorized
98
+ @given (
99
+ x = hh .arrays (
100
+ dtype = hh .numeric_dtypes ,
101
+ shape = hh .shapes (min_dims = 1 )),
102
+ data = st .data (),
103
+ )
104
+ def test_cumulative_prod (x , data ):
105
+ axes = st .integers (- x .ndim , x .ndim - 1 )
106
+ if x .ndim == 1 :
107
+ axes = axes | st .none ()
108
+ axis = data .draw (axes , label = 'axis' )
109
+ _axis , = sh .normalize_axis (axis , x .ndim )
110
+ dtype = data .draw (kwarg_dtypes (x .dtype ))
111
+ include_initial = data .draw (st .booleans (), label = "include_initial" )
112
+
113
+ kw = data .draw (
114
+ hh .specified_kwargs (
115
+ ("axis" , axis , None ),
116
+ ("dtype" , dtype , None ),
117
+ ("include_initial" , include_initial , False ),
118
+ ),
119
+ label = "kw" ,
120
+ )
121
+
122
+ out = xp .cumulative_prod (x , ** kw )
123
+
124
+ expected_shape = list (x .shape )
125
+ if include_initial :
126
+ expected_shape [_axis ] += 1
127
+ expected_shape = tuple (expected_shape )
128
+ ph .assert_shape ("cumulative_prod" , out_shape = out .shape , expected = expected_shape )
129
+
130
+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
131
+ if expected_dtype is None :
132
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
133
+ # uint32 or uint64), we skip testing the output dtype.
134
+ # See https://github.com/data-apis/array-api-tests/issues/106
135
+ if x .dtype in dh .uint_dtypes :
136
+ assert dh .is_int_dtype (out .dtype ) # sanity check
137
+ else :
138
+ ph .assert_dtype ("cumulative_prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
139
+
140
+ scalar_type = dh .get_scalar_type (out .dtype )
141
+
142
+ for x_idx , out_idx , in iter_indices (x .shape , expected_shape , skip_axes = _axis ):
143
+ #x_arr = x[x_idx.raw]
144
+ out_arr = out [out_idx .raw ]
145
+
146
+ if include_initial :
147
+ ph .assert_scalar_equals ("cumulative_prod" , type_ = scalar_type , idx = out_idx .raw , out = out_arr [0 ], expected = 1 )
148
+
149
+ #TODO: add value testing of cumulative_prod
150
+
151
+
94
152
def kwarg_dtypes (dtype : DataType ) -> st .SearchStrategy [Optional [DataType ]]:
95
153
dtypes = [d2 for d1 , d2 in dh .promotion_table if d1 == dtype ]
96
154
dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
0 commit comments