Skip to content

Commit

Permalink
Update GQA lowering and add lit test for GQA with rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Feb 10, 2025
1 parent bdb23be commit b98aae8
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 145 deletions.
222 changes: 125 additions & 97 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3968,6 +3968,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value pastKey = operands[3];
Value pastValue = operands[4];
Value seqlensK = operands[5];
Value totalSequenceLength = operands[6];
Value cosCache, sinCache;
if (doRotary) {
cosCache = operands[7];
Expand Down Expand Up @@ -4053,25 +4054,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

Value qRotary = qInput, kRotary = kInput;
if (doRotary) {
bool isFirstPrompt = false;
// TODO: add the rotary embedding support
// `totalSequenceLength` is a scalar tensor.
Value scalarTotalSeqLens = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), totalSequenceLength);
Value cstIntOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Type boolTy = rewriter.getType<Torch::BoolType>();
Value condA = rewriter.create<Torch::AtenGtIntOp>(
loc, boolTy, cstSequenceLength, cstIntOne);
Value condB = rewriter.create<Torch::AtenNeIntOp>(
loc, boolTy, cstSequenceLength, scalarTotalSeqLens);
// if (sequence_length > 1 && sequence_length !=
// total_sequence_length)
// is_subsequent_prompt = false; // Subsequent prompt
Value isSubsequentPrompt = rewriter.create<Torch::Aten__And__BoolOp>(
loc, boolTy, condA, condB);

// Generating position_ids for rotary_embedding as follows:
// if is_first_prompt:
// pos_ids = torch.tensor([0], dtype=torch.int64)
// else:
// pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64)
//
// total_seqlens = seqlens_k + 1
// past_seqlens = total_seqlens - sequence_length
// pos_ids = torch.arange(sequence_length,
// dtype=torch.int64).repeat(batch_size, 1)
// pos_ids = pos_ids + past_seqlens.view(-1, 1)
// cond = pos_ids < total_seqlens.view(-1, 1)
// one_tensor = torch.tensor(1, dtype=torch.int64)
// pos_ids = torch.where(cond, pos_ids, one_tensor)
SmallVector<int64_t> positionIdsSizeInt{1};
if (!isFirstPrompt)
positionIdsSizeInt = {batchSize, sequenceLength};

// pos_ids_b = torch.where(cond, pos_ids, one_tensor)
//
// if subsequent_prompt:
// pos_ids = pos_ids_b
// else:
// pos_ids = pos_ids_a
SmallVector<int64_t> positionIdsSizeInt{batchSize, sequenceLength};
Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get(
context, positionIdsSizeInt,
IntegerType::get(context, 64, IntegerType::Signed));
Expand All @@ -4087,93 +4102,106 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));

Value positionIds;
if (isFirstPrompt) {
positionIdsSizeInt = {1};
Value posIdsDataList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstIntZero});
positionIds = rewriter.create<Torch::AtenTensorOp>(
loc, positionIdsType, posIdsDataList, /*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
} else {
// Convert seqlens_k which is a tensor of type si32 to si64.
Torch::ValueTensorType seqLensKType =
cast<Torch::ValueTensorType>(seqlensK.getType());
seqlensK = rewriter.create<Torch::AtenToDtypeOp>(
loc,
seqLensKType.getWithSizesAndDtype(std::nullopt,
rewriter.getI64Type()),
seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse,
/*copy=*/cstFalse, /*memory_format=*/cstNone);
Value cstIntOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value totalSeqLens = rewriter.create<Torch::AtenAddScalarOp>(
loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne,
/*alpha=*/cstIntOne);
Value pastSeqLens = rewriter.create<Torch::AtenSubScalarOp>(
loc, totalSeqLens.getType(), /*self=*/totalSeqLens,
/*other=*/cstSequenceLength, /*alpha=*/cstIntOne);
Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get(
context, {sequenceLength},
IntegerType::get(context, 64, IntegerType::Signed));
Value initPosIds = rewriter.create<Torch::AtenArangeOp>(
loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pin_memory=*/cstNone);
Value repeatValuesList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
positionIds = rewriter.create<Torch::AtenRepeatOp>(
loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList);

Value cstIntMinusOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value viewSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});

Torch::ValueTensorType seqLensViewType =
Torch::ValueTensorType::get(
context, llvm::SmallVector<int64_t>{batchSize, 1},
IntegerType::get(context, 64, IntegerType::Signed));
pastSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, pastSeqLens, viewSizeList);

positionIds = rewriter.create<Torch::AtenAddTensorOp>(
loc, positionIdsType, positionIds, pastSeqLens,
/*alpha=*/cstIntOne);

totalSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, totalSeqLens, viewSizeList);
Value cond = rewriter.create<Torch::AtenLtTensorOp>(
loc,
positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(),
rewriter.getI1Type()),
positionIds, totalSeqLens);

Value cstOneTensorDataList =
rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstIntOne});
Value cstOneTensor = rewriter.create<Torch::AtenTensorOp>(
loc,
Torch::ValueTensorType::get(
context, {},
IntegerType::get(context, 64, IntegerType::Signed)),
cstOneTensorDataList, /*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*requires_grad=*/cstFalse);

positionIds = rewriter.create<Torch::AtenWhereSelfOp>(
loc, positionIdsType, cond, positionIds, cstOneTensor);
}
Value positionIdsA, positionIdsB;

Value posIdsSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstBatchSize, cstSequenceLength});
positionIdsA = rewriter.create<Torch::AtenZerosOp>(
loc, positionIdsType, /*size=*/posIdsSizeList,
/*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*device=*/cstNone,
/*pin_memory=*/cstNone);

// Convert seqlens_k which is a tensor of type si32 to si64.
Torch::ValueTensorType seqLensKType =
cast<Torch::ValueTensorType>(seqlensK.getType());
seqlensK = rewriter.create<Torch::AtenToDtypeOp>(
loc,
seqLensKType.getWithSizesAndDtype(
std::nullopt,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)),
seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse,
/*copy=*/cstFalse, /*memory_format=*/cstNone);
Value totalSeqLens = rewriter.create<Torch::AtenAddScalarOp>(
loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne,
/*alpha=*/cstIntOne);
Value pastSeqLens = rewriter.create<Torch::AtenSubScalarOp>(
loc, totalSeqLens.getType(), /*self=*/totalSeqLens,
/*other=*/cstSequenceLength, /*alpha=*/cstIntOne);
Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get(
context, {sequenceLength},
IntegerType::get(context, 64, IntegerType::Signed));
Value initPosIds = rewriter.create<Torch::AtenArangeOp>(
loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pin_memory=*/cstNone);
Value repeatValuesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
positionIdsB = rewriter.create<Torch::AtenRepeatOp>(
loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList);

Value cstIntMinusOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value viewSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});

Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get(
context, llvm::SmallVector<int64_t>{batchSize, 1},
IntegerType::get(context, 64, IntegerType::Signed));
pastSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, pastSeqLens, viewSizeList);

positionIdsB = rewriter.create<Torch::AtenAddTensorOp>(
loc, positionIdsType, positionIdsB, pastSeqLens,
/*alpha=*/cstIntOne);

totalSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, totalSeqLens, viewSizeList);
Value cond = rewriter.create<Torch::AtenLtTensorOp>(
loc,
positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(),
rewriter.getI1Type()),
positionIdsB, totalSeqLens);

Value cstOneTensorDataList =
rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstIntOne});
Value cstOneTensor = rewriter.create<Torch::AtenTensorOp>(
loc,
Torch::ValueTensorType::get(
context, {},
IntegerType::get(context, 64, IntegerType::Signed)),
cstOneTensorDataList, /*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*requires_grad=*/cstFalse);

positionIdsB = rewriter.create<Torch::AtenWhereSelfOp>(
loc, positionIdsType, cond, positionIdsB, cstOneTensor);

isSubsequentPrompt = rewriter.create<Torch::AtenIntBoolOp>(
loc, rewriter.getType<Torch::IntType>(), isSubsequentPrompt);
isSubsequentPrompt = rewriter.create<Torch::AtenFullOp>(
loc,
Torch::ValueTensorType::get(context, positionIdsSizeInt,
rewriter.getI1Type()),
/*size=*/posIdsSizeList, /*fill_value=*/isSubsequentPrompt,
/*dtype=*/
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Bool)),
/*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone);
Value positionIds = rewriter.create<Torch::AtenWhereSelfOp>(
loc, positionIdsType, isSubsequentPrompt, positionIdsB,
positionIdsA);

qRotary = rewriter.create<Torch::OnnxVariantAtenRotaryEmbeddingOp>(
loc, qInput.getType(), qInput, positionIds, cosCache, sinCache,
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3777,6 +3777,9 @@
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
"AtenSymConstrainRangeForSize_basic",
"AtenSymConstrainRange_basic",
"Aten_AssertScalar_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down
Loading

0 comments on commit b98aae8

Please sign in to comment.