Skip to content

Commit c6b00c7

Browse files
authored
Merge pull request #2011 from IntelPython/clean-up-accumulator-common-code
Technical debt clean-up in `_accumulation.py`
2 parents 5e1c87d + cd2bd1a commit c6b00c7

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

Diff for: dpctl/tensor/_accumulation.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _accumulate_common(
125125
if a1 != nd:
126126
out = dpt.permute_dims(out, perm)
127127

128-
final_ev = dpctl.SyclEvent()
129128
_manager = SequentialOrderManager[q]
130129
depends = _manager.submitted_events
131130
if implemented_types:
@@ -144,12 +143,11 @@ def _accumulate_common(
144143
_manager.add_event_pair(ht_e, acc_ev)
145144
if not (orig_out is None or out is orig_out):
146145
# Copy the out data from temporary buffer to original memory
147-
ht_e_cpy, acc_ev = ti._copy_usm_ndarray_into_usm_ndarray(
146+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
148147
src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev]
149148
)
150-
_manager.add_event_pair(ht_e_cpy, acc_ev)
149+
_manager.add_event_pair(ht_e_cpy, cpy_e)
151150
out = orig_out
152-
final_ev = acc_ev
153151
else:
154152
if _dtype_supported(res_dt, res_dt):
155153
tmp = dpt.empty(
@@ -160,21 +158,21 @@ def _accumulate_common(
160158
)
161159
_manager.add_event_pair(ht_e_cpy, cpy_e)
162160
if not include_initial:
163-
ht_e, final_ev = _accumulate_fn(
161+
ht_e, acc_ev = _accumulate_fn(
164162
src=tmp,
165163
trailing_dims_to_accumulate=1,
166164
dst=out,
167165
sycl_queue=q,
168166
depends=[cpy_e],
169167
)
170168
else:
171-
ht_e, final_ev = _accumulate_include_initial_fn(
169+
ht_e, acc_ev = _accumulate_include_initial_fn(
172170
src=tmp,
173171
dst=out,
174172
sycl_queue=q,
175173
depends=[cpy_e],
176174
)
177-
_manager.add_event_pair(ht_e, final_ev)
175+
_manager.add_event_pair(ht_e, acc_ev)
178176
else:
179177
buf_dt = _default_accumulation_type_fn(inp_dt, q)
180178
tmp = dpt.empty(
@@ -190,25 +188,25 @@ def _accumulate_common(
190188
if a1 != nd:
191189
tmp_res = dpt.permute_dims(tmp_res, perm)
192190
if not include_initial:
193-
ht_e, a_e = _accumulate_fn(
191+
ht_e, acc_ev = _accumulate_fn(
194192
src=tmp,
195193
trailing_dims_to_accumulate=1,
196194
dst=tmp_res,
197195
sycl_queue=q,
198196
depends=[cpy_e],
199197
)
200198
else:
201-
ht_e, a_e = _accumulate_include_initial_fn(
199+
ht_e, acc_ev = _accumulate_include_initial_fn(
202200
src=tmp,
203201
dst=tmp_res,
204202
sycl_queue=q,
205203
depends=[cpy_e],
206204
)
207-
_manager.add_event_pair(ht_e, a_e)
208-
ht_e_cpy2, final_ev = ti._copy_usm_ndarray_into_usm_ndarray(
209-
src=tmp_res, dst=out, sycl_queue=q, depends=[a_e]
205+
_manager.add_event_pair(ht_e, acc_ev)
206+
ht_e_cpy2, cpy_e2 = ti._copy_usm_ndarray_into_usm_ndarray(
207+
src=tmp_res, dst=out, sycl_queue=q, depends=[acc_ev]
210208
)
211-
_manager.add_event_pair(ht_e_cpy2, final_ev)
209+
_manager.add_event_pair(ht_e_cpy2, cpy_e2)
212210

213211
if appended_axis:
214212
out = dpt.squeeze(out)

0 commit comments

Comments
 (0)