8
8
#include < graphbolt/cuda_ops.h>
9
9
#include < graphbolt/fused_csc_sampling_graph.h>
10
10
11
+ #include < cstring>
12
+ #include < numeric>
13
+
11
14
#include " ./macro.h"
12
15
#include " ./utils.h"
13
16
@@ -107,9 +110,9 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
107
110
c10::DeviceType::CUDA, " IndexSelectCSCImpl" ,
108
111
{ return IndexSelectCSCImpl (indptr, indices, nodes, output_size); });
109
112
}
110
- sampling::FusedCSCSamplingGraph g (indptr, indices);
111
- const auto res = g. InSubgraph ( nodes);
112
- return std::make_tuple (res-> indptr , res-> indices . value ( ));
113
+ auto [output_indptr, results] = IndexSelectCSCBatched (
114
+ indptr, std::vector{indices}, nodes, false , output_size );
115
+ return std::make_tuple (output_indptr, results. at ( 0 ));
113
116
}
114
117
115
118
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched (
@@ -129,17 +132,78 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
129
132
indptr, indices_list, nodes, with_edge_ids, output_size);
130
133
});
131
134
}
135
+ constexpr int kDefaultGrainSize = 128 ;
136
+ const auto num_nodes = nodes.size (0 );
137
+ torch::Tensor output_indptr = torch::empty (
138
+ {num_nodes + 1 }, nodes.options ().dtype (indptr.scalar_type ()));
132
139
std::vector<torch::Tensor> results;
133
- torch::Tensor output_indptr;
134
- torch::Tensor edge_ids;
135
- for (auto & indices : indices_list) {
136
- sampling::FusedCSCSamplingGraph g (indptr, indices);
137
- const auto res = g.InSubgraph (nodes);
138
- output_indptr = res->indptr ;
139
- results.push_back (res->indices .value ());
140
- edge_ids = res->original_edge_ids ;
141
- }
142
- if (with_edge_ids) results.push_back (edge_ids);
140
+ torch::optional<torch::Tensor> edge_ids;
141
+ AT_DISPATCH_INDEX_TYPES (
142
+ indptr.scalar_type (), " IndexSelectCSCBatched::indptr" , ([&] {
143
+ using indptr_t = index_t ;
144
+ const auto indptr_data = indptr.data_ptr <indptr_t >();
145
+ auto out_indptr_data = output_indptr.data_ptr <indptr_t >();
146
+ out_indptr_data[0 ] = 0 ;
147
+ AT_DISPATCH_INDEX_TYPES (
148
+ nodes.scalar_type (), " IndexSelectCSCBatched::nodes" , ([&] {
149
+ const auto nodes_data = nodes.data_ptr <index_t >();
150
+ torch::parallel_for (
151
+ 0 , num_nodes, kDefaultGrainSize ,
152
+ [&](int64_t begin, int64_t end) {
153
+ for (int64_t i = begin; i < end; i++) {
154
+ const auto node_id = nodes_data[i];
155
+ const auto degree =
156
+ indptr_data[node_id + 1 ] - indptr_data[node_id];
157
+ out_indptr_data[i + 1 ] = degree;
158
+ }
159
+ });
160
+ output_indptr = output_indptr.cumsum (0 , indptr.scalar_type ());
161
+ out_indptr_data = output_indptr.data_ptr <indptr_t >();
162
+ TORCH_CHECK (
163
+ !output_size.has_value () ||
164
+ out_indptr_data[num_nodes] == *output_size,
165
+ " An incorrect output_size argument was provided." );
166
+ output_size = out_indptr_data[num_nodes];
167
+ for (const auto & indices : indices_list) {
168
+ results.push_back (torch::empty (
169
+ *output_size,
170
+ nodes.options ().dtype (indices.scalar_type ())));
171
+ }
172
+ if (with_edge_ids) {
173
+ edge_ids = torch::empty (
174
+ *output_size, nodes.options ().dtype (indptr.scalar_type ()));
175
+ }
176
+ torch::parallel_for (
177
+ 0 , num_nodes, kDefaultGrainSize ,
178
+ [&](int64_t begin, int64_t end) {
179
+ for (int64_t i = begin; i < end; i++) {
180
+ const auto output_offset = out_indptr_data[i];
181
+ const auto numel = out_indptr_data[i + 1 ] - output_offset;
182
+ const auto input_offset = indptr_data[nodes_data[i]];
183
+ for (size_t tensor_id = 0 ;
184
+ tensor_id < indices_list.size (); tensor_id++) {
185
+ auto output = reinterpret_cast <std::byte*>(
186
+ results[tensor_id].data_ptr ());
187
+ const auto input = reinterpret_cast <std::byte*>(
188
+ indices_list[tensor_id].data_ptr ());
189
+ const auto element_size =
190
+ indices_list[tensor_id].element_size ();
191
+ std::memcpy (
192
+ output + output_offset * element_size,
193
+ input + input_offset * element_size,
194
+ element_size * numel);
195
+ }
196
+ if (edge_ids.has_value ()) {
197
+ auto output = edge_ids->data_ptr <indptr_t >();
198
+ std::iota (
199
+ output + output_offset,
200
+ output + output_offset + numel, input_offset);
201
+ }
202
+ }
203
+ });
204
+ }));
205
+ }));
206
+ if (edge_ids) results.push_back (*edge_ids);
143
207
return std::make_tuple (output_indptr, results);
144
208
}
145
209
0 commit comments