Skip to content

Commit cb87b68

Browse files
committed
Implements in-place multiplication and subtraction
1 parent 034ab01 commit cb87b68

File tree

5 files changed

+341
-12
lines changed

5 files changed

+341
-12
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,11 @@
607607
the returned array is determined by the Type Promotion Rules.
608608
"""
609609
multiply = BinaryElementwiseFunc(
610-
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
610+
"multiply",
611+
ti._multiply_result_type,
612+
ti._multiply,
613+
_multiply_docstring_,
614+
ti._multiply_inplace,
611615
)
612616

613617
# U25: ==== NEGATIVE (x)
@@ -786,7 +790,11 @@
786790
of the returned array is determined by the Type Promotion Rules.
787791
"""
788792
subtract = BinaryElementwiseFunc(
789-
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
793+
"subtract",
794+
ti._subtract_result_type,
795+
ti._subtract,
796+
_subtract_docstring_,
797+
ti._subtract_inplace,
790798
)
791799

792800

dpctl/tensor/_usmarray.pyx

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,11 +1284,8 @@ cdef class usm_ndarray:
12841284
return self
12851285

12861286
def __imul__(self, other):
1287-
res = self.__mul__(other)
1288-
if res is NotImplemented:
1289-
return res
1290-
self.__setitem__(Ellipsis, res)
1291-
return self
1287+
from ._elementwise_funcs import multiply
1288+
return multiply.inplace(self, other)
12921289

12931290
def __ior__(self, other):
12941291
res = self.__or__(other)
@@ -1312,11 +1309,8 @@ cdef class usm_ndarray:
13121309
return self
13131310

13141311
def __isub__(self, other):
1315-
res = self.__sub__(other)
1316-
if res is NotImplemented:
1317-
return res
1318-
self.__setitem__(Ellipsis, res)
1319-
return self
1312+
from ._elementwise_funcs import subtract
1313+
return subtract.inplace(self, other)
13201314

13211315
def __itruediv__(self, other):
13221316
res = self.__truediv__(other)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "utils/type_utils.hpp"
3535

3636
#include "kernels/elementwise_functions/common.hpp"
37+
#include "kernels/elementwise_functions/common_inplace.hpp"
3738
#include <pybind11/pybind11.h>
3839

3940
namespace dpctl
@@ -371,6 +372,130 @@ struct MultiplyContigRowContigMatrixBroadcastFactory
371372
}
372373
};
373374

375+
template <typename argT, typename resT> struct MultiplyInplaceFunctor
376+
{
377+
378+
using supports_sg_loadstore = std::negation<
379+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
380+
using supports_vec = std::negation<
381+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
382+
383+
void operator()(resT &res, const argT &in)
384+
{
385+
res *= in;
386+
}
387+
388+
template <int vec_sz>
389+
void operator()(sycl::vec<resT, vec_sz> &res,
390+
const sycl::vec<argT, vec_sz> &in)
391+
{
392+
res *= in;
393+
}
394+
};
395+
396+
template <typename argT,
397+
typename resT,
398+
unsigned int vec_sz = 4,
399+
unsigned int n_vecs = 2>
400+
using MultiplyInplaceContigFunctor =
401+
elementwise_common::BinaryInplaceContigFunctor<
402+
argT,
403+
resT,
404+
MultiplyInplaceFunctor<argT, resT>,
405+
vec_sz,
406+
n_vecs>;
407+
408+
template <typename argT, typename resT, typename IndexerT>
409+
using MultiplyInplaceStridedFunctor =
410+
elementwise_common::BinaryInplaceStridedFunctor<
411+
argT,
412+
resT,
413+
IndexerT,
414+
MultiplyInplaceFunctor<argT, resT>>;
415+
416+
template <typename argT,
417+
typename resT,
418+
unsigned int vec_sz,
419+
unsigned int n_vecs>
420+
class multiply_inplace_contig_kernel;
421+
422+
template <typename argTy, typename resTy>
423+
sycl::event
424+
multiply_inplace_contig_impl(sycl::queue exec_q,
425+
size_t nelems,
426+
const char *arg_p,
427+
py::ssize_t arg_offset,
428+
char *res_p,
429+
py::ssize_t res_offset,
430+
const std::vector<sycl::event> &depends = {})
431+
{
432+
return elementwise_common::binary_inplace_contig_impl<
433+
argTy, resTy, MultiplyInplaceContigFunctor,
434+
multiply_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
435+
res_p, res_offset, depends);
436+
}
437+
438+
template <typename fnT, typename T1, typename T2>
439+
struct MultiplyInplaceContigFactory
440+
{
441+
fnT get()
442+
{
443+
if constexpr (std::is_same_v<
444+
typename MultiplyOutputType<T1, T2>::value_type,
445+
void>)
446+
{
447+
fnT fn = nullptr;
448+
return fn;
449+
}
450+
else {
451+
fnT fn = multiply_inplace_contig_impl<T1, T2>;
452+
return fn;
453+
}
454+
}
455+
};
456+
457+
template <typename resT, typename argT, typename IndexerT>
458+
class multiply_inplace_strided_kernel;
459+
460+
template <typename argTy, typename resTy>
461+
sycl::event multiply_inplace_strided_impl(
462+
sycl::queue exec_q,
463+
size_t nelems,
464+
int nd,
465+
const py::ssize_t *shape_and_strides,
466+
const char *arg_p,
467+
py::ssize_t arg_offset,
468+
char *res_p,
469+
py::ssize_t res_offset,
470+
const std::vector<sycl::event> &depends,
471+
const std::vector<sycl::event> &additional_depends)
472+
{
473+
return elementwise_common::binary_inplace_strided_impl<
474+
argTy, resTy, MultiplyInplaceStridedFunctor,
475+
multiply_inplace_strided_kernel>(exec_q, nelems, nd, shape_and_strides,
476+
arg_p, arg_offset, res_p, res_offset,
477+
depends, additional_depends);
478+
}
479+
480+
template <typename fnT, typename T1, typename T2>
481+
struct MultiplyInplaceStridedFactory
482+
{
483+
fnT get()
484+
{
485+
if constexpr (std::is_same_v<
486+
typename MultiplyOutputType<T1, T2>::value_type,
487+
void>)
488+
{
489+
fnT fn = nullptr;
490+
return fn;
491+
}
492+
else {
493+
fnT fn = multiply_inplace_strided_impl<T1, T2>;
494+
return fn;
495+
}
496+
}
497+
};
498+
374499
} // namespace multiply
375500
} // namespace kernels
376501
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,130 @@ struct SubtractContigRowContigMatrixBroadcastFactory
385385
}
386386
};
387387

388+
template <typename argT, typename resT> struct SubtractInplaceFunctor
389+
{
390+
391+
using supports_sg_loadstore = std::negation<
392+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
393+
using supports_vec = std::negation<
394+
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
395+
396+
void operator()(resT &res, const argT &in)
397+
{
398+
res -= in;
399+
}
400+
401+
template <int vec_sz>
402+
void operator()(sycl::vec<resT, vec_sz> &res,
403+
const sycl::vec<argT, vec_sz> &in)
404+
{
405+
res -= in;
406+
}
407+
};
408+
409+
template <typename argT,
410+
typename resT,
411+
unsigned int vec_sz = 4,
412+
unsigned int n_vecs = 2>
413+
using SubtractInplaceContigFunctor =
414+
elementwise_common::BinaryInplaceContigFunctor<
415+
argT,
416+
resT,
417+
SubtractInplaceFunctor<argT, resT>,
418+
vec_sz,
419+
n_vecs>;
420+
421+
template <typename argT, typename resT, typename IndexerT>
422+
using SubtractInplaceStridedFunctor =
423+
elementwise_common::BinaryInplaceStridedFunctor<
424+
argT,
425+
resT,
426+
IndexerT,
427+
SubtractInplaceFunctor<argT, resT>>;
428+
429+
template <typename argT,
430+
typename resT,
431+
unsigned int vec_sz,
432+
unsigned int n_vecs>
433+
class subtract_inplace_contig_kernel;
434+
435+
template <typename argTy, typename resTy>
436+
sycl::event
437+
subtract_inplace_contig_impl(sycl::queue exec_q,
438+
size_t nelems,
439+
const char *arg_p,
440+
py::ssize_t arg_offset,
441+
char *res_p,
442+
py::ssize_t res_offset,
443+
const std::vector<sycl::event> &depends = {})
444+
{
445+
return elementwise_common::binary_inplace_contig_impl<
446+
argTy, resTy, SubtractInplaceContigFunctor,
447+
subtract_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
448+
res_p, res_offset, depends);
449+
}
450+
451+
template <typename fnT, typename T1, typename T2>
452+
struct SubtractInplaceContigFactory
453+
{
454+
fnT get()
455+
{
456+
if constexpr (std::is_same_v<
457+
typename SubtractOutputType<T1, T2>::value_type,
458+
void>)
459+
{
460+
fnT fn = nullptr;
461+
return fn;
462+
}
463+
else {
464+
fnT fn = subtract_inplace_contig_impl<T1, T2>;
465+
return fn;
466+
}
467+
}
468+
};
469+
470+
template <typename resT, typename argT, typename IndexerT>
471+
class subtract_inplace_strided_kernel;
472+
473+
template <typename argTy, typename resTy>
474+
sycl::event subtract_inplace_strided_impl(
475+
sycl::queue exec_q,
476+
size_t nelems,
477+
int nd,
478+
const py::ssize_t *shape_and_strides,
479+
const char *arg_p,
480+
py::ssize_t arg_offset,
481+
char *res_p,
482+
py::ssize_t res_offset,
483+
const std::vector<sycl::event> &depends,
484+
const std::vector<sycl::event> &additional_depends)
485+
{
486+
return elementwise_common::binary_inplace_strided_impl<
487+
argTy, resTy, SubtractInplaceStridedFunctor,
488+
subtract_inplace_strided_kernel>(exec_q, nelems, nd, shape_and_strides,
489+
arg_p, arg_offset, res_p, res_offset,
490+
depends, additional_depends);
491+
}
492+
493+
template <typename fnT, typename T1, typename T2>
494+
struct SubtractInplaceStridedFactory
495+
{
496+
fnT get()
497+
{
498+
if constexpr (std::is_same_v<
499+
typename SubtractOutputType<T1, T2>::value_type,
500+
void>)
501+
{
502+
fnT fn = nullptr;
503+
return fn;
504+
}
505+
else {
506+
fnT fn = subtract_inplace_strided_impl<T1, T2>;
507+
return fn;
508+
}
509+
}
510+
};
511+
388512
} // namespace subtract
389513
} // namespace kernels
390514
} // namespace tensor

0 commit comments

Comments
 (0)