Skip to content

Commit 05c53ca

Browse files
authored
[Performance] Prefer parallelized conversion to CSC from COO instead of transposing CSR (dmlc#2793)
* fix coo2csr speed * add comments
1 parent 86229d4 commit 05c53ca

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

src/graph/unit_graph.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,18 +1310,20 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
13101310
LOG(FATAL) << "The graph have restricted sparse format " <<
13111311
CodeToStr(formats_) << ", cannot create CSC matrix.";
13121312
CSRPtr ret = in_csr_;
1313+
// Prefers converting from COO since it is parallelized.
1314+
// TODO(BarclayII): need benchmarking.
13131315
if (!in_csr_->defined()) {
1314-
if (out_csr_->defined()) {
1315-
const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1316+
if (coo_->defined()) {
1317+
const auto& newadj = aten::COOToCSR(
1318+
aten::COOTranspose(coo_->adj()));
13161319

13171320
if (inplace)
13181321
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
13191322
else
13201323
ret = std::make_shared<CSR>(meta_graph(), newadj);
13211324
} else {
1322-
CHECK(coo_->defined()) << "None of CSR, COO exist";
1323-
const auto& newadj = aten::COOToCSR(
1324-
aten::COOTranspose(coo_->adj()));
1325+
CHECK(out_csr_->defined()) << "None of CSR, COO exist";
1326+
const auto& newadj = aten::CSRTranspose(out_csr_->adj());
13251327

13261328
if (inplace)
13271329
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
@@ -1339,17 +1341,19 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
13391341
LOG(FATAL) << "The graph have restricted sparse format " <<
13401342
CodeToStr(formats_) << ", cannot create CSR matrix.";
13411343
CSRPtr ret = out_csr_;
1344+
// Prefers converting from COO since it is parallelized.
1345+
// TODO(BarclayII): need benchmarking.
13421346
if (!out_csr_->defined()) {
1343-
if (in_csr_->defined()) {
1344-
const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1347+
if (coo_->defined()) {
1348+
const auto& newadj = aten::COOToCSR(coo_->adj());
13451349

13461350
if (inplace)
13471351
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
13481352
else
13491353
ret = std::make_shared<CSR>(meta_graph(), newadj);
13501354
} else {
1351-
CHECK(coo_->defined()) << "None of CSR, COO exist";
1352-
const auto& newadj = aten::COOToCSR(coo_->adj());
1355+
CHECK(in_csr_->defined()) << "None of CSR, COO exist";
1356+
const auto& newadj = aten::CSRTranspose(in_csr_->adj());
13531357

13541358
if (inplace)
13551359
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);

0 commit comments

Comments
 (0)