@@ -191,6 +191,60 @@ struct BinaryInplaceStridedFunctor
191
191
}
192
192
};
193
193
194
+ template <typename argT, typename resT, typename BinaryOperatorT>
195
+ struct BinaryInplaceRowMatrixBroadcastingFunctor
196
+ {
197
+ private:
198
+ const argT *padded_vec;
199
+ resT *mat;
200
+ size_t n_elems;
201
+ size_t n1;
202
+
203
+ public:
204
+ BinaryInplaceRowMatrixBroadcastingFunctor (const argT *row_tp,
205
+ resT *mat_tp,
206
+ size_t n_elems_in_mat,
207
+ size_t n_elems_in_row)
208
+ : padded_vec(row_tp), mat(mat_tp), n_elems(n_elems_in_mat),
209
+ n1 (n_elems_in_row)
210
+ {
211
+ }
212
+
213
+ void operator ()(sycl::nd_item<1 > ndit) const
214
+ {
215
+ BinaryOperatorT op{};
216
+ static_assert (BinaryOperatorT::supports_sg_loadstore::value);
217
+
218
+ auto sg = ndit.get_sub_group ();
219
+ size_t gid = ndit.get_global_linear_id ();
220
+
221
+ std::uint8_t sgSize = sg.get_local_range ()[0 ];
222
+ size_t base = gid - sg.get_local_id ()[0 ];
223
+
224
+ if (base + sgSize < n_elems) {
225
+ using in_ptrT =
226
+ sycl::multi_ptr<const argT,
227
+ sycl::access ::address_space::global_space>;
228
+ using res_ptrT =
229
+ sycl::multi_ptr<resT,
230
+ sycl::access ::address_space::global_space>;
231
+
232
+ const argT vec_el = sg.load (in_ptrT (&padded_vec[base % n1]));
233
+ resT mat_el = sg.load (res_ptrT (&mat[base]));
234
+
235
+ op (mat_el, vec_el);
236
+
237
+ sg.store (res_ptrT (&mat[base]), mat_el);
238
+ }
239
+ else {
240
+ for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
241
+ k += sgSize) {
242
+ op (mat[k], padded_vec[k % n1]);
243
+ }
244
+ }
245
+ }
246
+ };
247
+
194
248
// Typedefs for function pointers
195
249
196
250
typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t )(
@@ -214,6 +268,17 @@ typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)(
214
268
const std::vector<sycl::event> &,
215
269
const std::vector<sycl::event> &);
216
270
271
+ typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t )(
272
+ sycl::queue,
273
+ std::vector<sycl::event> &,
274
+ size_t ,
275
+ size_t ,
276
+ const char *,
277
+ py::ssize_t ,
278
+ char *,
279
+ py::ssize_t ,
280
+ const std::vector<sycl::event> &);
281
+
217
282
template <typename argTy,
218
283
typename resTy,
219
284
template <typename T1, typename T2, unsigned int vs, unsigned int nv>
@@ -289,6 +354,79 @@ binary_inplace_strided_impl(sycl::queue exec_q,
289
354
return comp_ev;
290
355
}
291
356
357
+ template <typename argT,
358
+ typename resT,
359
+ template <typename T1, typename T3>
360
+ class BinaryInplaceRowMatrixBroadcastFunctorT ,
361
+ template <typename T1, typename T3>
362
+ class kernel_name >
363
+ sycl::event binary_inplace_row_matrix_broadcast_impl (
364
+ sycl::queue exec_q,
365
+ std::vector<sycl::event> &host_tasks,
366
+ size_t n0,
367
+ size_t n1,
368
+ const char *vec_p, // typeless pointer to (n1,) contiguous row
369
+ py::ssize_t vec_offset,
370
+ char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix
371
+ py::ssize_t mat_offset,
372
+ const std::vector<sycl::event> &depends = {})
373
+ {
374
+ const argT *vec = reinterpret_cast <const argT *>(vec_p) + vec_offset;
375
+ resT *mat = reinterpret_cast <resT *>(mat_p) + mat_offset;
376
+
377
+ const auto &dev = exec_q.get_device ();
378
+ const auto &sg_sizes = dev.get_info <sycl::info::device::sub_group_sizes>();
379
+ // Get device-specific kernel info max_sub_group_size
380
+ size_t max_sgSize =
381
+ *(std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
382
+
383
+ size_t n1_padded = n1 + max_sgSize;
384
+ argT *padded_vec = sycl::malloc_device<argT>(n1_padded, exec_q);
385
+
386
+ if (padded_vec == nullptr ) {
387
+ throw std::runtime_error (" Could not allocate memory on the device" );
388
+ }
389
+ sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
390
+ cgh.depends_on (depends); // ensure vec contains actual data
391
+ cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
392
+ auto i = id[0 ];
393
+ padded_vec[i] = vec[i % n1];
394
+ });
395
+ });
396
+
397
+ // sub-group spans work-items [I, I + sgSize)
398
+ // base = ndit.get_global_linear_id() - sg.get_local_id()[0]
399
+ // Generically, sg.load( &mat[base]) may load arrays from
400
+ // different rows of mat. The start corresponds to row (base / n0)
401
+ // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
402
+ // ensure that reads are accessible
403
+
404
+ size_t lws = 64 ;
405
+
406
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
407
+ cgh.depends_on (make_padded_vec_ev);
408
+
409
+ auto lwsRange = sycl::range<1 >(lws);
410
+ size_t n_elems = n0 * n1;
411
+ size_t n_groups = (n_elems + lws - 1 ) / lws;
412
+ auto gwsRange = sycl::range<1 >(n_groups * lws);
413
+
414
+ cgh.parallel_for <class kernel_name <argT, resT>>(
415
+ sycl::nd_range<1 >(gwsRange, lwsRange),
416
+ BinaryInplaceRowMatrixBroadcastFunctorT<argT, resT>(padded_vec, mat,
417
+ n_elems, n1));
418
+ });
419
+
420
+ sycl::event tmp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
421
+ cgh.depends_on (comp_ev);
422
+ sycl::context ctx = exec_q.get_context ();
423
+ cgh.host_task ([ctx, padded_vec]() { sycl::free (padded_vec, ctx); });
424
+ });
425
+ host_tasks.push_back (tmp_cleanup_ev);
426
+
427
+ return comp_ev;
428
+ }
429
+
292
430
} // namespace elementwise_common
293
431
} // namespace kernels
294
432
} // namespace tensor
0 commit comments