Skip to content

Commit b98aae8

Browse files
Update GQA lowering and add lit test for GQA with rotary embedding
1 parent bdb23be commit b98aae8

File tree

4 files changed

+258
-145
lines changed

4 files changed

+258
-145
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 125 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,6 +3968,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39683968
Value pastKey = operands[3];
39693969
Value pastValue = operands[4];
39703970
Value seqlensK = operands[5];
3971+
Value totalSequenceLength = operands[6];
39713972
Value cosCache, sinCache;
39723973
if (doRotary) {
39733974
cosCache = operands[7];
@@ -4053,25 +4054,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
40534054

40544055
Value qRotary = qInput, kRotary = kInput;
40554056
if (doRotary) {
4056-
bool isFirstPrompt = false;
4057-
// TODO: add the rotary embedding support
4057+
// `totalSequenceLength` is a scalar tensor.
4058+
Value scalarTotalSeqLens = rewriter.create<Torch::AtenItemOp>(
4059+
loc, rewriter.getType<Torch::IntType>(), totalSequenceLength);
4060+
Value cstIntOne = rewriter.create<Torch::ConstantIntOp>(
4061+
binder.getLoc(), rewriter.getI64IntegerAttr(1));
4062+
Type boolTy = rewriter.getType<Torch::BoolType>();
4063+
Value condA = rewriter.create<Torch::AtenGtIntOp>(
4064+
loc, boolTy, cstSequenceLength, cstIntOne);
4065+
Value condB = rewriter.create<Torch::AtenNeIntOp>(
4066+
loc, boolTy, cstSequenceLength, scalarTotalSeqLens);
4067+
// if (sequence_length > 1 && sequence_length !=
4068+
// total_sequence_length)
4069+
// is_subsequent_prompt = false; // Subsequent prompt
4070+
Value isSubsequentPrompt = rewriter.create<Torch::Aten__And__BoolOp>(
4071+
loc, boolTy, condA, condB);
40584072

40594073
// Generating position_ids for rotary_embedding as follows:
4060-
// if is_first_prompt:
4061-
// pos_ids = torch.tensor([0], dtype=torch.int64)
4062-
// else:
4074+
// pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64)
4075+
//
40634076
// total_seqlens = seqlens_k + 1
40644077
// past_seqlens = total_seqlens - sequence_length
40654078
// pos_ids = torch.arange(sequence_length,
40664079
// dtype=torch.int64).repeat(batch_size, 1)
40674080
// pos_ids = pos_ids + past_seqlens.view(-1, 1)
40684081
// cond = pos_ids < total_seqlens.view(-1, 1)
40694082
// one_tensor = torch.tensor(1, dtype=torch.int64)
4070-
// pos_ids = torch.where(cond, pos_ids, one_tensor)
4071-
SmallVector<int64_t> positionIdsSizeInt{1};
4072-
if (!isFirstPrompt)
4073-
positionIdsSizeInt = {batchSize, sequenceLength};
4074-
4083+
// pos_ids_b = torch.where(cond, pos_ids, one_tensor)
4084+
//
4085+
// if subsequent_prompt:
4086+
// pos_ids = pos_ids_b
4087+
// else:
4088+
// pos_ids = pos_ids_a
4089+
SmallVector<int64_t> positionIdsSizeInt{batchSize, sequenceLength};
40754090
Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get(
40764091
context, positionIdsSizeInt,
40774092
IntegerType::get(context, 64, IntegerType::Signed));
@@ -4087,93 +4102,106 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
40874102
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
40884103
rewriter.getF64FloatAttr(1.0));
40894104

4090-
Value positionIds;
4091-
if (isFirstPrompt) {
4092-
positionIdsSizeInt = {1};
4093-
Value posIdsDataList = rewriter.create<Torch::PrimListConstructOp>(
4094-
loc,
4095-
rewriter.getType<Torch::ListType>(
4096-
rewriter.getType<Torch::IntType>()),
4097-
SmallVector<Value>{cstIntZero});
4098-
positionIds = rewriter.create<Torch::AtenTensorOp>(
4099-
loc, positionIdsType, posIdsDataList, /*dtype=*/cstInt64Dtype,
4100-
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
4101-
} else {
4102-
// Convert seqlens_k which is a tensor of type si32 to si64.
4103-
Torch::ValueTensorType seqLensKType =
4104-
cast<Torch::ValueTensorType>(seqlensK.getType());
4105-
seqlensK = rewriter.create<Torch::AtenToDtypeOp>(
4106-
loc,
4107-
seqLensKType.getWithSizesAndDtype(std::nullopt,
4108-
rewriter.getI64Type()),
4109-
seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse,
4110-
/*copy=*/cstFalse, /*memory_format=*/cstNone);
4111-
Value cstIntOne = rewriter.create<Torch::ConstantIntOp>(
4112-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
4113-
Value totalSeqLens = rewriter.create<Torch::AtenAddScalarOp>(
4114-
loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne,
4115-
/*alpha=*/cstIntOne);
4116-
Value pastSeqLens = rewriter.create<Torch::AtenSubScalarOp>(
4117-
loc, totalSeqLens.getType(), /*self=*/totalSeqLens,
4118-
/*other=*/cstSequenceLength, /*alpha=*/cstIntOne);
4119-
Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get(
4120-
context, {sequenceLength},
4121-
IntegerType::get(context, 64, IntegerType::Signed));
4122-
Value initPosIds = rewriter.create<Torch::AtenArangeOp>(
4123-
loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4124-
/*layout=*/cstNone,
4125-
/*device=*/cstNone, /*pin_memory=*/cstNone);
4126-
Value repeatValuesList =
4127-
rewriter.create<Torch::PrimListConstructOp>(
4128-
binder.getLoc(),
4129-
Torch::ListType::get(Torch::IntType::get(context)),
4130-
llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4131-
positionIds = rewriter.create<Torch::AtenRepeatOp>(
4132-
loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList);
4133-
4134-
Value cstIntMinusOne = rewriter.create<Torch::ConstantIntOp>(
4135-
binder.getLoc(), rewriter.getI64IntegerAttr(1));
4136-
Value viewSizeList = rewriter.create<Torch::PrimListConstructOp>(
4137-
binder.getLoc(),
4138-
Torch::ListType::get(Torch::IntType::get(context)),
4139-
llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4140-
4141-
Torch::ValueTensorType seqLensViewType =
4142-
Torch::ValueTensorType::get(
4143-
context, llvm::SmallVector<int64_t>{batchSize, 1},
4144-
IntegerType::get(context, 64, IntegerType::Signed));
4145-
pastSeqLens = rewriter.create<Torch::AtenViewOp>(
4146-
loc, seqLensViewType, pastSeqLens, viewSizeList);
4147-
4148-
positionIds = rewriter.create<Torch::AtenAddTensorOp>(
4149-
loc, positionIdsType, positionIds, pastSeqLens,
4150-
/*alpha=*/cstIntOne);
4151-
4152-
totalSeqLens = rewriter.create<Torch::AtenViewOp>(
4153-
loc, seqLensViewType, totalSeqLens, viewSizeList);
4154-
Value cond = rewriter.create<Torch::AtenLtTensorOp>(
4155-
loc,
4156-
positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(),
4157-
rewriter.getI1Type()),
4158-
positionIds, totalSeqLens);
4159-
4160-
Value cstOneTensorDataList =
4161-
rewriter.create<Torch::PrimListConstructOp>(
4162-
loc,
4163-
rewriter.getType<Torch::ListType>(
4164-
rewriter.getType<Torch::IntType>()),
4165-
SmallVector<Value>{cstIntOne});
4166-
Value cstOneTensor = rewriter.create<Torch::AtenTensorOp>(
4167-
loc,
4168-
Torch::ValueTensorType::get(
4169-
context, {},
4170-
IntegerType::get(context, 64, IntegerType::Signed)),
4171-
cstOneTensorDataList, /*dtype=*/cstInt64Dtype,
4172-
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
4173-
4174-
positionIds = rewriter.create<Torch::AtenWhereSelfOp>(
4175-
loc, positionIdsType, cond, positionIds, cstOneTensor);
4176-
}
4105+
Value positionIdsA, positionIdsB;
4106+
4107+
Value posIdsSizeList = rewriter.create<Torch::PrimListConstructOp>(
4108+
loc,
4109+
rewriter.getType<Torch::ListType>(
4110+
rewriter.getType<Torch::IntType>()),
4111+
SmallVector<Value>{cstBatchSize, cstSequenceLength});
4112+
positionIdsA = rewriter.create<Torch::AtenZerosOp>(
4113+
loc, positionIdsType, /*size=*/posIdsSizeList,
4114+
/*dtype=*/cstInt64Dtype,
4115+
/*layout=*/cstNone, /*device=*/cstNone,
4116+
/*pin_memory=*/cstNone);
4117+
4118+
// Convert seqlens_k which is a tensor of type si32 to si64.
4119+
Torch::ValueTensorType seqLensKType =
4120+
cast<Torch::ValueTensorType>(seqlensK.getType());
4121+
seqlensK = rewriter.create<Torch::AtenToDtypeOp>(
4122+
loc,
4123+
seqLensKType.getWithSizesAndDtype(
4124+
std::nullopt,
4125+
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)),
4126+
seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse,
4127+
/*copy=*/cstFalse, /*memory_format=*/cstNone);
4128+
Value totalSeqLens = rewriter.create<Torch::AtenAddScalarOp>(
4129+
loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne,
4130+
/*alpha=*/cstIntOne);
4131+
Value pastSeqLens = rewriter.create<Torch::AtenSubScalarOp>(
4132+
loc, totalSeqLens.getType(), /*self=*/totalSeqLens,
4133+
/*other=*/cstSequenceLength, /*alpha=*/cstIntOne);
4134+
Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get(
4135+
context, {sequenceLength},
4136+
IntegerType::get(context, 64, IntegerType::Signed));
4137+
Value initPosIds = rewriter.create<Torch::AtenArangeOp>(
4138+
loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4139+
/*layout=*/cstNone,
4140+
/*device=*/cstNone, /*pin_memory=*/cstNone);
4141+
Value repeatValuesList = rewriter.create<Torch::PrimListConstructOp>(
4142+
binder.getLoc(),
4143+
Torch::ListType::get(Torch::IntType::get(context)),
4144+
llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4145+
positionIdsB = rewriter.create<Torch::AtenRepeatOp>(
4146+
loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList);
4147+
4148+
Value cstIntMinusOne = rewriter.create<Torch::ConstantIntOp>(
4149+
binder.getLoc(), rewriter.getI64IntegerAttr(1));
4150+
Value viewSizeList = rewriter.create<Torch::PrimListConstructOp>(
4151+
binder.getLoc(),
4152+
Torch::ListType::get(Torch::IntType::get(context)),
4153+
llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4154+
4155+
Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get(
4156+
context, llvm::SmallVector<int64_t>{batchSize, 1},
4157+
IntegerType::get(context, 64, IntegerType::Signed));
4158+
pastSeqLens = rewriter.create<Torch::AtenViewOp>(
4159+
loc, seqLensViewType, pastSeqLens, viewSizeList);
4160+
4161+
positionIdsB = rewriter.create<Torch::AtenAddTensorOp>(
4162+
loc, positionIdsType, positionIdsB, pastSeqLens,
4163+
/*alpha=*/cstIntOne);
4164+
4165+
totalSeqLens = rewriter.create<Torch::AtenViewOp>(
4166+
loc, seqLensViewType, totalSeqLens, viewSizeList);
4167+
Value cond = rewriter.create<Torch::AtenLtTensorOp>(
4168+
loc,
4169+
positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(),
4170+
rewriter.getI1Type()),
4171+
positionIdsB, totalSeqLens);
4172+
4173+
Value cstOneTensorDataList =
4174+
rewriter.create<Torch::PrimListConstructOp>(
4175+
loc,
4176+
rewriter.getType<Torch::ListType>(
4177+
rewriter.getType<Torch::IntType>()),
4178+
SmallVector<Value>{cstIntOne});
4179+
Value cstOneTensor = rewriter.create<Torch::AtenTensorOp>(
4180+
loc,
4181+
Torch::ValueTensorType::get(
4182+
context, {},
4183+
IntegerType::get(context, 64, IntegerType::Signed)),
4184+
cstOneTensorDataList, /*dtype=*/cstInt64Dtype,
4185+
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
4186+
4187+
positionIdsB = rewriter.create<Torch::AtenWhereSelfOp>(
4188+
loc, positionIdsType, cond, positionIdsB, cstOneTensor);
4189+
4190+
isSubsequentPrompt = rewriter.create<Torch::AtenIntBoolOp>(
4191+
loc, rewriter.getType<Torch::IntType>(), isSubsequentPrompt);
4192+
isSubsequentPrompt = rewriter.create<Torch::AtenFullOp>(
4193+
loc,
4194+
Torch::ValueTensorType::get(context, positionIdsSizeInt,
4195+
rewriter.getI1Type()),
4196+
/*size=*/posIdsSizeList, /*fill_value=*/isSubsequentPrompt,
4197+
/*dtype=*/
4198+
rewriter.create<Torch::ConstantIntOp>(
4199+
binder.getLoc(), rewriter.getI64IntegerAttr(
4200+
(int)torch_upstream::ScalarType::Bool)),
4201+
/*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone);
4202+
Value positionIds = rewriter.create<Torch::AtenWhereSelfOp>(
4203+
loc, positionIdsType, isSubsequentPrompt, positionIdsB,
4204+
positionIdsA);
41774205

41784206
qRotary = rewriter.create<Torch::OnnxVariantAtenRotaryEmbeddingOp>(
41794207
loc, qInput.getType(), qInput, positionIds, cosCache, sinCache,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3777,6 +3777,9 @@
37773777
"ScaledDotProductAttentionSameDynamicModule_basic",
37783778
"ScaledDotProductAttentionSameModule_basic",
37793779
"ScaledDotProductAttentionGQAModule_basic",
3780+
"AtenSymConstrainRangeForSize_basic",
3781+
"AtenSymConstrainRange_basic",
3782+
"Aten_AssertScalar_basic",
37803783
}
37813784

37823785
ONNX_TOSA_CRASHING_SET = {

0 commit comments

Comments
 (0)