Skip to content

Commit e06c9af

Browse files
Merge pull request #192 from Devsh-Graphics-Programming/new_wg_scan_test
Unit tests and benchmark for subgroup2 and workgroup2 stuff
2 parents 3a487ac + 0ba8eed commit e06c9af

24 files changed

+1835
-666
lines changed

11_FFT/app_resources/shader.comp.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ uint32_t3 glsl::gl_WorkGroupSize() { return uint32_t3(uint32_t(ConstevalParamete
1414

1515
struct SharedMemoryAccessor
1616
{
17-
template <typename IndexType, typename AccessType>
17+
template <typename AccessType, typename IndexType>
1818
void set(IndexType idx, AccessType value)
1919
{
2020
sharedmem[idx] = value;
2121
}
2222

23-
template <typename IndexType, typename AccessType>
23+
template <typename AccessType, typename IndexType>
2424
void get(IndexType idx, NBL_REF_ARG(AccessType) value)
2525
{
2626
value = sharedmem[idx];
@@ -44,14 +44,14 @@ struct Accessor
4444
}
4545

4646
// TODO: can't use our own BDA yet, because it doesn't support the types `workgroup::FFT` will invoke these templates with
47-
template <typename AccessType>
48-
void get(const uint32_t index, NBL_REF_ARG(AccessType) value)
47+
template <typename AccessType, typename IndexType>
48+
void get(const IndexType index, NBL_REF_ARG(AccessType) value)
4949
{
5050
value = vk::RawBufferLoad<AccessType>(address + index * sizeof(AccessType));
5151
}
5252

53-
template <typename AccessType>
54-
void set(const uint32_t index, const AccessType value)
53+
template <typename AccessType, typename IndexType>
54+
void set(const IndexType index, const AccessType value)
5555
{
5656
vk::RawBufferStore<AccessType>(address + index * sizeof(AccessType), value);
5757
}

23_ArithmeticUnitTest/app_resources/common.hlsl renamed to 23_Arithmetic2UnitTest/app_resources/common.hlsl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
22
#include "nbl/builtin/hlsl/functional.hlsl"
33

4-
template<uint32_t kScanElementCount=1024*1024>
5-
struct Output
4+
struct PushConstantData
65
{
7-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ScanElementCount = kScanElementCount;
8-
9-
uint32_t subgroupSize;
10-
uint32_t data[ScanElementCount];
6+
uint64_t pInputBuf;
7+
uint64_t pOutputBuf[8];
118
};
129

10+
namespace arithmetic
11+
{
1312
// Thanks to our unified HLSL/C++ STD lib we're able to remove a whole load of code
1413
template<typename T>
1514
struct bit_and : nbl::hlsl::bit_and<T>
@@ -92,5 +91,6 @@ struct ballot : nbl::hlsl::plus<T>
9291
static inline constexpr const char* name = "bitcount";
9392
#endif
9493
};
94+
}
9595

96-
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
96+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "common.hlsl"
2+
3+
using namespace nbl;
4+
using namespace hlsl;
5+
6+
[[vk::push_constant]] PushConstantData pc;
7+
8+
struct device_capabilities
9+
{
10+
#ifdef TEST_NATIVE
11+
NBL_CONSTEXPR_STATIC_INLINE bool shaderSubgroupArithmetic = true;
12+
#else
13+
NBL_CONSTEXPR_STATIC_INLINE bool shaderSubgroupArithmetic = false;
14+
#endif
15+
};
16+
17+
#ifndef OPERATION
18+
#error "Define OPERATION!"
19+
#endif
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma shader_stage(compute)
2+
3+
#define operation_t nbl::hlsl::OPERATION
4+
5+
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
6+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
7+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
8+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_params.hlsl"
9+
10+
#include "shaderCommon.hlsl"
11+
#include "nbl/builtin/hlsl/workgroup2/basic.hlsl"
12+
13+
template<class Binop, class device_capabilities>
14+
using params_t = SUBGROUP_CONFIG_T;
15+
16+
typedef vector<uint32_t, params_t<typename arithmetic::bit_and<uint32_t>::base_t, device_capabilities>::ItemsPerInvocation> type_t;
17+
18+
uint32_t globalIndex()
19+
{
20+
return glsl::gl_WorkGroupID().x*WORKGROUP_SIZE+workgroup::SubgroupContiguousIndex();
21+
}
22+
23+
template<class Binop>
24+
static void subtest(NBL_CONST_REF_ARG(type_t) sourceVal)
25+
{
26+
const uint64_t outputBufAddr = pc.pOutputBuf[Binop::BindingIndex];
27+
28+
assert(glsl::gl_SubgroupSize() == params_t<typename Binop::base_t, device_capabilities>::config_t::Size)
29+
30+
operation_t<params_t<typename Binop::base_t, device_capabilities> > func;
31+
type_t val = func(sourceVal);
32+
33+
vk::RawBufferStore<type_t>(outputBufAddr + sizeof(type_t) * globalIndex(), val, sizeof(uint32_t));
34+
}
35+
36+
type_t test()
37+
{
38+
const uint32_t idx = globalIndex();
39+
type_t sourceVal = vk::RawBufferLoad<type_t>(pc.pInputBuf + idx * sizeof(type_t));
40+
41+
subtest<arithmetic::bit_and<uint32_t> >(sourceVal);
42+
subtest<arithmetic::bit_xor<uint32_t> >(sourceVal);
43+
subtest<arithmetic::bit_or<uint32_t> >(sourceVal);
44+
subtest<arithmetic::plus<uint32_t> >(sourceVal);
45+
subtest<arithmetic::multiplies<uint32_t> >(sourceVal);
46+
subtest<arithmetic::minimum<uint32_t> >(sourceVal);
47+
subtest<arithmetic::maximum<uint32_t> >(sourceVal);
48+
return sourceVal;
49+
}
50+
51+
[numthreads(WORKGROUP_SIZE,1,1)]
52+
void main()
53+
{
54+
test();
55+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma shader_stage(compute)
2+
3+
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
4+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
5+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
6+
#include "nbl/builtin/hlsl/workgroup2/arithmetic.hlsl"
7+
8+
using config_t = WORKGROUP_CONFIG_T;
9+
10+
#include "shaderCommon.hlsl"
11+
12+
typedef vector<uint32_t, config_t::ItemsPerInvocation_0> type_t;
13+
14+
// final (level 1/2) scan needs to fit in one subgroup exactly
15+
groupshared uint32_t scratch[mpl::max_v<int16_t,config_t::SharedScratchElementCount,1>];
16+
17+
#include "../../common/include/WorkgroupDataAccessors.hlsl"
18+
19+
static ScratchProxy arithmeticAccessor;
20+
21+
template<class Binop, class device_capabilities>
22+
struct operation_t
23+
{
24+
using binop_base_t = typename Binop::base_t;
25+
using otype_t = typename Binop::type_t;
26+
27+
// workgroup reduction returns the value of the reduction
28+
// workgroup scans do no return anything, but use the data accessor to do the storing directly
29+
void operator()()
30+
{
31+
using data_proxy_t = PreloadedDataProxy<config_t::WorkgroupSizeLog2,config_t::VirtualWorkgroupSize,config_t::ItemsPerInvocation_0>;
32+
data_proxy_t dataAccessor = data_proxy_t::create(pc.pInputBuf, pc.pOutputBuf[Binop::BindingIndex]);
33+
dataAccessor.preload();
34+
#if IS_REDUCTION
35+
otype_t value =
36+
#endif
37+
OPERATION<config_t,binop_base_t,device_capabilities>::template __call<data_proxy_t, ScratchProxy>(dataAccessor,arithmeticAccessor);
38+
// we barrier before because we alias the accessors for Binop
39+
arithmeticAccessor.workgroupExecutionAndMemoryBarrier();
40+
#if IS_REDUCTION
41+
[unroll]
42+
for (uint32_t i = 0; i < data_proxy_t::PreloadedDataCount; i++)
43+
dataAccessor.preloaded[i] = value;
44+
#endif
45+
dataAccessor.unload();
46+
}
47+
};
48+
49+
50+
template<class Binop>
51+
static void subtest()
52+
{
53+
assert(glsl::gl_SubgroupSize() == config_t::SubgroupSize)
54+
55+
operation_t<Binop,device_capabilities> func;
56+
func();
57+
}
58+
59+
void test()
60+
{
61+
subtest<arithmetic::bit_and<uint32_t> >();
62+
subtest<arithmetic::bit_xor<uint32_t> >();
63+
subtest<arithmetic::bit_or<uint32_t> >();
64+
subtest<arithmetic::plus<uint32_t> >();
65+
subtest<arithmetic::multiplies<uint32_t> >();
66+
subtest<arithmetic::minimum<uint32_t> >();
67+
subtest<arithmetic::maximum<uint32_t> >();
68+
}
69+
70+
[numthreads(config_t::WorkgroupSize,1,1)]
71+
void main()
72+
{
73+
test();
74+
}

0 commit comments

Comments
 (0)