@@ -5660,28 +5660,38 @@ partial dictionary MLOpSupportLimits {
56605660 builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
56615661 const batchSize = input.shape[1] ;
56625662 const inputSize = input.shape[2] ;
5663- const numDirections = (options.direction == 'both' ? 2 : 1);
5663+ const direction = options.direction || 'forward' ;
5664+ const numDirections = (direction == 'both' ? 2 : 1);
56645665 let hiddenState = options.initialHiddenState;
56655666 let cellState = options.initialCellState;
56665667
56675668 if (!hiddenState) {
5668- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5669- const totalSize = numDirections * hiddenSize;
5669+ const desc = {
5670+ dataType: 'float32' ,
5671+ shape: [numDirections, batchSize, hiddenSize]
5672+ };
5673+ const totalSize = numDirections * batchSize * hiddenSize;
56705674 hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
56715675 }
56725676
56735677 if (!cellState) {
5674- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
5675- const totalSize = numDirections * hiddenSize;
5678+ const desc = {
5679+ dataType: 'float32' ,
5680+ shape: [numDirections, batchSize, hiddenSize]
5681+ };
5682+ const totalSize = numDirections * batchSize * hiddenSize;
56765683 cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
56775684 }
56785685
5679- let sequence = null;
56805686 let currentWeight = [];
56815687 let currentRecurrentWeight = [];
56825688 let currentBias = [];
56835689 let currentRecurrentBias = [];
56845690 let currentPeepholeWeight = [];
5691+ let forwardSequence = null;
5692+ let backwardSequence = null;
5693+ let outputHidden = null;
5694+ let outputCell = null;
56855695
56865696 for (let dir = 0; dir < numDirections; ++dir) {
56875697 currentWeight.push(squeeze(
@@ -5711,36 +5721,27 @@ partial dictionary MLOpSupportLimits {
57115721 builder.slice(
57125722 options.peepholeWeight, [dir, 0] , [1, 3 * hiddenSize] ))) :
57135723 null);
5714- }
5715-
5716- for (let step = 0; step < steps; ++step) {
5717- let currentHidden = [];
5718- let currentCell = [];
5719- let nextHidden = null;
5720- let nextCell = null;
57215724
5722- for (let dir = 0; dir < numDirections; ++dir) {
5723- currentHidden.push(squeeze(
5724- builder,
5725- builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5726- currentCell.push(squeeze(
5727- builder,
5728- builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
5729- }
5725+ let currentHidden = squeeze(
5726+ builder,
5727+ builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
5728+ let currentCell = squeeze(
5729+ builder,
5730+ builder.slice(cellState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
57305731
5731- for (let dir = 0; dir < numDirections ; ++dir ) {
5732- let slice =
5733- (dir == 1 || options. direction == 'backward' ? steps - step - 1 : step);
5734- let currentInput = squeeze(
5732+ for (let step = 0; step < steps ; ++step ) {
5733+ const slice =
5734+ (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
5735+ const currentInput = squeeze(
57355736 builder,
57365737 builder.slice(input, [slice, 0, 0] , [1, batchSize, inputSize] ));
57375738
5738- let results = builder.lstmCell(
5739+ [currentHidden, currentCell] = builder.lstmCell(
57395740 currentInput,
57405741 currentWeight[dir] ,
57415742 currentRecurrentWeight[dir] ,
5742- currentHidden[dir] ,
5743- currentCell[dir] ,
5743+ currentHidden,
5744+ currentCell,
57445745 hiddenSize,
57455746 {
57465747 bias: currentBias[dir] ,
@@ -5750,27 +5751,58 @@ partial dictionary MLOpSupportLimits {
57505751 activations: options.activations
57515752 });
57525753
5753- let output = builder.reshape(results[0] , [1, batchSize, hiddenSize] );
5754- let cell = builder.reshape(results[1] , [1, batchSize, hiddenSize] );
5755-
5756- nextHidden =
5757- (nextHidden ? builder.concat([nextHidden, output] , 0) : output);
5758- nextCell = (nextCell ? builder.concat([nextCell, cell] , 0) : cell);
5754+ if (options.returnSequence) {
5755+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5756+ // to 4D([steps, numDirections, batchSize, hiddenSize] )
5757+ const expandedHiddenAs4D =
5758+ builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize] );
5759+
5760+ if (direction == 'forward' || (dir == 0 && direction == 'both' )) {
5761+ forwardSequence = forwardSequence ?
5762+ builder.concat([forwardSequence, expandedHiddenAs4D] , 0) :
5763+ expandedHiddenAs4D;
5764+ } else if (
5765+ direction == 'backward' || (dir == 1 && direction == 'both' )) {
5766+ backwardSequence = backwardSequence ?
5767+ builder.concat([expandedHiddenAs4D, backwardSequence] , 0) :
5768+ expandedHiddenAs4D;
5769+ }
5770+ }
57595771 }
57605772
5761- hiddenState = nextHidden;
5762- cellState = nextCell;
5773+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
5774+ // to 3D([numDirections, batchSize, hiddenSize] )
5775+ const expandedHiddenAs3D =
5776+ builder.reshape(currentHidden, [1, batchSize, hiddenSize] );
5777+ outputHidden = outputHidden ?
5778+ builder.concat([outputHidden, expandedHiddenAs3D] , 0) :
5779+ expandedHiddenAs3D;
5780+
5781+ // Expand currentCell of 2D([batchSize, hiddenSize] )
5782+ // to 3D([numDirections, batchSize, hiddenSize] )
5783+ const expandedCellAs3D =
5784+ builder.reshape(currentCell, [1, batchSize, hiddenSize] );
5785+ outputCell = outputCell ?
5786+ builder.concat([outputCell, expandedCellAs3D] , 0) :
5787+ expandedCellAs3D;
5788+ }
57635789
5764- if (options.returnSequence) {
5765- nextHidden =
5766- builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize] );
5767- sequence =
5768- (sequence ? builder.concat([sequence, nextHidden] , 0) : nextHidden);
5790+ if (options.returnSequence) {
5791+ let outputSequence = null;
5792+
5793+ if (direction == 'forward' ) {
5794+ outputSequence = forwardSequence;
5795+ } else if (direction == 'backward' ) {
5796+ outputSequence = backwardSequence;
5797+ } else if (direction == 'both' ) {
5798+ // Concat along axis 1 (numDirections dimension)
5799+ outputSequence = builder.concat([forwardSequence, backwardSequence] , 1);
57695800 }
5770- }
57715801
5772- return (
5773- sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState] );
5802+ return [outputHidden, outputCell, outputSequence] ;
5803+ } else {
5804+ return [outputHidden, outputCell] ;
5805+ }
57745806 }
57755807 </pre>
57765808</details>
0 commit comments