Skip to content

Commit 701c05b

Browse files
No need to call _copy_overlapping if src and dst address same memory
``` In [1]: import dpctl.tensor as dpt, dpctl, dpctl.utils In [2]: n, m = 8 * 540, 8 * 960 In [3]: a = dpt.ones((m, n)) In [4]: b = dpt.zeros((m, n)) In [5]: b_s = dpt.zeros((m, n+2)) In [6]: with dpctl.utils.onetrace_enabled(): ...: b_s[:,:-2] += a ...: Device Timeline (queue: 0x556080b9cea0): zeCommandListAppendMemoryCopy(H2D)[48 bytes]<4.1> [ns] = 16946404661 (append) 16952292497 (submit) 16952613747 (start) 16952623538 (end) Device Timeline (queue: 0x556080b9cea0): dpctl::tensor::kernels::add::add_inplace_strided_kernel<float, float, dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer>[SIMD32 {64800; 1; 1} {512; 1; 1}]<5.1> [ns] = 17017855801 (append) 17018342202 (submit) 17019138920 (start) 17030770482 (end) ``` Earlier, two more copy operations were being performed as well.
1 parent f032154 commit 701c05b

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

dpctl/tensor/_copy_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def _copy_same_shape(dst, src):
213213
"""Assumes src and dst have the same shape."""
214214
# check that memory regions do not overlap
215215
if ti._array_overlap(dst, src):
216+
if src._pointer == dst._pointer and (
217+
src is dst
218+
or (src.strides == dst.strides and src.dtype == dst.dtype)
219+
):
220+
return
216221
_copy_overlapping(src=src, dst=dst)
217222
return
218223

0 commit comments

Comments
 (0)