Skip to content

Commit c4f972b

Browse files
Merge remote-tracking branch 'origin/master' into elementwise-floor-ceil-trunc
2 parents 59fb1b5 + d3ce80e commit c4f972b

18 files changed

+304
-57
lines changed

.github/workflows/generate-coverage.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
- name: Install dpctl dependencies
8080
shell: bash -l {0}
8181
run: |
82-
pip install numpy cython setuptools pytest pytest-cov scikit-build cmake coverage[toml]
82+
pip install numpy cython"<3" setuptools pytest pytest-cov scikit-build cmake coverage[toml]
8383
8484
- name: Build dpctl with coverage
8585
shell: bash -l {0}

.github/workflows/generate-docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
if: ${{ !github.event.pull_request || github.event.action != 'closed' }}
5050
shell: bash -l {0}
5151
run: |
52-
pip install numpy cython setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics
52+
pip install numpy cython"<3" setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics
5353
- name: Checkout repo
5454
uses: actions/checkout@v3
5555
with:

.github/workflows/os-llvm-sycl-build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ jobs:
108108
- name: Install dpctl dependencies
109109
shell: bash -l {0}
110110
run: |
111-
pip install numpy cython setuptools pytest scikit-build cmake
111+
pip install numpy cython"<3" setuptools pytest scikit-build cmake
112112
113113
- name: Checkout repo
114114
uses: actions/checkout@v3

conda-recipe/meta.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ requirements:
2020
- cmake >=3.21
2121
- ninja
2222
- git
23-
- cython
23+
- cython <3
2424
- python
2525
- scikit-build
2626
- numpy

dpctl/tensor/_copy_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def _copy_same_shape(dst, src):
213213
"""Assumes src and dst have the same shape."""
214214
# check that memory regions do not overlap
215215
if ti._array_overlap(dst, src):
216+
if src._pointer == dst._pointer and (
217+
src is dst
218+
or (src.strides == dst.strides and src.dtype == dst.dtype)
219+
):
220+
return
216221
_copy_overlapping(src=src, dst=dst)
217222
return
218223

dpctl/tensor/_elementwise_common.py

+40-27
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ def __call__(self, x, out=None, order="K"):
5252
if not isinstance(x, dpt.usm_ndarray):
5353
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
5454

55+
if order not in ["C", "F", "K", "A"]:
56+
order = "K"
57+
buf_dt, res_dt = _find_buf_dtype(
58+
x.dtype, self.result_type_resolver_fn_, x.sycl_device
59+
)
60+
if res_dt is None:
61+
raise TypeError(
62+
f"function '{self.name_}' does not support input type "
63+
f"({x.dtype}), "
64+
"and the input could not be safely coerced to any "
65+
"supported types according to the casting rule ''safe''."
66+
)
67+
68+
orig_out = out
5569
if out is not None:
5670
if not isinstance(out, dpt.usm_ndarray):
5771
raise TypeError(
@@ -64,8 +78,21 @@ def __call__(self, x, out=None, order="K"):
6478
f"Expected output shape is {x.shape}, got {out.shape}"
6579
)
6680

67-
if ti._array_overlap(x, out):
68-
raise TypeError("Input and output arrays have memory overlap")
81+
if res_dt != out.dtype:
82+
raise TypeError(
83+
f"Output array of type {res_dt} is needed,"
84+
f" got {out.dtype}"
85+
)
86+
87+
if (
88+
buf_dt is None
89+
and ti._array_overlap(x, out)
90+
and not ti._same_logical_tensors(x, out)
91+
):
92+
# Allocate a temporary buffer to avoid memory overlapping.
93+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
94+
# created, so the array overlap check isn't needed.
95+
out = dpt.empty_like(out)
6996

7097
if (
7198
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
@@ -75,18 +102,6 @@ def __call__(self, x, out=None, order="K"):
75102
"Input and output allocation queues are not compatible"
76103
)
77104

78-
if order not in ["C", "F", "K", "A"]:
79-
order = "K"
80-
buf_dt, res_dt = _find_buf_dtype(
81-
x.dtype, self.result_type_resolver_fn_, x.sycl_device
82-
)
83-
if res_dt is None:
84-
raise TypeError(
85-
f"function '{self.name_}' does not support input type "
86-
f"({x.dtype}), "
87-
"and the input could not be safely coerced to any "
88-
"supported types according to the casting rule ''safe''."
89-
)
90105
exec_q = x.sycl_queue
91106
if buf_dt is None:
92107
if out is None:
@@ -96,17 +111,20 @@ def __call__(self, x, out=None, order="K"):
96111
if order == "A":
97112
order = "F" if x.flags.f_contiguous else "C"
98113
out = dpt.empty_like(x, dtype=res_dt, order=order)
99-
else:
100-
if res_dt != out.dtype:
101-
raise TypeError(
102-
f"Output array of type {res_dt} is needed,"
103-
f" got {out.dtype}"
104-
)
105114

106-
ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q)
107-
ht.wait()
115+
ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
116+
117+
if not (orig_out is None or orig_out is out):
118+
# Copy the out data from temporary buffer to original memory
119+
ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
120+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
121+
)
122+
ht_copy_ev.wait()
123+
out = orig_out
108124

125+
ht_unary_ev.wait()
109126
return out
127+
110128
if order == "K":
111129
buf = _empty_like_orderK(x, buf_dt)
112130
else:
@@ -122,11 +140,6 @@ def __call__(self, x, out=None, order="K"):
122140
out = _empty_like_orderK(buf, res_dt)
123141
else:
124142
out = dpt.empty_like(buf, dtype=res_dt, order=order)
125-
else:
126-
if buf_dt != out.dtype:
127-
raise TypeError(
128-
f"Output array of type {buf_dt} is needed, got {out.dtype}"
129-
)
130143

131144
ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
132145
ht_copy_ev.wait()

dpctl/tensor/libtensor/include/utils/memory_overlap.hpp

+47
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,53 @@ struct MemoryOverlap
100100
}
101101
};
102102

103+
struct SameLogicalTensors
104+
{
105+
bool operator()(dpctl::tensor::usm_ndarray ar1,
106+
dpctl::tensor::usm_ndarray ar2) const
107+
{
108+
// Same ndim
109+
int nd1 = ar1.get_ndim();
110+
if (nd1 != ar2.get_ndim())
111+
return false;
112+
113+
// Same dtype
114+
int tn1 = ar1.get_typenum();
115+
if (tn1 != ar2.get_typenum())
116+
return false;
117+
118+
// Same pointer
119+
const char *ar1_data = ar1.get_data();
120+
const char *ar2_data = ar2.get_data();
121+
122+
if (ar1_data != ar2_data)
123+
return false;
124+
125+
// Same shape and strides
126+
const py::ssize_t *ar1_shape = ar1.get_shape_raw();
127+
const py::ssize_t *ar2_shape = ar2.get_shape_raw();
128+
129+
if (!std::equal(ar1_shape, ar1_shape + nd1, ar2_shape))
130+
return false;
131+
132+
// Same shape and strides
133+
auto const &ar1_strides = ar1.get_strides_vector();
134+
auto const &ar2_strides = ar2.get_strides_vector();
135+
136+
auto ar1_beg_it = std::begin(ar1_strides);
137+
auto ar1_end_it = std::end(ar1_strides);
138+
139+
auto ar2_beg_it = std::begin(ar2_strides);
140+
141+
if (!std::equal(ar1_beg_it, ar1_end_it, ar2_beg_it))
142+
return false;
143+
144+
// all checks passed: arrays are logical views
145+
// into the same memory
146+
return true;
147+
}
148+
};
149+
103150
} // namespace overlap
104151
} // namespace tensor
105152
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
128128

129129
// check memory overlap
130130
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
131-
if (overlap(src, dst)) {
131+
auto const &same_logical_tensors =
132+
dpctl::tensor::overlap::SameLogicalTensors();
133+
if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
132134
throw py::value_error("Arrays index overlapping segments of memory");
133135
}
134136

dpctl/tensor/libtensor/source/tensor_py.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using dpctl::tensor::c_contiguous_strides;
6060
using dpctl::tensor::f_contiguous_strides;
6161

6262
using dpctl::tensor::overlap::MemoryOverlap;
63+
using dpctl::tensor::overlap::SameLogicalTensors;
6364

6465
using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
6566

@@ -338,6 +339,15 @@ PYBIND11_MODULE(_tensor_impl, m)
338339
"Determines if the memory regions indexed by each array overlap",
339340
py::arg("array1"), py::arg("array2"));
340341

342+
auto same_logical_tensors = [](dpctl::tensor::usm_ndarray x1,
343+
dpctl::tensor::usm_ndarray x2) -> bool {
344+
auto const &same_logical_tensors = SameLogicalTensors();
345+
return same_logical_tensors(x1, x2);
346+
};
347+
m.def("_same_logical_tensors", same_logical_tensors,
348+
"Determines if the memory regions indexed by each array are the same",
349+
py::arg("array1"), py::arg("array2"));
350+
341351
m.def("_place", &py_place, "", py::arg("dst"), py::arg("cumsum"),
342352
py::arg("axis_start"), py::arg("axis_end"), py::arg("rhs"),
343353
py::arg("sycl_queue"), py::arg("depends") = py::list());

dpctl/tests/_numpy_warnings.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy
18+
import pytest
19+
20+
21+
@pytest.fixture
22+
def suppress_invalid_numpy_warnings():
23+
# invalid: treatment for invalid floating-point operation
24+
# (result is not an expressible number, typically indicates
25+
# that a NaN was produced)
26+
old_settings = numpy.seterr(invalid="ignore")
27+
yield
28+
numpy.seterr(**old_settings) # reset to default

dpctl/tests/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,15 @@
2626
invalid_filter,
2727
valid_filter,
2828
)
29+
from _numpy_warnings import suppress_invalid_numpy_warnings
2930

3031
sys.path.append(os.path.join(os.path.dirname(__file__), "helper"))
3132

3233
# common fixtures
33-
__all__ = ["check", "device_selector", "invalid_filter", "valid_filter"]
34+
__all__ = [
35+
"check",
36+
"device_selector",
37+
"invalid_filter",
38+
"suppress_invalid_numpy_warnings",
39+
"valid_filter",
40+
]

dpctl/tests/elementwise/test_abs.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import dpctl.tensor as dpt
2323
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

25-
from .utils import _all_dtypes, _usm_types
25+
from .utils import _all_dtypes, _no_complex_dtypes, _usm_types
2626

2727

2828
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -113,3 +113,25 @@ def test_abs_complex(dtype):
113113
np.testing.assert_allclose(
114114
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
115115
)
116+
117+
118+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
119+
def test_abs_out_overlap(dtype):
120+
q = get_queue_or_skip()
121+
skip_if_dtype_not_supported(dtype, q)
122+
123+
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
124+
X = dpt.reshape(X, (3, 5, 4))
125+
126+
Xnp = dpt.asnumpy(X)
127+
Ynp = np.abs(Xnp, out=Xnp)
128+
129+
Y = dpt.abs(X, out=X)
130+
assert Y is X
131+
assert np.allclose(dpt.asnumpy(X), Xnp)
132+
133+
Ynp = np.abs(Xnp, out=Xnp[::-1])
134+
Y = dpt.abs(X, out=X[::-1])
135+
assert Y is not X
136+
assert np.allclose(dpt.asnumpy(X), Xnp)
137+
assert np.allclose(dpt.asnumpy(Y), Ynp)

dpctl/tests/elementwise/test_exp.py

+23
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,26 @@ def test_exp_strided(dtype):
145145
atol=tol,
146146
rtol=tol,
147147
)
148+
149+
150+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
151+
def test_exp_out_overlap(dtype):
152+
q = get_queue_or_skip()
153+
skip_if_dtype_not_supported(dtype, q)
154+
155+
X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
156+
X = dpt.reshape(X, (3, 5))
157+
158+
Xnp = dpt.asnumpy(X)
159+
Ynp = np.exp(Xnp, out=Xnp)
160+
161+
Y = dpt.exp(X, out=X)
162+
tol = 8 * dpt.finfo(Y.dtype).resolution
163+
assert Y is X
164+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
165+
166+
Ynp = np.exp(Xnp, out=Xnp[::-1])
167+
Y = dpt.exp(X, out=X[::-1])
168+
assert Y is not X
169+
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
170+
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)