@@ -4331,20 +4331,26 @@ partial dictionary MLOpSupportLimits {
4331
4331
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
4332
4332
const batchSize = input.shape[1] ;
4333
4333
const inputSize = input.shape[2] ;
4334
- const numDirections = (options.direction == 'both' ? 2 : 1);
4334
+ const direction = options.direction || 'forward' ;
4335
+ const numDirections = (direction == 'both' ? 2 : 1);
4335
4336
let hiddenState = options.initialHiddenState;
4336
4337
4337
4338
if (!hiddenState) {
4338
- const desc = {dataType: 'float32' , shape: [numDirections, 1, hiddenSize] };
4339
- const totalSize = numDirections * hiddenSize;
4339
+ const desc = {
4340
+ dataType: 'float32' ,
4341
+ shape: [numDirections, batchSize, hiddenSize]
4342
+ };
4343
+ const totalSize = numDirections * batchSize * hiddenSize;
4340
4344
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
4341
4345
}
4342
4346
4343
- let sequence = null;
4344
4347
let currentWeight = [];
4345
4348
let currentRecurrentWeight = [];
4346
4349
let currentBias = [];
4347
4350
let currentRecurrentBias = [];
4351
+ let forwardSequence = null;
4352
+ let backwardSequence = null;
4353
+ let outputHidden = null;
4348
4354
4349
4355
for (let dir = 0; dir < numDirections; ++dir) {
4350
4356
currentWeight.push(squeeze(
@@ -4367,57 +4373,75 @@ partial dictionary MLOpSupportLimits {
4367
4373
builder.slice(
4368
4374
options.recurrentBias, [dir, 0] , [1, 3 * hiddenSize] ))) :
4369
4375
null);
4370
- }
4371
-
4372
- for (let step = 0; step < steps; ++step) {
4373
- let currentHidden = [];
4374
- let currentOutput = null;
4375
-
4376
- for (let dir = 0; dir < numDirections; ++dir) {
4377
- currentHidden.push(squeeze(
4378
- builder,
4379
- builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] )));
4380
- }
4376
+ let currentHidden = squeeze(
4377
+ builder,
4378
+ builder.slice(hiddenState, [dir, 0, 0] , [1, batchSize, hiddenSize] ));
4381
4379
4382
- for (let dir = 0; dir < numDirections ; ++dir ) {
4383
- let slice =
4384
- (dir == 1 || options. direction == 'backward' ? steps - step - 1 : step);
4385
- let currentInput = squeeze(
4380
+ for (let step = 0; step < steps ; ++step ) {
4381
+ const slice =
4382
+ (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
4383
+ const currentInput = squeeze(
4386
4384
builder,
4387
4385
builder.slice(input, [slice, 0, 0] , [1, batchSize, inputSize] ));
4388
4386
4389
- let result = builder.reshape(
4390
- builder.gruCell(
4391
- currentInput,
4392
- currentWeight[dir] ,
4393
- currentRecurrentWeight[dir] ,
4394
- currentHidden[dir] ,
4395
- hiddenSize,
4396
- {
4397
- bias: currentBias[dir] ,
4398
- recurrentBias: currentRecurrentBias[dir] ,
4399
- resetAfter: options.resetAfter,
4400
- layout: options.layout,
4401
- activations: options.activations
4402
- }),
4403
- [1, batchSize, hiddenSize] );
4404
-
4405
- currentOutput =
4406
- (currentOutput ? builder.concat([currentOutput, result] , 0) : result);
4407
- }
4387
+ currentHidden = builder.gruCell(
4388
+ currentInput,
4389
+ currentWeight[dir] ,
4390
+ currentRecurrentWeight[dir] ,
4391
+ currentHidden,
4392
+ hiddenSize,
4393
+ {
4394
+ bias: currentBias[dir] ,
4395
+ recurrentBias: currentRecurrentBias[dir] ,
4396
+ resetAfter: options.resetAfter,
4397
+ layout: options.layout,
4398
+ activations: options.activations
4399
+ });
4408
4400
4409
- hiddenState = currentOutput;
4401
+ if (options.returnSequence) {
4402
+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
4403
+ // to 4D([steps, numDirections, batchSize, hiddenSize] )
4404
+ const expandedHiddenAs4D =
4405
+ builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize] );
4410
4406
4411
- if (options.returnSequence) {
4412
- currentOutput = builder.reshape(
4413
- currentOutput, [1, numDirections, batchSize, hiddenSize] );
4414
- sequence =
4415
- (sequence ? builder.concat([sequence, currentOutput] , 0) :
4416
- currentOutput);
4407
+ if (direction == 'forward' || (dir == 0 && direction == 'both' )) {
4408
+ forwardSequence = forwardSequence ?
4409
+ builder.concat([forwardSequence, expandedHiddenAs4D] , 0) :
4410
+ expandedHiddenAs4D;
4411
+ } else if (
4412
+ direction == 'backward' || (dir == 1 && direction == 'both' )) {
4413
+ backwardSequence = backwardSequence ?
4414
+ builder.concat([expandedHiddenAs4D, backwardSequence] , 0) :
4415
+ expandedHiddenAs4D;
4416
+ }
4417
+ }
4417
4418
}
4419
+
4420
+ // Expand currentHidden of 2D([batchSize, hiddenSize] )
4421
+ // to 3D([numDirections, batchSize, hiddenSize] )
4422
+ const expandedHiddenAs3D =
4423
+ builder.reshape(currentHidden, [1, batchSize, hiddenSize] );
4424
+ outputHidden = outputHidden ?
4425
+ builder.concat([outputHidden, expandedHiddenAs3D] , 0) :
4426
+ expandedHiddenAs3D;
4418
4427
}
4419
4428
4420
- return (sequence ? [hiddenState, sequence] : [hiddenState] );
4429
+ if (options.returnSequence) {
4430
+ let outputSequence = null;
4431
+
4432
+ if (direction == 'forward' ) {
4433
+ outputSequence = forwardSequence;
4434
+ } else if (direction == 'backward' ) {
4435
+ outputSequence = backwardSequence;
4436
+ } else if (direction == 'both' ) {
4437
+ // Concat along axis 1 (numDirections dimension)
4438
+ outputSequence = builder.concat([forwardSequence, backwardSequence] , 1);
4439
+ }
4440
+
4441
+ return [outputHidden, outputSequence] ;
4442
+ } else {
4443
+ return [outputHidden] ;
4444
+ }
4421
4445
}
4422
4446
</pre>
4423
4447
</details>
0 commit comments