@@ -211,8 +211,8 @@ class UniDirectionalLstm {
211
211
const ActivationFuncs::Entry& activation_func_h, float clip, concurrency::ThreadPool* thread_pool);
212
212
213
213
void Compute (const gsl::span<const T>& inputs, const gsl::span<const int >& sequence_lengths, int num_directions,
214
- const gsl::span< const T>& input_weights, const gsl::span<const T>& recurrent_weights ,
215
- gsl::span<T>& outputs, gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);
214
+ const GemmWeights< T>& input_weights, const GemmWeights<T>& recurrent_weights, gsl::span<T>& outputs ,
215
+ gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);
216
216
217
217
~UniDirectionalLstm () = default ;
218
218
@@ -290,20 +290,74 @@ class UniDirectionalLstm {
290
290
291
291
} // namespace detail
292
292
293
+ Status DeepCpuLstmOp::TryPackWeights (const Tensor& weights, PackedWeights& packed_weights, bool & is_packed) {
294
+ const auto & shape = weights.Shape ();
295
+ if (shape.NumDimensions () != 3 ) {
296
+ return Status::OK ();
297
+ }
298
+
299
+ // weights: [num_directions, 4*hidden_size, input_size]
300
+ // recurrence weights: [num_directions, 4*hidden_size, hidden_size]
301
+ const size_t N = static_cast <size_t >(shape[1 ]);
302
+ const size_t K = static_cast <size_t >(shape[2 ]);
303
+
304
+ if ((shape[0 ] != num_directions_) || (N != static_cast <size_t >(hidden_size_ * 4 ))) {
305
+ return Status::OK ();
306
+ }
307
+
308
+ const size_t packed_weights_size = MlasGemmPackBSize (N, K);
309
+ if (packed_weights_size == 0 ) {
310
+ return Status::OK ();
311
+ }
312
+
313
+ auto alloc = Info ().GetAllocator (0 , OrtMemTypeDefault);
314
+ auto * packed_weights_data = alloc->Alloc (SafeInt<size_t >(packed_weights_size) * num_directions_);
315
+ packed_weights.buffer_ = BufferUniquePtr (packed_weights_data, BufferDeleter (alloc));
316
+ packed_weights.weights_size_ = packed_weights_size;
317
+ packed_weights.shape_ = shape;
318
+
319
+ const auto * weights_data = weights.Data <float >();
320
+ for (int i = 0 ; i < num_directions_; i++) {
321
+ MlasGemmPackB (CblasTrans, N, K, weights_data, K, packed_weights_data);
322
+ packed_weights_data = static_cast <uint8_t *>(packed_weights_data) + packed_weights_size;
323
+ weights_data += N * K;
324
+ }
325
+
326
+ is_packed = true ;
327
+ return Status::OK ();
328
+ }
329
+
330
+ #if !defined(USE_MKLML_FOR_BLAS)
331
+ Status DeepCpuLstmOp::PrePack (const Tensor& tensor, int input_idx, bool & is_packed) {
332
+ is_packed = false ;
333
+
334
+ if (tensor.IsDataType <float >()) {
335
+ if (input_idx == 1 ) {
336
+ return TryPackWeights (tensor, packed_W_, is_packed);
337
+ } else if (input_idx == 2 ) {
338
+ return TryPackWeights (tensor, packed_R_, is_packed);
339
+ }
340
+ }
341
+
342
+ return Status::OK ();
343
+ }
344
+ #endif
345
+
293
346
Status DeepCpuLstmOp::Compute (OpKernelContext* context) const {
294
347
const Tensor& X = *context->Input <Tensor>(0 ); // inputs. [seq_length, batch_size, input_size]
295
348
296
349
Status status;
297
350
// auto& logger = context->Logger();
298
351
299
- if (X.IsDataType <float >())
352
+ if (X.IsDataType <float >()) {
300
353
status = ComputeImpl<float >(*context);
301
- else if (X.IsDataType <double >()) {
354
+ } else if (X.IsDataType <double >()) {
302
355
/* Need to update all the helpers to support double...
303
356
status = ComputeImpl<double>(*context); */
304
357
ORT_NOT_IMPLEMENTED (" LSTM operator does not support double yet" );
305
- } else
358
+ } else {
306
359
ORT_THROW (" Invalid data type for LSTM operator of " , X.DataType ());
360
+ }
307
361
308
362
return status;
309
363
}
@@ -322,8 +376,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
322
376
auto & logger = context.Logger ();
323
377
324
378
const Tensor& X = *context.Input <Tensor>(0 ); // inputs. [seq_length, batch_size, input_size]
325
- const Tensor& W = *context.Input <Tensor>(1 ); // weights. [num_directions, 4*hidden_size, input_size]
326
- const Tensor& R = *context.Input <Tensor>(2 ); // recurrence weights. [num_directions, 4*hidden_size, hidden_size]
379
+ const Tensor* W = packed_W_.buffer_ ? nullptr : context.Input <Tensor>(1 );
380
+ // weights. [num_directions, 4*hidden_size, input_size]
381
+ const Tensor* R = packed_R_.buffer_ ? nullptr : context.Input <Tensor>(2 );
382
+ // recurrence weights. [num_directions, 4*hidden_size, hidden_size]
327
383
328
384
// optional
329
385
const Tensor* B = context.Input <Tensor>(3 ); // bias. [num_directions, 8*hidden_size]
@@ -332,13 +388,16 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
332
388
const Tensor* initial_c = context.Input <Tensor>(6 ); // initial cell. [num_directions, batch_size, hidden_size]
333
389
const Tensor* P = context.Input <Tensor>(7 ); // peephole weights. [num_directions, 3*hidden_size]
334
390
335
- auto & X_shape = X.Shape ();
391
+ const auto & X_shape = X.Shape ();
336
392
337
393
int seq_length = gsl::narrow<int >(X_shape[0 ]);
338
394
int batch_size = gsl::narrow<int >(X_shape[1 ]);
339
395
int input_size = gsl::narrow<int >(X_shape[2 ]);
340
396
341
- Status status = ValidateInputs (X, W, R, B, sequence_lens, initial_h, initial_c, P, batch_size);
397
+ const auto & W_shape = (W != nullptr ) ? W->Shape () : packed_W_.shape_ ;
398
+ const auto & R_shape = (R != nullptr ) ? R->Shape () : packed_R_.shape_ ;
399
+
400
+ Status status = ValidateInputs (X, W_shape, R_shape, B, sequence_lens, initial_h, initial_c, P, batch_size);
342
401
ORT_RETURN_IF_ERROR (status);
343
402
344
403
// LSTM outputs are optional but must be in the same order
@@ -370,8 +429,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
370
429
status = context.GetTempSpaceAllocator (&alloc);
371
430
ORT_RETURN_IF_ERROR (status);
372
431
373
- gsl::span<const T> input_weights = W.DataAsSpan <T>();
374
- gsl::span<const T> recurrent_weights = R.DataAsSpan <T>();
432
+ const auto * input_weights = (W != nullptr ) ? W->Data <T>() : nullptr ;
433
+ const auto * recurrent_weights = (R != nullptr ) ? R->Data <T>() : nullptr ;
434
+
375
435
gsl::span<const T> bias = B != nullptr ? B->DataAsSpan <T>() : gsl::span<const T>();
376
436
gsl::span<const T> peephole_weights = P != nullptr ? P->DataAsSpan <T>() : gsl::span<const T>();
377
437
@@ -381,8 +441,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
381
441
const size_t bias_size_per_direction = 8 * hidden_size_;
382
442
const size_t peephole_weights_size_per_direction = 3 * hidden_size_;
383
443
384
- gsl::span<const T> input_weights_1 = input_weights.subspan (0 , input_weights_size_per_direction);
385
- gsl::span<const T> recurrent_weights_1 = recurrent_weights.subspan (0 , hidden_weights_size_per_direction);
444
+ GemmWeights<T> input_weights_1 (0 , input_weights, input_weights_size_per_direction, packed_W_);
445
+ GemmWeights<T> recurrent_weights_1 (0 , recurrent_weights, hidden_weights_size_per_direction, packed_R_);
446
+
386
447
gsl::span<const T> bias_1 = bias.empty () ? bias : bias.subspan (0 , bias_size_per_direction);
387
448
gsl::span<const T> peephole_weights_1 =
388
449
peephole_weights.empty () ? peephole_weights : peephole_weights.subspan (0 , peephole_weights_size_per_direction);
@@ -427,11 +488,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
427
488
gsl::span<T> last_cell_1 = last_cell.subspan (0 , last_cell_size_per_direction);
428
489
429
490
if (direction_ == Direction::kBidirectional ) {
491
+ GemmWeights<T> input_weights_2 (1 , input_weights, input_weights_size_per_direction, packed_W_);
492
+ GemmWeights<T> recurrent_weights_2 (1 , recurrent_weights, hidden_weights_size_per_direction, packed_R_);
493
+
430
494
// spans for second direction
431
- gsl::span<const T> input_weights_2 =
432
- input_weights.subspan (input_weights_size_per_direction, input_weights_size_per_direction);
433
- gsl::span<const T> hidden_weights_2 =
434
- recurrent_weights.subspan (hidden_weights_size_per_direction, hidden_weights_size_per_direction);
435
495
gsl::span<const T> bias_2 = bias.empty () ? bias : bias.subspan (bias_size_per_direction, bias_size_per_direction);
436
496
gsl::span<const T> peephole_weights_2 =
437
497
peephole_weights.empty () ? peephole_weights : peephole_weights.subspan (peephole_weights_size_per_direction, peephole_weights_size_per_direction);
@@ -459,8 +519,8 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
459
519
460
520
fw.Compute (input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1,
461
521
hidden_output_1, last_cell_1);
462
- bw.Compute (input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2 , output_2, hidden_output_2 ,
463
- last_cell_2);
522
+ bw.Compute (input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2 , output_2,
523
+ hidden_output_2, last_cell_2);
464
524
} else {
465
525
detail::UniDirectionalLstm<T> fw (alloc, logger, seq_length, batch_size, input_size, hidden_size_, direction_,
466
526
input_forget_, bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
@@ -481,11 +541,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
481
541
return Status::OK ();
482
542
}
483
543
484
- Status DeepCpuLstmOp::ValidateInputs (const Tensor& X, const Tensor& W , const Tensor& R, const Tensor* B ,
485
- const Tensor* sequence_lens , const Tensor* initial_h , const Tensor* initial_c ,
486
- const Tensor* P, int batch_size) const {
544
+ Status DeepCpuLstmOp::ValidateInputs (const Tensor& X, const TensorShape& W_shape , const TensorShape& R_shape ,
545
+ const Tensor* B , const Tensor* sequence_lens , const Tensor* initial_h ,
546
+ const Tensor* initial_c, const Tensor* P, int batch_size) const {
487
547
auto status =
488
- rnn::detail::ValidateCommonRnnInputs (X, W, R , B, 4 , sequence_lens, initial_h, num_directions_, hidden_size_);
548
+ rnn::detail::ValidateCommonRnnInputs (X, W_shape, R_shape , B, 4 , sequence_lens, initial_h, num_directions_, hidden_size_);
489
549
ORT_RETURN_IF_ERROR (status);
490
550
491
551
if (initial_c != nullptr ) {
@@ -680,8 +740,8 @@ void UniDirectionalLstm<T>::LoadBias(const gsl::span<const T>& WbRb_values) {
680
740
template <typename T>
681
741
void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
682
742
const gsl::span<const int >& sequence_lengths_arg, const int num_directions,
683
- const gsl::span< const T>& input_weights ,
684
- const gsl::span< const T>& recurrent_weights, gsl::span<T>& outputs,
743
+ const GemmWeights<T>& input_weights, const GemmWeights< T>& recurrent_weights ,
744
+ gsl::span<T>& outputs,
685
745
gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state) {
686
746
// copy spans (just T* and size, not data in span) as we may change them
687
747
gsl::span<const T> inputs = inputs_arg;
@@ -736,9 +796,9 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
736
796
const int total_rows = max_sequence_length * batch_size_;
737
797
738
798
// apply the weights to all the inputs and save to output_IOFC
739
- ComputeGemm (total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin (), inputs.cend (), input_size_,
740
- input_weights. cbegin (), input_weights. cend (), // W[iofc]
741
- input_size_, beta, output_iofc_.begin (), output_iofc_.end (), hidden_size_x4, thread_pool_);
799
+ ComputeGemm (total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin (), inputs.cend (),
800
+ input_weights,
801
+ beta, output_iofc_.begin (), output_iofc_.end (), hidden_size_x4, thread_pool_);
742
802
743
803
DumpMatrix (" Xt*(W[iofc]^T)" , output_iofc_.data (), total_rows, hidden_size_x4);
744
804
@@ -783,10 +843,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
783
843
784
844
// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
785
845
// Do it sequentially to avoid nested parallelism
786
- ComputeGemm (local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha, previous_state,
787
- previous_state_end, // Ht-1
788
- hidden_size_, recurrent_weights. cbegin (), recurrent_weights. cend (), // R[iofc]
789
- hidden_size_, beta, step_out_IOFC, output_iofc_.end (), // input contains Xt*(W[iofc]^T)
846
+ ComputeGemm (local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha,
847
+ previous_state, previous_state_end, // Ht-1
848
+ recurrent_weights, // R[iofc]
849
+ beta, step_out_IOFC, output_iofc_.end (), // input contains Xt*(W[iofc]^T)
790
850
hidden_size_x4, nullptr );
791
851
792
852
DumpMatrix (" Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
@@ -861,9 +921,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
861
921
span_T_iter step_out_IOFC = output_iofc_.begin () + (step * batch_size_) * hidden_size_x4;
862
922
863
923
// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
864
- ComputeGemm (batch_size_, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1
865
- hidden_size_, recurrent_weights.cbegin (), recurrent_weights.cend (), // R[iofc]
866
- hidden_size_, beta, step_out_IOFC, output_iofc_.end (), // input contains Xt*(W[iofc]^T)
924
+ ComputeGemm (batch_size_, hidden_size_x4, hidden_size_, alpha,
925
+ previous_state, previous_state_end, // Ht-1
926
+ recurrent_weights, // R[iofc]
927
+ beta, step_out_IOFC, output_iofc_.end (), // input contains Xt*(W[iofc]^T)
867
928
hidden_size_x4, thread_pool_);
868
929
869
930
span_T_iter batched_output;
0 commit comments