Skip to content

Commit 293d074

Browse files
authored
Bugfix: Fix emulation error of lstm by 'backward' and 'both' direction options (#802)
* Fix emulation error of lstm by 'backward' and 'both' direction options * format JavaScript emulation of lstm operation
1 parent 17a44e2 commit 293d074

File tree

1 file changed

+76
-44
lines changed

1 file changed

+76
-44
lines changed

index.bs

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5660,28 +5660,38 @@ partial dictionary MLOpSupportLimits {
56605660
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
56615661
const batchSize = input.shape[1];
56625662
const inputSize = input.shape[2];
5663-
const numDirections = (options.direction == 'both' ? 2 : 1);
5663+
const direction = options.direction || 'forward';
5664+
const numDirections = (direction == 'both' ? 2 : 1);
56645665
let hiddenState = options.initialHiddenState;
56655666
let cellState = options.initialCellState;
56665667

56675668
if (!hiddenState) {
5668-
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
5669-
const totalSize = numDirections * hiddenSize;
5669+
const desc = {
5670+
dataType: 'float32',
5671+
shape: [numDirections, batchSize, hiddenSize]
5672+
};
5673+
const totalSize = numDirections * batchSize * hiddenSize;
56705674
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
56715675
}
56725676

56735677
if (!cellState) {
5674-
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
5675-
const totalSize = numDirections * hiddenSize;
5678+
const desc = {
5679+
dataType: 'float32',
5680+
shape: [numDirections, batchSize, hiddenSize]
5681+
};
5682+
const totalSize = numDirections * batchSize * hiddenSize;
56765683
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
56775684
}
56785685

5679-
let sequence = null;
56805686
let currentWeight = [];
56815687
let currentRecurrentWeight = [];
56825688
let currentBias = [];
56835689
let currentRecurrentBias = [];
56845690
let currentPeepholeWeight = [];
5691+
let forwardSequence = null;
5692+
let backwardSequence = null;
5693+
let outputHidden = null;
5694+
let outputCell = null;
56855695

56865696
for (let dir = 0; dir < numDirections; ++dir) {
56875697
currentWeight.push(squeeze(
@@ -5711,36 +5721,27 @@ partial dictionary MLOpSupportLimits {
57115721
builder.slice(
57125722
options.peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) :
57135723
null);
5714-
}
5715-
5716-
for (let step = 0; step < steps; ++step) {
5717-
let currentHidden = [];
5718-
let currentCell = [];
5719-
let nextHidden = null;
5720-
let nextCell = null;
57215724

5722-
for (let dir = 0; dir < numDirections; ++dir) {
5723-
currentHidden.push(squeeze(
5724-
builder,
5725-
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
5726-
currentCell.push(squeeze(
5727-
builder,
5728-
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
5729-
}
5725+
let currentHidden = squeeze(
5726+
builder,
5727+
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
5728+
let currentCell = squeeze(
5729+
builder,
5730+
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]));
57305731

5731-
for (let dir = 0; dir < numDirections; ++dir) {
5732-
let slice =
5733-
(dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
5734-
let currentInput = squeeze(
5732+
for (let step = 0; step < steps; ++step) {
5733+
const slice =
5734+
(dir == 1 || direction == 'backward' ? steps - step - 1 : step);
5735+
const currentInput = squeeze(
57355736
builder,
57365737
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));
57375738

5738-
let results = builder.lstmCell(
5739+
[currentHidden, currentCell] = builder.lstmCell(
57395740
currentInput,
57405741
currentWeight[dir],
57415742
currentRecurrentWeight[dir],
5742-
currentHidden[dir],
5743-
currentCell[dir],
5743+
currentHidden,
5744+
currentCell,
57445745
hiddenSize,
57455746
{
57465747
bias: currentBias[dir],
@@ -5750,27 +5751,58 @@ partial dictionary MLOpSupportLimits {
57505751
activations: options.activations
57515752
});
57525753

5753-
let output = builder.reshape(results[0], [1, batchSize, hiddenSize]);
5754-
let cell = builder.reshape(results[1], [1, batchSize, hiddenSize]);
5755-
5756-
nextHidden =
5757-
(nextHidden ? builder.concat([nextHidden, output], 0) : output);
5758-
nextCell = (nextCell ? builder.concat([nextCell, cell], 0) : cell);
5754+
if (options.returnSequence) {
5755+
// Expand currentHidden of 2D([batchSize, hiddenSize])
5756+
// to 4D([steps, numDirections, batchSize, hiddenSize])
5757+
const expandedHiddenAs4D =
5758+
builder.reshape(currentHidden, [1, 1, batchSize, hiddenSize]);
5759+
5760+
if (direction == 'forward' || (dir == 0 && direction == 'both')) {
5761+
forwardSequence = forwardSequence ?
5762+
builder.concat([forwardSequence, expandedHiddenAs4D], 0) :
5763+
expandedHiddenAs4D;
5764+
} else if (
5765+
direction == 'backward' || (dir == 1 && direction == 'both')) {
5766+
backwardSequence = backwardSequence ?
5767+
builder.concat([expandedHiddenAs4D, backwardSequence], 0) :
5768+
expandedHiddenAs4D;
5769+
}
5770+
}
57595771
}
57605772

5761-
hiddenState = nextHidden;
5762-
cellState = nextCell;
5773+
// Expand currentHidden of 2D([batchSize, hiddenSize])
5774+
// to 3D([numDirections, batchSize, hiddenSize])
5775+
const expandedHiddenAs3D =
5776+
builder.reshape(currentHidden, [1, batchSize, hiddenSize]);
5777+
outputHidden = outputHidden ?
5778+
builder.concat([outputHidden, expandedHiddenAs3D], 0) :
5779+
expandedHiddenAs3D;
5780+
5781+
// Expand currentCell of 2D([batchSize, hiddenSize])
5782+
// to 3D([numDirections, batchSize, hiddenSize])
5783+
const expandedCellAs3D =
5784+
builder.reshape(currentCell, [1, batchSize, hiddenSize]);
5785+
outputCell = outputCell ?
5786+
builder.concat([outputCell, expandedCellAs3D], 0) :
5787+
expandedCellAs3D;
5788+
}
57635789

5764-
if (options.returnSequence) {
5765-
nextHidden =
5766-
builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]);
5767-
sequence =
5768-
(sequence ? builder.concat([sequence, nextHidden], 0) : nextHidden);
5790+
if (options.returnSequence) {
5791+
let outputSequence = null;
5792+
5793+
if (direction == 'forward') {
5794+
outputSequence = forwardSequence;
5795+
} else if (direction == 'backward') {
5796+
outputSequence = backwardSequence;
5797+
} else if (direction == 'both') {
5798+
// Concat along axis 1 (numDirections dimension)
5799+
outputSequence = builder.concat([forwardSequence, backwardSequence], 1);
57695800
}
5770-
}
57715801

5772-
return (
5773-
sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
5802+
return [outputHidden, outputCell, outputSequence];
5803+
} else {
5804+
return [outputHidden, outputCell];
5805+
}
57745806
}
57755807
</pre>
57765808
</details>

0 commit comments

Comments
 (0)