Skip to content

Commit 77d802f

Browse files
authored
[jena-weather] Add simpleRNN model (#222)
1 parent 6bfbe18 commit 77d802f

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

jena-weather/models.js

+19
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ function buildMLPModel(inputShape, kernelRegularizer, dropoutRate) {
116116
return model;
117117
}
118118

119+
/**
120+
* Build a simpleRNN-based model for the temperature-prediction problem.
121+
*
122+
* @param {tf.Shape} inputShape Input shape (without the batch dimenson).
123+
* @returns {tf.Model} A TensorFlow.js model consisting of a simpleRNN layer.
124+
*/
125+
function buildSimpleRNNModel(inputShape) {
126+
const model = tf.sequential();
127+
const rnnUnits = 32;
128+
model.add(tf.layers.simpleRNN({
129+
units: rnnUnits,
130+
inputShape
131+
}));
132+
model.add(tf.layers.dense({units: 1}));
133+
return model;
134+
}
135+
119136
/**
120137
* Build a GRU model for the temperature-prediction problem.
121138
*
@@ -163,6 +180,8 @@ export function buildModel(modelType, numTimeSteps, numFeatures) {
163180
const regularizer = null;
164181
const dropoutRate = 0.25;
165182
model = buildMLPModel(inputShape, regularizer, dropoutRate);
183+
} else if (modelType === 'simpleRNN') {
184+
model = buildSimpleRNNModel(inputShape);
166185
} else if (modelType === 'gru') {
167186
model = buildGRUModel(inputShape);
168187
// TODO(cais): Add gru-dropout with recurrentDropout.

jena-weather/train-rnn.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function parseArguments() {
3939
parser.addArgument('--modelType', {
4040
type: 'string',
4141
defaultValue: 'gru',
42-
optionStrings: ['gru'],
42+
optionStrings: ['baseline', 'gru', 'simpleRNN'],
4343
// TODO(cais): Add more model types, e.g., gru with recurrent dropout.
4444
help: 'Type of the model to train. Use "baseline" to compute the ' +
4545
'commonsense baseline prediction error.'

0 commit comments

Comments
 (0)