@@ -9271,6 +9271,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
92719271 case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
92729272 retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
92739273 break;
9274+ case hlsl::IntrinsicOp::IOP_QuadAny:
9275+ case hlsl::IntrinsicOp::IOP_QuadAll:
9276+ retVal = processWaveQuadAnyAll(callExpr, hlslOpcode);
9277+ break;
92749278 case hlsl::IntrinsicOp::IOP_abort:
92759279 case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
92769280 case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
@@ -10233,6 +10237,53 @@ SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
1023310237 opcode, retType, spv::Scope::Subgroup, {value, target}, srcLoc);
1023410238}
1023510239
10240+ SpirvInstruction *SpirvEmitter::processWaveQuadAnyAll(const CallExpr *callExpr,
10241+ hlsl::IntrinsicOp op) {
10242+ // Signatures:
10243+ // bool QuadAny(bool localValue)
10244+ // bool QuadAll(bool localValue)
10245+ assert(callExpr->getNumArgs() == 1);
10246+ assert(op == hlsl::IntrinsicOp::IOP_QuadAny ||
10247+ op == hlsl::IntrinsicOp::IOP_QuadAll);
10248+ featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
10249+ callExpr->getExprLoc());
10250+
10251+ auto *predicate = doExpr(callExpr->getArg(0));
10252+ const auto srcLoc = callExpr->getExprLoc();
10253+
10254+ if (!featureManager.isExtensionEnabled(Extension::KHR_quad_control)) {
10255+ // We can't use QuadAny/QuadAll, so implement them using QuadSwap. We
10256+ // will read the value at each quad invocation, then combine them.
10257+
10258+ spv::Op reducer = op == hlsl::IntrinsicOp::IOP_QuadAny
10259+ ? spv::Op::OpLogicalOr
10260+ : spv::Op::OpLogicalAnd;
10261+
10262+ SpirvInstruction *result = predicate;
10263+
10264+ for (size_t i = 0; i < 3; i++) {
10265+ SpirvInstruction *invocationValue = spvBuilder.createGroupNonUniformOp(
10266+ spv::Op::OpGroupNonUniformQuadSwap, astContext.BoolTy,
10267+ spv::Scope::Subgroup,
10268+ {predicate, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
10269+ llvm::APInt(32, i))},
10270+ srcLoc);
10271+ result = spvBuilder.createBinaryOp(reducer, astContext.BoolTy, result,
10272+ invocationValue, srcLoc);
10273+ }
10274+
10275+ return result;
10276+ }
10277+
10278+ spv::Op opcode = op == hlsl::IntrinsicOp::IOP_QuadAny
10279+ ? spv::Op::OpGroupNonUniformQuadAnyKHR
10280+ : spv::Op::OpGroupNonUniformQuadAllKHR;
10281+
10282+ return spvBuilder.createGroupNonUniformOp(opcode, astContext.BoolTy,
10283+ llvm::Optional<spv::Scope>(),
10284+ {predicate}, srcLoc);
10285+ }
10286+
1023610287SpirvInstruction *
1023710288SpirvEmitter::processWaveActiveAllEqual(const CallExpr *callExpr) {
1023810289 assert(callExpr->getNumArgs() == 1);
0 commit comments