You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
vulkan: Implement grouped query attention in the coopmat2 FA shader
When adjacent batches of Q share the same batches of K/V, batch them into
the same workgroup. For example, when:
dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1))
previously we would run 32 workgroups computing 1 result each, now we will
run 8 workgroups computing 4 results each.
This doesn't directly translate to better performance (at least when you have
>=32 SMs), but in a subsequent change I'll enable split_k which will scale much
better with 4x fewer workgroups.
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
103
105
#define DECODEFUNC
104
106
#endif
105
107
108
+
// Store the output when doing grouped query attention.
109
+
// Rows index by Q's dimension 2, and the first N rows are valid.
110
+
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
111
+
{
112
+
if (r < N && c < D) {
113
+
uint32_t offset = (iq2 + r) * D + c;
114
+
data_o[o_offset + offset] = D_TYPE(elem);
115
+
}
116
+
return elem;
117
+
}
118
+
119
+
// Load the slope matrix, indexed by Q's dimension 2.
120
+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
0 commit comments