Skip to content

Commit 3e910b5

Browse files
committed
similar config string thing but for subgroups
1 parent 1c56eb0 commit 3e910b5

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

23_Arithmetic2UnitTest/app_resources/testSubgroup.comp.hlsl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,29 @@
55
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
66
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
77
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
8+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_params.hlsl"
89

910
#include "shaderCommon.hlsl"
1011
#include "nbl/builtin/hlsl/workgroup2/basic.hlsl"
1112

12-
typedef vector<uint32_t, ITEMS_PER_INVOCATION> type_t;
13+
template<class Binop, class device_capabilities>
14+
using params_t = SUBGROUP_CONFIG_T;
15+
16+
typedef vector<uint32_t, params_t<typename arithmetic::bit_and<uint32_t>::base_t, device_capabilities>::ItemsPerInvocation> type_t;
1317

1418
uint32_t globalIndex()
1519
{
1620
return glsl::gl_WorkGroupID().x*WORKGROUP_SIZE+workgroup::SubgroupContiguousIndex();
1721
}
1822

19-
template<class Binop, uint32_t N>
23+
template<class Binop>
2024
static void subtest(NBL_CONST_REF_ARG(type_t) sourceVal)
2125
{
22-
using config_t = subgroup2::Configuration<SUBGROUP_SIZE_LOG2>;
23-
using params_t = subgroup2::ArithmeticParams<config_t, typename Binop::base_t, N, device_capabilities>;
24-
2526
const uint64_t outputBufAddr = pc.pOutputBuf[Binop::BindingIndex];
2627

27-
assert(glsl::gl_SubgroupSize() == 1u<<SUBGROUP_SIZE_LOG2)
28+
assert(glsl::gl_SubgroupSize() == params_t<typename Binop::base_t, device_capabilities>::config_t::Size)
2829

29-
operation_t<params_t> func;
30+
operation_t<params_t<typename Binop::base_t, device_capabilities> > func;
3031
type_t val = func(sourceVal);
3132

3233
vk::RawBufferStore<type_t>(outputBufAddr + sizeof(type_t) * globalIndex(), val, sizeof(uint32_t));
@@ -37,13 +38,13 @@ type_t test()
3738
const uint32_t idx = globalIndex();
3839
type_t sourceVal = vk::RawBufferLoad<type_t>(pc.pInputBuf + idx * sizeof(type_t));
3940

40-
subtest<arithmetic::bit_and<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
41-
subtest<arithmetic::bit_xor<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
42-
subtest<arithmetic::bit_or<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
43-
subtest<arithmetic::plus<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
44-
subtest<arithmetic::multiplies<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
45-
subtest<arithmetic::minimum<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
46-
subtest<arithmetic::maximum<uint32_t>, ITEMS_PER_INVOCATION>(sourceVal);
41+
subtest<arithmetic::bit_and<uint32_t> >(sourceVal);
42+
subtest<arithmetic::bit_xor<uint32_t> >(sourceVal);
43+
subtest<arithmetic::bit_or<uint32_t> >(sourceVal);
44+
subtest<arithmetic::plus<uint32_t> >(sourceVal);
45+
subtest<arithmetic::multiplies<uint32_t> >(sourceVal);
46+
subtest<arithmetic::minimum<uint32_t> >(sourceVal);
47+
subtest<arithmetic::maximum<uint32_t> >(sourceVal);
4748
return sourceVal;
4849
}
4950

23_Arithmetic2UnitTest/main.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "nbl/application_templates/MonoAssetManagerAndBuiltinResourceApplication.hpp"
33
#include "app_resources/common.hlsl"
44
#include "nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl"
5+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_params.hlsl"
56

67
using namespace nbl;
78
using namespace core;
@@ -186,7 +187,7 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
186187
for (auto subgroupSize = MinSubgroupSize; subgroupSize <= MaxSubgroupSize; subgroupSize *= 2u)
187188
{
188189
const uint8_t subgroupSizeLog2 = hlsl::findMSB(subgroupSize);
189-
for (uint32_t workgroupSize = 64; workgroupSize <= MaxWorkgroupSize; workgroupSize *= 2u)
190+
for (uint32_t workgroupSize = subgroupSize; workgroupSize <= MaxWorkgroupSize; workgroupSize *= 2u)
190191
{
191192
// make sure renderdoc captures everything for debugging
192193
m_api->startCapture();
@@ -198,12 +199,12 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
198199
uint32_t itemsPerWG = workgroupSize * itemsPerInvocation;
199200
m_logger->log("Testing Items per Invocation %u", ILogger::ELL_INFO, itemsPerInvocation);
200201
bool passed = true;
201-
//passed = runTest<emulatedReduction, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
202-
//logTestOutcome(passed, itemsPerWG);
203-
//passed = runTest<emulatedScanInclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
204-
//logTestOutcome(passed, itemsPerWG);
205-
//passed = runTest<emulatedScanExclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
206-
//logTestOutcome(passed, itemsPerWG);
202+
passed = runTest<emulatedReduction, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
203+
logTestOutcome(passed, itemsPerWG);
204+
passed = runTest<emulatedScanInclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
205+
logTestOutcome(passed, itemsPerWG);
206+
passed = runTest<emulatedScanExclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
207+
logTestOutcome(passed, itemsPerWG);
207208

208209
hlsl::workgroup2::SArithmeticConfiguration wgConfig;
209210
wgConfig.init(hlsl::findMSB(workgroupSize), subgroupSizeLog2, itemsPerInvocation);
@@ -331,24 +332,25 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
331332
}
332333
else
333334
{
334-
const std::string definitions[4] = {
335+
hlsl::subgroup2::SArithmeticParams sgParams;
336+
sgParams.init(subgroupSizeLog2, itemsPerInvoc);
337+
338+
const std::string definitions[3] = {
335339
"subgroup2::" + arith_name,
336340
std::to_string(workgroupSize),
337-
std::to_string(itemsPerInvoc),
338-
std::to_string(subgroupSizeLog2)
341+
sgParams.getParamTemplateStructString()
339342
};
340343

341-
const IShaderCompiler::SMacroDefinition defines[5] = {
344+
const IShaderCompiler::SMacroDefinition defines[4] = {
342345
{ "OPERATION", definitions[0] },
343346
{ "WORKGROUP_SIZE", definitions[1] },
344-
{ "ITEMS_PER_INVOCATION", definitions[2] },
345-
{ "SUBGROUP_SIZE_LOG2", definitions[3] },
347+
{ "SUBGROUP_CONFIG_T", definitions[2] },
346348
{ "TEST_NATIVE", "1" }
347349
};
348350
if (useNative)
349-
options.preprocessorOptions.extraDefines = { defines, defines + 5 };
350-
else
351351
options.preprocessorOptions.extraDefines = { defines, defines + 4 };
352+
else
353+
options.preprocessorOptions.extraDefines = { defines, defines + 3 };
352354

353355
overriddenUnspecialized = compiler->compileToSPIRV((const char*)source->getContent()->getPointer(), options);
354356
}

0 commit comments

Comments
 (0)