Skip to content

Commit 333fbdb

Browse files
[WebGPU/JSEP] Support group query attention do_rotary attribute (#23524)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 773bb4f commit 333fbdb

File tree

2 files changed

+118
-9
lines changed

2 files changed

+118
-9
lines changed

js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33

44
import { TensorView } from '../../tensor-view';
55
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
6-
import { ComputeContext } from '../types';
6+
import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
7+
import { DataType } from '../../../wasm-common';
78

89
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
910
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
1011
import { createSplitProgramInfo, SplitAttributes } from './split';
1112
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
13+
import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding';
14+
import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
1215
export interface GroupQueryAttentionAttributes {
1316
numHeads: number;
1417
kvNumHeads: number;
@@ -24,9 +27,6 @@ export const validateInputs = (
2427
inputs: readonly TensorView[],
2528
attributes: GroupQueryAttentionAttributes,
2629
): AttentionParameters => {
27-
if (attributes.doRotary) {
28-
throw new Error('GroupQuerryAttention do_rotary attribute is not supported');
29-
}
3030
if (attributes.doRotary && inputs.length <= 7) {
3131
throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
3232
}
@@ -35,6 +35,9 @@ export const validateInputs = (
3535
const value = inputs[2];
3636
const pastKey = inputs[3];
3737
const pastValue = inputs[4];
38+
if (attributes.doRotary !== 0 && inputs.length <= 7) {
39+
throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero');
40+
}
3841
if (attributes.localWindowSize !== -1) {
3942
throw new Error('Local attention is not supported');
4043
}
@@ -238,6 +241,77 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
238241
return reshapedInput;
239242
};
240243

244+
const generatePositionIdsProgramInfo = (
245+
batchSize: number,
246+
sequenceLength: number,
247+
seqLens: TensorView,
248+
totalSeqLen: TensorView,
249+
) => {
250+
const outputDataType = DataType.int64;
251+
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
252+
const outputShape = [batchSize * sequenceLength];
253+
const outputSize = batchSize * sequenceLength;
254+
const programUniforms: ProgramUniform[] = [
255+
{ type: DataType.uint32, data: outputSize },
256+
{ type: DataType.uint32, data: sequenceLength },
257+
{ type: DataType.uint32, data: batchSize },
258+
];
259+
const getShaderSource = (shaderHelper: ShaderHelper) => {
260+
const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims);
261+
const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims);
262+
const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape);
263+
264+
const uniforms: UniformsArrayType = [
265+
{ name: 'output_size', type: 'u32' },
266+
{ name: 'sequence_length', type: 'u32' },
267+
{ name: 'batch_size', type: 'u32' },
268+
];
269+
270+
return `
271+
${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)}
272+
${shaderHelper.mainStart()}
273+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
274+
let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')});
275+
let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length;
276+
let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length;
277+
let batch_idx = global_idx / uniforms.sequence_length;
278+
let sequence_idx = i32(global_idx % uniforms.sequence_length);
279+
var pos_id: i32 = 0;
280+
let seqlen = ${seqLensInputHelper.getByOffset('batch_idx')};
281+
let total_seqlen = seqlen + 1;
282+
if (is_first_prompt) {
283+
if (sequence_idx < total_seqlen) {
284+
pos_id = sequence_idx;
285+
} else {
286+
pos_id = 1;
287+
}
288+
${positionIdsHelper.setByOffset('global_idx', 'pos_id')}
289+
} else if (is_subsequent_prompt) {
290+
let past_seqlen = total_seqlen - i32(uniforms.sequence_length);
291+
if (past_seqlen + sequence_idx < total_seqlen) {
292+
pos_id = past_seqlen + sequence_idx;
293+
} else {
294+
pos_id = 1;
295+
}
296+
${positionIdsHelper.setByOffset('global_idx', 'pos_id')}
297+
} else if (global_idx < uniforms.batch_size) {
298+
${positionIdsHelper.setByOffset('global_idx', 'seqlen')}
299+
};
300+
}
301+
`;
302+
};
303+
return {
304+
name: 'GeneratePositionIds',
305+
shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies },
306+
getRunData: () => ({
307+
outputs: [{ dims: outputShape, dataType: outputDataType }],
308+
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
309+
programUniforms,
310+
}),
311+
getShaderSource,
312+
};
313+
};
314+
241315
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
242316
const params = validateInputs(context.inputs, attributes);
243317
if (context.inputs[0].dims.length === 5) {
@@ -268,22 +342,57 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
268342
!k && !v
269343
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
270344
: [q, k!, v!];
271-
345+
let qRotary: TensorView | undefined;
346+
let kRotary: TensorView | undefined;
347+
if (attributes.doRotary) {
348+
const posIds = context.compute(
349+
generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!),
350+
{ inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] },
351+
)[0];
352+
const cosCache = context.inputs[7];
353+
const sinCache = context.inputs[8];
354+
const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
355+
interleaved: attributes.rotaryInterleaved !== 0,
356+
numHeads: params.numHeads,
357+
rotaryEmbeddingDim: 0,
358+
scale: attributes.scale,
359+
});
360+
const inputs = [query, posIds, cosCache, sinCache];
361+
const outputs = [-1];
362+
qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), {
363+
inputs,
364+
outputs,
365+
})[0];
366+
inputs.splice(0, 1, key);
367+
const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
368+
interleaved: attributes.rotaryInterleaved !== 0,
369+
numHeads: params.kvNumHeads!,
370+
rotaryEmbeddingDim: 0,
371+
scale: attributes.scale,
372+
});
373+
kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), {
374+
inputs,
375+
outputs,
376+
})[0];
377+
}
272378
const Q = maybeTransposeToBNSHAndAddBias(
273379
context,
274380
params.batchSize,
275381
params.numHeads,
276382
params.sequenceLength,
277383
params.headSize,
278-
query,
384+
attributes.doRotary ? qRotary! : query,
279385
undefined,
280386
0,
281387
);
388+
const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params);
389+
const V = maybeTransposeToBNSH(context, value, params);
390+
282391
applyAttention(
283392
context,
284393
Q,
285-
maybeTransposeToBNSH(context, key, params),
286-
maybeTransposeToBNSH(context, value, params),
394+
K,
395+
V,
287396
undefined,
288397
undefined,
289398
pastKey,

js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi
7575
}
7676
};
7777

78-
const createRotaryEmbeddingProgramInfo = (
78+
export const createRotaryEmbeddingProgramInfo = (
7979
inputs: readonly TensorView[],
8080
attributes: RotaryEmbeddingAttributes,
8181
): ProgramInfo => {

0 commit comments

Comments
 (0)