3
3
4
4
import { TensorView } from '../../tensor-view' ;
5
5
import { createAttributeWithCacheKey } from '../attribute-with-cache-key' ;
6
- import { ComputeContext } from '../types' ;
6
+ import { ComputeContext , ProgramInputTensorInfoDependency , ProgramUniform } from '../types' ;
7
+ import { DataType } from '../../../wasm-common' ;
7
8
8
9
import { applyAttention , AttentionMaskType , AttentionParameters , AttentionQkvFormat } from './attention' ;
9
10
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention' ;
10
11
import { createSplitProgramInfo , SplitAttributes } from './split' ;
11
12
import { createTransposeProgramInfo , TransposeAttributes } from './transpose' ;
13
+ import { RotaryEmbeddingAttributes , createRotaryEmbeddingProgramInfo } from './rotary-embedding' ;
14
+ import { inputVariable , outputVariable , ShaderHelper , UniformsArrayType } from './common' ;
12
15
export interface GroupQueryAttentionAttributes {
13
16
numHeads : number ;
14
17
kvNumHeads : number ;
@@ -24,9 +27,6 @@ export const validateInputs = (
24
27
inputs : readonly TensorView [ ] ,
25
28
attributes : GroupQueryAttentionAttributes ,
26
29
) : AttentionParameters => {
27
- if ( attributes . doRotary ) {
28
- throw new Error ( 'GroupQuerryAttention do_rotary attribute is not supported' ) ;
29
- }
30
30
if ( attributes . doRotary && inputs . length <= 7 ) {
31
31
throw new Error ( 'cos_cache and sin_cache inputs are required if do_rotary is specified' ) ;
32
32
}
@@ -35,6 +35,9 @@ export const validateInputs = (
35
35
const value = inputs [ 2 ] ;
36
36
const pastKey = inputs [ 3 ] ;
37
37
const pastValue = inputs [ 4 ] ;
38
+ if ( attributes . doRotary !== 0 && inputs . length <= 7 ) {
39
+ throw new Error ( 'cos_cast and sin_cache are expected if do_rotary attribute is non-zero' ) ;
40
+ }
38
41
if ( attributes . localWindowSize !== - 1 ) {
39
42
throw new Error ( 'Local attention is not supported' ) ;
40
43
}
@@ -238,6 +241,77 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
238
241
return reshapedInput ;
239
242
} ;
240
243
244
+ const generatePositionIdsProgramInfo = (
245
+ batchSize : number ,
246
+ sequenceLength : number ,
247
+ seqLens : TensorView ,
248
+ totalSeqLen : TensorView ,
249
+ ) => {
250
+ const outputDataType = DataType . int64 ;
251
+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ 'type' , 'type' ] ;
252
+ const outputShape = [ batchSize * sequenceLength ] ;
253
+ const outputSize = batchSize * sequenceLength ;
254
+ const programUniforms : ProgramUniform [ ] = [
255
+ { type : DataType . uint32 , data : outputSize } ,
256
+ { type : DataType . uint32 , data : sequenceLength } ,
257
+ { type : DataType . uint32 , data : batchSize } ,
258
+ ] ;
259
+ const getShaderSource = ( shaderHelper : ShaderHelper ) => {
260
+ const seqLensInputHelper = inputVariable ( 'seq_lens' , seqLens . dataType , seqLens . dims ) ;
261
+ const totalSeqLenInputHelper = inputVariable ( 'total_seq_lens' , totalSeqLen . dataType , totalSeqLen . dims ) ;
262
+ const positionIdsHelper = outputVariable ( 'pos_ids' , outputDataType , outputShape ) ;
263
+
264
+ const uniforms : UniformsArrayType = [
265
+ { name : 'output_size' , type : 'u32' } ,
266
+ { name : 'sequence_length' , type : 'u32' } ,
267
+ { name : 'batch_size' , type : 'u32' } ,
268
+ ] ;
269
+
270
+ return `
271
+ ${ shaderHelper . registerUniforms ( uniforms ) . declareVariables ( seqLensInputHelper , totalSeqLenInputHelper , positionIdsHelper ) }
272
+ ${ shaderHelper . mainStart ( ) }
273
+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
274
+ let total_sequence_length = u32(${ totalSeqLenInputHelper . getByOffset ( '0' ) } );
275
+ let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length;
276
+ let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length;
277
+ let batch_idx = global_idx / uniforms.sequence_length;
278
+ let sequence_idx = i32(global_idx % uniforms.sequence_length);
279
+ var pos_id: i32 = 0;
280
+ let seqlen = ${ seqLensInputHelper . getByOffset ( 'batch_idx' ) } ;
281
+ let total_seqlen = seqlen + 1;
282
+ if (is_first_prompt) {
283
+ if (sequence_idx < total_seqlen) {
284
+ pos_id = sequence_idx;
285
+ } else {
286
+ pos_id = 1;
287
+ }
288
+ ${ positionIdsHelper . setByOffset ( 'global_idx' , 'pos_id' ) }
289
+ } else if (is_subsequent_prompt) {
290
+ let past_seqlen = total_seqlen - i32(uniforms.sequence_length);
291
+ if (past_seqlen + sequence_idx < total_seqlen) {
292
+ pos_id = past_seqlen + sequence_idx;
293
+ } else {
294
+ pos_id = 1;
295
+ }
296
+ ${ positionIdsHelper . setByOffset ( 'global_idx' , 'pos_id' ) }
297
+ } else if (global_idx < uniforms.batch_size) {
298
+ ${ positionIdsHelper . setByOffset ( 'global_idx' , 'seqlen' ) }
299
+ };
300
+ }
301
+ ` ;
302
+ } ;
303
+ return {
304
+ name : 'GeneratePositionIds' ,
305
+ shaderCache : { hint : `${ batchSize } ;${ sequenceLength } ` , inputDependencies } ,
306
+ getRunData : ( ) => ( {
307
+ outputs : [ { dims : outputShape , dataType : outputDataType } ] ,
308
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
309
+ programUniforms,
310
+ } ) ,
311
+ getShaderSource,
312
+ } ;
313
+ } ;
314
+
241
315
export const groupQueryAttention = ( context : ComputeContext , attributes : GroupQueryAttentionAttributes ) : void => {
242
316
const params = validateInputs ( context . inputs , attributes ) ;
243
317
if ( context . inputs [ 0 ] . dims . length === 5 ) {
@@ -268,22 +342,57 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
268
342
! k && ! v
269
343
? context . compute ( createSplitProgramInfo ( [ q ] , splitAttributes ) , { inputs : [ q ] , outputs : [ - 1 , - 1 , - 1 ] } )
270
344
: [ q , k ! , v ! ] ;
271
-
345
+ let qRotary : TensorView | undefined ;
346
+ let kRotary : TensorView | undefined ;
347
+ if ( attributes . doRotary ) {
348
+ const posIds = context . compute (
349
+ generatePositionIdsProgramInfo ( params . batchSize , params . sequenceLength , seqLens ! , totalSequenceLengthInput ! ) ,
350
+ { inputs : [ seqLens ! , totalSequenceLengthInput ! ] , outputs : [ - 1 ] } ,
351
+ ) [ 0 ] ;
352
+ const cosCache = context . inputs [ 7 ] ;
353
+ const sinCache = context . inputs [ 8 ] ;
354
+ const qRotaryEmbeddingAttributes : RotaryEmbeddingAttributes = createAttributeWithCacheKey ( {
355
+ interleaved : attributes . rotaryInterleaved !== 0 ,
356
+ numHeads : params . numHeads ,
357
+ rotaryEmbeddingDim : 0 ,
358
+ scale : attributes . scale ,
359
+ } ) ;
360
+ const inputs = [ query , posIds , cosCache , sinCache ] ;
361
+ const outputs = [ - 1 ] ;
362
+ qRotary = context . compute ( createRotaryEmbeddingProgramInfo ( inputs , qRotaryEmbeddingAttributes ) , {
363
+ inputs,
364
+ outputs,
365
+ } ) [ 0 ] ;
366
+ inputs . splice ( 0 , 1 , key ) ;
367
+ const kRotaryEmbeddingAttributes : RotaryEmbeddingAttributes = createAttributeWithCacheKey ( {
368
+ interleaved : attributes . rotaryInterleaved !== 0 ,
369
+ numHeads : params . kvNumHeads ! ,
370
+ rotaryEmbeddingDim : 0 ,
371
+ scale : attributes . scale ,
372
+ } ) ;
373
+ kRotary = context . compute ( createRotaryEmbeddingProgramInfo ( inputs , kRotaryEmbeddingAttributes ) , {
374
+ inputs,
375
+ outputs,
376
+ } ) [ 0 ] ;
377
+ }
272
378
const Q = maybeTransposeToBNSHAndAddBias (
273
379
context ,
274
380
params . batchSize ,
275
381
params . numHeads ,
276
382
params . sequenceLength ,
277
383
params . headSize ,
278
- query ,
384
+ attributes . doRotary ? qRotary ! : query ,
279
385
undefined ,
280
386
0 ,
281
387
) ;
388
+ const K = maybeTransposeToBNSH ( context , attributes . doRotary ? kRotary ! : key , params ) ;
389
+ const V = maybeTransposeToBNSH ( context , value , params ) ;
390
+
282
391
applyAttention (
283
392
context ,
284
393
Q ,
285
- maybeTransposeToBNSH ( context , key , params ) ,
286
- maybeTransposeToBNSH ( context , value , params ) ,
394
+ K ,
395
+ V ,
287
396
undefined ,
288
397
undefined ,
289
398
pastKey ,
0 commit comments