-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathnn.js
96 lines (88 loc) · 2.2 KB
/
nn.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Daniel Shiffman
// Nature of Code
// https://github.com/nature-of-code/noc-syllabus-S19
class NeuralNetwork {
constructor(a, b, c, d) {
if (a instanceof tf.Sequential) {
this.model = a;
this.num_inputs = b;
this.num_hidden = c;
this.num_outputs = d;
} else {
this.num_inputs = a;
this.num_hidden = b;
this.num_outputs = c;
this.model = this.createModel();
}
}
createModel() {
const model = tf.sequential();
let hidden = tf.layers.dense({
inputShape: [this.num_inputs],
units: this.num_hidden,
activation: 'tanh'
});
let output = tf.layers.dense({
units: this.num_outputs,
activation: 'tanh'
});
model.add(hidden);
model.add(output);
return model;
}
dispose() {
this.model.dispose();
}
save() {
this.model.save('downloads://vehicle`-brain');
}
// Synchronous for now
predict(input_array) {
// console.log(input_array);
return tf.tidy(() => {
let xs = tf.tensor([input_array]);
let ys = this.model.predict(xs);
let y_values = ys.dataSync();
return y_values;
});
}
// Adding function for neuro-evolution
copy() {
return tf.tidy(() => {
const modelCopy = this.createModel();
const w = this.model.getWeights();
for (let i = 0; i < w.length; i++) {
w[i] = w[i].clone();
}
modelCopy.setWeights(w);
const nn = new NeuralNetwork(modelCopy, this.num_inputs, this.num_hidden, this.num_outputs);
return nn;
});
}
// Accept an arbitrary function for mutation
mutate(rate) {
tf.tidy(() => {
const w = this.model.getWeights();
for (let i = 0; i < w.length; i++) {
let shape = w[i].shape;
let arr = w[i].dataSync().slice();
for (let j = 0; j < arr.length; j++) {
arr[j] = mutateWeight(arr[j], rate);
}
let newW = tf.tensor(arr, shape);
w[i] = newW;
}
this.model.setWeights(w);
});
}
}
// Mutation function to be passed into bird.brain
function mutateWeight(x, rate) {
if (random(1) < rate) {
let offset = randomGaussian() * 0.5;
let newx = x + offset;
return newx;
} else {
return x;
}
}