Skip to content

Commit d5b6da3

Browse files
authored
update baseball-node example to add single-sample prediction (#282)
* update baseball-node example to add single-sample prediction
1 parent b12c78f commit d5b6da3

File tree

6 files changed

+144
-94
lines changed

6 files changed

+144
-94
lines changed

baseball-node/client.js

+31-11
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,50 @@
1616
*/
1717

1818
import io from 'socket.io-client';
19-
const evalTestButton = document.getElementById('eval-test-button');
19+
const predictContainer = document.getElementById('predictContainer');
20+
const predictButton = document.getElementById('predict-button');
2021

2122
const socket =
22-
io('http://localhost:8001',
23-
{reconnectionDelay: 300, reconnectionDelayMax: 300});
23+
io('http://localhost:8001',
24+
{ reconnectionDelay: 300, reconnectionDelayMax: 300 });
2425

2526
const BAR_WIDTH_PX = 300;
2627

27-
evalTestButton.onclick = () => {
28-
evalTestButton.textContent = 'Loading...';
29-
socket.emit('test_data', 'true');
28+
const testSample = [2.668, -114.333, -1.908, 4.786, 25.707, -45.21, 78, 0];
29+
30+
predictButton.onclick = () => {
31+
predictButton.disabled = true;
32+
socket.emit('predictSample', testSample);
3033
};
3134

35+
// functions to handle socket events
3236
socket.on('connect', () => {
33-
evalTestButton.style.display = 'block';
34-
evalTestButton.textContent = 'Eval Test';
37+
document.getElementById('trainingStatus').innerHTML = 'Training in Progress';
3538
});
3639

3740
socket.on('accuracyPerClass', (accPerClass) => {
3841
plotAccuracyPerClass(accPerClass);
3942
});
4043

44+
socket.on('trainingComplete', () => {
45+
document.getElementById('trainingStatus').innerHTML = 'Training Complete';
46+
document.getElementById('predictSample').innerHTML = '[' + testSample.join(', ') + ']';
47+
predictContainer.style.display = 'block';
48+
});
49+
50+
socket.on('predictResult', (result) => {
51+
plotPredictResult(result);
52+
});
53+
4154
socket.on('disconnect', () => {
42-
evalTestButton.style.display = 'block';
55+
document.getElementById('trainingStatus').innerHTML = '';
56+
predictContainer.style.display = 'none';
4357
document.getElementById('waiting-msg').style.display = 'block';
4458
document.getElementById('table').style.display = 'none';
4559
});
4660

61+
// functions to update display
4762
function plotAccuracyPerClass(accPerClass) {
48-
console.log(accPerClass);
4963
document.getElementById('table').style.display = 'block';
5064
document.getElementById('waiting-msg').style.display = 'none';
5165

@@ -75,7 +89,6 @@ function plotAccuracyPerClass(accPerClass) {
7589

7690
plotScoreBar(scores.training, scoreContainer);
7791
if (scores.validation) {
78-
document.getElementById('eval-test-button').style.display = 'none';
7992
plotScoreBar(scores.validation, scoreContainer, 'validation');
8093
}
8194
});
@@ -88,3 +101,10 @@ function plotScoreBar(score, container, className = '') {
88101
scoreDiv.innerHTML = (score * 100).toFixed(1);
89102
container.appendChild(scoreDiv);
90103
}
104+
105+
function plotPredictResult(result) {
106+
predictButton.textContent = 'Predict Pitch';
107+
predictButton.disabled = false;
108+
document.getElementById('predictResult').innerHTML = result;
109+
console.log(result);
110+
}

baseball-node/index.html

+15-76
Original file line numberDiff line numberDiff line change
@@ -2,96 +2,35 @@
22
<html>
33
<head>
44
<title>Pitch Training Accuracy</title>
5+
<link rel="stylesheet" href="styles.css">
56
</head>
67
<body>
78
<h3 id="waiting-msg">Waiting for server...</h3>
89
<div id="table">
9-
<h2 style="text-align:center;">Pitch accuracy (%)</h2>
10+
<h2>Pitch Accuracy By Class (%)</h2>
1011
<div id="legend">
1112
<div class="legend-item">
1213
<div class="score"></div>
13-
<div>Train set</div>
14+
<div>Training set</div>
1415
</div>
1516
<div class="legend-item">
1617
<div class="score validation"></div>
17-
<div>Live set</div>
18+
<div>Test set</div>
1819
</div>
1920
</div>
2021
<div id="table-rows"></div>
21-
<button id="eval-test-button">Eval Test</button>
22+
</div>
23+
<p>
24+
<hr>
25+
<p>
26+
<span style="font-size:16px" id="trainingStatus"></span>
27+
<p>
28+
<div id="predictContainer" style="font-size:16px;display:none">
29+
Sample sensor data: <span id="predictSample"></span>
30+
<button style="font-size:18px;padding:5px;margin-right:10px" id="predict-button">Predict Pitch</button>
31+
<p>
32+
Predicted Pitch Type: <span style="font-weight:bold" id="predictResult"></span>
2233
</div>
2334
<script src="dist/bundle.js"></script>
24-
<style>
25-
#table {
26-
width: 660px;
27-
display: none;
28-
}
29-
#table-rows {
30-
border-right: 2px solid #bbb;
31-
}
32-
#table .row {
33-
display: flex;
34-
align-items: center;
35-
margin: 25px 0;
36-
}
37-
#legend {
38-
position: absolute;
39-
}
40-
.legend-item {
41-
display: flex;
42-
align-items: center;
43-
margin-bottom: 20px;
44-
}
45-
46-
.legend-item .score {
47-
width: 30px;
48-
margin-right: 10px;
49-
}
50-
51-
.label {
52-
text-align: center;
53-
font-family: "Google Sans", sans-serif;
54-
font-size: 24px;
55-
color: #5f6368;
56-
line-height: 24px;
57-
font-weight: 500;
58-
}
59-
#table .label {
60-
margin-right: 20px;
61-
width: 360px;
62-
text-align: right;
63-
}
64-
#table .score {
65-
background-color: #0277bd;
66-
height: 30px;
67-
text-align: right;
68-
line-height: 30px;
69-
color: white;
70-
padding-right: 10px;
71-
box-sizing: border-box;
72-
}
73-
#table .score.validation {
74-
background-color: #ef6c00;
75-
}
76-
77-
html,
78-
body {
79-
font-family: Roboto, sans-serif;
80-
color: #5f6368;
81-
}
82-
83-
body {
84-
background-color: rgb(248, 249, 250);
85-
}
86-
87-
#accuracyCanvas > div {
88-
display: none;
89-
}
90-
91-
#eval-test-button {
92-
padding: 10px;
93-
font-size: 24px;
94-
}
95-
</style>
9635
</body>
9736
</html>

baseball-node/package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
"clang-format": "~1.2.3",
1212
"mkdirp": "~0.5.1",
1313
"webpack": "~4.28.4",
14-
"webpack-cli": "^3.2.1",
15-
"webpack-dev-server": "~3.1.14"
14+
"webpack-cli": "^3.3.2",
15+
"webpack-dev-server": "~3.4.1"
1616
},
1717
"dependencies": {
18-
"@tensorflow/tfjs-node": "^1.0.1",
18+
"@tensorflow/tfjs-node": "^1.1.2",
1919
"argparse": "^1.0.10",
2020
"socket.io": "~2.2.0"
2121
}

baseball-node/pitch_type.js

+15
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,20 @@ async function evaluate(useTestData) {
115115
return results;
116116
}
117117

118+
async function predictSample(sample) {
119+
console.log('calling predictSample on ', sample);
120+
let result = model.predict(tf.tensor(sample, [1,sample.length])).arraySync();
121+
console.log(result);
122+
var maxValue = 0;
123+
var predictedPitch = 7;
124+
for (var i = 0; i < NUM_PITCH_CLASSES; i++) {
125+
if (result[0][i] > maxValue) {
126+
predictedPitch = i;
127+
}
128+
}
129+
return pitchFromClassNum(predictedPitch);
130+
}
131+
118132
// Determines accuracy evaluation for a given pitch class by index:
119133
function calcPitchClassEval(pitchIndex, classSize, values) {
120134
// Output has 7 different class values for each pitch, offset based on
@@ -154,6 +168,7 @@ module.exports = {
154168
evaluate,
155169
model,
156170
pitchFromClassNum,
171+
predictSample,
157172
testValidationData,
158173
trainingData,
159174
TEST_DATA_LENGTH

baseball-node/server.js

+15-4
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,37 @@ async function run() {
2929
const port = process.env.PORT || PORT;
3030
const server = http.createServer();
3131
const io = socketio(server);
32-
let useTestData = false;
32+
let useTestData = true;
3333

3434
server.listen(port, () => {
35-
console.log(` > Running socket on port: ${port}`);
35+
console.log(`Running socket on port: ${port}`);
3636
});
3737

3838
io.on('connection', (socket) => {
3939
socket.on('test_data', (value) => {
4040
useTestData = value === 'true' ? true : false;
4141
});
42+
43+
socket.on('predictSample', async (sample) => {
44+
console.log('received predict request');
45+
io.emit('predictResult', await pitch_type.predictSample(sample));
46+
});
4247
});
4348

4449
io.emit('accuracyPerClass', await pitch_type.evaluate(useTestData));
4550
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
4651

47-
while (true) {
48-
await pitch_type.model.fitDataset(pitch_type.trainingData, {epochs: 1});
52+
let numTrainingIterations = 10;
53+
for (var i = 0; i < numTrainingIterations; i++) {
54+
console.log(`Training iteration : ${i + 1} / ${numTrainingIterations}`);
55+
await pitch_type.model.fitDataset(pitch_type.trainingData, { epochs: 1 });
4956
io.emit('accuracyPerClass', await pitch_type.evaluate(useTestData));
5057
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
5158
}
59+
60+
io.emit('trainingComplete', true);
61+
console.log('training complete');
62+
5263
}
5364

5465
run();

baseball-node/styles.css

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#table {
2+
width: 660px;
3+
display: none;
4+
}
5+
#table-rows {
6+
border-right: 2px solid #bbb;
7+
}
8+
#table .row {
9+
display: flex;
10+
align-items: center;
11+
margin: 25px 0;
12+
}
13+
#legend {
14+
position: absolute;
15+
}
16+
.legend-item {
17+
display: flex;
18+
align-items: center;
19+
margin-bottom: 20px;
20+
}
21+
22+
.legend-item .score {
23+
width: 30px;
24+
margin-right: 10px;
25+
}
26+
27+
.label {
28+
text-align: center;
29+
font-family: "Google Sans", sans-serif;
30+
font-size: 24px;
31+
color: #5f6368;
32+
line-height: 24px;
33+
font-weight: 500;
34+
}
35+
#table .label {
36+
margin-right: 20px;
37+
width: 360px;
38+
text-align: right;
39+
}
40+
#table .score {
41+
background-color: #0277bd;
42+
height: 30px;
43+
text-align: right;
44+
line-height: 30px;
45+
color: white;
46+
padding-right: 10px;
47+
box-sizing: border-box;
48+
}
49+
#table .score.validation {
50+
background-color: #ef6c00;
51+
}
52+
53+
html,
54+
body {
55+
font-family: Roboto, sans-serif;
56+
color: #5f6368;
57+
}
58+
59+
body {
60+
background-color: rgb(248, 249, 250);
61+
}
62+
63+
#accuracyCanvas > div {
64+
display: none;
65+
}

0 commit comments

Comments
 (0)