Skip to content

Commit 5f85e73

Browse files
authored
Bugfix: Fix emulation error of gru by 'backward' and 'both' direction options (#803)
* Fix emulation error of gru by 'backward' and 'both' direction options * format JavaScript emulation of gru operation
1 parent 293d074 commit 5f85e73

File tree

1 file changed

+70
-46
lines changed

1 file changed

+70
-46
lines changed

index.bs

+70-46
Original file line numberDiff line numberDiff line change
@@ -4331,20 +4331,26 @@ partial dictionary MLOpSupportLimits {
43314331
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
43324332
const batchSize = input.shape[1];
43334333
const inputSize = input.shape[2];
4334-
const numDirections = (options.direction == 'both' ? 2 : 1);
4334+
const direction = options.direction || 'forward';
4335+
const numDirections = (direction == 'both' ? 2 : 1);
43354336
let hiddenState = options.initialHiddenState;
43364337

43374338
if (!hiddenState) {
4338-
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
4339-
const totalSize = numDirections * hiddenSize;
4339+
const desc = {
4340+
dataType: 'float32',
4341+
shape: [numDirections, batchSize, hiddenSize]
4342+
};
4343+
const totalSize = numDirections * batchSize * hiddenSize;
43404344
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
43414345
}
43424346

4343-
let sequence = null;
43444347
let currentWeight = [];
43454348
let currentRecurrentWeight = [];
43464349
let currentBias = [];
43474350
let currentRecurrentBias = [];
4351+
let forwardSequence = null;
4352+
let backwardSequence = null;
4353+
let outputHidden = null;
43484354

43494355
for (let dir = 0; dir < numDirections; ++dir) {
43504356
currentWeight.push(squeeze(
@@ -4367,57 +4373,75 @@ partial dictionary MLOpSupportLimits {
43674373
builder.slice(
43684374
options.recurrentBias, [dir, 0], [1, 3 * hiddenSize]))) :
43694375
null);
4370-
}
4371-
4372-
for (let step = 0; step < steps; ++step) {
4373-
let currentHidden = [];
4374-
let currentOutput = null;
4375-
4376-
for (let dir = 0; dir < numDirections; ++dir) {
4377-
currentHidden.push(squeeze(
4378-
builder,
4379-
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
4380-
}
4376+
let currentHidden = squeeze(
4377+
builder,
4378+
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
43814379

4382-
for (let dir = 0; dir < numDirections; ++dir) {
4383-
let slice =
4384-
(dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
4385-
let currentInput = squeeze(
4380+
for (let step = 0; step < steps; ++step) {
4381+
const slice =
4382+
(dir == 1 || direction == 'backward' ? steps - step - 1 : step);
4383+
const currentInput = squeeze(
43864384
builder,
43874385
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));
43884386

4389-
let result = builder.reshape(
4390-
builder.gruCell(
4391-
currentInput,
4392-
currentWeight[dir],
4393-
currentRecurrentWeight[dir],
4394-
currentHidden[dir],
4395-
hiddenSize,
4396-
{
4397-
bias: currentBias[dir],
4398-
recurrentBias: currentRecurrentBias[dir],
4399-
resetAfter: options.resetAfter,
4400-
layout: options.layout,
4401-
activations: options.activations
4402-
}),
4403-
[1, batchSize, hiddenSize]);
4404-
4405-
currentOutput =
4406-
(currentOutput ? builder.concat([currentOutput, result], 0) : result);
4407-
}
4387+
currentHidden = builder.gruCell(
4388+
currentInput,
4389+
currentWeight[dir],
4390+
currentRecurrentWeight[dir],
4391+
currentHidden,
4392+
hiddenSize,
4393+
{
4394+
bias: currentBias[dir],
4395+
recurrentBias: currentRecurrentBias[dir],
4396+
resetAfter: options.resetAfter,
4397+
layout: options.layout,
4398+
activations: options.activations
4399+
});
44084400

4409-
hiddenState = currentOutput;
4401+
if (options.returnSequence) {
4402+
// Expand currentHidden of 2D([batchSize, hiddenSize])
4403+
// to 4D([steps, numDirections, batchSize, hiddenSize])
4404+
const expandedHiddenAs4D =
4405+
builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize]);
44104406

4411-
if (options.returnSequence) {
4412-
currentOutput = builder.reshape(
4413-
currentOutput, [1, numDirections, batchSize, hiddenSize]);
4414-
sequence =
4415-
(sequence ? builder.concat([sequence, currentOutput], 0) :
4416-
currentOutput);
4407+
if (direction == 'forward' || (dir == 0 && direction == 'both')) {
4408+
forwardSequence = forwardSequence ?
4409+
builder.concat([forwardSequence, expandedHiddenAs4D], 0) :
4410+
expandedHiddenAs4D;
4411+
} else if (
4412+
direction == 'backward' || (dir == 1 && direction == 'both')) {
4413+
backwardSequence = backwardSequence ?
4414+
builder.concat([expandedHiddenAs4D, backwardSequence], 0) :
4415+
expandedHiddenAs4D;
4416+
}
4417+
}
44174418
}
4419+
4420+
// Expand currentHidden of 2D([batchSize, hiddenSize])
4421+
// to 3D([numDirections, batchSize, hiddenSize])
4422+
const expandedHiddenAs3D =
4423+
builder.reshape(currentHidden, [1, batchSize, hiddenSize]);
4424+
outputHidden = outputHidden ?
4425+
builder.concat([outputHidden, expandedHiddenAs3D], 0) :
4426+
expandedHiddenAs3D;
44184427
}
44194428

4420-
return (sequence ? [hiddenState, sequence] : [hiddenState]);
4429+
if (options.returnSequence) {
4430+
let outputSequence = null;
4431+
4432+
if (direction == 'forward') {
4433+
outputSequence = forwardSequence;
4434+
} else if (direction == 'backward') {
4435+
outputSequence = backwardSequence;
4436+
} else if (direction == 'both') {
4437+
// Concat along axis 1 (numDirections dimension)
4438+
outputSequence = builder.concat([forwardSequence, backwardSequence], 1);
4439+
}
4440+
4441+
return [outputHidden, outputSequence];
4442+
} else {
4443+
return [outputHidden];
4444+
}
44214445
}
44224446
</pre>
44234447
</details>

0 commit comments

Comments
 (0)