@@ -52,6 +52,20 @@ def __call__(self, x, out=None, order="K"):
52
52
if not isinstance (x , dpt .usm_ndarray ):
53
53
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
54
54
55
+ if order not in ["C" , "F" , "K" , "A" ]:
56
+ order = "K"
57
+ buf_dt , res_dt = _find_buf_dtype (
58
+ x .dtype , self .result_type_resolver_fn_ , x .sycl_device
59
+ )
60
+ if res_dt is None :
61
+ raise TypeError (
62
+ f"function '{ self .name_ } ' does not support input type "
63
+ f"({ x .dtype } ), "
64
+ "and the input could not be safely coerced to any "
65
+ "supported types according to the casting rule ''safe''."
66
+ )
67
+
68
+ orig_out = out
55
69
if out is not None :
56
70
if not isinstance (out , dpt .usm_ndarray ):
57
71
raise TypeError (
@@ -64,8 +78,21 @@ def __call__(self, x, out=None, order="K"):
64
78
f"Expected output shape is { x .shape } , got { out .shape } "
65
79
)
66
80
67
- if ti ._array_overlap (x , out ):
68
- raise TypeError ("Input and output arrays have memory overlap" )
81
+ if res_dt != out .dtype :
82
+ raise TypeError (
83
+ f"Output array of type { res_dt } is needed,"
84
+ f" got { out .dtype } "
85
+ )
86
+
87
+ if (
88
+ buf_dt is None
89
+ and ti ._array_overlap (x , out )
90
+ and not ti ._same_logical_tensors (x , out )
91
+ ):
92
+ # Allocate a temporary buffer to avoid memory overlapping.
93
+ # Note if `buf_dt` is not None, a temporary copy of `x` will be
94
+ # created, so the array overlap check isn't needed.
95
+ out = dpt .empty_like (out )
69
96
70
97
if (
71
98
dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
@@ -75,18 +102,6 @@ def __call__(self, x, out=None, order="K"):
75
102
"Input and output allocation queues are not compatible"
76
103
)
77
104
78
- if order not in ["C" , "F" , "K" , "A" ]:
79
- order = "K"
80
- buf_dt , res_dt = _find_buf_dtype (
81
- x .dtype , self .result_type_resolver_fn_ , x .sycl_device
82
- )
83
- if res_dt is None :
84
- raise TypeError (
85
- f"function '{ self .name_ } ' does not support input type "
86
- f"({ x .dtype } ), "
87
- "and the input could not be safely coerced to any "
88
- "supported types according to the casting rule ''safe''."
89
- )
90
105
exec_q = x .sycl_queue
91
106
if buf_dt is None :
92
107
if out is None :
@@ -96,17 +111,20 @@ def __call__(self, x, out=None, order="K"):
96
111
if order == "A" :
97
112
order = "F" if x .flags .f_contiguous else "C"
98
113
out = dpt .empty_like (x , dtype = res_dt , order = order )
99
- else :
100
- if res_dt != out .dtype :
101
- raise TypeError (
102
- f"Output array of type { res_dt } is needed,"
103
- f" got { out .dtype } "
104
- )
105
114
106
- ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
107
- ht .wait ()
115
+ ht_unary_ev , unary_ev = self .unary_fn_ (x , out , sycl_queue = exec_q )
116
+
117
+ if not (orig_out is None or orig_out is out ):
118
+ # Copy the out data from temporary buffer to original memory
119
+ ht_copy_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
120
+ src = out , dst = orig_out , sycl_queue = exec_q , depends = [unary_ev ]
121
+ )
122
+ ht_copy_ev .wait ()
123
+ out = orig_out
108
124
125
+ ht_unary_ev .wait ()
109
126
return out
127
+
110
128
if order == "K" :
111
129
buf = _empty_like_orderK (x , buf_dt )
112
130
else :
@@ -122,11 +140,6 @@ def __call__(self, x, out=None, order="K"):
122
140
out = _empty_like_orderK (buf , res_dt )
123
141
else :
124
142
out = dpt .empty_like (buf , dtype = res_dt , order = order )
125
- else :
126
- if buf_dt != out .dtype :
127
- raise TypeError (
128
- f"Output array of type { buf_dt } is needed, got { out .dtype } "
129
- )
130
143
131
144
ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
132
145
ht_copy_ev .wait ()
0 commit comments