Skip to content

Commit a26c531

Browse files
committed
more banging on the tensorflow api
1 parent 1fa1697 commit a26c531

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/net/check_model.js

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import * as tf from '@tensorflow/tfjs-node-gpu';
22

33
const model = await tf.loadLayersModel('file:///dev/shm/brain/model.json');
4-
model.loadWeights()
54

6-
console.log('prediction', model.predict(tf.tensor([ Array(400) ])));
5+
// obscure looking tf operations taken from
6+
// https://github.com/prouhard/tfjs-mountaincar/blob/master/src/js/model.js#L69-L74
7+
const pred = model.predict(tf.tensor([ Array(400) ]));
8+
const sigmoid = tf.sigmoid(pred);
9+
const probs = tf.div(sigmoid, tf.sum(sigmoid));
10+
const end = tf.multinomial(probs, 1).dataSync()[0] - 1;
11+
12+
console.log('probs', end);

src/net/model.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ export const createModel = (inputVectorSize) => {
1111

1212
// Prepare the model for training: Specify the loss and the optimizer.
1313
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
14+
model.summary();
1415
return model;
1516
}

0 commit comments

Comments
 (0)