Skip to content

Commit d3ce80e

Browse files
Merge pull request #1281 from IntelPython/unary_out_overlap
Created a temporary copy in case of overlap for unary function
2 parents a6d16f2 + 03a46e1 commit d3ce80e

17 files changed

+294
-52
lines changed

.github/workflows/generate-coverage.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
- name: Install dpctl dependencies
8080
shell: bash -l {0}
8181
run: |
82-
pip install numpy cython setuptools pytest pytest-cov scikit-build cmake coverage[toml]
82+
pip install numpy cython"<3" setuptools pytest pytest-cov scikit-build cmake coverage[toml]
8383
8484
- name: Build dpctl with coverage
8585
shell: bash -l {0}

.github/workflows/generate-docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
if: ${{ !github.event.pull_request || github.event.action != 'closed' }}
5050
shell: bash -l {0}
5151
run: |
52-
pip install numpy cython setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics
52+
pip install numpy cython"<3" setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics
5353
- name: Checkout repo
5454
uses: actions/checkout@v3
5555
with:

.github/workflows/os-llvm-sycl-build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ jobs:
108108
- name: Install dpctl dependencies
109109
shell: bash -l {0}
110110
run: |
111-
pip install numpy cython setuptools pytest scikit-build cmake
111+
pip install numpy cython"<3" setuptools pytest scikit-build cmake
112112
113113
- name: Checkout repo
114114
uses: actions/checkout@v3

conda-recipe/meta.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ requirements:
2020
- cmake >=3.21
2121
- ninja
2222
- git
23-
- cython
23+
- cython <3
2424
- python
2525
- scikit-build
2626
- numpy

dpctl/tensor/_elementwise_common.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def __call__(self, x, out=None, order="K"):
5252
if not isinstance(x, dpt.usm_ndarray):
5353
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
5454

55+
if order not in ["C", "F", "K", "A"]:
56+
order = "K"
57+
buf_dt, res_dt = _find_buf_dtype(
58+
x.dtype, self.result_type_resolver_fn_, x.sycl_device
59+
)
60+
if res_dt is None:
61+
raise RuntimeError
62+
63+
orig_out = out
5564
if out is not None:
5665
if not isinstance(out, dpt.usm_ndarray):
5766
raise TypeError(
@@ -64,8 +73,21 @@ def __call__(self, x, out=None, order="K"):
6473
f"Expected output shape is {x.shape}, got {out.shape}"
6574
)
6675

67-
if ti._array_overlap(x, out):
68-
raise TypeError("Input and output arrays have memory overlap")
76+
if res_dt != out.dtype:
77+
raise TypeError(
78+
f"Output array of type {res_dt} is needed,"
79+
f" got {out.dtype}"
80+
)
81+
82+
if (
83+
buf_dt is None
84+
and ti._array_overlap(x, out)
85+
and not ti._same_logical_tensors(x, out)
86+
):
87+
# Allocate a temporary buffer to avoid memory overlapping.
88+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
89+
# created, so the array overlap check isn't needed.
90+
out = dpt.empty_like(out)
6991

7092
if (
7193
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
@@ -75,13 +97,6 @@ def __call__(self, x, out=None, order="K"):
7597
"Input and output allocation queues are not compatible"
7698
)
7799

78-
if order not in ["C", "F", "K", "A"]:
79-
order = "K"
80-
buf_dt, res_dt = _find_buf_dtype(
81-
x.dtype, self.result_type_resolver_fn_, x.sycl_device
82-
)
83-
if res_dt is None:
84-
raise RuntimeError
85100
exec_q = x.sycl_queue
86101
if buf_dt is None:
87102
if out is None:
@@ -91,17 +106,20 @@ def __call__(self, x, out=None, order="K"):
91106
if order == "A":
92107
order = "F" if x.flags.f_contiguous else "C"
93108
out = dpt.empty_like(x, dtype=res_dt, order=order)
94-
else:
95-
if res_dt != out.dtype:
96-
raise TypeError(
97-
f"Output array of type {res_dt} is needed,"
98-
f" got {out.dtype}"
99-
)
100109

101-
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
102-
ht.wait()
110+
ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
111+
112+
if not (orig_out is None or orig_out is out):
113+
# Copy the out data from temporary buffer to original memory
114+
ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
115+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
116+
)
117+
ht_copy_ev.wait()
118+
out = orig_out
103119

120+
ht_unary_ev.wait()
104121
return out
122+
105123
if order == "K":
106124
buf = _empty_like_orderK(x, buf_dt)
107125
else:
@@ -117,11 +135,6 @@ def __call__(self, x, out=None, order="K"):
117135
out = _empty_like_orderK(buf, res_dt)
118136
else:
119137
out = dpt.empty_like(buf, dtype=res_dt, order=order)
120-
else:
121-
if buf_dt != out.dtype:
122-
raise TypeError(
123-
f"Output array of type {buf_dt} is needed, got {out.dtype}"
124-
)
125138

126139
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
127140
ht_copy_ev.wait()

dpctl/tensor/libtensor/include/utils/memory_overlap.hpp

+47
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,53 @@ struct MemoryOverlap
100100
}
101101
};
102102

103+
struct SameLogicalTensors
104+
{
105+
bool operator()(dpctl::tensor::usm_ndarray ar1,
106+
dpctl::tensor::usm_ndarray ar2) const
107+
{
108+
// Same ndim
109+
int nd1 = ar1.get_ndim();
110+
if (nd1 != ar2.get_ndim())
111+
return false;
112+
113+
// Same dtype
114+
int tn1 = ar1.get_typenum();
115+
if (tn1 != ar2.get_typenum())
116+
return false;
117+
118+
// Same pointer
119+
const char *ar1_data = ar1.get_data();
120+
const char *ar2_data = ar2.get_data();
121+
122+
if (ar1_data != ar2_data)
123+
return false;
124+
125+
// Same shape and strides
126+
const py::ssize_t *ar1_shape = ar1.get_shape_raw();
127+
const py::ssize_t *ar2_shape = ar2.get_shape_raw();
128+
129+
if (!std::equal(ar1_shape, ar1_shape + nd1, ar2_shape))
130+
return false;
131+
132+
// Same shape and strides
133+
auto const &ar1_strides = ar1.get_strides_vector();
134+
auto const &ar2_strides = ar2.get_strides_vector();
135+
136+
auto ar1_beg_it = std::begin(ar1_strides);
137+
auto ar1_end_it = std::end(ar1_strides);
138+
139+
auto ar2_beg_it = std::begin(ar2_strides);
140+
141+
if (!std::equal(ar1_beg_it, ar1_end_it, ar2_beg_it))
142+
return false;
143+
144+
// all checks passed: arrays are logical views
145+
// into the same memory
146+
return true;
147+
}
148+
};
149+
103150
} // namespace overlap
104151
} // namespace tensor
105152
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
128128

129129
// check memory overlap
130130
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
131-
if (overlap(src, dst)) {
131+
auto const &same_logical_tensors =
132+
dpctl::tensor::overlap::SameLogicalTensors();
133+
if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
132134
throw py::value_error("Arrays index overlapping segments of memory");
133135
}
134136

dpctl/tensor/libtensor/source/tensor_py.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using dpctl::tensor::c_contiguous_strides;
6060
using dpctl::tensor::f_contiguous_strides;
6161

6262
using dpctl::tensor::overlap::MemoryOverlap;
63+
using dpctl::tensor::overlap::SameLogicalTensors;
6364

6465
using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
6566

@@ -338,6 +339,15 @@ PYBIND11_MODULE(_tensor_impl, m)
338339
"Determines if the memory regions indexed by each array overlap",
339340
py::arg("array1"), py::arg("array2"));
340341

342+
auto same_logical_tensors = [](dpctl::tensor::usm_ndarray x1,
343+
dpctl::tensor::usm_ndarray x2) -> bool {
344+
auto const &same_logical_tensors = SameLogicalTensors();
345+
return same_logical_tensors(x1, x2);
346+
};
347+
m.def("_same_logical_tensors", same_logical_tensors,
348+
"Determines if the memory regions indexed by each array are the same",
349+
py::arg("array1"), py::arg("array2"));
350+
341351
m.def("_place", &py_place, "", py::arg("dst"), py::arg("cumsum"),
342352
py::arg("axis_start"), py::arg("axis_end"), py::arg("rhs"),
343353
py::arg("sycl_queue"), py::arg("depends") = py::list());

dpctl/tests/_numpy_warnings.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy
18+
import pytest
19+
20+
21+
@pytest.fixture
22+
def suppress_invalid_numpy_warnings():
23+
# invalid: treatment for invalid floating-point operation
24+
# (result is not an expressible number, typically indicates
25+
# that a NaN was produced)
26+
old_settings = numpy.seterr(invalid="ignore")
27+
yield
28+
numpy.seterr(**old_settings) # reset to default

dpctl/tests/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,15 @@
2626
invalid_filter,
2727
valid_filter,
2828
)
29+
from _numpy_warnings import suppress_invalid_numpy_warnings
2930

3031
sys.path.append(os.path.join(os.path.dirname(__file__), "helper"))
3132

3233
# common fixtures
33-
__all__ = ["check", "device_selector", "invalid_filter", "valid_filter"]
34+
__all__ = [
35+
"check",
36+
"device_selector",
37+
"invalid_filter",
38+
"suppress_invalid_numpy_warnings",
39+
"valid_filter",
40+
]

dpctl/tests/elementwise/test_abs.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import dpctl.tensor as dpt
2323
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

25-
from .utils import _all_dtypes, _usm_types
25+
from .utils import _all_dtypes, _no_complex_dtypes, _usm_types
2626

2727

2828
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -113,3 +113,25 @@ def test_abs_complex(dtype):
113113
np.testing.assert_allclose(
114114
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
115115
)
116+
117+
118+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
119+
def test_abs_out_overlap(dtype):
120+
q = get_queue_or_skip()
121+
skip_if_dtype_not_supported(dtype, q)
122+
123+
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
124+
X = dpt.reshape(X, (3, 5, 4))
125+
126+
Xnp = dpt.asnumpy(X)
127+
Ynp = np.abs(Xnp, out=Xnp)
128+
129+
Y = dpt.abs(X, out=X)
130+
assert Y is X
131+
assert np.allclose(dpt.asnumpy(X), Xnp)
132+
133+
Ynp = np.abs(Xnp, out=Xnp[::-1])
134+
Y = dpt.abs(X, out=X[::-1])
135+
assert Y is not X
136+
assert np.allclose(dpt.asnumpy(X), Xnp)
137+
assert np.allclose(dpt.asnumpy(Y), Ynp)

dpctl/tests/elementwise/test_exp.py

+23
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,26 @@ def test_exp_strided(dtype):
145145
atol=tol,
146146
rtol=tol,
147147
)
148+
149+
150+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
151+
def test_exp_out_overlap(dtype):
152+
q = get_queue_or_skip()
153+
skip_if_dtype_not_supported(dtype, q)
154+
155+
X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
156+
X = dpt.reshape(X, (3, 5))
157+
158+
Xnp = dpt.asnumpy(X)
159+
Ynp = np.exp(Xnp, out=Xnp)
160+
161+
Y = dpt.exp(X, out=X)
162+
tol = 8 * dpt.finfo(Y.dtype).resolution
163+
assert Y is X
164+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
165+
166+
Ynp = np.exp(Xnp, out=Xnp[::-1])
167+
Y = dpt.exp(X, out=X[::-1])
168+
assert Y is not X
169+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
170+
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)