Skip to content

Commit f07059c

Browse files
authored
Add weight prepacking to LSTM kernel (microsoft#5305)
1 parent 11c194c commit f07059c

File tree

7 files changed

+213
-91
lines changed

7 files changed

+213
-91
lines changed

onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
272272
int batch_size = gsl::narrow<int>(X_shape[1]);
273273
int input_size = gsl::narrow<int>(X_shape[2]);
274274

275-
auto status = ValidateCommonRnnInputs(X, W, R, B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
275+
auto status = ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
276276
ORT_RETURN_IF_ERROR(status);
277277

278278
// GRU outputs are optional but must be in the same order

onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ class UniDirectionalLstm {
211211
const ActivationFuncs::Entry& activation_func_h, float clip, concurrency::ThreadPool* thread_pool);
212212

213213
void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
214-
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
215-
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);
214+
const GemmWeights<T>& input_weights, const GemmWeights<T>& recurrent_weights, gsl::span<T>& outputs,
215+
gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);
216216

217217
~UniDirectionalLstm() = default;
218218

@@ -290,20 +290,74 @@ class UniDirectionalLstm {
290290

291291
} // namespace detail
292292

293+
Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed) {
294+
const auto& shape = weights.Shape();
295+
if (shape.NumDimensions() != 3) {
296+
return Status::OK();
297+
}
298+
299+
// weights: [num_directions, 4*hidden_size, input_size]
300+
// recurrence weights: [num_directions, 4*hidden_size, hidden_size]
301+
const size_t N = static_cast<size_t>(shape[1]);
302+
const size_t K = static_cast<size_t>(shape[2]);
303+
304+
if ((shape[0] != num_directions_) || (N != static_cast<size_t>(hidden_size_ * 4))) {
305+
return Status::OK();
306+
}
307+
308+
const size_t packed_weights_size = MlasGemmPackBSize(N, K);
309+
if (packed_weights_size == 0) {
310+
return Status::OK();
311+
}
312+
313+
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
314+
auto* packed_weights_data = alloc->Alloc(SafeInt<size_t>(packed_weights_size) * num_directions_);
315+
packed_weights.buffer_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
316+
packed_weights.weights_size_ = packed_weights_size;
317+
packed_weights.shape_ = shape;
318+
319+
const auto* weights_data = weights.Data<float>();
320+
for (int i = 0; i < num_directions_; i++) {
321+
MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data);
322+
packed_weights_data = static_cast<uint8_t*>(packed_weights_data) + packed_weights_size;
323+
weights_data += N * K;
324+
}
325+
326+
is_packed = true;
327+
return Status::OK();
328+
}
329+
330+
#if !defined(USE_MKLML_FOR_BLAS)
331+
Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) {
332+
is_packed = false;
333+
334+
if (tensor.IsDataType<float>()) {
335+
if (input_idx == 1) {
336+
return TryPackWeights(tensor, packed_W_, is_packed);
337+
} else if (input_idx == 2) {
338+
return TryPackWeights(tensor, packed_R_, is_packed);
339+
}
340+
}
341+
342+
return Status::OK();
343+
}
344+
#endif
345+
293346
Status DeepCpuLstmOp::Compute(OpKernelContext* context) const {
294347
const Tensor& X = *context->Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
295348

296349
Status status;
297350
// auto& logger = context->Logger();
298351

299-
if (X.IsDataType<float>())
352+
if (X.IsDataType<float>()) {
300353
status = ComputeImpl<float>(*context);
301-
else if (X.IsDataType<double>()) {
354+
} else if (X.IsDataType<double>()) {
302355
/* Need to update all the helpers to support double...
303356
status = ComputeImpl<double>(*context); */
304357
ORT_NOT_IMPLEMENTED("LSTM operator does not support double yet");
305-
} else
358+
} else {
306359
ORT_THROW("Invalid data type for LSTM operator of ", X.DataType());
360+
}
307361

308362
return status;
309363
}
@@ -322,8 +376,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
322376
auto& logger = context.Logger();
323377

324378
const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
325-
const Tensor& W = *context.Input<Tensor>(1); // weights. [num_directions, 4*hidden_size, input_size]
326-
const Tensor& R = *context.Input<Tensor>(2); // recurrence weights. [num_directions, 4*hidden_size, hidden_size]
379+
const Tensor* W = packed_W_.buffer_ ? nullptr : context.Input<Tensor>(1);
380+
// weights. [num_directions, 4*hidden_size, input_size]
381+
const Tensor* R = packed_R_.buffer_ ? nullptr : context.Input<Tensor>(2);
382+
// recurrence weights. [num_directions, 4*hidden_size, hidden_size]
327383

328384
// optional
329385
const Tensor* B = context.Input<Tensor>(3); // bias. [num_directions, 8*hidden_size]
@@ -332,13 +388,16 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
332388
const Tensor* initial_c = context.Input<Tensor>(6); // initial cell. [num_directions, batch_size, hidden_size]
333389
const Tensor* P = context.Input<Tensor>(7); // peephole weights. [num_directions, 3*hidden_size]
334390

335-
auto& X_shape = X.Shape();
391+
const auto& X_shape = X.Shape();
336392

337393
int seq_length = gsl::narrow<int>(X_shape[0]);
338394
int batch_size = gsl::narrow<int>(X_shape[1]);
339395
int input_size = gsl::narrow<int>(X_shape[2]);
340396

341-
Status status = ValidateInputs(X, W, R, B, sequence_lens, initial_h, initial_c, P, batch_size);
397+
const auto& W_shape = (W != nullptr) ? W->Shape() : packed_W_.shape_;
398+
const auto& R_shape = (R != nullptr) ? R->Shape() : packed_R_.shape_;
399+
400+
Status status = ValidateInputs(X, W_shape, R_shape, B, sequence_lens, initial_h, initial_c, P, batch_size);
342401
ORT_RETURN_IF_ERROR(status);
343402

344403
// LSTM outputs are optional but must be in the same order
@@ -370,8 +429,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
370429
status = context.GetTempSpaceAllocator(&alloc);
371430
ORT_RETURN_IF_ERROR(status);
372431

373-
gsl::span<const T> input_weights = W.DataAsSpan<T>();
374-
gsl::span<const T> recurrent_weights = R.DataAsSpan<T>();
432+
const auto* input_weights = (W != nullptr) ? W->Data<T>() : nullptr;
433+
const auto* recurrent_weights = (R != nullptr) ? R->Data<T>() : nullptr;
434+
375435
gsl::span<const T> bias = B != nullptr ? B->DataAsSpan<T>() : gsl::span<const T>();
376436
gsl::span<const T> peephole_weights = P != nullptr ? P->DataAsSpan<T>() : gsl::span<const T>();
377437

@@ -381,8 +441,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
381441
const size_t bias_size_per_direction = 8 * hidden_size_;
382442
const size_t peephole_weights_size_per_direction = 3 * hidden_size_;
383443

384-
gsl::span<const T> input_weights_1 = input_weights.subspan(0, input_weights_size_per_direction);
385-
gsl::span<const T> recurrent_weights_1 = recurrent_weights.subspan(0, hidden_weights_size_per_direction);
444+
GemmWeights<T> input_weights_1(0, input_weights, input_weights_size_per_direction, packed_W_);
445+
GemmWeights<T> recurrent_weights_1(0, recurrent_weights, hidden_weights_size_per_direction, packed_R_);
446+
386447
gsl::span<const T> bias_1 = bias.empty() ? bias : bias.subspan(0, bias_size_per_direction);
387448
gsl::span<const T> peephole_weights_1 =
388449
peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(0, peephole_weights_size_per_direction);
@@ -427,11 +488,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
427488
gsl::span<T> last_cell_1 = last_cell.subspan(0, last_cell_size_per_direction);
428489

429490
if (direction_ == Direction::kBidirectional) {
491+
GemmWeights<T> input_weights_2(1, input_weights, input_weights_size_per_direction, packed_W_);
492+
GemmWeights<T> recurrent_weights_2(1, recurrent_weights, hidden_weights_size_per_direction, packed_R_);
493+
430494
// spans for second direction
431-
gsl::span<const T> input_weights_2 =
432-
input_weights.subspan(input_weights_size_per_direction, input_weights_size_per_direction);
433-
gsl::span<const T> hidden_weights_2 =
434-
recurrent_weights.subspan(hidden_weights_size_per_direction, hidden_weights_size_per_direction);
435495
gsl::span<const T> bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction);
436496
gsl::span<const T> peephole_weights_2 =
437497
peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction);
@@ -459,8 +519,8 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
459519

460520
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1,
461521
hidden_output_1, last_cell_1);
462-
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2,
463-
last_cell_2);
522+
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2,
523+
hidden_output_2, last_cell_2);
464524
} else {
465525
detail::UniDirectionalLstm<T> fw(alloc, logger, seq_length, batch_size, input_size, hidden_size_, direction_,
466526
input_forget_, bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
@@ -481,11 +541,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
481541
return Status::OK();
482542
}
483543

484-
Status DeepCpuLstmOp::ValidateInputs(const Tensor& X, const Tensor& W, const Tensor& R, const Tensor* B,
485-
const Tensor* sequence_lens, const Tensor* initial_h, const Tensor* initial_c,
486-
const Tensor* P, int batch_size) const {
544+
Status DeepCpuLstmOp::ValidateInputs(const Tensor& X, const TensorShape& W_shape, const TensorShape& R_shape,
545+
const Tensor* B, const Tensor* sequence_lens, const Tensor* initial_h,
546+
const Tensor* initial_c, const Tensor* P, int batch_size) const {
487547
auto status =
488-
rnn::detail::ValidateCommonRnnInputs(X, W, R, B, 4, sequence_lens, initial_h, num_directions_, hidden_size_);
548+
rnn::detail::ValidateCommonRnnInputs(X, W_shape, R_shape, B, 4, sequence_lens, initial_h, num_directions_, hidden_size_);
489549
ORT_RETURN_IF_ERROR(status);
490550

491551
if (initial_c != nullptr) {
@@ -680,8 +740,8 @@ void UniDirectionalLstm<T>::LoadBias(const gsl::span<const T>& WbRb_values) {
680740
template <typename T>
681741
void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
682742
const gsl::span<const int>& sequence_lengths_arg, const int num_directions,
683-
const gsl::span<const T>& input_weights,
684-
const gsl::span<const T>& recurrent_weights, gsl::span<T>& outputs,
743+
const GemmWeights<T>& input_weights, const GemmWeights<T>& recurrent_weights,
744+
gsl::span<T>& outputs,
685745
gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state) {
686746
// copy spans (just T* and size, not data in span) as we may change them
687747
gsl::span<const T> inputs = inputs_arg;
@@ -736,9 +796,9 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
736796
const int total_rows = max_sequence_length * batch_size_;
737797

738798
// apply the weights to all the inputs and save to output_IOFC
739-
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(), input_size_,
740-
input_weights.cbegin(), input_weights.cend(), // W[iofc]
741-
input_size_, beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, thread_pool_);
799+
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(),
800+
input_weights,
801+
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, thread_pool_);
742802

743803
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
744804

@@ -783,10 +843,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
783843

784844
// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
785845
// Do it sequentially to avoid nested parallelism
786-
ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha, previous_state,
787-
previous_state_end, // Ht-1
788-
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
789-
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
846+
ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha,
847+
previous_state, previous_state_end, // Ht-1
848+
recurrent_weights, // R[iofc]
849+
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
790850
hidden_size_x4, nullptr);
791851

792852
DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
@@ -861,9 +921,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
861921
span_T_iter step_out_IOFC = output_iofc_.begin() + (step * batch_size_) * hidden_size_x4;
862922

863923
// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
864-
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1
865-
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
866-
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
924+
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha,
925+
previous_state, previous_state_end, // Ht-1
926+
recurrent_weights, // R[iofc]
927+
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
867928
hidden_size_x4, thread_pool_);
868929

869930
span_T_iter batched_output;

onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,22 @@ class DeepCpuLstmOp final : public OpKernel {
5050
activation_func_betas);
5151
}
5252

53+
#if !defined(USE_MKLML_FOR_BLAS)
54+
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
55+
#endif
5356
Status Compute(OpKernelContext* context) const override;
5457

5558
~DeepCpuLstmOp() override = default;
5659

5760
private:
61+
Status TryPackWeights(const Tensor& weights, rnn::detail::PackedWeights& packed_weights, bool& is_packed);
62+
5863
template <typename T>
5964
Status ComputeImpl(OpKernelContext& context) const;
6065

6166
Status ValidateInputs(const Tensor& X,
62-
const Tensor& W,
63-
const Tensor& R,
67+
const TensorShape& W,
68+
const TensorShape& R,
6469
const Tensor* B,
6570
const Tensor* sequence_lens,
6671
const Tensor* initial_h,
@@ -75,6 +80,9 @@ class DeepCpuLstmOp final : public OpKernel {
7580
float clip_;
7681
bool input_forget_ = false;
7782

83+
rnn::detail::PackedWeights packed_W_;
84+
rnn::detail::PackedWeights packed_R_;
85+
7886
rnn::detail::ActivationFuncs activation_funcs_;
7987

8088
};

onnxruntime/core/providers/cpu/rnn/rnn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
119119
int64_t batch_size = X.Shape()[1];
120120
int64_t input_size = X.Shape()[2];
121121

122-
auto status = rnn::detail::ValidateCommonRnnInputs(X, W, R, B, 1, sequence_lens, initial_h,
122+
auto status = rnn::detail::ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 1, sequence_lens, initial_h,
123123
num_directions, hidden_size_);
124124
ORT_RETURN_IF_ERROR(status);
125125

onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,15 @@ namespace detail {
2424
using namespace ::onnxruntime::common;
2525

2626
Status ValidateCommonRnnInputs(const Tensor& X,
27-
const Tensor& W,
28-
const Tensor& R,
27+
const TensorShape& W_shape,
28+
const TensorShape& R_shape,
2929
const Tensor* B,
3030
int WRB_dim_1_multipler,
3131
const Tensor* sequence_lens,
3232
const Tensor* initial_h,
3333
int64_t num_directions,
3434
int64_t hidden_size) {
3535
auto& X_shape = X.Shape();
36-
auto& W_shape = W.Shape();
37-
auto& R_shape = R.Shape();
3836

3937
int64_t seq_length = X_shape[0];
4038
int64_t batch_size = X_shape[1];

0 commit comments

Comments
 (0)