-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathautoencoder-gpu.test.ts
76 lines (60 loc) · 1.92 KB
/
autoencoder-gpu.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import AutoencoderGPU from './autoencoder-gpu';
import { INeuralNetworkTrainOptions } from './neural-network';
const trainingData = [
[0, 0, 0],
[0, 1, 1],
[1, 0, 1],
[1, 1, 0],
];
const xornet = new AutoencoderGPU<number[], number[]>({
inputSize: 3,
hiddenLayers: [4, 2, 4],
outputSize: 3,
});
const errorThresh = 0.0011;
const trainOptions: Partial<INeuralNetworkTrainOptions> = {
errorThresh,
iterations: 250000,
learningRate: 0.1,
log: (details) => console.log(details),
// logPeriod: 500,
logPeriod: 500,
};
const result = xornet.train(trainingData, trainOptions);
test('denoise a data sample', async () => {
expect(result.error).toBeLessThanOrEqual(errorThresh);
function xor(...args: number[]) {
return Math.round(xornet.denoise(args)[2]);
}
const run1 = xor(0, 0, 0);
const run2 = xor(0, 1, 1);
const run3 = xor(1, 0, 1);
const run4 = xor(1, 1, 0);
expect(run1).toBe(0);
expect(run2).toBe(1);
expect(run3).toBe(1);
expect(run4).toBe(0);
});
test('encode and decode a data sample', async () => {
expect(result.error).toBeLessThanOrEqual(errorThresh);
const run1$input = [0, 0, 0];
const run1$encoded = xornet.encode(run1$input);
const run1$decoded = xornet.decode(run1$encoded);
const run2$input = [0, 1, 1];
const run2$encoded = xornet.encode(run2$input);
const run2$decoded = xornet.decode(run2$encoded);
for (let i = 0; i < 3; i++)
expect(Math.round(run1$decoded[i])).toBe(run1$input[i]);
for (let i = 0; i < 3; i++)
expect(Math.round(run2$decoded[i])).toBe(run2$input[i]);
});
test('test a data sample for anomalies', async () => {
expect(result.error).toBeLessThanOrEqual(errorThresh);
function likelyIncludesAnomalies(...args: number[]) {
expect(xornet.likelyIncludesAnomalies(args, 0.5)).toBe(false);
}
likelyIncludesAnomalies(0, 0, 0);
likelyIncludesAnomalies(0, 1, 1);
likelyIncludesAnomalies(1, 0, 1);
likelyIncludesAnomalies(1, 1, 0);
});