Skip to content

Commit 02a1115

Browse files
committed
[ET-VK] Add coop shader for int8 linear
Pull Request resolved: #10304 Title says it all! ## Changes Add some utility functions to `ComputeGraph` to get the device name and look for substrings within the device name. Apply co-operative shader for vector * matrix computations, except for Adreno 702 for which it performs worse as determined by experimentation. ghstack-source-id: 279884141 Differential Revision: [D73279548](https://our.internmc.facebook.com/intern/diff/D73279548/)
1 parent a17f098 commit 02a1115

File tree

7 files changed

+196
-2
lines changed

7 files changed

+196
-2
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
179179
return utils::kChannelsPacked;
180180
}
181181

182+
bool ComputeGraph::device_name_contains(const char* substr) {
183+
return context_->adapter_ptr()->device_name().find(substr) !=
184+
std::string::npos;
185+
}
186+
182187
void ComputeGraph::check_no_active_value_ptrs() {
183188
VK_CHECK_COND(
184189
values_in_use_ == 0,

backends/vulkan/runtime/graph/ComputeGraph.h

+9
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,15 @@ class ComputeGraph final {
443443
utils::GPUMemoryLayout suggested_memory_layout(
444444
const std::vector<int64_t>& sizes);
445445

446+
inline bool device_is_adreno() {
447+
return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO;
448+
}
449+
const std::string& device_name() {
450+
return context()->adapter_ptr()->device_name();
451+
}
452+
453+
bool device_name_contains(const char* substr);
454+
446455
//
447456
// Graph Building
448457
//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
#define NGROUPS 8
19+
#define NWORKERS 8
20+
21+
${define_required_extensions(DTYPE)}
22+
23+
$if WEIGHT_STORAGE == "buffer":
24+
${define_required_extensions("int8")}
25+
26+
#extension GL_EXT_control_flow_attributes : require
27+
28+
layout(std430) buffer;
29+
30+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
34+
35+
layout(push_constant) uniform restrict Block {
36+
ivec4 out_sizes;
37+
ivec4 in_sizes;
38+
ivec4 weight_sizes;
39+
};
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
44+
45+
void main() {
46+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
47+
const uint out_col = gl_GlobalInvocationID.x << 2;
48+
49+
const int gid = int(gl_LocalInvocationID.x); // group id
50+
const int wid = int(gl_LocalInvocationID.z); // worker id
51+
52+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
53+
return;
54+
}
55+
56+
VEC4_T a[TILE_ROWS];
57+
VEC4_T b[4];
58+
VEC4_T local_c[TILE_ROWS];
59+
60+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
61+
local_c[i] = VEC4_T(0.0);
62+
}
63+
64+
$if SCALES_STORAGE == "buffer":
65+
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
66+
$else:
67+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
68+
69+
for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
70+
// Preload t_weight
71+
[[unroll]] for (int i = 0; i < 4; i++) {
72+
$if WEIGHT_STORAGE == "buffer":
73+
b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2];
74+
$else:
75+
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
76+
}
77+
// Preload t_in
78+
for (int i = 0; i < TILE_ROWS; i++) {
79+
$if IN_STORAGE == "buffer":
80+
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
81+
$else:
82+
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
83+
}
84+
85+
// Accumulate partial output
86+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87+
local_c[i] += a[i].x * b[0] +
88+
a[i].y * b[1] +
89+
a[i].z * b[2] +
90+
a[i].w * b[3];
91+
}
92+
}
93+
94+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
95+
partial_c[gid][wid][i] = local_c[i];
96+
}
97+
98+
memoryBarrierShared();
99+
barrier();
100+
101+
if (wid != 0) {
102+
return;
103+
}
104+
105+
VEC4_T c[TILE_ROWS];
106+
107+
for (int row = 0; row < TILE_ROWS; ++row) {
108+
c[row] = VEC4_T(0.0);
109+
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
110+
c[row] += partial_c[gid][worker][row];
111+
}
112+
}
113+
114+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
115+
$if OUT_STORAGE == "buffer":
116+
if (out_row + i < out_sizes.y) {
117+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
118+
}
119+
$else:
120+
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
121+
}
122+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_coop:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
IN_STORAGE: texture3d
11+
OUT_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
SCALES_STORAGE: texture2d
14+
TILE_ROWS: 4
15+
generate_variant_forall:
16+
TILE_ROWS:
17+
- VALUE: 1
18+
SUFFIX: o4x1
19+
shader_variants:
20+
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float
21+
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float
22+
IN_STORAGE: buffer
23+
OUT_STORAGE: buffer
24+
- NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float
25+
IN_STORAGE: buffer
26+
OUT_STORAGE: buffer
27+
WEIGHT_STORAGE: buffer
28+
SCALES_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

+20-2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ void add_q_8w_linear_node(
142142

143143
void add_q_8w_linear_tiled_node(
144144
ComputeGraph& graph,
145+
const bool use_coop_algorithm,
145146
const ValueRef mat1,
146147
const ValueRef q_mat2_data,
147148
const ValueRef scales_data,
@@ -168,7 +169,8 @@ void add_q_8w_linear_tiled_node(
168169
ValueRef scales =
169170
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);
170171

171-
std::string kernel_name = "q_8w_linear_tiled";
172+
std::string kernel_name =
173+
use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled";
172174
kernel_name.reserve(kShaderNameReserve);
173175
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
174176
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
@@ -197,6 +199,9 @@ void add_q_8w_linear_tiled_node(
197199
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;
198200

199201
utils::uvec3 local_wg_size{64, 1, 1};
202+
if (use_coop_algorithm) {
203+
local_wg_size = {8, 1, 8};
204+
}
200205

201206
graph.execute_nodes().emplace_back(new DispatchNode(
202207
graph,
@@ -257,13 +262,26 @@ bool can_use_tiled_impl(
257262
return true;
258263
}
259264

265+
bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) {
266+
// Do not use coop algorithm for Adreno 702; manual experimentation shows that
267+
// it performs worse than the tiled algorithm.
268+
// TODO(ssjia): Determine a more robust heuristic to determine when the coop
269+
// algorithm should be used, instead of depending on specific device identity.
270+
if (graph.device_is_adreno() && graph.device_name_contains("702")) {
271+
return false;
272+
}
273+
// Check that the computation is vector * matrix
274+
return (graph.size_at<int>(-2, mat1) == 1);
275+
}
276+
260277
void weight_int8pack_mm(
261278
ComputeGraph& graph,
262279
const std::vector<ValueRef>& args) {
263280
check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
264281
if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) {
282+
bool use_coop_algorithm = can_use_coop_impl(graph, args[0]);
265283
return add_q_8w_linear_tiled_node(
266-
graph, args[0], args[1], args[2], args[3]);
284+
graph, use_coop_algorithm, args[0], args[1], args[2], args[3]);
267285
}
268286
return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
269287
}

backends/vulkan/runtime/vk_api/Adapter.h

+9
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ class Adapter final {
122122
return physical_device_.timestamp_period;
123123
}
124124

125+
// Device Identity
126+
inline const std::string& device_name() const {
127+
return physical_device_.device_name;
128+
}
129+
130+
inline vkapi::DeviceType device_type() const {
131+
return physical_device_.device_type;
132+
}
133+
125134
// Queue Management
126135

127136
Queue request_queue();

backends/vulkan/test/op_tests/cases.py

+3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def get_linear_inputs():
152152
@register_test_suite("aten._weight_int8pack_mm.default")
153153
def get_weight_int8pack_mm_inputs():
154154
MKN_list = [
155+
[1, 480, 256],
156+
[1, 1024, 1024],
157+
[1, 1024, 256],
155158
[3, 480, 256],
156159
[6, 480, 256],
157160
[6, 256, 1024],

0 commit comments

Comments
 (0)