@@ -3968,6 +3968,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3968
3968
Value pastKey = operands[3 ];
3969
3969
Value pastValue = operands[4 ];
3970
3970
Value seqlensK = operands[5 ];
3971
+ Value totalSequenceLength = operands[6 ];
3971
3972
Value cosCache, sinCache;
3972
3973
if (doRotary) {
3973
3974
cosCache = operands[7 ];
@@ -4053,25 +4054,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
4053
4054
4054
4055
Value qRotary = qInput, kRotary = kInput ;
4055
4056
if (doRotary) {
4056
- bool isFirstPrompt = false ;
4057
- // TODO: add the rotary embedding support
4057
+ // `totalSequenceLength` is a scalar tensor.
4058
+ Value scalarTotalSeqLens = rewriter.create <Torch::AtenItemOp>(
4059
+ loc, rewriter.getType <Torch::IntType>(), totalSequenceLength);
4060
+ Value cstIntOne = rewriter.create <Torch::ConstantIntOp>(
4061
+ binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4062
+ Type boolTy = rewriter.getType <Torch::BoolType>();
4063
+ Value condA = rewriter.create <Torch::AtenGtIntOp>(
4064
+ loc, boolTy, cstSequenceLength, cstIntOne);
4065
+ Value condB = rewriter.create <Torch::AtenNeIntOp>(
4066
+ loc, boolTy, cstSequenceLength, scalarTotalSeqLens);
4067
+ // if (sequence_length > 1 && sequence_length !=
4068
+ // total_sequence_length)
4069
+ // is_subsequent_prompt = false; // Subsequent prompt
4070
+ Value isSubsequentPrompt = rewriter.create <Torch::Aten__And__BoolOp>(
4071
+ loc, boolTy, condA, condB);
4058
4072
4059
4073
// Generating position_ids for rotary_embedding as follows:
4060
- // if is_first_prompt:
4061
- // pos_ids = torch.tensor([0], dtype=torch.int64)
4062
- // else:
4074
+ // pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64)
4075
+ //
4063
4076
// total_seqlens = seqlens_k + 1
4064
4077
// past_seqlens = total_seqlens - sequence_length
4065
4078
// pos_ids = torch.arange(sequence_length,
4066
4079
// dtype=torch.int64).repeat(batch_size, 1)
4067
4080
// pos_ids = pos_ids + past_seqlens.view(-1, 1)
4068
4081
// cond = pos_ids < total_seqlens.view(-1, 1)
4069
4082
// one_tensor = torch.tensor(1, dtype=torch.int64)
4070
- // pos_ids = torch.where(cond, pos_ids, one_tensor)
4071
- SmallVector<int64_t > positionIdsSizeInt{1 };
4072
- if (!isFirstPrompt)
4073
- positionIdsSizeInt = {batchSize, sequenceLength};
4074
-
4083
+ // pos_ids_b = torch.where(cond, pos_ids, one_tensor)
4084
+ //
4085
+ // if subsequent_prompt:
4086
+ // pos_ids = pos_ids_b
4087
+ // else:
4088
+ // pos_ids = pos_ids_a
4089
+ SmallVector<int64_t > positionIdsSizeInt{batchSize, sequenceLength};
4075
4090
Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get (
4076
4091
context, positionIdsSizeInt,
4077
4092
IntegerType::get (context, 64 , IntegerType::Signed));
@@ -4087,93 +4102,106 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
4087
4102
binder.getLoc (), rewriter.getType <Torch::FloatType>(),
4088
4103
rewriter.getF64FloatAttr (1.0 ));
4089
4104
4090
- Value positionIds;
4091
- if (isFirstPrompt) {
4092
- positionIdsSizeInt = {1 };
4093
- Value posIdsDataList = rewriter.create <Torch::PrimListConstructOp>(
4094
- loc,
4095
- rewriter.getType <Torch::ListType>(
4096
- rewriter.getType <Torch::IntType>()),
4097
- SmallVector<Value>{cstIntZero});
4098
- positionIds = rewriter.create <Torch::AtenTensorOp>(
4099
- loc, positionIdsType, posIdsDataList, /* dtype=*/ cstInt64Dtype,
4100
- /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4101
- } else {
4102
- // Convert seqlens_k which is a tensor of type si32 to si64.
4103
- Torch::ValueTensorType seqLensKType =
4104
- cast<Torch::ValueTensorType>(seqlensK.getType ());
4105
- seqlensK = rewriter.create <Torch::AtenToDtypeOp>(
4106
- loc,
4107
- seqLensKType.getWithSizesAndDtype (std::nullopt,
4108
- rewriter.getI64Type ()),
4109
- seqlensK, cstInt64Dtype, /* non_blocking=*/ cstFalse,
4110
- /* copy=*/ cstFalse, /* memory_format=*/ cstNone);
4111
- Value cstIntOne = rewriter.create <Torch::ConstantIntOp>(
4112
- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4113
- Value totalSeqLens = rewriter.create <Torch::AtenAddScalarOp>(
4114
- loc, seqlensK.getType (), /* self=*/ seqlensK, /* other=*/ cstIntOne,
4115
- /* alpha=*/ cstIntOne);
4116
- Value pastSeqLens = rewriter.create <Torch::AtenSubScalarOp>(
4117
- loc, totalSeqLens.getType (), /* self=*/ totalSeqLens,
4118
- /* other=*/ cstSequenceLength, /* alpha=*/ cstIntOne);
4119
- Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get (
4120
- context, {sequenceLength},
4121
- IntegerType::get (context, 64 , IntegerType::Signed));
4122
- Value initPosIds = rewriter.create <Torch::AtenArangeOp>(
4123
- loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4124
- /* layout=*/ cstNone,
4125
- /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4126
- Value repeatValuesList =
4127
- rewriter.create <Torch::PrimListConstructOp>(
4128
- binder.getLoc (),
4129
- Torch::ListType::get (Torch::IntType::get (context)),
4130
- llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4131
- positionIds = rewriter.create <Torch::AtenRepeatOp>(
4132
- loc, positionIdsType, initPosIds, /* repeats=*/ repeatValuesList);
4133
-
4134
- Value cstIntMinusOne = rewriter.create <Torch::ConstantIntOp>(
4135
- binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4136
- Value viewSizeList = rewriter.create <Torch::PrimListConstructOp>(
4137
- binder.getLoc (),
4138
- Torch::ListType::get (Torch::IntType::get (context)),
4139
- llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4140
-
4141
- Torch::ValueTensorType seqLensViewType =
4142
- Torch::ValueTensorType::get (
4143
- context, llvm::SmallVector<int64_t >{batchSize, 1 },
4144
- IntegerType::get (context, 64 , IntegerType::Signed));
4145
- pastSeqLens = rewriter.create <Torch::AtenViewOp>(
4146
- loc, seqLensViewType, pastSeqLens, viewSizeList);
4147
-
4148
- positionIds = rewriter.create <Torch::AtenAddTensorOp>(
4149
- loc, positionIdsType, positionIds, pastSeqLens,
4150
- /* alpha=*/ cstIntOne);
4151
-
4152
- totalSeqLens = rewriter.create <Torch::AtenViewOp>(
4153
- loc, seqLensViewType, totalSeqLens, viewSizeList);
4154
- Value cond = rewriter.create <Torch::AtenLtTensorOp>(
4155
- loc,
4156
- positionIdsType.getWithSizesAndDtype (positionIdsType.getSizes (),
4157
- rewriter.getI1Type ()),
4158
- positionIds, totalSeqLens);
4159
-
4160
- Value cstOneTensorDataList =
4161
- rewriter.create <Torch::PrimListConstructOp>(
4162
- loc,
4163
- rewriter.getType <Torch::ListType>(
4164
- rewriter.getType <Torch::IntType>()),
4165
- SmallVector<Value>{cstIntOne});
4166
- Value cstOneTensor = rewriter.create <Torch::AtenTensorOp>(
4167
- loc,
4168
- Torch::ValueTensorType::get (
4169
- context, {},
4170
- IntegerType::get (context, 64 , IntegerType::Signed)),
4171
- cstOneTensorDataList, /* dtype=*/ cstInt64Dtype,
4172
- /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4173
-
4174
- positionIds = rewriter.create <Torch::AtenWhereSelfOp>(
4175
- loc, positionIdsType, cond, positionIds, cstOneTensor);
4176
- }
4105
+ Value positionIdsA, positionIdsB;
4106
+
4107
+ Value posIdsSizeList = rewriter.create <Torch::PrimListConstructOp>(
4108
+ loc,
4109
+ rewriter.getType <Torch::ListType>(
4110
+ rewriter.getType <Torch::IntType>()),
4111
+ SmallVector<Value>{cstBatchSize, cstSequenceLength});
4112
+ positionIdsA = rewriter.create <Torch::AtenZerosOp>(
4113
+ loc, positionIdsType, /* size=*/ posIdsSizeList,
4114
+ /* dtype=*/ cstInt64Dtype,
4115
+ /* layout=*/ cstNone, /* device=*/ cstNone,
4116
+ /* pin_memory=*/ cstNone);
4117
+
4118
+ // Convert seqlens_k which is a tensor of type si32 to si64.
4119
+ Torch::ValueTensorType seqLensKType =
4120
+ cast<Torch::ValueTensorType>(seqlensK.getType ());
4121
+ seqlensK = rewriter.create <Torch::AtenToDtypeOp>(
4122
+ loc,
4123
+ seqLensKType.getWithSizesAndDtype (
4124
+ std::nullopt,
4125
+ rewriter.getIntegerType (/* width=*/ 64 , /* isSigned=*/ true )),
4126
+ seqlensK, cstInt64Dtype, /* non_blocking=*/ cstFalse,
4127
+ /* copy=*/ cstFalse, /* memory_format=*/ cstNone);
4128
+ Value totalSeqLens = rewriter.create <Torch::AtenAddScalarOp>(
4129
+ loc, seqlensK.getType (), /* self=*/ seqlensK, /* other=*/ cstIntOne,
4130
+ /* alpha=*/ cstIntOne);
4131
+ Value pastSeqLens = rewriter.create <Torch::AtenSubScalarOp>(
4132
+ loc, totalSeqLens.getType (), /* self=*/ totalSeqLens,
4133
+ /* other=*/ cstSequenceLength, /* alpha=*/ cstIntOne);
4134
+ Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get (
4135
+ context, {sequenceLength},
4136
+ IntegerType::get (context, 64 , IntegerType::Signed));
4137
+ Value initPosIds = rewriter.create <Torch::AtenArangeOp>(
4138
+ loc, initPosIdsType, cstSequenceLength, cstInt64Dtype,
4139
+ /* layout=*/ cstNone,
4140
+ /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4141
+ Value repeatValuesList = rewriter.create <Torch::PrimListConstructOp>(
4142
+ binder.getLoc (),
4143
+ Torch::ListType::get (Torch::IntType::get (context)),
4144
+ llvm::SmallVector<Value>{cstBatchSize, cstIntOne});
4145
+ positionIdsB = rewriter.create <Torch::AtenRepeatOp>(
4146
+ loc, positionIdsType, initPosIds, /* repeats=*/ repeatValuesList);
4147
+
4148
+ Value cstIntMinusOne = rewriter.create <Torch::ConstantIntOp>(
4149
+ binder.getLoc (), rewriter.getI64IntegerAttr (1 ));
4150
+ Value viewSizeList = rewriter.create <Torch::PrimListConstructOp>(
4151
+ binder.getLoc (),
4152
+ Torch::ListType::get (Torch::IntType::get (context)),
4153
+ llvm::SmallVector<Value>{cstIntMinusOne, cstIntOne});
4154
+
4155
+ Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get (
4156
+ context, llvm::SmallVector<int64_t >{batchSize, 1 },
4157
+ IntegerType::get (context, 64 , IntegerType::Signed));
4158
+ pastSeqLens = rewriter.create <Torch::AtenViewOp>(
4159
+ loc, seqLensViewType, pastSeqLens, viewSizeList);
4160
+
4161
+ positionIdsB = rewriter.create <Torch::AtenAddTensorOp>(
4162
+ loc, positionIdsType, positionIdsB, pastSeqLens,
4163
+ /* alpha=*/ cstIntOne);
4164
+
4165
+ totalSeqLens = rewriter.create <Torch::AtenViewOp>(
4166
+ loc, seqLensViewType, totalSeqLens, viewSizeList);
4167
+ Value cond = rewriter.create <Torch::AtenLtTensorOp>(
4168
+ loc,
4169
+ positionIdsType.getWithSizesAndDtype (positionIdsType.getSizes (),
4170
+ rewriter.getI1Type ()),
4171
+ positionIdsB, totalSeqLens);
4172
+
4173
+ Value cstOneTensorDataList =
4174
+ rewriter.create <Torch::PrimListConstructOp>(
4175
+ loc,
4176
+ rewriter.getType <Torch::ListType>(
4177
+ rewriter.getType <Torch::IntType>()),
4178
+ SmallVector<Value>{cstIntOne});
4179
+ Value cstOneTensor = rewriter.create <Torch::AtenTensorOp>(
4180
+ loc,
4181
+ Torch::ValueTensorType::get (
4182
+ context, {},
4183
+ IntegerType::get (context, 64 , IntegerType::Signed)),
4184
+ cstOneTensorDataList, /* dtype=*/ cstInt64Dtype,
4185
+ /* layout=*/ cstNone, /* requires_grad=*/ cstFalse);
4186
+
4187
+ positionIdsB = rewriter.create <Torch::AtenWhereSelfOp>(
4188
+ loc, positionIdsType, cond, positionIdsB, cstOneTensor);
4189
+
4190
+ isSubsequentPrompt = rewriter.create <Torch::AtenIntBoolOp>(
4191
+ loc, rewriter.getType <Torch::IntType>(), isSubsequentPrompt);
4192
+ isSubsequentPrompt = rewriter.create <Torch::AtenFullOp>(
4193
+ loc,
4194
+ Torch::ValueTensorType::get (context, positionIdsSizeInt,
4195
+ rewriter.getI1Type ()),
4196
+ /* size=*/ posIdsSizeList, /* fill_value=*/ isSubsequentPrompt,
4197
+ /* dtype=*/
4198
+ rewriter.create <Torch::ConstantIntOp>(
4199
+ binder.getLoc (), rewriter.getI64IntegerAttr (
4200
+ (int )torch_upstream::ScalarType::Bool)),
4201
+ /* layout=*/ cstNone, /* device=*/ cstNone, /* pin_memory=*/ cstNone);
4202
+ Value positionIds = rewriter.create <Torch::AtenWhereSelfOp>(
4203
+ loc, positionIdsType, isSubsequentPrompt, positionIdsB,
4204
+ positionIdsA);
4177
4205
4178
4206
qRotary = rewriter.create <Torch::OnnxVariantAtenRotaryEmbeddingOp>(
4179
4207
loc, qInput.getType (), qInput, positionIds, cosCache, sinCache,
0 commit comments