Skip to content

Commit a91a413

Browse files
authored
vulkan: optimize coopmat2 dequant functions (#10855)
Change the code to do 16b loads when possible and extract the appropriate component late, so the code is effectively decoding a pair of elements and then selecting one. This can allow more commoning to happen in the compiler when neighboring elements are loaded.
1 parent e34c5af commit a91a413

File tree

1 file changed

+45
-25
lines changed

1 file changed

+45
-25
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

+45-25
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
1010
const float16_t d = bl.block.d;
1111
const uint idx = coordInBlock[1];
1212
const uint shift = (idx & 0x10) >> 2;
13-
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
13+
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
1414
qs >>= shift;
15-
qs &= 0xF;
15+
qs &= 0x0F0F;
16+
qs = unpack8(qs)[idx & 1];
1617
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
1718
return ret;
1819
}
@@ -152,15 +153,17 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
152153
block_q4_K block;
153154
};
154155

156+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
157+
block_q4_K_packed16 block;
158+
};
159+
155160
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
156161
{
162+
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
157163
const uint idx = coordInBlock[1];
158-
const uint iqs = idx;
159164

160-
const uint n = iqs / 64; // 0,1,2,3
161-
const uint b = (iqs % 64) / 32; // 0,1
165+
const uint b = (idx & 0x20) >> 5; // 0,1
162166
const uint is = (idx & 0xE0) >> 5; // 0..7
163-
const uint qsi = n * 32 + (iqs % 32); // 0..127
164167

165168
const f16vec2 loadd = bl.block.d;
166169

@@ -184,9 +187,11 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
184187
const float16_t d = loadd.x * float16_t(sc);
185188
const float16_t m = loadd.y * float16_t(mbyte);
186189

187-
uint32_t dmask = 0xF << (b * 4);
190+
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
191+
qs = (qs >> (b * 4)) & 0x0F0F;
192+
qs = unpack8(qs)[idx & 1];
188193

189-
float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
194+
float16_t ret = d * float16_t(qs) - m;
190195

191196
return ret;
192197
}
@@ -195,18 +200,19 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
195200
block_q5_K block;
196201
};
197202

203+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
204+
block_q5_K_packed16 block;
205+
};
206+
198207
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
199208
{
209+
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
200210
const uint idx = coordInBlock[1];
201-
const uint iqs = idx;
202211

203-
const uint n = iqs / 64; // 0,1,2,3
204-
const uint b = (iqs % 64) / 32; // 0,1
212+
const uint b = (idx & 0x20) >> 5; // 0,1
205213
const uint is = (idx & 0xE0) >> 5; // 0..7
206-
const uint qsi = n * 32 + (iqs % 32); // 0..127
207-
const uint qhi = (iqs % 32); // 0..31
208214

209-
const uint8_t hm = uint8_t(1 << (iqs / 32));
215+
const uint32_t hm = 0x0101 << is;
210216

211217
const f16vec2 loadd = bl.block.d;
212218

@@ -230,9 +236,15 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
230236
const float16_t d = loadd.x * float16_t(sc);
231237
const float16_t m = loadd.y * float16_t(mbyte);
232238

233-
uint32_t dmask = 0xF << (b * 4);
239+
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
240+
qh = qh & hm;
241+
qh = unpack8(qh)[idx & 1];
234242

235-
float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
243+
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
244+
qs = (qs >> (b * 4)) & 0x0F0F;
245+
qs = unpack8(qs)[idx & 1];
246+
247+
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
236248

237249
return ret;
238250
}
@@ -241,22 +253,30 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_
241253
block_q6_K block;
242254
};
243255

256+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
257+
block_q6_K_packed16 block;
258+
};
259+
244260
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
245261
{
262+
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
246263
const uint idx = coordInBlock[1];
247-
const uint iqs = idx;
248264

249-
const uint n = iqs / 128; // 0,1
250-
const uint b = (iqs % 128) / 64; // 0,1
251-
const uint is_b = (iqs % 32) / 16; // 0,1
252-
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
253-
const uint is = 8 * n + qhshift + is_b; // 0..15
254-
const uint qsi = n * 64 + (iqs % 64); // 0..127
255-
const uint qhi = n * 32 + (iqs % 32); // 0..63
265+
const uint b = (idx & 0x40) >> 6; // 0,1
266+
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
267+
const uint is = (idx & 0xF0) >> 4; // 0..15
256268

257269
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
258270

259-
float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
271+
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
272+
ql = (ql >> (b * 4)) & 0x0F0F;
273+
274+
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
275+
qh = ((qh >> qhshift) & 0x0303) << 4;
276+
277+
int q = unpack8(ql | qh)[idx & 1];
278+
279+
float16_t ret = dscale * float16_t(q - 32);
260280

261281
return ret;
262282
}

0 commit comments

Comments
 (0)