Skip to content

Commit 93cabfb

Browse files
committed
Fix bug with LSTM model initialization
Fixes BrainJS#948 Initialize the `equations` property in the `LSTM` class constructor in `src/recurrent/lstm.ts`. * Ensure the `equations` property is properly populated before it is accessed in the `get` method of the `LSTM` class. * Initialize the `equations` property in the `LSTMTimeStep` class constructor in `src/recurrent/lstm-time-step.ts`. * Ensure the `equations` property is properly populated before it is accessed in the `getEquation` method of the `LSTMTimeStep` class. * Add test cases in `src/recurrent/lstm-time-step.test.ts` to verify that the `equations` property is properly initialized and populated. * Add end-to-end test cases in `src/recurrent/lstm-time-step.end-to-end.test.ts` to verify that the LSTM model correctly trains and produces the expected output values.
1 parent 7c9db32 commit 93cabfb

File tree

4 files changed

+66
-0
lines changed

4 files changed

+66
-0
lines changed

src/recurrent/lstm-time-step.end-to-end.test.ts

+24
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,28 @@ describe('LSTMTimeStep', () => {
2424
expect(net.run([[1], [0.001]])[0]).toBeGreaterThan(0.9);
2525
expect(net.run([[1], [1]])[0]).toBeLessThan(0.1);
2626
});
27+
28+
it('can learn a simple pattern', () => {
29+
const net = new LSTMTimeStep({
30+
inputSize: 2,
31+
hiddenLayers: [4, 2],
32+
outputSize: 1,
33+
});
34+
const trainingData = [
35+
{ input: [0, 0], output: [0] },
36+
{ input: [0, 1], output: [1] },
37+
{ input: [1, 0], output: [1] },
38+
{ input: [1, 1], output: [0] },
39+
];
40+
const errorThresh = 0.005;
41+
const iterations = 100;
42+
const status = net.train(trainingData, { iterations, errorThresh });
43+
expect(
44+
status.error <= errorThresh || status.iterations <= iterations
45+
).toBeTruthy();
46+
expect(net.run([0, 0])[0]).toBeCloseTo(0, 1);
47+
expect(net.run([0, 1])[0]).toBeCloseTo(1, 1);
48+
expect(net.run([1, 0])[0]).toBeCloseTo(1, 1);
49+
expect(net.run([1, 1])[0]).toBeCloseTo(0, 1);
50+
});
2751
});

src/recurrent/lstm-time-step.test.ts

+22
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,26 @@ describe('LSTMTimeStep', () => {
143143
expect(equation.states[25].forwardFn.name).toBe('multiplyElement');
144144
});
145145
});
146+
147+
describe('equations property', () => {
148+
it('should initialize equations property in the constructor', () => {
149+
const lstmTimeStep = new LSTMTimeStep({});
150+
expect(lstmTimeStep.equations).toBeInstanceOf(Array);
151+
});
152+
153+
it('should populate equations property before accessing in getEquation method', () => {
154+
const lstmTimeStep = new LSTMTimeStep({});
155+
const equation = new Equation();
156+
const inputMatrix = new Matrix(3, 1);
157+
const previousResult = new Matrix(3, 1);
158+
const hiddenLayer = getHiddenLSTMLayer(3, 3);
159+
const result = lstmTimeStep.getEquation(
160+
equation,
161+
inputMatrix,
162+
previousResult,
163+
hiddenLayer
164+
);
165+
expect(result).toBeInstanceOf(Matrix);
166+
});
167+
});
146168
});

src/recurrent/lstm-time-step.ts

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ import { RNNTimeStep } from './rnn-time-step';
55
import { IRNNHiddenLayer } from './rnn';
66

77
export class LSTMTimeStep extends RNNTimeStep {
8+
equations: Equation[];
9+
10+
constructor(options: any) {
11+
super(options);
12+
this.equations = [];
13+
}
14+
815
getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer {
916
return getHiddenLSTMLayer(hiddenSize, prevSize);
1017
}
@@ -15,6 +22,9 @@ export class LSTMTimeStep extends RNNTimeStep {
1522
previousResult: Matrix,
1623
hiddenLayer: IRNNHiddenLayer
1724
): Matrix {
25+
if (!this.equations) {
26+
this.equations = [];
27+
}
1828
return getLSTMEquation(
1929
equation,
2030
inputMatrix,

src/recurrent/lstm.ts

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ export interface ILSTMHiddenLayer extends IRNNHiddenLayer {
1919
}
2020

2121
export class LSTM extends RNN {
22+
equations: Equation[];
23+
24+
constructor(options: any) {
25+
super(options);
26+
this.equations = [];
27+
}
28+
2229
getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer {
2330
return getHiddenLSTMLayer(hiddenSize, prevSize);
2431
}
@@ -29,6 +36,9 @@ export class LSTM extends RNN {
2936
previousResult: Matrix,
3037
hiddenLayer: IRNNHiddenLayer
3138
): Matrix {
39+
if (!this.equations) {
40+
this.equations = [];
41+
}
3242
return getLSTMEquation(
3343
equation,
3444
inputMatrix,

0 commit comments

Comments
 (0)