Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add Onnx->Torch lowering for GroupQueryAttention op #4006

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
389 changes: 389 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,393 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
cstRotaryEmbeddingDim, cstScale);
return success();
});
patterns.onOp(
"GroupQueryAttention", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallVector<Value> operands;
SmallVector<Type> resultTypes;
int64_t doRotary, kvNumHeads, localWindowSize, numHeads,
rotaryInterleaved, smoothSoftmax;
float scale, softcap;
if (binder.tensorOperandsList(operands))
return rewriter.notifyMatchFailure(binder.op,
"operands bind failure");

if (binder.tensorResultTypes(resultTypes))
return rewriter.notifyMatchFailure(binder.op,
"result types bind failure");

if (resultTypes.size() != 3)
return rewriter.notifyMatchFailure(binder.op,
"expected 3 result types");

if (binder.s64IntegerAttr(doRotary, "do_rotary") ||
binder.s64IntegerAttr(kvNumHeads, "kv_num_heads") ||
binder.s64IntegerAttr(localWindowSize, "local_window_size", -1) ||
binder.s64IntegerAttr(numHeads, "num_heads") ||
binder.s64IntegerAttr(rotaryInterleaved, "rotary_interleaved") ||
binder.f32FloatAttr(scale, "scale") ||
binder.s64IntegerAttr(smoothSoftmax, "smooth_softmax") ||
binder.f32FloatAttr(softcap, "softcap"))
return rewriter.notifyMatchFailure(binder.op,
"op attributes bind failure");

// This lowering excepts input operands to be either 7 or 9 based on the
// `do_rotary` attribute. If it's false, then the input operands can be
// 7 but if it's true then the operands has to be 9 including cos_cache
// and sin_cache for rotary_embedding.
if (!((operands.size() == 9) || (!doRotary && operands.size() == 7)))
return rewriter.notifyMatchFailure(
binder.op, "Unimplemented: excepted input operands to be either "
"7 or 9 based on the `do_rotary` attribute");

if (kvNumHeads == 0)
return rewriter.notifyMatchFailure(
binder.op,
"kv_num_heads is a required attribute and should be non-zero");

if (localWindowSize != -1)
return rewriter.notifyMatchFailure(
binder.op,
"Unimplemented: local_window_size attribute is not supported, "
"hence it should have default value equal to -1");

if (numHeads == 0)
return rewriter.notifyMatchFailure(
binder.op,
"num_heads is a required attribute and should be non-zero");

if (smoothSoftmax != 0)
return rewriter.notifyMatchFailure(
binder.op,
"Unimplemented: smooth_softmax attribute is not supported, hence "
"it should have default value equal to 0");

if (softcap != 0.0f)
return rewriter.notifyMatchFailure(
binder.op, "Unimplemented: softcap attribute is not supported, "
"hence it should have default value equal to 0.0");

Location loc = binder.getLoc();
MLIRContext *context = binder.op->getContext();
Value query = operands[0];
Value key = operands[1];
Value value = operands[2];
Value pastKey = operands[3];
Value pastValue = operands[4];
Value seqlensK = operands[5];
Value totalSequenceLength = operands[6];
Value cosCache, sinCache;
if (doRotary) {
cosCache = operands[7];
sinCache = operands[8];
}

Torch::ValueTensorType queryType =
cast<Torch::ValueTensorType>(query.getType());
if (!(queryType.hasSizes() && queryType.areAllSizesKnown()))
return rewriter.notifyMatchFailure(
binder.op,
"Expected `query` input to have statically known sizes");

SmallVector<int64_t> queryDims{queryType.getSizes()};
int64_t batchSize = queryDims[0];
int64_t sequenceLength = queryDims[1];
int64_t hiddenSize = queryDims[2];
int64_t headSize = hiddenSize / numHeads;

Value cstBatchSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(batchSize));
Value cstSequenceLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(sequenceLength));
Value cstHiddenSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(hiddenSize));
Value cstHeadSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(headSize));
Value cstNumHeads = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(numHeads));
Value cstKVNumHeads = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(kvNumHeads));

// Reshape Query, Key and Value as follows:
// Query: (batch_size, sequence_length, hidden_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the reshape for key and value is from https://github.com/microsoft/onnxruntime/blob/65008cbb7393b121400a40dd8af4cc93d506918f/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L85-L88

  std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
  std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
  Tensor* present_k = context->Output(1, present_k_shape);
  Tensor* present_v = context->Output(2, present_v_shape);

But I didn't find the query reshape code in onnx. Could you point me where is it from?

// -> (batch_size, num_heads, sequence_length, head_size)
// Key: (batch_size, kv_sequence_length, kv_hidden_size)
// -> (batch_size, kv_num_heads, sequence_length, head_size)
// Value: (batch_size, kv_sequence_length, kv_hidden_size)
// -> (batch_size, kv_num_heads, sequence_length, head_size)

// Reshaping query.
SmallVector<int64_t, 6> queryReshapeSizesInt{batchSize, numHeads,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 6? only 4 value here.SmallVector<int64_t, 4>

sequenceLength, headSize};
Value queryReshapeSizesList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(query.getContext())),
llvm::SmallVector<Value>{cstBatchSize, cstNumHeads,
cstSequenceLength, cstHeadSize});
Value qInput = rewriter.create<Torch::AtenReshapeOp>(
loc,
queryType.getWithSizesAndDtype(queryReshapeSizesInt,
queryType.getOptionalDtype()),
query, queryReshapeSizesList);

// Reshaping key.
SmallVector<int64_t, 6> kvReshapeSizesInt{batchSize, kvNumHeads,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

sequenceLength, headSize};
Value kvReshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(query.getContext())),
llvm::SmallVector<Value>{cstBatchSize, cstKVNumHeads,
cstSequenceLength, cstHeadSize});
Torch::ValueTensorType keyType =
cast<Torch::ValueTensorType>(key.getType());
Value kInput = rewriter.create<Torch::AtenReshapeOp>(
loc,
keyType.getWithSizesAndDtype(kvReshapeSizesInt,
keyType.getOptionalDtype()),
key, kvReshapeSizesList);

// Reshaping value.
Torch::ValueTensorType valueType =
cast<Torch::ValueTensorType>(value.getType());
Value vInput = rewriter.create<Torch::AtenReshapeOp>(
loc,
valueType.getWithSizesAndDtype(kvReshapeSizesInt,
valueType.getOptionalDtype()),
value, kvReshapeSizesList);

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

Value qRotary = qInput, kRotary = kInput;
if (doRotary) {
// `totalSequenceLength` is a scalar tensor.
Value scalarTotalSeqLens = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), totalSequenceLength);
Value cstIntOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Type boolTy = rewriter.getType<Torch::BoolType>();
Value condA = rewriter.create<Torch::AtenGtIntOp>(
loc, boolTy, cstSequenceLength, cstIntOne);
Value condB = rewriter.create<Torch::AtenNeIntOp>(
loc, boolTy, cstSequenceLength, scalarTotalSeqLens);
// if (sequence_length > 1 && sequence_length !=
// total_sequence_length)
// is_subsequent_prompt = false; // Subsequent prompt
Value isSubsequentPrompt = rewriter.create<Torch::Aten__And__BoolOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why flip is_first_prompt to isSubsequentPrompt ?

loc, boolTy, condA, condB);

// Generating position_ids for rotary_embedding as follows:
// pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a complicated op. Better to provide a complete runnable pytorch script to demonstate the whole algorithm.
Something like this https://colab.research.google.com/drive/1-Z4Mbfbt9v60-25sbSYmJusEJYtA8ItO?usp=sharing or nonzero.py

//
// total_seqlens = seqlens_k + 1
// past_seqlens = total_seqlens - sequence_length
// pos_ids = torch.arange(sequence_length,
// dtype=torch.int64).repeat(batch_size, 1)
// pos_ids = pos_ids + past_seqlens.view(-1, 1)
// cond = pos_ids < total_seqlens.view(-1, 1)
// one_tensor = torch.tensor(1, dtype=torch.int64)
// pos_ids_b = torch.where(cond, pos_ids, one_tensor)
//
// if subsequent_prompt:
// pos_ids = pos_ids_b
// else:
// pos_ids = pos_ids_a
SmallVector<int64_t> positionIdsSizeInt{batchSize, sequenceLength};
Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get(
context, positionIdsSizeInt,
IntegerType::get(context, 64, IntegerType::Signed));
Value cstInt64Dtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Long));

Value cstInterleaved = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(rotaryInterleaved));
Value cstIntZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstFloatOne = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));

Value positionIdsA, positionIdsB;

Value posIdsSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstBatchSize, cstSequenceLength});
positionIdsA = rewriter.create<Torch::AtenZerosOp>(
loc, positionIdsType, /*size=*/posIdsSizeList,
/*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*device=*/cstNone,
/*pin_memory=*/cstNone);

// Convert seqlens_k which is a tensor of type si32 to si64.
Torch::ValueTensorType seqLensKType =
cast<Torch::ValueTensorType>(seqlensK.getType());
seqlensK = rewriter.create<Torch::AtenToDtypeOp>(
loc,
seqLensKType.getWithSizesAndDtype(
std::nullopt,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)),
seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse,
/*copy=*/cstFalse, /*memory_format=*/cstNone);
Value totalSeqLens = rewriter.create<Torch::AtenAddScalarOp>(
loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne,
/*alpha=*/cstIntOne);
Value pastSeqLens = rewriter.create<Torch::AtenSubScalarOp>(
loc, totalSeqLens.getType(), /*self=*/totalSeqLens,
/*other=*/cstSequenceLength, /*alpha=*/cstIntOne);
Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get(
context, {sequenceLength},
IntegerType::get(context, 64, IntegerType::Signed));
Value initPosIds = rewriter.create<Torch::AtenArangeOp>(
loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pin_memory=*/cstNone);
Value repeatValuesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
positionIdsB = rewriter.create<Torch::AtenRepeatOp>(
loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList);

Value cstIntMinusOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value viewSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)),
llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});

Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get(
context, llvm::SmallVector<int64_t>{batchSize, 1},
IntegerType::get(context, 64, IntegerType::Signed));
pastSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, pastSeqLens, viewSizeList);

positionIdsB = rewriter.create<Torch::AtenAddTensorOp>(
loc, positionIdsType, positionIdsB, pastSeqLens,
/*alpha=*/cstIntOne);

totalSeqLens = rewriter.create<Torch::AtenViewOp>(
loc, seqLensViewType, totalSeqLens, viewSizeList);
Value cond = rewriter.create<Torch::AtenLtTensorOp>(
loc,
positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(),
rewriter.getI1Type()),
positionIdsB, totalSeqLens);

Value cstOneTensorDataList =
rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstIntOne});
Value cstOneTensor = rewriter.create<Torch::AtenTensorOp>(
loc,
Torch::ValueTensorType::get(
context, {},
IntegerType::get(context, 64, IntegerType::Signed)),
cstOneTensorDataList, /*dtype=*/cstInt64Dtype,
/*layout=*/cstNone, /*requires_grad=*/cstFalse);

positionIdsB = rewriter.create<Torch::AtenWhereSelfOp>(
loc, positionIdsType, cond, positionIdsB, cstOneTensor);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To align with onnx implementation. Better to add comments
// Initialize separate buffers for rotary embeddings
// Assume packed_qkv == false, qkv is always unpacked.

isSubsequentPrompt = rewriter.create<Torch::AtenIntBoolOp>(
loc, rewriter.getType<Torch::IntType>(), isSubsequentPrompt);
isSubsequentPrompt = rewriter.create<Torch::AtenFullOp>(
loc,
Torch::ValueTensorType::get(context, positionIdsSizeInt,
rewriter.getI1Type()),
/*size=*/posIdsSizeList, /*fill_value=*/isSubsequentPrompt,
/*dtype=*/
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Bool)),
/*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone);
Value positionIds = rewriter.create<Torch::AtenWhereSelfOp>(
loc, positionIdsType, isSubsequentPrompt, positionIdsB,
positionIdsA);

qRotary = rewriter.create<Torch::OnnxVariantRotaryEmbeddingOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Run rotary embedding for Q and K

loc, qInput.getType(), qInput, positionIds, cosCache, sinCache,
cstInterleaved, /*is_packed_batching=*/cstIntZero,
/*num_heads=*/cstIntZero, /*rotary_embedding_dim=*/cstIntZero,
/*scale=*/cstFloatOne);

kRotary = rewriter.create<Torch::OnnxVariantRotaryEmbeddingOp>(
loc, qInput.getType(), kInput, positionIds, cosCache, sinCache,
cstInterleaved, /*is_packed_batching=*/cstIntZero,
/*num_heads=*/cstIntZero, /*rotary_embedding_dim=*/cstIntZero,
/*scale=*/cstFloatOne);
}

// Do attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Compute the attention score and apply the score to V

Value cstEnableGQA = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value cstFloatZero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(0.0));
Value cstScale = cstNone;
if (scale != 0.0f)
cstScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(scale));
Value attention =
rewriter.create<Torch::AtenScaledDotProductAttentionOp>(
loc, qRotary.getType(), qRotary, kRotary, vInput,
/*attn_mask=*/cstNone,
/*dropout_p=*/cstFloatZero, /*is_causal=*/cstFalse, cstScale,
cstEnableGQA);
// Reshaping the attention result from:
// (batch_size, num_heads, sequence_length, head_size)
// -> (batch_size, sequence_length, hidden_size)
Value attentionResultSizesList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(attention.getContext())),
llvm::SmallVector<Value>{cstBatchSize, cstSequenceLength,
cstHiddenSize});
attention = rewriter.create<Torch::AtenReshapeOp>(
loc, resultTypes[0], attention, attentionResultSizesList);

// Compute 2nd and 3rd result: present_key, present_value.
// present_key = torch.cat([past_key, key], dim=2) or past_key
// present_value = torch.cat([past_value, value], dim=2) or past_value
Value presentKey = pastKey, presentValue = pastValue;
if (!llvm::equal(
cast<Torch::ValueTensorType>(pastKey.getType()).getSizes(),
cast<Torch::ValueTensorType>(resultTypes[1]).getSizes())) {
Value cstConcatDim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
Type kvListElemType = keyType.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type kvListType = Torch::ListType::get(kvListElemType);
Value keyList = rewriter.create<Torch::PrimListConstructOp>(
loc, kvListType, SmallVector<Value>{pastKey, kRotary});
presentKey = rewriter.create<Torch::AtenCatOp>(loc, resultTypes[1],
keyList, cstConcatDim);
}

if (!llvm::equal(
cast<Torch::ValueTensorType>(pastValue.getType()).getSizes(),
cast<Torch::ValueTensorType>(resultTypes[2]).getSizes())) {
Value cstConcatDim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
Type kvListElemType = keyType.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type kvListType = Torch::ListType::get(kvListElemType);
Value valueList = rewriter.create<Torch::PrimListConstructOp>(
loc, kvListType, SmallVector<Value>{pastValue, vInput});
presentValue = rewriter.create<Torch::AtenCatOp>(
loc, resultTypes[2], valueList, cstConcatDim);
}

rewriter.replaceOp(binder.op, {attention, presentKey, presentValue});
return success();
});
}
Loading
Loading