@@ -4331,20 +4331,26 @@ partial dictionary MLOpSupportLimits {
43314331 builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
43324332 const batchSize = input.shape[1] ;
43334333 const inputSize = input.shape[2] ;
4334- const numDirections = (options.direction == 'both' ? 2 : 1);
4334+ const direction = options.direction || 'forward' ;
4335+ const numDirections = (direction == 'both' ? 2 : 1);
43354336 let hiddenState = options.initialHiddenState;
43364337
43374338 if (!hiddenState) {
4338- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
4339- const totalSize = numDirections * hiddenSize;
4339+ const desc = {
4340+ dataType: 'float32' ,
4341+ shape: [numDirections, batchSize, hiddenSize]
4342+ };
4343+ const totalSize = numDirections * batchSize * hiddenSize;
43404344 hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
43414345 }
43424346
4343- let sequence = null;
43444347 let currentWeight = [];
43454348 let currentRecurrentWeight = [];
43464349 let currentBias = [];
43474350 let currentRecurrentBias = [];
4351+ let forwardSequence = null;
4352+ let backwardSequence = null;
4353+ let outputHidden = null;
43484354
43494355 for (let dir = 0; dir < numDirections; ++dir) {
43504356 currentWeight.push(squeeze(
@@ -4367,57 +4373,75 @@ partial dictionary MLOpSupportLimits {
43674373 builder.slice(
43684374 options.recurrentBias, [dir, 0] , [1, 3 * hiddenSize] ))) :
43694375 null);
4370- }
4371-
4372- for (let step = 0; step < steps; ++step) {
4373- let currentHidden = [];
4374- let currentOutput = null;
4375-
4376- for (let dir = 0; dir < numDirections; ++dir) {
4377- currentHidden.push(squeeze(
4378- builder,
4379- builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
4380- }
4376+ let currentHidden = squeeze(
4377+ builder,
4378+ builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
43814379
4382- for (let dir = 0; dir < numDirections ; ++dir ) {
4383- let slice =
4384- (dir == 1 || options. direction == 'backward' ? steps - step - 1 : step);
4385- let currentInput = squeeze(
4380+ for (let step = 0; step < steps ; ++step ) {
4381+ const slice =
4382+ (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
4383+ const currentInput = squeeze(
43864384 builder,
43874385 builder.slice(input, [slice, 0, 0] , [1, batchSize, inputSize] ));
43884386
4389- let result = builder.reshape(
4390- builder.gruCell(
4391- currentInput,
4392- currentWeight[dir] ,
4393- currentRecurrentWeight[dir] ,
4394- currentHidden[dir] ,
4395- hiddenSize,
4396- {
4397- bias: currentBias[dir] ,
4398- recurrentBias: currentRecurrentBias[dir] ,
4399- resetAfter: options.resetAfter,
4400- layout: options.layout,
4401- activations: options.activations
4402- }),
4403- [1, batchSize, hiddenSize] );
4404-
4405- currentOutput =
4406- (currentOutput ? builder.concat([currentOutput, result] , 0) : result);
4407- }
4387+ currentHidden = builder.gruCell(
4388+ currentInput,
4389+ currentWeight[dir] ,
4390+ currentRecurrentWeight[dir] ,
4391+ currentHidden,
4392+ hiddenSize,
4393+ {
4394+ bias: currentBias[dir] ,
4395+ recurrentBias: currentRecurrentBias[dir] ,
4396+ resetAfter: options.resetAfter,
4397+ layout: options.layout,
4398+ activations: options.activations
4399+ });
44084400
4409- hiddenState = currentOutput;
4401+ if (options.returnSequence) {
4402+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
4403+ // to 4D([steps, numDirections, batchSize, hiddenSize] )
4404+ const expandedHiddenAs4D =
4405+ builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize] );
44104406
4411- if (options.returnSequence) {
4412- currentOutput = builder.reshape(
4413- currentOutput, [1, numDirections, batchSize, hiddenSize] );
4414- sequence =
4415- (sequence ? builder.concat([sequence, currentOutput] , 0) :
4416- currentOutput);
4407+ if (direction == 'forward' || (dir == 0 && direction == 'both' )) {
4408+ forwardSequence = forwardSequence ?
4409+ builder.concat([forwardSequence, expandedHiddenAs4D] , 0) :
4410+ expandedHiddenAs4D;
4411+ } else if (
4412+ direction == 'backward' || (dir == 1 && direction == 'both' )) {
4413+ backwardSequence = backwardSequence ?
4414+ builder.concat([expandedHiddenAs4D, backwardSequence] , 0) :
4415+ expandedHiddenAs4D;
4416+ }
4417+ }
44174418 }
4419+
4420+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
4421+ // to 3D([numDirections, batchSize, hiddenSize] )
4422+ const expandedHiddenAs3D =
4423+ builder.reshape(currentHidden, [1, batchSize, hiddenSize] );
4424+ outputHidden = outputHidden ?
4425+ builder.concat([outputHidden, expandedHiddenAs3D] , 0) :
4426+ expandedHiddenAs3D;
44184427 }
44194428
4420- return (sequence ? [hiddenState, sequence] : [hiddenState] );
4429+ if (options.returnSequence) {
4430+ let outputSequence = null;
4431+
4432+ if (direction == 'forward' ) {
4433+ outputSequence = forwardSequence;
4434+ } else if (direction == 'backward' ) {
4435+ outputSequence = backwardSequence;
4436+ } else if (direction == 'both' ) {
4437+ // Concat along axis 1 (numDirections dimension)
4438+ outputSequence = builder.concat([forwardSequence, backwardSequence] , 1);
4439+ }
4440+
4441+ return [outputHidden, outputSequence] ;
4442+ } else {
4443+ return [outputHidden] ;
4444+ }
44214445 }
44224446 </pre>
44234447</details>
0 commit comments