Skip to content

Commit d8ef89c

Browse files
Keenutsbaldurk
andauthored
[SPIR-V] Add payload to OpEmitMeshTasksEXT (#7485)
This commit fixes the missing payload parameter for the OpEmitMeshTasksEXT instruction. Errors such as the passed variable storage class or type are already tested. Fixes #7082 Co-Authored-by: baldurk <[email protected]> Co-authored-by: baldurk <[email protected]>
1 parent dc59ed0 commit d8ef89c

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13021,15 +13021,16 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
1302113021
: spv::StorageClass::Output;
1302213022
auto *payloadArg = doExpr(args[3]);
1302313023
bool isValid = false;
13024-
const VarDecl *param = nullptr;
13024+
SpirvInstruction *param = nullptr;
1302513025
if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
1302613026
if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
1302713027
if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
1302813028
if (paramDecl->hasAttr<HLSLGroupSharedAttr>()) {
1302913029
isValid = declIdMapper.createPayloadStageVars(
1303013030
sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
1303113031
"out.var", &payloadArg);
13032-
param = paramDecl;
13032+
param =
13033+
declIdMapper.getDeclEvalInfo(paramDecl, paramDecl->getLocation());
1303313034
}
1303413035
}
1303513036
}
@@ -13046,7 +13047,7 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
1304613047

1304713048
if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
1304813049
// for EXT_mesh_shader, create opEmitMeshTasksEXT.
13049-
spvBuilder.createEmitMeshTasksEXT(threadX, threadY, threadZ, loc, nullptr,
13050+
spvBuilder.createEmitMeshTasksEXT(threadX, threadY, threadZ, loc, param,
1305013051
range);
1305113052
} else {
1305213053
// for NV_mesh_shader, set TaskCountNV = threadX * threadY * threadZ.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %dxc -E main -T as_6_8 -spirv %s -E main -fspv-target-env=vulkan1.1spirv1.4 | FileCheck %s
2+
3+
struct S {
4+
uint a;
5+
};
6+
7+
groupshared S s;
8+
// CHECK: %s = OpVariable {{.*}} TaskPayloadWorkgroupEXT
9+
10+
[numthreads(1, 1, 1)]
11+
void main()
12+
{
13+
// CHECK: OpEmitMeshTasksEXT %uint_1 %uint_1 %uint_1 %s
14+
DispatchMesh(1, 1, 1, s);
15+
}

0 commit comments

Comments
 (0)