Skip to content

Commit a65bedc

Browse files
Merge pull request #257 from BrainJS/256-fix-to-function
fix: Resolve `toFunction`
2 parents 671d175 + a82714a commit a65bedc

File tree

2 files changed

+66
-16
lines changed

2 files changed

+66
-16
lines changed

src/neural-network.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,11 +885,11 @@ export default class NeuralNetwork {
885885
case 'sigmoid':
886886
return `1/(1+1/Math.exp(${result.join('')}))`;
887887
case 'relu':
888-
return `var sum = ${result.join('')};(sum < 0 ? 0 : sum);`;
888+
return `(${result.join('')} < 0 ? 0 : ${result.join('')})`;
889889
case 'leaky-relu':
890-
return `var sum = ${result.join('')};(sum < 0 ? 0 : 0.01 * sum);`;
890+
return `(${result.join('')} < 0 ? 0 : 0.01 * ${result.join('')})`;
891891
case 'tanh':
892-
return `Math.tanh(${result.join('')});`;
892+
return `Math.tanh(${result.join('')})`;
893893
default:
894894
throw new Error('unknown activation type ' + activation);
895895
}

test/base/to-function.js

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,68 @@ import assert from 'assert';
22
import NeuralNetwork from '../../src/neural-network';
33

44
describe('.toFunction()', () => {
5-
const originalNet = new NeuralNetwork();
6-
const xorTrainingData = [
7-
{input: [0, 0], output: [0]},
8-
{input: [0, 1], output: [1]},
9-
{input: [1, 0], output: [1]},
10-
{input: [1, 1], output: [0]}];
11-
originalNet.train(xorTrainingData);
12-
const xor = originalNet.toFunction();
13-
it('runs same as original network', () => {
14-
assert.deepEqual(xor([0, 0])[0].toFixed(6), originalNet.run([0, 0])[0].toFixed(6));
15-
assert.deepEqual(xor([0, 1])[0].toFixed(6), originalNet.run([0, 1])[0].toFixed(6));
16-
assert.deepEqual(xor([1, 0])[0].toFixed(6), originalNet.run([1, 0])[0].toFixed(6));
17-
assert.deepEqual(xor([1, 1])[0].toFixed(6), originalNet.run([1, 1])[0].toFixed(6));
5+
describe('sigmoid activation', () => {
6+
const originalNet = new NeuralNetwork();
7+
const xorTrainingData = [
8+
{input: [0, 0], output: [0]},
9+
{input: [0, 1], output: [1]},
10+
{input: [1, 0], output: [1]},
11+
{input: [1, 1], output: [0]}];
12+
originalNet.train(xorTrainingData);
13+
const xor = originalNet.toFunction();
14+
it('runs same as original network', () => {
15+
assert.deepEqual(xor([0, 0])[0].toFixed(6), originalNet.run([0, 0])[0].toFixed(6));
16+
assert.deepEqual(xor([0, 1])[0].toFixed(6), originalNet.run([0, 1])[0].toFixed(6));
17+
assert.deepEqual(xor([1, 0])[0].toFixed(6), originalNet.run([1, 0])[0].toFixed(6));
18+
assert.deepEqual(xor([1, 1])[0].toFixed(6), originalNet.run([1, 1])[0].toFixed(6));
19+
});
20+
});
21+
describe('relu activation', () => {
22+
const originalNet = new NeuralNetwork({ activation: 'relu' });
23+
const xorTrainingData = [
24+
{input: [0, 0], output: [0]},
25+
{input: [0, 1], output: [1]},
26+
{input: [1, 0], output: [1]},
27+
{input: [1, 1], output: [0]}];
28+
originalNet.train(xorTrainingData);
29+
const xor = originalNet.toFunction();
30+
it('runs same as original network', () => {
31+
assert.deepEqual(xor([0, 0])[0].toFixed(6), originalNet.run([0, 0])[0].toFixed(6));
32+
assert.deepEqual(xor([0, 1])[0].toFixed(6), originalNet.run([0, 1])[0].toFixed(6));
33+
assert.deepEqual(xor([1, 0])[0].toFixed(6), originalNet.run([1, 0])[0].toFixed(6));
34+
assert.deepEqual(xor([1, 1])[0].toFixed(6), originalNet.run([1, 1])[0].toFixed(6));
35+
});
36+
});
37+
describe('leaky-relu activation', () => {
38+
const originalNet = new NeuralNetwork({ activation: 'leaky-relu' });
39+
const xorTrainingData = [
40+
{input: [0, 0], output: [0]},
41+
{input: [0, 1], output: [1]},
42+
{input: [1, 0], output: [1]},
43+
{input: [1, 1], output: [0]}];
44+
originalNet.train(xorTrainingData);
45+
const xor = originalNet.toFunction();
46+
it('runs same as original network', () => {
47+
assert.deepEqual(xor([0, 0])[0].toFixed(6), originalNet.run([0, 0])[0].toFixed(6));
48+
assert.deepEqual(xor([0, 1])[0].toFixed(6), originalNet.run([0, 1])[0].toFixed(6));
49+
assert.deepEqual(xor([1, 0])[0].toFixed(6), originalNet.run([1, 0])[0].toFixed(6));
50+
assert.deepEqual(xor([1, 1])[0].toFixed(6), originalNet.run([1, 1])[0].toFixed(6));
51+
});
52+
});
53+
describe('tanh activation', () => {
54+
const originalNet = new NeuralNetwork({ activation: 'tanh' });
55+
const xorTrainingData = [
56+
{input: [0, 0], output: [0]},
57+
{input: [0, 1], output: [1]},
58+
{input: [1, 0], output: [1]},
59+
{input: [1, 1], output: [0]}];
60+
originalNet.train(xorTrainingData);
61+
const xor = originalNet.toFunction();
62+
it('runs same as original network', () => {
63+
assert.deepEqual(xor([0, 0])[0].toFixed(6), originalNet.run([0, 0])[0].toFixed(6));
64+
assert.deepEqual(xor([0, 1])[0].toFixed(6), originalNet.run([0, 1])[0].toFixed(6));
65+
assert.deepEqual(xor([1, 0])[0].toFixed(6), originalNet.run([1, 0])[0].toFixed(6));
66+
assert.deepEqual(xor([1, 1])[0].toFixed(6), originalNet.run([1, 1])[0].toFixed(6));
67+
});
1868
});
1969
});

0 commit comments

Comments
 (0)