Skip to content

Commit 5a12e1d

Browse files
yashk2810Copybara-Service
authored andcommitted
Switching to categorical_crossentropy
PiperOrigin-RevId: 234160929
1 parent 8efbd46 commit 5a12e1d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

site/en/r2/tutorials/distribute/keras.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
"num_train_examples = ds_info.splits['train'].num_examples\n",
191191
"num_test_examples = ds_info.splits['test'].num_examples\n",
192192
"\n",
193-
"BUFFER_SIZE = num_train_examples\n",
193+
"BUFFER_SIZE = 10000\n",
194194
"BATCH_SIZE = 64"
195195
]
196196
},
@@ -217,7 +217,8 @@
217217
"def scale(image, label):\n",
218218
" image = tf.cast(image, tf.float32)\n",
219219
" image /= 255\n",
220-
" return image, label"
220+
" \n",
221+
" return image, label[..., tf.newaxis]"
221222
]
222223
},
223224
{
@@ -318,10 +319,10 @@
318319
" tf.keras.layers.Dense(64, activation='relu'),\n",
319320
" tf.keras.layers.Dense(10, activation='softmax')\n",
320321
" ])\n",
321-
" # TODO(yashkatariya): Add accuracy when b/122371345 is fixed.\n",
322+
"\n",
322323
" model.compile(loss='sparse_categorical_crossentropy',\n",
323-
" optimizer=tf.keras.optimizers.Adam())\n",
324-
" #metrics=['accuracy'])"
324+
" optimizer=tf.keras.optimizers.Adam(),\n",
325+
" metrics=['accuracy'])"
325326
]
326327
},
327328
{

0 commit comments

Comments
 (0)