@@ -125,7 +125,6 @@ def _accumulate_common(
125
125
if a1 != nd :
126
126
out = dpt .permute_dims (out , perm )
127
127
128
- final_ev = dpctl .SyclEvent ()
129
128
_manager = SequentialOrderManager [q ]
130
129
depends = _manager .submitted_events
131
130
if implemented_types :
@@ -144,12 +143,11 @@ def _accumulate_common(
144
143
_manager .add_event_pair (ht_e , acc_ev )
145
144
if not (orig_out is None or out is orig_out ):
146
145
# Copy the out data from temporary buffer to original memory
147
- ht_e_cpy , acc_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
146
+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
148
147
src = out , dst = orig_out , sycl_queue = q , depends = [acc_ev ]
149
148
)
150
- _manager .add_event_pair (ht_e_cpy , acc_ev )
149
+ _manager .add_event_pair (ht_e_cpy , cpy_e )
151
150
out = orig_out
152
- final_ev = acc_ev
153
151
else :
154
152
if _dtype_supported (res_dt , res_dt ):
155
153
tmp = dpt .empty (
@@ -160,21 +158,21 @@ def _accumulate_common(
160
158
)
161
159
_manager .add_event_pair (ht_e_cpy , cpy_e )
162
160
if not include_initial :
163
- ht_e , final_ev = _accumulate_fn (
161
+ ht_e , acc_ev = _accumulate_fn (
164
162
src = tmp ,
165
163
trailing_dims_to_accumulate = 1 ,
166
164
dst = out ,
167
165
sycl_queue = q ,
168
166
depends = [cpy_e ],
169
167
)
170
168
else :
171
- ht_e , final_ev = _accumulate_include_initial_fn (
169
+ ht_e , acc_ev = _accumulate_include_initial_fn (
172
170
src = tmp ,
173
171
dst = out ,
174
172
sycl_queue = q ,
175
173
depends = [cpy_e ],
176
174
)
177
- _manager .add_event_pair (ht_e , final_ev )
175
+ _manager .add_event_pair (ht_e , acc_ev )
178
176
else :
179
177
buf_dt = _default_accumulation_type_fn (inp_dt , q )
180
178
tmp = dpt .empty (
@@ -190,25 +188,25 @@ def _accumulate_common(
190
188
if a1 != nd :
191
189
tmp_res = dpt .permute_dims (tmp_res , perm )
192
190
if not include_initial :
193
- ht_e , a_e = _accumulate_fn (
191
+ ht_e , acc_ev = _accumulate_fn (
194
192
src = tmp ,
195
193
trailing_dims_to_accumulate = 1 ,
196
194
dst = tmp_res ,
197
195
sycl_queue = q ,
198
196
depends = [cpy_e ],
199
197
)
200
198
else :
201
- ht_e , a_e = _accumulate_include_initial_fn (
199
+ ht_e , acc_ev = _accumulate_include_initial_fn (
202
200
src = tmp ,
203
201
dst = tmp_res ,
204
202
sycl_queue = q ,
205
203
depends = [cpy_e ],
206
204
)
207
- _manager .add_event_pair (ht_e , a_e )
208
- ht_e_cpy2 , final_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
209
- src = tmp_res , dst = out , sycl_queue = q , depends = [a_e ]
205
+ _manager .add_event_pair (ht_e , acc_ev )
206
+ ht_e_cpy2 , cpy_e2 = ti ._copy_usm_ndarray_into_usm_ndarray (
207
+ src = tmp_res , dst = out , sycl_queue = q , depends = [acc_ev ]
210
208
)
211
- _manager .add_event_pair (ht_e_cpy2 , final_ev )
209
+ _manager .add_event_pair (ht_e_cpy2 , cpy_e2 )
212
210
213
211
if appended_axis :
214
212
out = dpt .squeeze (out )
0 commit comments