|
22 | 22 | namespace nvfuser {
|
23 | 23 |
|
24 | 24 | namespace {
|
25 |
| -// True if a given domain is a loop domain of a given tensor and its |
26 |
| -// loop is partitioned with respect to the memory type of the tensor |
27 |
| -bool isPartitionedLoop(const TensorView* tv, IterDomain* id) { |
28 |
| - // False if id is not a loop ID |
29 |
| - if (std::find(tv->getLoopDomain().begin(), tv->getLoopDomain().end(), id) == |
30 |
| - tv->getLoopDomain().end()) { |
31 |
| - return false; |
32 |
| - } |
33 |
| - |
34 |
| - // If the memory of this domain is partitioned with respect to the |
35 |
| - // parallel type of the domain, there's no allocation for the domain |
36 |
| - return ir_utils::isMemoryPartitionedAcross( |
37 |
| - tv->getMemoryType(), id->getParallelType()); |
38 |
| -} |
39 |
| - |
40 |
| -bool isSizeOneDomain(IterDomain* id) { |
41 |
| - return id->isBroadcast() || id->extent()->isOneInt(); |
42 |
| -} |
43 |
| - |
44 |
| -// True if a given domain of a tensor *may* require allocation |
45 |
| -bool mayRequireAllocation(const TensorView* tv, IterDomain* id) { |
46 |
| - // Conditions to consider: |
47 |
| - // - Fully partitioned |
48 |
| - // - Size one: Allocation is done based on the promotion ID, but as |
49 |
| - // long as the original ID has size one, its allocation should |
50 |
| - // remain size one. |
51 |
| - // - Reduction: Check the original ID, not the promotion, which may |
52 |
| - // be a reduction ID even though the original ID is not a reduction |
53 |
| - return !isPartitionedLoop(tv, id) && !isSizeOneDomain(id) && |
54 |
| - !id->isReduction() && !id->isStride(); |
55 |
| -} |
56 | 25 |
|
57 | 26 | // Get the allocation stride of a given allocation domain
|
58 | 27 | Val* getStrideOfGlobalMemoryTensor(TensorView* tv, int64_t alloc_dim) {
|
@@ -386,7 +355,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
|
386 | 355 | std::vector<IterDomain*> actual_allocation_ids;
|
387 | 356 | std::vector<std::optional<bool>> actual_contiguity;
|
388 | 357 | for (auto [i, id] : enumerate(allocation_domains)) {
|
389 |
| - if (mayRequireAllocation(tv, id)) { |
| 358 | + if (ir_utils::mayRequireAllocation(tv, id)) { |
390 | 359 | actual_allocation_ids.push_back(id);
|
391 | 360 | actual_contiguity.push_back(contiguity.at(i));
|
392 | 361 | }
|
@@ -464,7 +433,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
|
464 | 433 | auto allocation_domain = allocation_domains.at(dim);
|
465 | 434 | auto promotion_domain = promoted_allocation_domains.at(dim);
|
466 | 435 |
|
467 |
| - if (!mayRequireAllocation(tv, allocation_domain)) { |
| 436 | + if (!ir_utils::mayRequireAllocation(tv, allocation_domain)) { |
468 | 437 | continue;
|
469 | 438 | }
|
470 | 439 |
|
@@ -494,7 +463,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
|
494 | 463 | for (const auto i : arange(allocation_domains.size())) {
|
495 | 464 | auto allocation_domain = allocation_domains.at(i);
|
496 | 465 | auto promotion_domain = promoted_allocation_domains.at(i);
|
497 |
| - if (!mayRequireAllocation(tv, allocation_domain)) { |
| 466 | + if (!ir_utils::mayRequireAllocation(tv, allocation_domain)) { |
498 | 467 | continue;
|
499 | 468 | }
|
500 | 469 | auto stride = strides.at(i);
|
@@ -760,7 +729,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
|
760 | 729 | for (auto out : expr->outputs()) {
|
761 | 730 | auto it = equiv_domain_set.find(out->as<IterDomain>());
|
762 | 731 | if (it == equiv_domain_set.end() &&
|
763 |
| - mayRequireAllocation(tv, out->as<IterDomain>())) { |
| 732 | + ir_utils::mayRequireAllocation(tv, out->as<IterDomain>())) { |
764 | 733 | // missing dependency
|
765 | 734 | return std::nullopt;
|
766 | 735 | }
|
@@ -1277,7 +1246,25 @@ class AllocationInserter : public kir::ExprMutator {
|
1277 | 1246 |
|
1278 | 1247 | auto out_tv = out->as<TensorView>();
|
1279 | 1248 | auto default_val =
|
1280 |
| - gpu_lower_->predicateElimination().getInitValue(out_tv); |
| 1249 | + FusionInfoGuard::current()->tensorDefaultVal().get(out_tv); |
| 1250 | + |
| 1251 | + // Check if out_tv must also be initialized for predicate |
| 1252 | + // elimination. If so, the two initialization values must match |
| 1253 | + if (auto init_for_pred_elimination = |
| 1254 | + gpu_lower_->predicateElimination().getInitValue(out_tv)) { |
| 1255 | + if (default_val != nullptr) { |
| 1256 | + NVF_ERROR( |
| 1257 | + default_val->sameAs(init_for_pred_elimination), |
| 1258 | + "Conflicting default val for ", |
| 1259 | + out_tv->toString(), |
| 1260 | + ". ", |
| 1261 | + default_val->toString(), |
| 1262 | + " vs ", |
| 1263 | + init_for_pred_elimination->toString()); |
| 1264 | + } else { |
| 1265 | + default_val = init_for_pred_elimination; |
| 1266 | + } |
| 1267 | + } |
1281 | 1268 |
|
1282 | 1269 | Val* init = nullptr;
|
1283 | 1270 | if (out_tv->dtype() == DataType::Float4_e2m1fn) {
|
|
0 commit comments