Skip to content

Commit 3794cbc

Browse files
authored
Merge pull request #1244 from IntelPython/inplace-matrix-row-variant
Optimized in-place operators for rows and matrices
2 parents 21b2767 + 47b2921 commit 3794cbc

17 files changed

+433
-50
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

+56-3
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
218218
};
219219

220220
template <typename T1, typename T2, typename resT, typename IndexerT>
221-
class add_strided_strided_kernel;
221+
class add_strided_kernel;
222222

223223
template <typename argTy1, typename argTy2>
224224
sycl::event add_strided_impl(sycl::queue exec_q,
@@ -235,8 +235,7 @@ sycl::event add_strided_impl(sycl::queue exec_q,
235235
const std::vector<sycl::event> &additional_depends)
236236
{
237237
return elementwise_common::binary_strided_impl<
238-
argTy1, argTy2, AddOutputType, AddStridedFunctor,
239-
add_strided_strided_kernel>(
238+
argTy1, argTy2, AddOutputType, AddStridedFunctor, add_strided_kernel>(
240239
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
241240
arg2_offset, res_p, res_offset, depends, additional_depends);
242241
}
@@ -480,6 +479,60 @@ struct AddInplaceStridedFactory
480479
}
481480
};
482481

482+
template <typename argT, typename resT>
483+
class add_inplace_row_matrix_broadcast_sg_krn;
484+
485+
template <typename argT, typename resT>
486+
using AddInplaceRowMatrixBroadcastingFunctor =
487+
elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor<
488+
argT,
489+
resT,
490+
AddInplaceFunctor<argT, resT>>;
491+
492+
template <typename argT, typename resT>
493+
sycl::event add_inplace_row_matrix_broadcast_impl(
494+
sycl::queue exec_q,
495+
std::vector<sycl::event> &host_tasks,
496+
size_t n0,
497+
size_t n1,
498+
const char *vec_p, // typeless pointer to (n1,) contiguous row
499+
py::ssize_t vec_offset,
500+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
501+
py::ssize_t mat_offset,
502+
const std::vector<sycl::event> &depends = {})
503+
{
504+
return elementwise_common::binary_inplace_row_matrix_broadcast_impl<
505+
argT, resT, AddInplaceRowMatrixBroadcastingFunctor,
506+
add_inplace_row_matrix_broadcast_sg_krn>(exec_q, host_tasks, n0, n1,
507+
vec_p, vec_offset, mat_p,
508+
mat_offset, depends);
509+
}
510+
511+
template <typename fnT, typename T1, typename T2>
512+
struct AddInplaceRowMatrixBroadcastFactory
513+
{
514+
fnT get()
515+
{
516+
using resT = typename AddOutputType<T1, T2>::value_type;
517+
if constexpr (!std::is_same_v<resT, T2>) {
518+
fnT fn = nullptr;
519+
return fn;
520+
}
521+
else {
522+
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
523+
dpctl::tensor::type_utils::is_complex<T2>::value)
524+
{
525+
fnT fn = nullptr;
526+
return fn;
527+
}
528+
else {
529+
fnT fn = add_inplace_row_matrix_broadcast_impl<T1, T2>;
530+
return fn;
531+
}
532+
}
533+
}
534+
};
535+
483536
} // namespace add
484537
} // namespace kernels
485538
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp

+138
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,60 @@ struct BinaryInplaceStridedFunctor
191191
}
192192
};
193193

194+
template <typename argT, typename resT, typename BinaryOperatorT>
195+
struct BinaryInplaceRowMatrixBroadcastingFunctor
196+
{
197+
private:
198+
const argT *padded_vec;
199+
resT *mat;
200+
size_t n_elems;
201+
size_t n1;
202+
203+
public:
204+
BinaryInplaceRowMatrixBroadcastingFunctor(const argT *row_tp,
205+
resT *mat_tp,
206+
size_t n_elems_in_mat,
207+
size_t n_elems_in_row)
208+
: padded_vec(row_tp), mat(mat_tp), n_elems(n_elems_in_mat),
209+
n1(n_elems_in_row)
210+
{
211+
}
212+
213+
void operator()(sycl::nd_item<1> ndit) const
214+
{
215+
BinaryOperatorT op{};
216+
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
217+
218+
auto sg = ndit.get_sub_group();
219+
size_t gid = ndit.get_global_linear_id();
220+
221+
std::uint8_t sgSize = sg.get_local_range()[0];
222+
size_t base = gid - sg.get_local_id()[0];
223+
224+
if (base + sgSize < n_elems) {
225+
using in_ptrT =
226+
sycl::multi_ptr<const argT,
227+
sycl::access::address_space::global_space>;
228+
using res_ptrT =
229+
sycl::multi_ptr<resT,
230+
sycl::access::address_space::global_space>;
231+
232+
const argT vec_el = sg.load(in_ptrT(&padded_vec[base % n1]));
233+
resT mat_el = sg.load(res_ptrT(&mat[base]));
234+
235+
op(mat_el, vec_el);
236+
237+
sg.store(res_ptrT(&mat[base]), mat_el);
238+
}
239+
else {
240+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
241+
k += sgSize) {
242+
op(mat[k], padded_vec[k % n1]);
243+
}
244+
}
245+
}
246+
};
247+
194248
// Typedefs for function pointers
195249

196250
typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t)(
@@ -214,6 +268,17 @@ typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)(
214268
const std::vector<sycl::event> &,
215269
const std::vector<sycl::event> &);
216270

271+
typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t)(
272+
sycl::queue,
273+
std::vector<sycl::event> &,
274+
size_t,
275+
size_t,
276+
const char *,
277+
py::ssize_t,
278+
char *,
279+
py::ssize_t,
280+
const std::vector<sycl::event> &);
281+
217282
template <typename argTy,
218283
typename resTy,
219284
template <typename T1, typename T2, unsigned int vs, unsigned int nv>
@@ -289,6 +354,79 @@ binary_inplace_strided_impl(sycl::queue exec_q,
289354
return comp_ev;
290355
}
291356

357+
template <typename argT,
358+
typename resT,
359+
template <typename T1, typename T3>
360+
class BinaryInplaceRowMatrixBroadcastFunctorT,
361+
template <typename T1, typename T3>
362+
class kernel_name>
363+
sycl::event binary_inplace_row_matrix_broadcast_impl(
364+
sycl::queue exec_q,
365+
std::vector<sycl::event> &host_tasks,
366+
size_t n0,
367+
size_t n1,
368+
const char *vec_p, // typeless pointer to (n1,) contiguous row
369+
py::ssize_t vec_offset,
370+
char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
371+
py::ssize_t mat_offset,
372+
const std::vector<sycl::event> &depends = {})
373+
{
374+
const argT *vec = reinterpret_cast<const argT *>(vec_p) + vec_offset;
375+
resT *mat = reinterpret_cast<resT *>(mat_p) + mat_offset;
376+
377+
const auto &dev = exec_q.get_device();
378+
const auto &sg_sizes = dev.get_info<sycl::info::device::sub_group_sizes>();
379+
// Get device-specific kernel info max_sub_group_size
380+
size_t max_sgSize =
381+
*(std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
382+
383+
size_t n1_padded = n1 + max_sgSize;
384+
argT *padded_vec = sycl::malloc_device<argT>(n1_padded, exec_q);
385+
386+
if (padded_vec == nullptr) {
387+
throw std::runtime_error("Could not allocate memory on the device");
388+
}
389+
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
390+
cgh.depends_on(depends); // ensure vec contains actual data
391+
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
392+
auto i = id[0];
393+
padded_vec[i] = vec[i % n1];
394+
});
395+
});
396+
397+
// sub-group spans work-items [I, I + sgSize)
398+
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
399+
// Generically, sg.load( &mat[base]) may load arrays from
400+
// different rows of mat. The start corresponds to row (base / n0)
401+
// We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
402+
// ensure that reads are accessible
403+
404+
size_t lws = 64;
405+
406+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
407+
cgh.depends_on(make_padded_vec_ev);
408+
409+
auto lwsRange = sycl::range<1>(lws);
410+
size_t n_elems = n0 * n1;
411+
size_t n_groups = (n_elems + lws - 1) / lws;
412+
auto gwsRange = sycl::range<1>(n_groups * lws);
413+
414+
cgh.parallel_for<class kernel_name<argT, resT>>(
415+
sycl::nd_range<1>(gwsRange, lwsRange),
416+
BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>(padded_vec, mat,
417+
n_elems, n1));
418+
});
419+
420+
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
421+
cgh.depends_on(comp_ev);
422+
sycl::context ctx = exec_q.get_context();
423+
cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); });
424+
});
425+
host_tasks.push_back(tmp_cleanup_ev);
426+
427+
return comp_ev;
428+
}
429+
292430
} // namespace elementwise_common
293431
} // namespace kernels
294432
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ template <typename fnT, typename T1, typename T2> struct EqualTypeMapFactory
201201
};
202202

203203
template <typename T1, typename T2, typename resT, typename IndexerT>
204-
class equal_strided_strided_kernel;
204+
class equal_strided_kernel;
205205

206206
template <typename argTy1, typename argTy2>
207207
sycl::event
@@ -220,9 +220,9 @@ equal_strided_impl(sycl::queue exec_q,
220220
{
221221
return elementwise_common::binary_strided_impl<
222222
argTy1, argTy2, EqualOutputType, EqualStridedFunctor,
223-
equal_strided_strided_kernel>(
224-
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
225-
arg2_offset, res_p, res_offset, depends, additional_depends);
223+
equal_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
224+
arg1_offset, arg2_p, arg2_offset, res_p,
225+
res_offset, depends, additional_depends);
226226
}
227227

228228
template <typename fnT, typename T1, typename T2> struct EqualStridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ struct FloorDivideTypeMapFactory
235235
};
236236

237237
template <typename T1, typename T2, typename resT, typename IndexerT>
238-
class floor_divide_strided_strided_kernel;
238+
class floor_divide_strided_kernel;
239239

240240
template <typename argTy1, typename argTy2>
241241
sycl::event
@@ -254,7 +254,7 @@ floor_divide_strided_impl(sycl::queue exec_q,
254254
{
255255
return elementwise_common::binary_strided_impl<
256256
argTy1, argTy2, FloorDivideOutputType, FloorDivideStridedFunctor,
257-
floor_divide_strided_strided_kernel>(
257+
floor_divide_strided_kernel>(
258258
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
259259
arg2_offset, res_p, res_offset, depends, additional_depends);
260260
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ template <typename fnT, typename T1, typename T2> struct GreaterTypeMapFactory
255255
};
256256

257257
template <typename T1, typename T2, typename resT, typename IndexerT>
258-
class greater_strided_strided_kernel;
258+
class greater_strided_kernel;
259259

260260
template <typename argTy1, typename argTy2>
261261
sycl::event
@@ -289,7 +289,7 @@ greater_strided_impl(sycl::queue exec_q,
289289
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
290290

291291
cgh.parallel_for<
292-
greater_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
292+
greater_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293293
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
294294
arg1_tp, arg2_tp, res_tp, indexer));
295295
});

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ struct GreaterEqualTypeMapFactory
261261
};
262262

263263
template <typename T1, typename T2, typename resT, typename IndexerT>
264-
class greater_equal_strided_strided_kernel;
264+
class greater_equal_strided_kernel;
265265

266266
template <typename argTy1, typename argTy2>
267267
sycl::event
@@ -295,8 +295,8 @@ greater_equal_strided_impl(sycl::queue exec_q,
295295
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
296296
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
297297

298-
cgh.parallel_for<greater_equal_strided_strided_kernel<argTy1, argTy2,
299-
resTy, IndexerT>>(
298+
cgh.parallel_for<
299+
greater_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
300300
{nelems},
301301
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
302302
arg1_tp, arg2_tp, res_tp, indexer));

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ template <typename fnT, typename T1, typename T2> struct LessTypeMapFactory
253253
};
254254

255255
template <typename T1, typename T2, typename resT, typename IndexerT>
256-
class less_strided_strided_kernel;
256+
class less_strided_kernel;
257257

258258
template <typename argTy1, typename argTy2>
259259
sycl::event
@@ -286,8 +286,7 @@ less_strided_impl(sycl::queue exec_q,
286286
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
287287
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
288288

289-
cgh.parallel_for<
290-
less_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
289+
cgh.parallel_for<less_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
291290
{nelems}, LessStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
292291
arg1_tp, arg2_tp, res_tp, indexer));
293292
});

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ template <typename fnT, typename T1, typename T2> struct LessEqualTypeMapFactory
256256
};
257257

258258
template <typename T1, typename T2, typename resT, typename IndexerT>
259-
class less_equal_strided_strided_kernel;
259+
class less_equal_strided_kernel;
260260

261261
template <typename argTy1, typename argTy2>
262262
sycl::event
@@ -290,7 +290,7 @@ less_equal_strided_impl(sycl::queue exec_q,
290290
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
291291

292292
cgh.parallel_for<
293-
less_equal_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293+
less_equal_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
294294
{nelems}, LessEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
295295
arg1_tp, arg2_tp, res_tp, indexer));
296296
});

0 commit comments

Comments
 (0)