@@ -5660,28 +5660,38 @@ partial dictionary MLOpSupportLimits {
5660
5660
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
5661
5661
const batchSize = input.shape[1] ;
5662
5662
const inputSize = input.shape[2] ;
5663
- const numDirections = (options.direction == 'both' ? 2 : 1);
5663
+ const direction = options.direction || 'forward' ;
5664
+ const numDirections = (direction == 'both' ? 2 : 1);
5664
5665
let hiddenState = options.initialHiddenState;
5665
5666
let cellState = options.initialCellState;
5666
5667
5667
5668
if (!hiddenState) {
5668
- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5669
- const totalSize = numDirections * hiddenSize;
5669
+ const desc = {
5670
+ dataType: 'float32' ,
5671
+ shape: [numDirections, batchSize, hiddenSize]
5672
+ };
5673
+ const totalSize = numDirections * batchSize * hiddenSize;
5670
5674
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
5671
5675
}
5672
5676
5673
5677
if (!cellState) {
5674
- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5675
- const totalSize = numDirections * hiddenSize;
5678
+ const desc = {
5679
+ dataType: 'float32' ,
5680
+ shape: [numDirections, batchSize, hiddenSize]
5681
+ };
5682
+ const totalSize = numDirections * batchSize * hiddenSize;
5676
5683
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
5677
5684
}
5678
5685
5679
- let sequence = null;
5680
5686
let currentWeight = [];
5681
5687
let currentRecurrentWeight = [];
5682
5688
let currentBias = [];
5683
5689
let currentRecurrentBias = [];
5684
5690
let currentPeepholeWeight = [];
5691
+ let forwardSequence = null;
5692
+ let backwardSequence = null;
5693
+ let outputHidden = null;
5694
+ let outputCell = null;
5685
5695
5686
5696
for (let dir = 0; dir < numDirections; ++dir) {
5687
5697
currentWeight.push(squeeze(
@@ -5711,36 +5721,27 @@ partial dictionary MLOpSupportLimits {
5711
5721
builder.slice(
5712
5722
options.peepholeWeight, [dir, 0] , [1, 3 * hiddenSize] ))) :
5713
5723
null);
5714
- }
5715
-
5716
- for (let step = 0; step < steps; ++step) {
5717
- let currentHidden = [];
5718
- let currentCell = [];
5719
- let nextHidden = null;
5720
- let nextCell = null;
5721
5724
5722
- for (let dir = 0; dir < numDirections; ++dir) {
5723
- currentHidden.push(squeeze(
5724
- builder,
5725
- builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5726
- currentCell.push(squeeze(
5727
- builder,
5728
- builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5729
- }
5725
+ let currentHidden = squeeze(
5726
+ builder,
5727
+ builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
5728
+ let currentCell = squeeze(
5729
+ builder,
5730
+ builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
5730
5731
5731
- for (let dir = 0; dir < numDirections ; ++dir ) {
5732
- let slice =
5733
- (dir == 1 || options. direction == 'backward' ? steps - step - 1 : step);
5734
- let currentInput = squeeze(
5732
+ for (let step = 0; step < steps ; ++step ) {
5733
+ const slice =
5734
+ (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
5735
+ const currentInput = squeeze(
5735
5736
builder,
5736
5737
builder.slice(input, [slice, 0, 0] , [1, batchSize, inputSize] ));
5737
5738
5738
- let results = builder.lstmCell(
5739
+ [currentHidden, currentCell] = builder.lstmCell(
5739
5740
currentInput,
5740
5741
currentWeight[dir] ,
5741
5742
currentRecurrentWeight[dir] ,
5742
- currentHidden[dir] ,
5743
- currentCell[dir] ,
5743
+ currentHidden,
5744
+ currentCell,
5744
5745
hiddenSize,
5745
5746
{
5746
5747
bias: currentBias[dir] ,
@@ -5750,27 +5751,58 @@ partial dictionary MLOpSupportLimits {
5750
5751
activations: options.activations
5751
5752
});
5752
5753
5753
- let output = builder.reshape(results[0] , [1, batchSize, hiddenSize] );
5754
- let cell = builder.reshape(results[1] , [1, batchSize, hiddenSize] );
5755
-
5756
- nextHidden =
5757
- (nextHidden ? builder.concat([nextHidden, output] , 0) : output);
5758
- nextCell = (nextCell ? builder.concat([nextCell, cell] , 0) : cell);
5754
+ if (options.returnSequence) {
5755
+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5756
+ // to 4D([steps, numDirections, batchSize, hiddenSize] )
5757
+ const expandedHiddenAs4D =
5758
+ builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize] );
5759
+
5760
+ if (direction == 'forward' || (dir == 0 && direction == 'both' )) {
5761
+ forwardSequence = forwardSequence ?
5762
+ builder.concat([forwardSequence, expandedHiddenAs4D] , 0) :
5763
+ expandedHiddenAs4D;
5764
+ } else if (
5765
+ direction == 'backward' || (dir == 1 && direction == 'both' )) {
5766
+ backwardSequence = backwardSequence ?
5767
+ builder.concat([expandedHiddenAs4D, backwardSequence] , 0) :
5768
+ expandedHiddenAs4D;
5769
+ }
5770
+ }
5759
5771
}
5760
5772
5761
- hiddenState = nextHidden;
5762
- cellState = nextCell;
5773
+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5774
+ // to 3D([numDirections, batchSize, hiddenSize] )
5775
+ const expandedHiddenAs3D =
5776
+ builder.reshape(currentHidden, [1, batchSize, hiddenSize] );
5777
+ outputHidden = outputHidden ?
5778
+ builder.concat([outputHidden, expandedHiddenAs3D] , 0) :
5779
+ expandedHiddenAs3D;
5780
+
5781
+ // Expand currentCell of 2D([batchSize, hiddenSize] )
5782
+ // to 3D([numDirections, batchSize, hiddenSize] )
5783
+ const expandedCellAs3D =
5784
+ builder.reshape(currentCell, [1, batchSize, hiddenSize] );
5785
+ outputCell = outputCell ?
5786
+ builder.concat([outputCell, expandedCellAs3D] , 0) :
5787
+ expandedCellAs3D;
5788
+ }
5763
5789
5764
- if (options.returnSequence) {
5765
- nextHidden =
5766
- builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize] );
5767
- sequence =
5768
- (sequence ? builder.concat([sequence, nextHidden] , 0) : nextHidden);
5790
+ if (options.returnSequence) {
5791
+ let outputSequence = null;
5792
+
5793
+ if (direction == 'forward' ) {
5794
+ outputSequence = forwardSequence;
5795
+ } else if (direction == 'backward' ) {
5796
+ outputSequence = backwardSequence;
5797
+ } else if (direction == 'both' ) {
5798
+ // Concat along axis 1 (numDirections dimension)
5799
+ outputSequence = builder.concat([forwardSequence, backwardSequence] , 1);
5769
5800
}
5770
- }
5771
5801
5772
- return (
5773
- sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState] );
5802
+ return [outputHidden, outputCell, outputSequence] ;
5803
+ } else {
5804
+ return [outputHidden, outputCell] ;
5805
+ }
5774
5806
}
5775
5807
</pre>
5776
5808
</details>
0 commit comments