|
| 1 | +/** |
| 2 | +* @license |
| 3 | +* Copyright 2018 Google LLC. All Rights Reserved. |
| 4 | +* |
| 5 | +* Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +* you may not use this file except in compliance with the License. |
| 7 | +* You may obtain a copy of the License at |
| 8 | +* |
| 9 | +* http:// www.apache.org/licenses/LICENSE-2.0 |
| 10 | +* |
| 11 | +* Unless required by applicable law or agreed to in writing, software |
| 12 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +* See the License for the specific language governing permissions and |
| 15 | +* limitations under the License. |
| 16 | +* ============================================================================== |
| 17 | +*/ |
| 18 | + |
| 19 | +// This tiny example illustrates how little code is necessary build / |
| 20 | +// train / predict from a model in TensorFlow.js. Edit this code |
| 21 | +// and refresh the index.html to quickly explore the API. |
| 22 | + |
| 23 | +// Tiny TFJS train / predict example. |
| 24 | +async function myFirstTfjs() { |
| 25 | + // Create a simple model. |
| 26 | + const model = tf.sequential(); |
| 27 | + model.add(tf.layers.dense({units: 1, inputShape: [1]})); |
| 28 | + // Prepare the model for training: Specify the loss and the optimizer. |
| 29 | + model.compile({loss: 'meanSquaredError', |
| 30 | + optimizer: 'sgd', |
| 31 | + useBias: 'true'}); |
| 32 | + // Generate some synthetic data for training. (y = 2x - 1) |
| 33 | + const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]); |
| 34 | + const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]); |
| 35 | + // Train the model using the data. |
| 36 | + await model.fit(xs, ys, {epochs: 250}); |
| 37 | + // Use the model to do inference on a data point the model hasn't seen. |
| 38 | + // Perform within a tf.tidy block to perform cleanup on intermediate tensors. |
| 39 | + tf.tidy(() => { |
| 40 | + // Should print approximately 39. |
| 41 | + document.getElementById('micro_out_div').innerText += model.predict( |
| 42 | + tf.tensor2d([20], [1, 1])); |
| 43 | + }); |
| 44 | + // Manually clean up the memory for these variables. |
| 45 | + xs.dispose(); |
| 46 | + ys.dispose(); |
| 47 | +} |
| 48 | + |
| 49 | +myFirstTfjs(); |
0 commit comments