@@ -1310,18 +1310,20 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
1310
1310
LOG (FATAL) << " The graph have restricted sparse format " <<
1311
1311
CodeToStr (formats_) << " , cannot create CSC matrix." ;
1312
1312
CSRPtr ret = in_csr_;
1313
+ // Prefers converting from COO since it is parallelized.
1314
+ // TODO(BarclayII): need benchmarking.
1313
1315
if (!in_csr_->defined ()) {
1314
- if (out_csr_->defined ()) {
1315
- const auto & newadj = aten::CSRTranspose (out_csr_->adj ());
1316
+ if (coo_->defined ()) {
1317
+ const auto & newadj = aten::COOToCSR (
1318
+ aten::COOTranspose (coo_->adj ()));
1316
1319
1317
1320
if (inplace)
1318
1321
*(const_cast <UnitGraph*>(this )->in_csr_ ) = CSR (meta_graph (), newadj);
1319
1322
else
1320
1323
ret = std::make_shared<CSR>(meta_graph (), newadj);
1321
1324
} else {
1322
- CHECK (coo_->defined ()) << " None of CSR, COO exist" ;
1323
- const auto & newadj = aten::COOToCSR (
1324
- aten::COOTranspose (coo_->adj ()));
1325
+ CHECK (out_csr_->defined ()) << " None of CSR, COO exist" ;
1326
+ const auto & newadj = aten::CSRTranspose (out_csr_->adj ());
1325
1327
1326
1328
if (inplace)
1327
1329
*(const_cast <UnitGraph*>(this )->in_csr_ ) = CSR (meta_graph (), newadj);
@@ -1339,17 +1341,19 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
1339
1341
LOG (FATAL) << " The graph have restricted sparse format " <<
1340
1342
CodeToStr (formats_) << " , cannot create CSR matrix." ;
1341
1343
CSRPtr ret = out_csr_;
1344
+ // Prefers converting from COO since it is parallelized.
1345
+ // TODO(BarclayII): need benchmarking.
1342
1346
if (!out_csr_->defined ()) {
1343
- if (in_csr_ ->defined ()) {
1344
- const auto & newadj = aten::CSRTranspose (in_csr_ ->adj ());
1347
+ if (coo_ ->defined ()) {
1348
+ const auto & newadj = aten::COOToCSR (coo_ ->adj ());
1345
1349
1346
1350
if (inplace)
1347
1351
*(const_cast <UnitGraph*>(this )->out_csr_ ) = CSR (meta_graph (), newadj);
1348
1352
else
1349
1353
ret = std::make_shared<CSR>(meta_graph (), newadj);
1350
1354
} else {
1351
- CHECK (coo_ ->defined ()) << " None of CSR, COO exist" ;
1352
- const auto & newadj = aten::COOToCSR (coo_ ->adj ());
1355
+ CHECK (in_csr_ ->defined ()) << " None of CSR, COO exist" ;
1356
+ const auto & newadj = aten::CSRTranspose (in_csr_ ->adj ());
1353
1357
1354
1358
if (inplace)
1355
1359
*(const_cast <UnitGraph*>(this )->out_csr_ ) = CSR (meta_graph (), newadj);
0 commit comments