@@ -3968,6 +3968,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39683968 Value pastKey = operands[3 ];
39693969 Value pastValue = operands[4 ];
39703970 Value seqlensK = operands[5 ];
3971+ Value totalSequenceLength = operands[6 ];
39713972 Value cosCache, sinCache;
39723973 if (doRotary) {
39733974 cosCache = operands[7 ];
@@ -4053,25 +4054,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
40534054
40544055 Value qRotary = qInput, kRotary = kInput ;
40554056 if (doRotary) {
4056- bool isFirstPrompt = false ;
4057- // TODO: add the rotary embedding support
4057+ // `totalSequenceLength` is a scalar tensor.
4058+ Value scalarTotalSeqLens = rewriter.create <Torch::AtenItemOp>(
4059+ loc, rewriter.getType <Torch::IntType>(), totalSequenceLength);
4060+ Value cstIntOne = rewriter.create <Torch::ConstantIntOp>(
4061+ binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4062+ Type boolTy = rewriter.getType <Torch::BoolType>();
4063+ Value condA = rewriter.create <Torch::AtenGtIntOp>(
4064+ loc, boolTy, cstSequenceLength, cstIntOne);
4065+ Value condB = rewriter.create <Torch::AtenNeIntOp>(
4066+ loc, boolTy, cstSequenceLength, scalarTotalSeqLens);
4067+ // if (sequence_length > 1 && sequence_length !=
4068+ // total_sequence_length)
4069+ // is_subsequent_prompt = false; // Subsequent prompt
4070+ Value isSubsequentPrompt = rewriter.create <Torch::Aten__And__BoolOp>(
4071+ loc, boolTy, condA, condB);
40584072
40594073 // Generating position_ids for rotary_embedding as follows:
4060- // if is_first_prompt:
4061- // pos_ids = torch.tensor([0], dtype=torch.int64)
4062- // else:
4074+ // pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64)
4075+ //
40634076 // total_seqlens = seqlens_k + 1
40644077 // past_seqlens = total_seqlens - sequence_length
40654078 // pos_ids = torch.arange(sequence_length,
40664079 // dtype=torch.int64).repeat(batch_size, 1)
40674080 // pos_ids = pos_ids + past_seqlens.view(-1, 1)
40684081 // cond = pos_ids < total_seqlens.view(-1, 1)
40694082 // one_tensor = torch.tensor(1, dtype=torch.int64)
4070- // pos_ids = torch.where(cond, pos_ids, one_tensor)
4071- SmallVector<int64_t > positionIdsSizeInt{1 };
4072- if (!isFirstPrompt)
4073- positionIdsSizeInt = {batchSize, sequenceLength};
4074-
4083+ // pos_ids_b = torch.where(cond, pos_ids, one_tensor)
4084+ //
4085+ // if subsequent_prompt:
4086+ // pos_ids = pos_ids_b
4087+ // else:
4088+ // pos_ids = pos_ids_a
4089+ SmallVector<int64_t > positionIdsSizeInt{batchSize, sequenceLength};
40754090 Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get (
40764091 context, positionIdsSizeInt,
40774092 IntegerType::get (context, 64 , IntegerType::Signed));
@@ -4087,93 +4102,106 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
40874102 binder.getLoc (), rewriter.getType <Torch::FloatType>(),
40884103 rewriter.getF64FloatAttr (1.0 ));
40894104
4090- Value positionIds;
4091- if (isFirstPrompt) {
4092- positionIdsSizeInt = {1 };
4093- Value posIdsDataList = rewriter.create <Torch::PrimListConstructOp>(
4094- loc,
4095- rewriter.getType <Torch::ListType>(
4096- rewriter.getType <Torch::IntType>()),
4097- SmallVector<Value>{cstIntZero});
4098- positionIds = rewriter.create <Torch::AtenTensorOp>(
4099- loc, positionIdsType, posIdsDataList, /* dtype=*/ cstInt64Dtype,
4100- /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4101- } else {
4102- // Convert seqlens_k which is a tensor of type si32 to si64.
4103- Torch::ValueTensorType seqLensKType =
4104- cast<Torch::ValueTensorType>(seqlensK.getType ());
4105- seqlensK = rewriter.create <Torch::AtenToDtypeOp>(
4106- loc,
4107- seqLensKType.getWithSizesAndDtype (std::nullopt ,
4108- rewriter.getI64Type ()),
4109- seqlensK, cstInt64Dtype, /* non_blocking=*/ cstFalse,
4110- /* copy=*/ cstFalse, /* memory_format=*/ cstNone);
4111- Value cstIntOne = rewriter.create <Torch::ConstantIntOp>(
4112- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4113- Value totalSeqLens = rewriter.create <Torch::AtenAddScalarOp>(
4114- loc, seqlensK.getType (), /* self=*/ seqlensK, /* other=*/ cstIntOne,
4115- /* alpha=*/ cstIntOne);
4116- Value pastSeqLens = rewriter.create <Torch::AtenSubScalarOp>(
4117- loc, totalSeqLens.getType (), /* self=*/ totalSeqLens,
4118- /* other=*/ cstSequenceLength, /* alpha=*/ cstIntOne);
4119- Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get (
4120- context, {sequenceLength},
4121- IntegerType::get (context, 64 , IntegerType::Signed));
4122- Value initPosIds = rewriter.create <Torch::AtenArangeOp>(
4123- loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4124- /* layout=*/ cstNone,
4125- /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4126- Value repeatValuesList =
4127- rewriter.create <Torch::PrimListConstructOp>(
4128- binder.getLoc (),
4129- Torch::ListType::get (Torch::IntType::get (context)),
4130- llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4131- positionIds = rewriter.create <Torch::AtenRepeatOp>(
4132- loc, positionIdsType, initPosIds, /* repeats=*/ repeatValuesList);
4133-
4134- Value cstIntMinusOne = rewriter.create <Torch::ConstantIntOp>(
4135- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4136- Value viewSizeList = rewriter.create <Torch::PrimListConstructOp>(
4137- binder.getLoc (),
4138- Torch::ListType::get (Torch::IntType::get (context)),
4139- llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4140-
4141- Torch::ValueTensorType seqLensViewType =
4142- Torch::ValueTensorType::get (
4143- context, llvm::SmallVector<int64_t >{batchSize, 1 },
4144- IntegerType::get (context, 64 , IntegerType::Signed));
4145- pastSeqLens = rewriter.create <Torch::AtenViewOp>(
4146- loc, seqLensViewType, pastSeqLens, viewSizeList);
4147-
4148- positionIds = rewriter.create <Torch::AtenAddTensorOp>(
4149- loc, positionIdsType, positionIds, pastSeqLens,
4150- /* alpha=*/ cstIntOne);
4151-
4152- totalSeqLens = rewriter.create <Torch::AtenViewOp>(
4153- loc, seqLensViewType, totalSeqLens, viewSizeList);
4154- Value cond = rewriter.create <Torch::AtenLtTensorOp>(
4155- loc,
4156- positionIdsType.getWithSizesAndDtype (positionIdsType.getSizes (),
4157- rewriter.getI1Type ()),
4158- positionIds, totalSeqLens);
4159-
4160- Value cstOneTensorDataList =
4161- rewriter.create <Torch::PrimListConstructOp>(
4162- loc,
4163- rewriter.getType <Torch::ListType>(
4164- rewriter.getType <Torch::IntType>()),
4165- SmallVector<Value>{cstIntOne});
4166- Value cstOneTensor = rewriter.create <Torch::AtenTensorOp>(
4167- loc,
4168- Torch::ValueTensorType::get (
4169- context, {},
4170- IntegerType::get (context, 64 , IntegerType::Signed)),
4171- cstOneTensorDataList, /* dtype=*/ cstInt64Dtype,
4172- /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4173-
4174- positionIds = rewriter.create <Torch::AtenWhereSelfOp>(
4175- loc, positionIdsType, cond, positionIds, cstOneTensor);
4176- }
4105+ Value positionIdsA, positionIdsB;
4106+
4107+ Value posIdsSizeList = rewriter.create <Torch::PrimListConstructOp>(
4108+ loc,
4109+ rewriter.getType <Torch::ListType>(
4110+ rewriter.getType <Torch::IntType>()),
4111+ SmallVector<Value>{cstBatchSize, cstSequenceLength});
4112+ positionIdsA = rewriter.create <Torch::AtenZerosOp>(
4113+ loc, positionIdsType, /* size=*/ posIdsSizeList,
4114+ /* dtype=*/ cstInt64Dtype,
4115+ /* layout=*/ cstNone, /* device=*/ cstNone,
4116+ /* pin_memory=*/ cstNone);
4117+
4118+ // Convert seqlens_k which is a tensor of type si32 to si64.
4119+ Torch::ValueTensorType seqLensKType =
4120+ cast<Torch::ValueTensorType>(seqlensK.getType ());
4121+ seqlensK = rewriter.create <Torch::AtenToDtypeOp>(
4122+ loc,
4123+ seqLensKType.getWithSizesAndDtype (
4124+ std::nullopt ,
4125+ rewriter.getIntegerType (/* width=*/ 64 , /* isSigned=*/ true )),
4126+ seqlensK, cstInt64Dtype, /* non_blocking=*/ cstFalse,
4127+ /* copy=*/ cstFalse, /* memory_format=*/ cstNone);
4128+ Value totalSeqLens = rewriter.create <Torch::AtenAddScalarOp>(
4129+ loc, seqlensK.getType (), /* self=*/ seqlensK, /* other=*/ cstIntOne,
4130+ /* alpha=*/ cstIntOne);
4131+ Value pastSeqLens = rewriter.create <Torch::AtenSubScalarOp>(
4132+ loc, totalSeqLens.getType (), /* self=*/ totalSeqLens,
4133+ /* other=*/ cstSequenceLength, /* alpha=*/ cstIntOne);
4134+ Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get (
4135+ context, {sequenceLength},
4136+ IntegerType::get (context, 64 , IntegerType::Signed));
4137+ Value initPosIds = rewriter.create <Torch::AtenArangeOp>(
4138+ loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4139+ /* layout=*/ cstNone,
4140+ /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4141+ Value repeatValuesList = rewriter.create <Torch::PrimListConstructOp>(
4142+ binder.getLoc (),
4143+ Torch::ListType::get (Torch::IntType::get (context)),
4144+ llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4145+ positionIdsB = rewriter.create <Torch::AtenRepeatOp>(
4146+ loc, positionIdsType, initPosIds, /* repeats=*/ repeatValuesList);
4147+
4148+ Value cstIntMinusOne = rewriter.create <Torch::ConstantIntOp>(
4149+ binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4150+ Value viewSizeList = rewriter.create <Torch::PrimListConstructOp>(
4151+ binder.getLoc (),
4152+ Torch::ListType::get (Torch::IntType::get (context)),
4153+ llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4154+
4155+ Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get (
4156+ context, llvm::SmallVector<int64_t >{batchSize, 1 },
4157+ IntegerType::get (context, 64 , IntegerType::Signed));
4158+ pastSeqLens = rewriter.create <Torch::AtenViewOp>(
4159+ loc, seqLensViewType, pastSeqLens, viewSizeList);
4160+
4161+ positionIdsB = rewriter.create <Torch::AtenAddTensorOp>(
4162+ loc, positionIdsType, positionIdsB, pastSeqLens,
4163+ /* alpha=*/ cstIntOne);
4164+
4165+ totalSeqLens = rewriter.create <Torch::AtenViewOp>(
4166+ loc, seqLensViewType, totalSeqLens, viewSizeList);
4167+ Value cond = rewriter.create <Torch::AtenLtTensorOp>(
4168+ loc,
4169+ positionIdsType.getWithSizesAndDtype (positionIdsType.getSizes (),
4170+ rewriter.getI1Type ()),
4171+ positionIdsB, totalSeqLens);
4172+
4173+ Value cstOneTensorDataList =
4174+ rewriter.create <Torch::PrimListConstructOp>(
4175+ loc,
4176+ rewriter.getType <Torch::ListType>(
4177+ rewriter.getType <Torch::IntType>()),
4178+ SmallVector<Value>{cstIntOne});
4179+ Value cstOneTensor = rewriter.create <Torch::AtenTensorOp>(
4180+ loc,
4181+ Torch::ValueTensorType::get (
4182+ context, {},
4183+ IntegerType::get (context, 64 , IntegerType::Signed)),
4184+ cstOneTensorDataList, /* dtype=*/ cstInt64Dtype,
4185+ /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4186+
4187+ positionIdsB = rewriter.create <Torch::AtenWhereSelfOp>(
4188+ loc, positionIdsType, cond, positionIdsB, cstOneTensor);
4189+
4190+ isSubsequentPrompt = rewriter.create <Torch::AtenIntBoolOp>(
4191+ loc, rewriter.getType <Torch::IntType>(), isSubsequentPrompt);
4192+ isSubsequentPrompt = rewriter.create <Torch::AtenFullOp>(
4193+ loc,
4194+ Torch::ValueTensorType::get (context, positionIdsSizeInt,
4195+ rewriter.getI1Type ()),
4196+ /* size=*/ posIdsSizeList, /* fill_value=*/ isSubsequentPrompt,
4197+ /* dtype=*/
4198+ rewriter.create <Torch::ConstantIntOp>(
4199+ binder.getLoc (), rewriter.getI64IntegerAttr (
4200+ (int )torch_upstream::ScalarType::Bool)),
4201+ /* layout=*/ cstNone, /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4202+ Value positionIds = rewriter.create <Torch::AtenWhereSelfOp>(
4203+ loc, positionIdsType, isSubsequentPrompt, positionIdsB,
4204+ positionIdsA);
41774205
41784206 qRotary = rewriter.create <Torch::OnnxVariantAtenRotaryEmbeddingOp>(
41794207 loc, qInput.getType (), qInput, positionIds, cosCache, sinCache,
0 commit comments