Skip to content

Commit 8b99b25

Browse files
committed
Add grouping to ArgsortOp
The CUB routine allows to use multiple items per thread. The Group parallel type is used to supply multiple elements. Scheduler updates will follow. Other similar ops will also be updated similarly.
1 parent b2253cc commit 8b99b25

20 files changed

+1173
-83
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ list(APPEND NVFUSER_SRCS
193193
${NVFUSER_SRCS_DIR}/debug.cpp
194194
${NVFUSER_SRCS_DIR}/device_lower/analysis/bank_conflict.cpp
195195
${NVFUSER_SRCS_DIR}/device_lower/analysis/circular_buffer.cpp
196+
${NVFUSER_SRCS_DIR}/device_lower/analysis/default_val.cpp
196197
${NVFUSER_SRCS_DIR}/device_lower/analysis/device_version.cpp
197198
${NVFUSER_SRCS_DIR}/device_lower/analysis/divisible_split.cpp
198199
${NVFUSER_SRCS_DIR}/device_lower/analysis/fused_reduction.cpp

csrc/codegen.cpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,14 +1426,21 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
14261426
return std::ranges::find(sorted_ids, id) != sorted_ids.end();
14271427
});
14281428

1429-
// At this moment, we only support argsort on thread parallelized
1430-
// dimensions. No serial dimension is allowed either.
14311429
ParallelTypeBitmap sorted_parallel_types;
1430+
IterDomain* grouped_id = nullptr;
14321431
for (auto id : sorted_loop_ids) {
1433-
NVF_ERROR(
1434-
isParallelTypeThreadDim(id->getParallelType()),
1435-
"Argsort on non-thread dimension is not supported");
1436-
sorted_parallel_types.set(id->getParallelType());
1432+
if (isParallelTypeThreadDim(id->getParallelType())) {
1433+
sorted_parallel_types.set(id->getParallelType());
1434+
} else if (id->getParallelType() == ParallelType::Group) {
1435+
NVF_ERROR(
1436+
grouped_id == nullptr,
1437+
"Multiple grouped IDs not supported: ",
1438+
aop->toString());
1439+
grouped_id = id;
1440+
} else {
1441+
NVF_THROW(
1442+
"Invalid parallel type: ", id->toString(), " of ", aop->toString());
1443+
}
14371444
}
14381445

14391446
// TID parallel types must only be used for the sorted IDs with the static
@@ -1469,8 +1476,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
14691476
}
14701477
}
14711478

1472-
// TODO: support ITEMS_PER_THREAD > 1
1473-
constexpr int items_per_thread = 1;
1479+
const int64_t items_per_thread = grouped_id != nullptr
1480+
? grouped_id->extent()->evaluate().as<int64_t>()
1481+
: 1;
14741482

14751483
const auto input = aop->in()->as<kir::TensorIndex>();
14761484

@@ -1479,33 +1487,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
14791487

14801488
// Call the runtime argsort function
14811489
ArgumentBuilder func_args;
1482-
1483-
// The output tensor is assumed to be a register tensor, and thus
1484-
// its storage should always be available without predication
1485-
NVF_ERROR_EQ(
1486-
output->view()->getMemoryType(),
1487-
MemoryType::Local,
1488-
"Argsort output must be a Local tensor: ",
1489-
output->toString());
1490-
func_args.arg("*(int64_t(*)[")
1491-
.append(items_per_thread)
1492-
.append("])")
1493-
.append("(&")
1494-
.append(genInline(output))
1495-
.append(")");
1496-
1497-
NVF_ERROR(aop->predicate() != nullptr && aop->predicate()->hasValue());
1498-
// {pred ? input : (isDescending ? min : max)}
1499-
func_args.arg("{")
1500-
.append(genInline(aop->predicate()))
1501-
.append(" ? ")
1502-
.append(genInline(input))
1503-
.append(" : ")
1504-
.append(
1505-
aop->isDescending() ? getMinimumValue(input->dtype())
1506-
: getMaximumValue(input->dtype()))
1507-
.append("}");
1508-
1490+
func_args.arg(genVariableNameConvertAlignedArray(output));
1491+
func_args.arg(genVariableNameConvertAlignedArray(input));
15091492
func_args.arg(aop->isDescending() ? "true" : "false"); // descending flag
15101493
func_args.arg(genComputeBlockDim());
15111494

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
9+
#include <device_lower/analysis/default_val.h>
10+
#include <fusion.h>
11+
#include <ir/internal_nodes.h>
12+
#include <ir/utils.h>
13+
#include <ops/utils.h>
14+
15+
namespace nvfuser {
16+
17+
TensorDefaultVal::TensorDefaultVal(Fusion* fusion) {
18+
for (auto expr: fusion->exprs()) {
19+
dispatch(expr);
20+
}
21+
}
22+
23+
void TensorDefaultVal::handle(ArgsortOp* aop) {
24+
// It is already validated that the input is exclusively used by
25+
// this argsort op, so it's free to initialize it for this op
26+
auto inp_tv = ir_utils::getTvInput(aop);
27+
28+
Val* default_val = nullptr;
29+
if (aop->isDescending()) {
30+
default_val = ops::getMinimumValue(inp_tv->dtype());
31+
} else {
32+
default_val = ops::getMaximumValue(inp_tv->dtype());
33+
}
34+
35+
registerDefaultVal(inp_tv, default_val);
36+
}
37+
38+
void TensorDefaultVal::registerDefaultVal(TensorView* tv, Val* val) {
39+
auto inserted = default_val_map_.emplace(tv, val).second;
40+
if (!inserted) {
41+
NVF_ERROR(default_val_map_[tv]->sameAs(val),
42+
"Duplicate setting of default val for ", tv->toString(),
43+
". ", default_val_map_[tv]->toString(), " vs ",
44+
val->toString());
45+
}
46+
}
47+
48+
Val* TensorDefaultVal::get(TensorView* tv) const {
49+
auto it = default_val_map_.find(tv);
50+
if (it != default_val_map_.end()) {
51+
return it->second;
52+
} else {
53+
return nullptr;
54+
}
55+
}
56+
57+
} // namespace nvfuser
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
#pragma once
9+
10+
#include <dispatch.h>
11+
12+
#include <unordered_map>
13+
14+
namespace nvfuser {
15+
16+
class Fusion;
17+
class Val;
18+
class TensorView;
19+
20+
class TensorDefaultVal : public OptOutDispatch {
21+
public:
22+
TensorDefaultVal(Fusion* fusion);
23+
24+
Val* get(TensorView* tv) const;
25+
26+
private:
27+
void handle(ArgsortOp* aop) final;
28+
29+
void registerDefaultVal(TensorView* tv, Val* val);
30+
31+
private:
32+
std::unordered_map<TensorView*, Val*> default_val_map_;
33+
};
34+
35+
} // namespace nvfuser

csrc/device_lower/analysis/fusion_info.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include <compute_at_map.h>
11+
#include <device_lower/analysis/default_val.h>
1112
#include <device_lower/analysis/fused_reduction.h>
1213
#include <device_lower/analysis/padded_parallel_dimensions.h>
1314
#include <device_lower/analysis/thread_predicate.h>
@@ -114,6 +115,11 @@ class FusionInfo {
114115

115116
FUSION_INFO_DEFINE_FUNCTIONS(IdModel, id_model, idModel);
116117

118+
FUSION_INFO_DEFINE_FUNCTIONS(
119+
TensorDefaultVal,
120+
tensor_default_val,
121+
tensorDefaultVal);
122+
117123
private:
118124
FUSION_INFO_DEFINE_FIELD(
119125
ConcretizedBroadcastDomains,
@@ -132,6 +138,8 @@ class FusionInfo {
132138
FUSION_INFO_DEFINE_FIELD(ComputeAtMap, ca_map);
133139

134140
FUSION_INFO_DEFINE_FIELD(IdModel, id_model);
141+
142+
FUSION_INFO_DEFINE_FIELD(TensorDefaultVal, tensor_default_val);
135143
};
136144

137145
#undef FUSION_INFO_DEFINE_FUNCTIONS

csrc/device_lower/lower2device.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ void GpuLower::analysis(Fusion* fusion) {
456456
replaceSymbolicSizes(fusion_);
457457
dumpExprsIfEnabled(fusion_->exprs(), "replaceSymbolicSizes");
458458

459+
// Does not need to be placed here as it has no dependency to any other
460+
// analysis.
461+
info().set(std::make_unique<TensorDefaultVal>(fusion_));
462+
459463
// New IterDomains may be created, so it is expected that generated
460464
// code may use diffrent variable names
461465
if (idModelOptions().buildIdModel()) {

csrc/device_lower/pass/allocation.cpp

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,6 @@
2222
namespace nvfuser {
2323

2424
namespace {
25-
// True if a given domain is a loop domain of a given tensor and its
26-
// loop is partitioned with respect to the memory type of the tensor
27-
bool isPartitionedLoop(const TensorView* tv, IterDomain* id) {
28-
// False if id is not a loop ID
29-
if (std::find(tv->getLoopDomain().begin(), tv->getLoopDomain().end(), id) ==
30-
tv->getLoopDomain().end()) {
31-
return false;
32-
}
33-
34-
// If the memory of this domain is partitioned with respect to the
35-
// parallel type of the domain, there's no allocation for the domain
36-
return ir_utils::isMemoryPartitionedAcross(
37-
tv->getMemoryType(), id->getParallelType());
38-
}
39-
40-
bool isSizeOneDomain(IterDomain* id) {
41-
return id->isBroadcast() || id->extent()->isOneInt();
42-
}
43-
44-
// True if a given domain of a tensor *may* require allocation
45-
bool mayRequireAllocation(const TensorView* tv, IterDomain* id) {
46-
// Conditions to consider:
47-
// - Fully partitioned
48-
// - Size one: Allocation is done based on the promotion ID, but as
49-
// long as the original ID has size one, its allocation should
50-
// remain size one.
51-
// - Reduction: Check the original ID, not the promotion, which may
52-
// be a reduction ID even though the original ID is not a reduction
53-
return !isPartitionedLoop(tv, id) && !isSizeOneDomain(id) &&
54-
!id->isReduction() && !id->isStride();
55-
}
5625

5726
// Get the allocation stride of a given allocation domain
5827
Val* getStrideOfGlobalMemoryTensor(TensorView* tv, int64_t alloc_dim) {
@@ -386,7 +355,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
386355
std::vector<IterDomain*> actual_allocation_ids;
387356
std::vector<std::optional<bool>> actual_contiguity;
388357
for (auto [i, id] : enumerate(allocation_domains)) {
389-
if (mayRequireAllocation(tv, id)) {
358+
if (ir_utils::mayRequireAllocation(tv, id)) {
390359
actual_allocation_ids.push_back(id);
391360
actual_contiguity.push_back(contiguity.at(i));
392361
}
@@ -464,7 +433,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
464433
auto allocation_domain = allocation_domains.at(dim);
465434
auto promotion_domain = promoted_allocation_domains.at(dim);
466435

467-
if (!mayRequireAllocation(tv, allocation_domain)) {
436+
if (!ir_utils::mayRequireAllocation(tv, allocation_domain)) {
468437
continue;
469438
}
470439

@@ -494,7 +463,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
494463
for (const auto i : arange(allocation_domains.size())) {
495464
auto allocation_domain = allocation_domains.at(i);
496465
auto promotion_domain = promoted_allocation_domains.at(i);
497-
if (!mayRequireAllocation(tv, allocation_domain)) {
466+
if (!ir_utils::mayRequireAllocation(tv, allocation_domain)) {
498467
continue;
499468
}
500469
auto stride = strides.at(i);
@@ -760,7 +729,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
760729
for (auto out : expr->outputs()) {
761730
auto it = equiv_domain_set.find(out->as<IterDomain>());
762731
if (it == equiv_domain_set.end() &&
763-
mayRequireAllocation(tv, out->as<IterDomain>())) {
732+
ir_utils::mayRequireAllocation(tv, out->as<IterDomain>())) {
764733
// missing dependency
765734
return std::nullopt;
766735
}
@@ -1277,7 +1246,25 @@ class AllocationInserter : public kir::ExprMutator {
12771246

12781247
auto out_tv = out->as<TensorView>();
12791248
auto default_val =
1280-
gpu_lower_->predicateElimination().getInitValue(out_tv);
1249+
FusionInfoGuard::current()->tensorDefaultVal().get(out_tv);
1250+
1251+
// Check if out_tv must also be initialized for predicate
1252+
// elimination. If so, the two initialization values must match
1253+
if (auto init_for_pred_elimination =
1254+
gpu_lower_->predicateElimination().getInitValue(out_tv)) {
1255+
if (default_val != nullptr) {
1256+
NVF_ERROR(
1257+
default_val->sameAs(init_for_pred_elimination),
1258+
"Conflicting default val for ",
1259+
out_tv->toString(),
1260+
". ",
1261+
default_val->toString(),
1262+
" vs ",
1263+
init_for_pred_elimination->toString());
1264+
} else {
1265+
default_val = init_for_pred_elimination;
1266+
}
1267+
}
12811268

12821269
Val* init = nullptr;
12831270
if (out_tv->dtype() == DataType::Float4_e2m1fn) {

0 commit comments

Comments
 (0)