Skip to content

Commit d5c2512

Browse files
Merge pull request #1882 from arcaneframework/dev/gg-add-in-place-filtering
Add specific implementation of 'GenericFiltering' when input and output are the same
2 parents 4dfafd6 + 993dbe0 commit d5c2512

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

arcane/src/arcane/accelerator/GenericFilterer.h

+61-13
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ class GenericFilteringFlag
178178
{
179179
public:
180180

181-
void apply(GenericFilteringBase& s, SmallSpan<const DataType> input, SmallSpan<OutputDataType> output, SmallSpan<const FlagType> flag)
181+
void apply(GenericFilteringBase& s, SmallSpan<const DataType> input,
182+
SmallSpan<OutputDataType> output, SmallSpan<const FlagType> flag)
182183
{
183184
const Int32 nb_item = input.size();
184185
if (output.size() != nb_item)
@@ -264,7 +265,13 @@ class GenericFilteringIf
264265
{
265266
public:
266267

267-
template <typename SelectLambda, typename InputIterator, typename OutputIterator>
268+
/*!
269+
* \brief Applique le filtre.
270+
*
271+
* Si \a InPlace est vrai, alors OutputIterator vaut InputIterator et on
272+
* met à jour directement \a input_iter.
273+
*/
274+
template <bool InPlace, typename SelectLambda, typename InputIterator, typename OutputIterator>
268275
void apply(GenericFilteringBase& s, Int32 nb_item, InputIterator input_iter, OutputIterator output_iter,
269276
const SelectLambda& select_lambda, const TraceInfo& trace_info)
270277
{
@@ -282,15 +289,26 @@ class GenericFilteringIf
282289
cudaStream_t stream = impl::CudaUtils::toNativeStream(queue);
283290
// Premier appel pour connaitre la taille pour l'allocation
284291
int* nb_out_ptr = nullptr;
285-
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(nullptr, temp_storage_size,
286-
input_iter, output_iter, nb_out_ptr, nb_item,
287-
select_lambda, stream));
292+
if constexpr (InPlace)
293+
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(nullptr, temp_storage_size,
294+
input_iter, nb_out_ptr, nb_item,
295+
select_lambda, stream));
296+
else
297+
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(nullptr, temp_storage_size,
298+
input_iter, output_iter, nb_out_ptr, nb_item,
299+
select_lambda, stream));
288300

289301
s._allocateTemporaryStorage(temp_storage_size);
290302
nb_out_ptr = s._getDeviceNbOutPointer();
291-
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(s.m_algo_storage.address(), temp_storage_size,
292-
input_iter, output_iter, nb_out_ptr, nb_item,
293-
select_lambda, stream));
303+
if constexpr (InPlace)
304+
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(s.m_algo_storage.address(), temp_storage_size,
305+
input_iter, nb_out_ptr, nb_item,
306+
select_lambda, stream));
307+
else
308+
ARCANE_CHECK_CUDA(::cub::DeviceSelect::If(s.m_algo_storage.address(), temp_storage_size,
309+
input_iter, output_iter, nb_out_ptr, nb_item,
310+
select_lambda, stream));
311+
294312
s._copyDeviceNbOutToHostNbOut();
295313
} break;
296314
#endif
@@ -300,9 +318,11 @@ class GenericFilteringIf
300318
// Premier appel pour connaitre la taille pour l'allocation
301319
hipStream_t stream = impl::HipUtils::toNativeStream(queue);
302320
int* nb_out_ptr = nullptr;
321+
// NOTE: il n'y a pas de version spécifique de 'select' en-place.
322+
// A priori il est possible que \a input_iter et \a output_iter
323+
// aient la même valeur.
303324
ARCANE_CHECK_HIP(rocprim::select(nullptr, temp_storage_size, input_iter, output_iter,
304325
nb_out_ptr, nb_item, select_lambda, stream));
305-
306326
s._allocateTemporaryStorage(temp_storage_size);
307327
nb_out_ptr = s._getDeviceNbOutPointer();
308328
ARCANE_CHECK_HIP(rocprim::select(s.m_algo_storage.address(), temp_storage_size, input_iter, output_iter,
@@ -318,7 +338,7 @@ class GenericFilteringIf
318338
case eExecutionPolicy::Thread:
319339
if (nb_item > 500) {
320340
MultiThreadAlgo scanner;
321-
Int32 v = scanner.doFilter(launch_info.loopRunInfo(), nb_item, input_iter, output_iter, select_lambda);
341+
Int32 v = scanner.doFilter<InPlace>(launch_info.loopRunInfo(), nb_item, input_iter, output_iter, select_lambda);
322342
s.m_host_nb_out_storage[0] = v;
323343
break;
324344
}
@@ -433,6 +453,7 @@ class GenericFilterer
433453
* Filtre tous les éléments de \a input pour lesquels \a select_lambda vaut \a true et
434454
* remplit \a output avec les valeurs filtrées. \a output doit avoir une taille assez
435455
* grande pour contenir tous les éléments filtrés.
456+
* Les zones mémoire associées à \a input et \a output ne doivent pas se chevaucher.
436457
*
437458
* \a select_lambda doit avoir un opérateur `ARCCORE_HOST_DEVICE bool operator()(const DataType& v) const`.
438459
*
@@ -468,13 +489,36 @@ class GenericFilterer
468489
const Int32 nb_value = input.size();
469490
if (output.size() != nb_value)
470491
ARCANE_FATAL("Sizes are not equals: input={0} output={1}", nb_value, output.size());
492+
if (input.data() == output.data())
493+
ARCANE_FATAL("Input and Output are the same. Use in place overload instead");
494+
_setCalled();
495+
if (_checkEmpty(nb_value))
496+
return;
497+
impl::GenericFilteringBase* base_ptr = this;
498+
impl::GenericFilteringIf gf;
499+
gf.apply<false>(*base_ptr, nb_value, input.data(), output.data(), select_lambda, trace_info);
500+
}
471501

502+
/*!
503+
* \brief Applique un filtre en place.
504+
*
505+
* Cette méthode est identique à applyIf(SmallSpan<const DataType>, SmallSpan<DataType>,
506+
* const SelectLambda&, const TraceInfo& trace_info) mais les valeurs filtrées sont
507+
* directement recopié dans le tableau \a input_output.
508+
*/
509+
template <typename DataType, typename SelectLambda>
510+
void applyIf(SmallSpan<DataType> input_output, const SelectLambda& select_lambda,
511+
const TraceInfo& trace_info = TraceInfo())
512+
{
513+
const Int32 nb_value = input_output.size();
514+
if (nb_value <= 0)
515+
return;
472516
_setCalled();
473517
if (_checkEmpty(nb_value))
474518
return;
475519
impl::GenericFilteringBase* base_ptr = this;
476520
impl::GenericFilteringIf gf;
477-
gf.apply(*base_ptr, nb_value, input.data(), output.data(), select_lambda, trace_info);
521+
gf.apply<true>(*base_ptr, nb_value, input_output.data(), input_output.data(), select_lambda, trace_info);
478522
}
479523

480524
/*!
@@ -484,6 +528,8 @@ class GenericFilterer
484528
* SmallSpan<DataType> output, const SelectLambda& select_lambda) mais permet de spécifier un
485529
* itérateur \a input_iter pour l'entrée et \a output_iter pour la sortie.
486530
* Le nombre d'entité en entrée est donné par \a nb_value.
531+
*
532+
* Les zones mémoire associées à \a input_iter et \a output_iter ne doivent pas se chevaucher.
487533
*/
488534
template <typename InputIterator, typename OutputIterator, typename SelectLambda>
489535
void applyIf(Int32 nb_value, InputIterator input_iter, OutputIterator output_iter,
@@ -494,7 +540,7 @@ class GenericFilterer
494540
return;
495541
impl::GenericFilteringBase* base_ptr = this;
496542
impl::GenericFilteringIf gf;
497-
gf.apply(*base_ptr, nb_value, input_iter, output_iter, select_lambda, trace_info);
543+
gf.apply<false>(*base_ptr, nb_value, input_iter, output_iter, select_lambda, trace_info);
498544
}
499545

500546
/*!
@@ -527,6 +573,8 @@ class GenericFilterer
527573
* filterer.applyWithIndex(input.size(), select_lambda, setter_lambda);
528574
* Int32 nb_out = filterer.nbOutputElement();
529575
* \endcode
576+
*
577+
* Les zones mémoire associées aux valeurs d'entrée et de sortie ne doivent pas se chevaucher.
530578
*/
531579
template <typename SelectLambda, typename SetterLambda>
532580
void applyWithIndex(Int32 nb_value, const SelectLambda& select_lambda,
@@ -539,7 +587,7 @@ class GenericFilterer
539587
impl::GenericFilteringIf gf;
540588
impl::IndexIterator input_iter;
541589
impl::SetterLambdaIterator<SetterLambda> out(setter_lambda);
542-
gf.apply(*base_ptr, nb_value, input_iter, out, select_lambda, trace_info);
590+
gf.apply<false>(*base_ptr, nb_value, input_iter, out, select_lambda, trace_info);
543591
}
544592

545593
//! Nombre d'éléments en sortie.

arcane/src/arcane/accelerator/MultiThreadAlgo.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class MultiThreadAlgo
130130
Arcane::arcaneParallelFor(0, nb_block, run_info, final_sum_func);
131131
}
132132

133-
template <typename InputIterator, typename OutputIterator, typename SelectLambda>
133+
template <bool InPlace, typename InputIterator, typename OutputIterator, typename SelectLambda>
134134
Int32 doFilter(ForLoopRunInfo run_info, Int32 nb_value,
135135
InputIterator input, OutputIterator output,
136136
SelectLambda select_lambda)
@@ -215,10 +215,9 @@ class MultiThreadAlgo
215215
}
216216
};
217217

218-
// Pour l'instant il est possible que l'entrée et la sortie
219-
// se chevauchent. Dans ce cas on fait le remplissage en séquentiel.
220-
const bool may_input_and_output_overlap = true;
221-
if (may_input_and_output_overlap)
218+
// Si l'entrée et la sortie sont les mêmes, on fait le remplissage en séquentiel.
219+
// TODO: faire en parallèle.
220+
if (InPlace)
222221
filter_func(0, nb_block);
223222
else
224223
Arcane::arcaneParallelFor(0, nb_block, run_info, filter_func);

arcane/src/arcane/materials/IncrementalComponentModifier_Accelerator.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,13 @@ _removeItemsInGroup(ItemGroup cells, SmallSpan<const Int32> removed_ids)
362362

363363
// Lors de l'application du filtre, le tableau d'entrée et de sortie
364364
// est le même (c'est normalement supporté par le GenericFilterer).
365-
SmallSpan<const Int32> input_ids(items_local_id);
366-
SmallSpan<Int32> output_ids_view(items_local_id);
365+
SmallSpan<Int32> input_ids(items_local_id);
367366
SmallSpan<const bool> filtered_cells(m_work_info.removedCells());
368367
Accelerator::GenericFilterer filterer(m_queue);
369368
auto select_filter = [=] ARCCORE_HOST_DEVICE(Int32 local_id) -> bool {
370369
return !filtered_cells[local_id];
371370
};
372-
filterer.applyIf(input_ids, output_ids_view, select_filter, A_FUNCINFO);
371+
filterer.applyIf(input_ids, select_filter, A_FUNCINFO);
373372

374373
Int32 current_nb_item = items_local_id.size();
375374
Int32 nb_remaining = filterer.nbOutputElement();

0 commit comments

Comments
 (0)