Skip to content

Commit 6bfbe18

Browse files
authored
[simple-object-detection] A few improvements (#224)
- Add a command line flag for switching between tfjs-node and tfjs-node-gpu - Change README.md accordingly - Add live demo link to README.md - Remove unnecessary tensor creation, which causes problems because tfjs-node/tfjs-node-gpu is imported
1 parent eb46148 commit 6bfbe18

File tree

4 files changed

+62
-57
lines changed

4 files changed

+62
-57
lines changed

simple-object-detection/README.md

+8-13
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ are too many training examples generated. In the meantime, having a large
3535
number of training examples benefits the accuracy of the model after
3636
training. The default number of examples is 2000. You can adjust the number
3737
of examples by using the `--numExamples` flag of the `yarn train` command.
38-
For example, the hosted model is trained with the 10000 examples, using
38+
For example, the hosted model is trained with the 20000 examples, using
3939
the command line:
4040

4141
```sh
@@ -52,18 +52,13 @@ See `train.js` for other adjustable parameters.
5252
Note that by default, the model is trained using the CPU version of tfjs-node.
5353
If you machine is equipped with a CUDA(R) GPU, you may switch to using
5454
tfjs-node-gpu, which will significantly shorten the training time. Specifically,
55-
in `package.json`, change the dependency `tfjs-node` to `tfjs-node-gpu`. Then,
56-
in `train.js`, change the line
55+
add the `--gpu` flag to the command above, i.e.,
5756

58-
```js
59-
require('@tensorflow/tfjs-node');
60-
```
61-
62-
to
63-
64-
```js
65-
require('@tensorflow/tfjs-node-gpu');
57+
```sh
58+
yarn train --gpu \
59+
--numExamples 20000 \
60+
--initialTransferEpochs 100 \
61+
--fineTuningEpochs 200
6662
```
6763

68-
TODO(cais): Add the link below.
69-
[See this example live!](./README.md)
64+
[See this example live!](https://storage.googleapis.com/tfjs-examples/simple-object-detection/dist/index.html)

simple-object-detection/package.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs-node": "0.2.1",
12+
"@tensorflow/tfjs": "0.14.2",
1313
"argparse": "^1.0.10",
1414
"canvas": "2.0.1",
1515
"node-fetch": "2.2.1"
@@ -22,6 +22,8 @@
2222
"train": "node train.js"
2323
},
2424
"devDependencies": {
25+
"@tensorflow/tfjs-node": "0.2.1",
26+
"@tensorflow/tfjs-node-gpu": "0.2.1",
2527
"babel-core": "^6.26.3",
2628
"babel-plugin-transform-runtime": "~6.23.0",
2729
"babel-polyfill": "~6.26.0",

simple-object-detection/train.js

+13-13
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,6 @@ const argparse = require('argparse');
2222
const canvas = require('canvas');
2323
const tf = require('@tensorflow/tfjs');
2424
const synthesizer = require('./synthetic_images');
25-
const fetch = require('node-fetch');
26-
27-
// To train the model using CUDA/CuDNN,
28-
// 1) Make sure you have a CUDA-enabled GPU on your system.
29-
// 2) Install the necessary NVIDIA driver, CUDA toolkit and CuDNN library.
30-
// 3) Change the "@tensorflow/tfjs-node" dependency to
31-
// "@tensorflow/tfjs-node-gpu" in package.json.
32-
// 4) Change the following line to:
33-
// require('@tensorflow/tfjs-node-gpu');
34-
require('@tensorflow/tfjs-node');
35-
36-
global.fetch = fetch;
3725

3826
const CANVAS_SIZE = 224; // Matches the input size of MobileNet.
3927

@@ -47,7 +35,7 @@ const topLayerName =
4735
// Used to scale the first column (0-1 shape indicator) of `yTrue`
4836
// in order to ensure balanced contributions to the final loss value
4937
// from shape and bounding-box predictions.
50-
const LABEL_MULTIPLIER = tf.tensor1d([CANVAS_SIZE, 1, 1, 1, 1]);
38+
const LABEL_MULTIPLIER = [CANVAS_SIZE, 1, 1, 1, 1];
5139

5240
/**
5341
* Custom loss function for object detection.
@@ -155,6 +143,10 @@ async function buildObjectDetectionModel() {
155143
const numLines = 10;
156144

157145
const parser = new argparse.ArgumentParser();
146+
parser.addArgument('--gpu', {
147+
action: 'storeTrue',
148+
help: "Use tfjs-node-gpu for training (required CUDA and CuDNN)"
149+
});
158150
parser.addArgument(
159151
'--numExamples',
160152
{type: 'int', defaultValue: 2000, help: 'Number of training exapmles'});
@@ -181,6 +173,14 @@ async function buildObjectDetectionModel() {
181173
});
182174
const args = parser.parseArgs();
183175

176+
if (args.gpu) {
177+
console.log('Training using GPU.');
178+
require('@tensorflow/tfjs-node-gpu');
179+
} else {
180+
console.log('Training using CPU.');
181+
require('@tensorflow/tfjs-node');
182+
}
183+
184184
const modelSaveURL = 'file://./dist/object_detection_model';
185185

186186
const tBegin = tf.util.now();

simple-object-detection/yarn.lock

+38-30
Original file line numberDiff line numberDiff line change
@@ -688,38 +688,51 @@
688688
version "1.1.0"
689689
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
690690

691-
"@tensorflow/[email protected].1":
692-
version "0.7.1"
693-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-0.7.1.tgz#af33b89f34fdfbdfbe8d9361c5542d7db514364d"
694-
integrity sha512-kMQqM3GI5bBl6YQ98jCJNu49NvAWGeI3iVxO3ZqOYtN90lb/+3dSBelDo2LHFXc8jnJHpFOwkFPSZlCVrVGRag==
691+
"@tensorflow/[email protected].2":
692+
version "0.7.2"
693+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-0.7.2.tgz#49e578f71eb82d821af05176754c3452b42cfe9c"
694+
integrity sha512-m46mtaF57x2NcxlNUKdJOCUp3ZSJU9bp9MzyEQ0Iz1bW2kKIxx1DDRjuP0fAeHX5H5Mh/tWIHB9yK6NwLz+aQQ==
695695
dependencies:
696696
"@types/long" "~3.0.32"
697697
protobufjs "~6.8.6"
698698

699-
"@tensorflow/[email protected].2":
700-
version "0.14.2"
701-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.14.2.tgz#b1ee6af0d893782a1c3b2c988091c36c42f1d7a4"
702-
integrity sha512-VVbcu6H3ioKCkfkep/gQASfzPnQt3C5v+4ppH9pQ6Lf0lD+l3NMuMJYxa8Wjac1TfiWhFEX58bJvhpMfTGsUlg==
699+
"@tensorflow/[email protected].5":
700+
version "0.14.5"
701+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.14.5.tgz#17c3beeec31c4cd92b0f79a5ef30c4975a11e408"
702+
integrity sha512-CSUgKuC17J1Ylr1s6iD1k2/tJr9lD16sUEjtzJbtiuTYCELOwujGK/1htunA7o3BwLuU7aqEI92MoKElEKa7qA==
703703
dependencies:
704704
"@types/seedrandom" "2.4.27"
705705
"@types/webgl-ext" "0.0.30"
706706
"@types/webgl2" "0.0.4"
707707
seedrandom "2.4.3"
708708

709-
"@tensorflow/[email protected].4":
710-
version "0.1.4"
711-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-0.1.4.tgz#4732a2505c5d7d0d8ba0458f5b92837fd02ba442"
712-
integrity sha512-YwaWNZnJj++QFEHQ1AKLqn2fvnmSp1X6CJ5YL5XJhq+m8P0AoouW9IpumCgO6WSjnD1M83/cVGZXzDIgJ4IlLg==
709+
"@tensorflow/[email protected].7":
710+
version "0.1.7"
711+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-data/-/tfjs-data-0.1.7.tgz#8a4e43313b3d63cdfab719c0c1c47ced2ef321e3"
712+
integrity sha512-RENjeBdBLq7GS9594kQx2GbM0WQV16VfxzzB0j2sq5vJh9GZQi2DB5Emq2LqZWs5rSeh7PDHZylGOn/ve6f8PA==
713713
dependencies:
714714
"@types/node-fetch" "^2.1.2"
715715
node-fetch "~2.1.2"
716716
seedrandom "~2.4.3"
717-
utf8 "~2.1.2"
718717

719-
"@tensorflow/[email protected]":
720-
version "0.9.1"
721-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-0.9.1.tgz#685b741cf5c42aa9b2621485ffafdcc43c8db699"
722-
integrity sha512-TZTi59E3rGdVTJCf/AEX1arigPp0426FJLGfiKzTXp3skEuubwDM9XlwiXWcB5+l+Pjvwg4FMNXwEyajXIxX2w==
718+
"@tensorflow/[email protected]":
719+
version "0.9.2"
720+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-layers/-/tfjs-layers-0.9.2.tgz#f5c1918d1a9660096f259cd1f99f59e689b41b69"
721+
integrity sha512-peB824cEXRBy5IgZPIodd8zpQ/54VGOYbR+zY+Q1Le7v3Np05EoDcL8Z98MtpBHo6jOM7b/3Lf2zjfJVv2qxJA==
722+
723+
"@tensorflow/[email protected]":
724+
version "0.2.1"
725+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-node-gpu/-/tfjs-node-gpu-0.2.1.tgz#f8843ff3dcdab3f00b642375e08e9705bd726039"
726+
integrity sha512-S43gOH0wA3qyv7Lc8MsQtoO/eFQNnRXi5TJmwhR4HsBUXzaJKzzwF8kOI9G22V2JBOZbsOkOgJ2Rfl9DsT2skg==
727+
dependencies:
728+
"@tensorflow/tfjs" "~0.14.1"
729+
adm-zip "^0.4.11"
730+
bindings "~1.3.0"
731+
https-proxy-agent "^2.2.1"
732+
node-fetch "^2.3.0"
733+
progress "^2.0.0"
734+
rimraf "^2.6.2"
735+
tar "^4.4.6"
723736

724737
"@tensorflow/[email protected]":
725738
version "0.2.1"
@@ -735,15 +748,15 @@
735748
rimraf "^2.6.2"
736749
tar "^4.4.6"
737750

738-
"@tensorflow/tfjs@~0.14.1":
739-
version "0.14.1"
740-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-0.14.1.tgz#d9b198c26eb1f154c5c27228800c9df7630d2cbf"
741-
integrity sha512-fo68B4FNh/JBYAjArlLFsx2etE1kxCHG11bzsy01b2jZLrM9HE+jgmiNUGuHqN7aLu/CVqB76wDAtlBT4AnPsg==
751+
"@tensorflow/tfjs@0.14.2", "@tensorflow/tfjs@~0.14.1":
752+
version "0.14.2"
753+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs/-/tfjs-0.14.2.tgz#f38fa572286dadfe981c219f5639defd586c20c4"
754+
integrity sha512-d+kBdhn3L/BOIwwc44V1lUrs0O5s49ujhYXVHT9Hs6y3yq+OqPK10am16H1fNcxeMn12/3gGphebglObTD0/Sg==
742755
dependencies:
743-
"@tensorflow/tfjs-converter" "0.7.1"
744-
"@tensorflow/tfjs-core" "0.14.2"
745-
"@tensorflow/tfjs-data" "0.1.4"
746-
"@tensorflow/tfjs-layers" "0.9.1"
756+
"@tensorflow/tfjs-converter" "0.7.2"
757+
"@tensorflow/tfjs-core" "0.14.5"
758+
"@tensorflow/tfjs-data" "0.1.7"
759+
"@tensorflow/tfjs-layers" "0.9.2"
747760

748761
"@types/long@^4.0.0":
749762
version "4.0.0"
@@ -5420,11 +5433,6 @@ user-home@^2.0.0:
54205433
dependencies:
54215434
os-homedir "^1.0.0"
54225435

5423-
utf8@~2.1.2:
5424-
version "2.1.2"
5425-
resolved "https://registry.yarnpkg.com/utf8/-/utf8-2.1.2.tgz#1fa0d9270e9be850d9b05027f63519bf46457d96"
5426-
integrity sha1-H6DZJw6b6FDZsFAn9jUZv0ZFfZY=
5427-
54285436
util-deprecate@^1.0.1, util-deprecate@~1.0.1:
54295437
version "1.0.2"
54305438
resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf"

0 commit comments

Comments
 (0)