From 93cabfb8fac41f466a32ba1dc818572dd5a09a1e Mon Sep 17 00:00:00 2001 From: Rizmy Abdulla Date: Sat, 18 Jan 2025 11:43:14 +0530 Subject: [PATCH] Fix bug with LSTM model initialization Fixes #948 Initialize the `equations` property in the `LSTM` class constructor in `src/recurrent/lstm.ts`. * Ensure the `equations` property is properly populated before it is accessed in the `get` method of the `LSTM` class. * Initialize the `equations` property in the `LSTMTimeStep` class constructor in `src/recurrent/lstm-time-step.ts`. * Ensure the `equations` property is properly populated before it is accessed in the `getEquation` method of the `LSTMTimeStep` class. * Add test cases in `src/recurrent/lstm-time-step.test.ts` to verify that the `equations` property is properly initialized and populated. * Add end-to-end test cases in `src/recurrent/lstm-time-step.end-to-end.test.ts` to verify that the LSTM model correctly trains and produces the expected output values. --- .../lstm-time-step.end-to-end.test.ts | 24 +++++++++++++++++++ src/recurrent/lstm-time-step.test.ts | 22 +++++++++++++++++ src/recurrent/lstm-time-step.ts | 10 ++++++++ src/recurrent/lstm.ts | 10 ++++++++ 4 files changed, 66 insertions(+) diff --git a/src/recurrent/lstm-time-step.end-to-end.test.ts b/src/recurrent/lstm-time-step.end-to-end.test.ts index ecd2f160c..5a8b1cbb4 100644 --- a/src/recurrent/lstm-time-step.end-to-end.test.ts +++ b/src/recurrent/lstm-time-step.end-to-end.test.ts @@ -24,4 +24,28 @@ describe('LSTMTimeStep', () => { expect(net.run([[1], [0.001]])[0]).toBeGreaterThan(0.9); expect(net.run([[1], [1]])[0]).toBeLessThan(0.1); }); + + it('can learn a simple pattern', () => { + const net = new LSTMTimeStep({ + inputSize: 2, + hiddenLayers: [4, 2], + outputSize: 1, + }); + const trainingData = [ + { input: [0, 0], output: [0] }, + { input: [0, 1], output: [1] }, + { input: [1, 0], output: [1] }, + { input: [1, 1], output: [0] }, + ]; + const errorThresh = 0.005; + const iterations = 100; + const status = net.train(trainingData, { iterations, errorThresh }); + expect( + status.error <= errorThresh || status.iterations <= iterations + ).toBeTruthy(); + expect(net.run([0, 0])[0]).toBeCloseTo(0, 1); + expect(net.run([0, 1])[0]).toBeCloseTo(1, 1); + expect(net.run([1, 0])[0]).toBeCloseTo(1, 1); + expect(net.run([1, 1])[0]).toBeCloseTo(0, 1); + }); }); diff --git a/src/recurrent/lstm-time-step.test.ts b/src/recurrent/lstm-time-step.test.ts index 5a59f77e4..8a2475515 100644 --- a/src/recurrent/lstm-time-step.test.ts +++ b/src/recurrent/lstm-time-step.test.ts @@ -143,4 +143,26 @@ describe('LSTMTimeStep', () => { expect(equation.states[25].forwardFn.name).toBe('multiplyElement'); }); }); + + describe('equations property', () => { + it('should initialize equations property in the constructor', () => { + const lstmTimeStep = new LSTMTimeStep({}); + expect(lstmTimeStep.equations).toBeInstanceOf(Array); + }); + + it('should populate equations property before accessing in getEquation method', () => { + const lstmTimeStep = new LSTMTimeStep({}); + const equation = new Equation(); + const inputMatrix = new Matrix(3, 1); + const previousResult = new Matrix(3, 1); + const hiddenLayer = getHiddenLSTMLayer(3, 3); + const result = lstmTimeStep.getEquation( + equation, + inputMatrix, + previousResult, + hiddenLayer + ); + expect(result).toBeInstanceOf(Matrix); + }); + }); }); diff --git a/src/recurrent/lstm-time-step.ts b/src/recurrent/lstm-time-step.ts index c7af5f4be..f720d4821 100644 --- a/src/recurrent/lstm-time-step.ts +++ b/src/recurrent/lstm-time-step.ts @@ -5,6 +5,13 @@ import { RNNTimeStep } from './rnn-time-step'; import { IRNNHiddenLayer } from './rnn'; export class LSTMTimeStep extends RNNTimeStep { + equations: Equation[]; + + constructor(options: any) { + super(options); + this.equations = []; + } + getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer { return getHiddenLSTMLayer(hiddenSize, prevSize); } @@ -15,6 +22,9 @@ export class LSTMTimeStep extends RNNTimeStep { previousResult: Matrix, hiddenLayer: IRNNHiddenLayer ): Matrix { + if (!this.equations) { + this.equations = []; + } return getLSTMEquation( equation, inputMatrix, diff --git a/src/recurrent/lstm.ts b/src/recurrent/lstm.ts index 3d2cb39c0..d178aae9b 100644 --- a/src/recurrent/lstm.ts +++ b/src/recurrent/lstm.ts @@ -19,6 +19,13 @@ export interface ILSTMHiddenLayer extends IRNNHiddenLayer { } export class LSTM extends RNN { + equations: Equation[]; + + constructor(options: any) { + super(options); + this.equations = []; + } + getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer { return getHiddenLSTMLayer(hiddenSize, prevSize); } @@ -29,6 +36,9 @@ export class LSTM extends RNN { previousResult: Matrix, hiddenLayer: IRNNHiddenLayer ): Matrix { + if (!this.equations) { + this.equations = []; + } return getLSTMEquation( equation, inputMatrix,