@@ -178,7 +178,8 @@ class GenericFilteringFlag
178
178
{
179
179
public:
180
180
181
- void apply (GenericFilteringBase& s, SmallSpan<const DataType> input, SmallSpan<OutputDataType> output, SmallSpan<const FlagType> flag)
181
+ void apply (GenericFilteringBase& s, SmallSpan<const DataType> input,
182
+ SmallSpan<OutputDataType> output, SmallSpan<const FlagType> flag)
182
183
{
183
184
const Int32 nb_item = input.size ();
184
185
if (output.size () != nb_item)
@@ -264,7 +265,13 @@ class GenericFilteringIf
264
265
{
265
266
public:
266
267
267
- template <typename SelectLambda, typename InputIterator, typename OutputIterator>
268
+ /* !
269
+ * \brief Applique le filtre.
270
+ *
271
+ * Si \a InPlace est vrai, alors OutputIterator vaut InputIterator et on
272
+ * met à jour directement \a input_iter.
273
+ */
274
+ template <bool InPlace, typename SelectLambda, typename InputIterator, typename OutputIterator>
268
275
void apply (GenericFilteringBase& s, Int32 nb_item, InputIterator input_iter, OutputIterator output_iter,
269
276
const SelectLambda& select_lambda, const TraceInfo& trace_info)
270
277
{
@@ -282,15 +289,26 @@ class GenericFilteringIf
282
289
cudaStream_t stream = impl::CudaUtils::toNativeStream (queue);
283
290
// Premier appel pour connaitre la taille pour l'allocation
284
291
int * nb_out_ptr = nullptr ;
285
- ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (nullptr , temp_storage_size,
286
- input_iter, output_iter, nb_out_ptr, nb_item,
287
- select_lambda, stream));
292
+ if constexpr (InPlace)
293
+ ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (nullptr , temp_storage_size,
294
+ input_iter, nb_out_ptr, nb_item,
295
+ select_lambda, stream));
296
+ else
297
+ ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (nullptr , temp_storage_size,
298
+ input_iter, output_iter, nb_out_ptr, nb_item,
299
+ select_lambda, stream));
288
300
289
301
s._allocateTemporaryStorage (temp_storage_size);
290
302
nb_out_ptr = s._getDeviceNbOutPointer ();
291
- ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (s.m_algo_storage .address (), temp_storage_size,
292
- input_iter, output_iter, nb_out_ptr, nb_item,
293
- select_lambda, stream));
303
+ if constexpr (InPlace)
304
+ ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (s.m_algo_storage .address (), temp_storage_size,
305
+ input_iter, nb_out_ptr, nb_item,
306
+ select_lambda, stream));
307
+ else
308
+ ARCANE_CHECK_CUDA (::cub::DeviceSelect::If (s.m_algo_storage .address (), temp_storage_size,
309
+ input_iter, output_iter, nb_out_ptr, nb_item,
310
+ select_lambda, stream));
311
+
294
312
s._copyDeviceNbOutToHostNbOut ();
295
313
} break ;
296
314
#endif
@@ -300,9 +318,11 @@ class GenericFilteringIf
300
318
// Premier appel pour connaitre la taille pour l'allocation
301
319
hipStream_t stream = impl::HipUtils::toNativeStream (queue);
302
320
int * nb_out_ptr = nullptr ;
321
+ // NOTE: il n'y a pas de version spécifique de 'select' en-place.
322
+ // A priori il est possible que \a input_iter et \a output_iter
323
+ // aient la même valeur.
303
324
ARCANE_CHECK_HIP (rocprim::select (nullptr , temp_storage_size, input_iter, output_iter,
304
325
nb_out_ptr, nb_item, select_lambda, stream));
305
-
306
326
s._allocateTemporaryStorage (temp_storage_size);
307
327
nb_out_ptr = s._getDeviceNbOutPointer ();
308
328
ARCANE_CHECK_HIP (rocprim::select (s.m_algo_storage .address (), temp_storage_size, input_iter, output_iter,
@@ -318,7 +338,7 @@ class GenericFilteringIf
318
338
case eExecutionPolicy::Thread:
319
339
if (nb_item > 500 ) {
320
340
MultiThreadAlgo scanner;
321
- Int32 v = scanner.doFilter (launch_info.loopRunInfo (), nb_item, input_iter, output_iter, select_lambda);
341
+ Int32 v = scanner.doFilter <InPlace> (launch_info.loopRunInfo (), nb_item, input_iter, output_iter, select_lambda);
322
342
s.m_host_nb_out_storage [0 ] = v;
323
343
break ;
324
344
}
@@ -433,6 +453,7 @@ class GenericFilterer
433
453
* Filtre tous les éléments de \a input pour lesquels \a select_lambda vaut \a true et
434
454
* remplit \a output avec les valeurs filtrées. \a output doit avoir une taille assez
435
455
* grande pour contenir tous les éléments filtrés.
456
+ * Les zones mémoire associées à \a input et \a output ne doivent pas se chevaucher.
436
457
*
437
458
* \a select_lambda doit avoir un opérateur `ARCCORE_HOST_DEVICE bool operator()(const DataType& v) const`.
438
459
*
@@ -468,13 +489,36 @@ class GenericFilterer
468
489
const Int32 nb_value = input.size ();
469
490
if (output.size () != nb_value)
470
491
ARCANE_FATAL (" Sizes are not equals: input={0} output={1}" , nb_value, output.size ());
492
+ if (input.data () == output.data ())
493
+ ARCANE_FATAL (" Input and Output are the same. Use in place overload instead" );
494
+ _setCalled ();
495
+ if (_checkEmpty (nb_value))
496
+ return ;
497
+ impl::GenericFilteringBase* base_ptr = this ;
498
+ impl::GenericFilteringIf gf;
499
+ gf.apply <false >(*base_ptr, nb_value, input.data (), output.data (), select_lambda, trace_info);
500
+ }
471
501
502
+ /* !
503
+ * \brief Applique un filtre en place.
504
+ *
505
+ * Cette méthode est identique à applyIf(SmallSpan<const DataType>, SmallSpan<DataType>,
506
+ * const SelectLambda&, const TraceInfo& trace_info) mais les valeurs filtrées sont
507
+ * directement recopié dans le tableau \a input_output.
508
+ */
509
+ template <typename DataType, typename SelectLambda>
510
+ void applyIf (SmallSpan<DataType> input_output, const SelectLambda& select_lambda,
511
+ const TraceInfo& trace_info = TraceInfo())
512
+ {
513
+ const Int32 nb_value = input_output.size ();
514
+ if (nb_value <= 0 )
515
+ return ;
472
516
_setCalled ();
473
517
if (_checkEmpty (nb_value))
474
518
return ;
475
519
impl::GenericFilteringBase* base_ptr = this ;
476
520
impl::GenericFilteringIf gf;
477
- gf.apply (*base_ptr, nb_value, input .data (), output .data (), select_lambda, trace_info);
521
+ gf.apply < true > (*base_ptr, nb_value, input_output .data (), input_output .data (), select_lambda, trace_info);
478
522
}
479
523
480
524
/* !
@@ -484,6 +528,8 @@ class GenericFilterer
484
528
* SmallSpan<DataType> output, const SelectLambda& select_lambda) mais permet de spécifier un
485
529
* itérateur \a input_iter pour l'entrée et \a output_iter pour la sortie.
486
530
* Le nombre d'entité en entrée est donné par \a nb_value.
531
+ *
532
+ * Les zones mémoire associées à \a input_iter et \a output_iter ne doivent pas se chevaucher.
487
533
*/
488
534
template <typename InputIterator, typename OutputIterator, typename SelectLambda>
489
535
void applyIf (Int32 nb_value, InputIterator input_iter, OutputIterator output_iter,
@@ -494,7 +540,7 @@ class GenericFilterer
494
540
return ;
495
541
impl::GenericFilteringBase* base_ptr = this ;
496
542
impl::GenericFilteringIf gf;
497
- gf.apply (*base_ptr, nb_value, input_iter, output_iter, select_lambda, trace_info);
543
+ gf.apply < false > (*base_ptr, nb_value, input_iter, output_iter, select_lambda, trace_info);
498
544
}
499
545
500
546
/* !
@@ -527,6 +573,8 @@ class GenericFilterer
527
573
* filterer.applyWithIndex(input.size(), select_lambda, setter_lambda);
528
574
* Int32 nb_out = filterer.nbOutputElement();
529
575
* \endcode
576
+ *
577
+ * Les zones mémoire associées aux valeurs d'entrée et de sortie ne doivent pas se chevaucher.
530
578
*/
531
579
template <typename SelectLambda, typename SetterLambda>
532
580
void applyWithIndex (Int32 nb_value, const SelectLambda& select_lambda,
@@ -539,7 +587,7 @@ class GenericFilterer
539
587
impl::GenericFilteringIf gf;
540
588
impl::IndexIterator input_iter;
541
589
impl::SetterLambdaIterator<SetterLambda> out (setter_lambda);
542
- gf.apply (*base_ptr, nb_value, input_iter, out, select_lambda, trace_info);
590
+ gf.apply < false > (*base_ptr, nb_value, input_iter, out, select_lambda, trace_info);
543
591
}
544
592
545
593
// ! Nombre d'éléments en sortie.
0 commit comments