Skip to content

Commit ce29f58

Browse files
authored
[GraphBolt] Implement proper IndexSelectCSC for CPU. (#7670)
1 parent b5ee45f commit ce29f58

File tree

2 files changed

+90
-52
lines changed

2 files changed

+90
-52
lines changed

graphbolt/src/fused_csc_sampling_graph.cc

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <vector>
2020

2121
#include "./expand_indptr.h"
22+
#include "./index_select.h"
2223
#include "./macro.h"
2324
#include "./random.h"
2425
#include "./shared_memory_helper.h"
@@ -293,48 +294,21 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
293294
return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);
294295
});
295296
}
296-
using namespace torch::indexing;
297-
const int32_t kDefaultGrainSize = 100;
298-
const auto num_seeds = nodes.size(0);
299-
torch::Tensor indptr = torch::empty({num_seeds + 1}, indptr_.dtype());
300-
std::vector<torch::Tensor> indices_arr(num_seeds);
301-
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
302-
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
297+
std::vector<torch::Tensor> tensors{indices_};
298+
if (type_per_edge_.has_value()) {
299+
tensors.push_back(*type_per_edge_);
300+
}
303301

304-
AT_DISPATCH_INDEX_TYPES(
305-
indptr_.scalar_type(), "InSubgraph::indptr", ([&] {
306-
const auto indptr_data = indptr_.data_ptr<index_t>();
307-
auto out_indptr_data = indptr.data_ptr<index_t>();
308-
out_indptr_data[0] = 0;
309-
AT_DISPATCH_INDEX_TYPES(
310-
nodes.scalar_type(), "InSubgraph::nodes", ([&] {
311-
const auto nodes_data = nodes.data_ptr<index_t>();
312-
torch::parallel_for(
313-
0, num_seeds, kDefaultGrainSize,
314-
[&](size_t start, size_t end) {
315-
for (size_t i = start; i < end; ++i) {
316-
const auto node_id = nodes_data[i];
317-
const auto start_idx = indptr_data[node_id];
318-
const auto end_idx = indptr_data[node_id + 1];
319-
out_indptr_data[i + 1] = end_idx - start_idx;
320-
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
321-
edge_ids_arr[i] = torch::arange(
322-
start_idx, end_idx, indptr_.scalar_type());
323-
if (type_per_edge_) {
324-
type_per_edge_arr[i] =
325-
type_per_edge_.value().slice(0, start_idx, end_idx);
326-
}
327-
}
328-
});
329-
}));
330-
}));
302+
auto [output_indptr, results] =
303+
ops::IndexSelectCSCBatched(indptr_, tensors, nodes, true, torch::nullopt);
304+
torch::optional<torch::Tensor> type_per_edge;
305+
if (type_per_edge_.has_value()) {
306+
type_per_edge = results.at(1);
307+
}
331308

332309
return c10::make_intrusive<FusedSampledSubgraph>(
333-
indptr.cumsum(0), torch::cat(indices_arr), torch::cat(edge_ids_arr),
334-
nodes, torch::arange(0, NumNodes()),
335-
type_per_edge_
336-
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
337-
: torch::nullopt);
310+
output_indptr, results.at(0), results.back(), nodes,
311+
torch::arange(0, NumNodes()), type_per_edge);
338312
}
339313

340314
/**

graphbolt/src/index_select.cc

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#include <graphbolt/cuda_ops.h>
99
#include <graphbolt/fused_csc_sampling_graph.h>
1010

11+
#include <cstring>
12+
#include <numeric>
13+
1114
#include "./macro.h"
1215
#include "./utils.h"
1316

@@ -107,9 +110,9 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
107110
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
108111
{ return IndexSelectCSCImpl(indptr, indices, nodes, output_size); });
109112
}
110-
sampling::FusedCSCSamplingGraph g(indptr, indices);
111-
const auto res = g.InSubgraph(nodes);
112-
return std::make_tuple(res->indptr, res->indices.value());
113+
auto [output_indptr, results] = IndexSelectCSCBatched(
114+
indptr, std::vector{indices}, nodes, false, output_size);
115+
return std::make_tuple(output_indptr, results.at(0));
113116
}
114117

115118
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
@@ -129,17 +132,78 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
129132
indptr, indices_list, nodes, with_edge_ids, output_size);
130133
});
131134
}
135+
constexpr int kDefaultGrainSize = 128;
136+
const auto num_nodes = nodes.size(0);
137+
torch::Tensor output_indptr = torch::empty(
138+
{num_nodes + 1}, nodes.options().dtype(indptr.scalar_type()));
132139
std::vector<torch::Tensor> results;
133-
torch::Tensor output_indptr;
134-
torch::Tensor edge_ids;
135-
for (auto& indices : indices_list) {
136-
sampling::FusedCSCSamplingGraph g(indptr, indices);
137-
const auto res = g.InSubgraph(nodes);
138-
output_indptr = res->indptr;
139-
results.push_back(res->indices.value());
140-
edge_ids = res->original_edge_ids;
141-
}
142-
if (with_edge_ids) results.push_back(edge_ids);
140+
torch::optional<torch::Tensor> edge_ids;
141+
AT_DISPATCH_INDEX_TYPES(
142+
indptr.scalar_type(), "IndexSelectCSCBatched::indptr", ([&] {
143+
using indptr_t = index_t;
144+
const auto indptr_data = indptr.data_ptr<indptr_t>();
145+
auto out_indptr_data = output_indptr.data_ptr<indptr_t>();
146+
out_indptr_data[0] = 0;
147+
AT_DISPATCH_INDEX_TYPES(
148+
nodes.scalar_type(), "IndexSelectCSCBatched::nodes", ([&] {
149+
const auto nodes_data = nodes.data_ptr<index_t>();
150+
torch::parallel_for(
151+
0, num_nodes, kDefaultGrainSize,
152+
[&](int64_t begin, int64_t end) {
153+
for (int64_t i = begin; i < end; i++) {
154+
const auto node_id = nodes_data[i];
155+
const auto degree =
156+
indptr_data[node_id + 1] - indptr_data[node_id];
157+
out_indptr_data[i + 1] = degree;
158+
}
159+
});
160+
output_indptr = output_indptr.cumsum(0, indptr.scalar_type());
161+
out_indptr_data = output_indptr.data_ptr<indptr_t>();
162+
TORCH_CHECK(
163+
!output_size.has_value() ||
164+
out_indptr_data[num_nodes] == *output_size,
165+
"An incorrect output_size argument was provided.");
166+
output_size = out_indptr_data[num_nodes];
167+
for (const auto& indices : indices_list) {
168+
results.push_back(torch::empty(
169+
*output_size,
170+
nodes.options().dtype(indices.scalar_type())));
171+
}
172+
if (with_edge_ids) {
173+
edge_ids = torch::empty(
174+
*output_size, nodes.options().dtype(indptr.scalar_type()));
175+
}
176+
torch::parallel_for(
177+
0, num_nodes, kDefaultGrainSize,
178+
[&](int64_t begin, int64_t end) {
179+
for (int64_t i = begin; i < end; i++) {
180+
const auto output_offset = out_indptr_data[i];
181+
const auto numel = out_indptr_data[i + 1] - output_offset;
182+
const auto input_offset = indptr_data[nodes_data[i]];
183+
for (size_t tensor_id = 0;
184+
tensor_id < indices_list.size(); tensor_id++) {
185+
auto output = reinterpret_cast<std::byte*>(
186+
results[tensor_id].data_ptr());
187+
const auto input = reinterpret_cast<std::byte*>(
188+
indices_list[tensor_id].data_ptr());
189+
const auto element_size =
190+
indices_list[tensor_id].element_size();
191+
std::memcpy(
192+
output + output_offset * element_size,
193+
input + input_offset * element_size,
194+
element_size * numel);
195+
}
196+
if (edge_ids.has_value()) {
197+
auto output = edge_ids->data_ptr<indptr_t>();
198+
std::iota(
199+
output + output_offset,
200+
output + output_offset + numel, input_offset);
201+
}
202+
}
203+
});
204+
}));
205+
}));
206+
if (edge_ids) results.push_back(*edge_ids);
143207
return std::make_tuple(output_indptr, results);
144208
}
145209

0 commit comments

Comments
 (0)