Skip to content

Commit 034ab01

Browse files
committed
Elementwise functions now check writable flag of destination
1 parent ccf66de commit 034ab01

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
6363
const contig_dispatchT &contig_dispatch_vector,
6464
const strided_dispatchT &strided_dispatch_vector)
6565
{
66+
if (!dst.is_writable()) {
67+
throw py::value_error("Output array is read-only.");
68+
}
69+
6670
int src_typenum = src.get_typenum();
6771
int dst_typenum = dst.get_typenum();
6872

@@ -306,6 +310,9 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
306310
const contig_row_matrix_dispatchT
307311
&contig_row_matrix_broadcast_dispatch_table)
308312
{
313+
if (!dst.is_writable()) {
314+
throw py::value_error("Output array is read-only.");
315+
}
309316
// check type_nums
310317
int src1_typenum = src1.get_typenum();
311318
int src2_typenum = src2.get_typenum();
@@ -602,6 +609,10 @@ py_binary_inplace_ufunc(dpctl::tensor::usm_ndarray lhs,
602609
const contig_dispatchT &contig_dispatch_table,
603610
const strided_dispatchT &strided_dispatch_table)
604611
{
612+
if (!lhs.is_writable()) {
613+
throw py::value_error("Output array is read-only.");
614+
}
615+
605616
// check type_nums
606617
int rhs_typenum = rhs.get_typenum();
607618
int lhs_typenum = lhs.get_typenum();

0 commit comments

Comments
 (0)