-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathfor_each.hpp
116 lines (97 loc) · 2.94 KB
/
for_each.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause
#pragma once
#include <algorithm>
#include <execution>
#include <type_traits>
#include <utility>
#include <dr/concepts/concepts.hpp>
#include <dr/detail/logger.hpp>
#include <dr/detail/onedpl_direct_iterator.hpp>
#include <dr/detail/ranges_shim.hpp>
#include <dr/detail/sycl_utils.hpp>
#include <dr/mp/global.hpp>
namespace dr::mp {
// the concept doesn't work yet... for some reason
template <typename R>
concept dual_vector_range =
dr::distributed_range<R> && requires(R &r) { dr::ranges::segments(r)[0].is_compute(); };
void for_each(dual_vector_range auto &&dr, auto op) {
partial_for_each(dr, op);
partial_for_each(dr, op);
}
void partial_for_each(dual_vector_range auto &&dr, auto op) {
dr::drlog.debug(dr::logger::for_each, "partial_for_each: parallel execution\n");
if (rng::empty(dr)) {
return;
}
auto is_local = [](const auto &segment) {
return dr::ranges::rank(segment) == default_comm().rank();
};
for (auto &seg : dr::ranges::segments(dr) | rng::views::filter(is_local)) {
if (!seg.is_compute()) {
seg.swap_state();
continue;
}
auto b = dr::ranges::local(rng::begin(seg));
auto s = rng::subrange(b, b + rng::distance(seg));
if (mp::use_sycl()) {
dr::drlog.debug(" using sycl\n");
assert(rng::distance(s) > 0);
#ifdef SYCL_LANGUAGE_VERSION
dr::__detail::parallel_for(
dr::mp::sycl_queue(), sycl::range<1>(rng::distance(s)),
[first = rng::begin(s), op](auto idx) { op(first[idx]); })
.wait();
#else
assert(false);
#endif
} else {
dr::drlog.debug(" using cpu\n");
rng::for_each(s, op);
}
seg.swap_state();
}
barrier();
}
// Collective for_each on distributed range
void for_each(dr::distributed_range auto &&dr, auto op) {
dr::drlog.debug(dr::logger::for_each, "for_each: parallel execution\n");
if (rng::empty(dr)) {
return;
}
assert(aligned(dr));
for (const auto &s : local_segments(dr)) {
if (mp::use_sycl()) {
dr::drlog.debug(" using sycl\n");
assert(rng::distance(s) > 0);
#ifdef SYCL_LANGUAGE_VERSION
dr::__detail::parallel_for(
dr::mp::sycl_queue(), sycl::range<1>(rng::distance(s)),
[first = rng::begin(s), op](auto idx) { op(first[idx]); })
.wait();
#else
assert(false);
#endif
} else {
dr::drlog.debug(" using cpu\n");
rng::for_each(s, op);
}
}
barrier();
}
/// Collective for_each on iterator/sentinel for a distributed range
template <dr::distributed_iterator DI>
void for_each(DI first, DI last, auto op) {
mp::for_each(rng::subrange(first, last), op);
}
/// Collective for_each on iterator/sentinel for a distributed range
template <dr::distributed_iterator DI, std::integral I>
DI for_each_n(DI first, I n, auto op) {
auto last = first;
rng::advance(last, n);
mp::for_each(first, last, op);
return last;
}
} // namespace dr::mp