Skip to content

Commit 8f52625

Browse files
author
--global
committed
comment
1 parent acf1128 commit 8f52625

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

csrc/fusion_segmenter.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,18 +1753,23 @@ void eraseInputDistinctRootDomains(Fusion* fusion) {
17531753

17541754
for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
17551755
// Create a new logical domain and replacement TensorDomain.
1756-
// Given an logical domain, create a new IterDomain.
1757-
// Otherwise, clone the previous IterDomain
17581756
std::vector<IterDomain*> new_logical_domain;
1757+
1758+
// Ignore reduction ids for new tensordomain.
17591759
auto logical = TensorDomain::noReductions(tv->getLogicalDomain());
17601760
new_logical_domain.reserve(logical.size());
17611761

17621762
// Does the logical domain contain all concrete sized extents?
1763-
bool tv_is_concrete =
1764-
std::all_of(logical.begin(), logical.end(), [](IterDomain* id) {
1765-
return id->extent()->isConstScalar();
1766-
});
1763+
bool tv_is_concrete = true;
1764+
for (auto id : logical) {
1765+
if (!id->extent()->isConstScalar()) {
1766+
tv_is_concrete = false;
1767+
break;
1768+
}
1769+
}
17671770

1771+
// Given an rfactor IterDomain, create a new IterDomain.
1772+
// Otherwise, clone the previous IterDomain
17681773
for (const auto& id : logical) {
17691774
if (id->isRFactorProduct()) {
17701775
// Create new symbolic extents for logical iterDomains
@@ -1784,13 +1789,15 @@ void eraseInputDistinctRootDomains(Fusion* fusion) {
17841789
TensorDomain* new_td = IrBuilder::create<TensorDomain>(new_logical_domain);
17851790
TransformReplay::selfReplay(tv->domain(), new_td, true);
17861791
if (!tv->domain()->hasAllocation()) {
1792+
// The default contiguity for new_td is false. `selfReplay` does not
1793+
// replay contiguity when no allocation domain is present.
17871794
const std::vector<std::optional<bool>> old_contiguity =
17881795
tv->domain()->contiguity();
17891796
std::vector<std::optional<bool>> no_red_contiguity;
17901797
no_red_contiguity.reserve(old_contiguity.size());
1791-
for (const auto& [alloc_id, contiguity] :
1798+
for (const auto& [id, contiguity] :
17921799
zip(tv->getLogicalDomain(), old_contiguity)) {
1793-
if (alloc_id->isReduction()) {
1800+
if (id->isReduction()) {
17941801
continue;
17951802
}
17961803
no_red_contiguity.push_back(contiguity);

0 commit comments

Comments
 (0)