Skip to content

Commit 0228bc5

Browse files
shun-linseanpmorgan
authored andcommitted
Adding TimeStopping to tfa.callback (#757)
* added tutorial and code for TimeStopping
1 parent bf8a809 commit 0228bc5

File tree

5 files changed

+305
-2
lines changed

5 files changed

+305
-2
lines changed

docs/tutorials/time_stopping.ipynb

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"##### Copyright 2019 The TensorFlow Authors."
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 8,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
17+
"# you may not use this file except in compliance with the License.\n",
18+
"# You may obtain a copy of the License at\n",
19+
"#\n",
20+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
21+
"#\n",
22+
"# Unless required by applicable law or agreed to in writing, software\n",
23+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
24+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
25+
"# See the License for the specific language governing permissions and\n",
26+
"# limitations under the License."
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"# TensorFlow Addons Callbacks: TimeStopping"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"metadata": {},
39+
"source": [
40+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
41+
" <td>\n",
42+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/addons/tutorials/time_stopping\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
43+
" </td>\n",
44+
" <td>\n",
45+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
46+
" </td>\n",
47+
" <td>\n",
48+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
49+
" </td>\n",
50+
" <td>\n",
51+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
52+
" </td>\n",
53+
"</table>"
54+
]
55+
},
56+
{
57+
"cell_type": "markdown",
58+
"metadata": {},
59+
"source": [
60+
"## Overview\n",
61+
"This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons."
62+
]
63+
},
64+
{
65+
"cell_type": "markdown",
66+
"metadata": {},
67+
"source": [
68+
"## Setup"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 9,
74+
"metadata": {},
75+
"outputs": [],
76+
"source": [
77+
"!pip install -q --no-deps tensorflow-addons~=0.6"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": 10,
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"try:\n",
87+
" # %tensorflow_version only exists in Colab.\n",
88+
" %tensorflow_version 2.x\n",
89+
"except Exception:\n",
90+
" pass"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": 11,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"import tensorflow as tf\n",
100+
"import tensorflow_addons as tfa\n",
101+
"\n",
102+
"import tensorflow.keras as keras\n",
103+
"from tensorflow.keras.datasets import mnist\n",
104+
"from tensorflow.keras.models import Sequential\n",
105+
"from tensorflow.keras.layers import Dense, Dropout, Flatten"
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"metadata": {},
111+
"source": [
112+
"## Import and Normalize Data"
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": 12,
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"# the data, split between train and test sets\n",
122+
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
123+
"# normalize data\n",
124+
"x_train, x_test = x_train / 255.0, x_test / 255.0"
125+
]
126+
},
127+
{
128+
"cell_type": "markdown",
129+
"metadata": {},
130+
"source": [
131+
"## Build Simple MNIST CNN Model"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 13,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": [
140+
"# build the model using the Sequential API\n",
141+
"model = Sequential()\n",
142+
"model.add(Flatten(input_shape=(28, 28)))\n",
143+
"model.add(Dense(128, activation='relu'))\n",
144+
"model.add(Dropout(0.2))\n",
145+
"model.add(Dense(10, activation='softmax'))\n",
146+
"\n",
147+
"model.compile(optimizer='adam',\n",
148+
" loss = 'sparse_categorical_crossentropy',\n",
149+
" metrics=['accuracy'])"
150+
]
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {},
155+
"source": [
156+
"## Simple TimeStopping Usage"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": 14,
162+
"metadata": {
163+
"scrolled": true
164+
},
165+
"outputs": [
166+
{
167+
"name": "stdout",
168+
"output_type": "stream",
169+
"text": [
170+
"Train on 60000 samples, validate on 10000 samples\n",
171+
"Epoch 1/100\n",
172+
"60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533\n",
173+
"Epoch 2/100\n",
174+
"60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669\n",
175+
"Epoch 3/100\n",
176+
"60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704\n",
177+
"Epoch 4/100\n",
178+
"60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740\n",
179+
"Timed stopping at epoch 4 after training for 0:00:05\n"
180+
]
181+
},
182+
{
183+
"data": {
184+
"text/plain": [
185+
"<tensorflow.python.keras.callbacks.History at 0x110af0ef0>"
186+
]
187+
},
188+
"execution_count": 14,
189+
"metadata": {},
190+
"output_type": "execute_result"
191+
}
192+
],
193+
"source": [
194+
"# initialize TimeStopping callback \n",
195+
"time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)\n",
196+
"\n",
197+
"# train the model with tqdm_callback\n",
198+
"# make sure to set verbose = 0 to disable\n",
199+
"# the default progress bar.\n",
200+
"model.fit(x_train, y_train,\n",
201+
" batch_size=64,\n",
202+
" epochs=100,\n",
203+
" callbacks=[time_stopping_callback],\n",
204+
" validation_data=(x_test, y_test))"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"metadata": {},
211+
"outputs": [],
212+
"source": []
213+
}
214+
],
215+
"metadata": {
216+
"kernelspec": {
217+
"display_name": "Python 3",
218+
"language": "python",
219+
"name": "python3"
220+
},
221+
"language_info": {
222+
"codemirror_mode": {
223+
"name": "ipython",
224+
"version": 3
225+
},
226+
"file_extension": ".py",
227+
"mimetype": "text/x-python",
228+
"name": "python",
229+
"nbconvert_exporter": "python",
230+
"pygments_lexer": "ipython3",
231+
"version": "3.6.2"
232+
}
233+
},
234+
"nbformat": 4,
235+
"nbformat_minor": 2
236+
}

tensorflow_addons/callbacks/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "callbacks",
77
srcs = [
88
"__init__.py",
9+
"time_stopping.py",
910
"tqdm_progress_bar.py",
1011
],
1112
deps = [

tensorflow_addons/callbacks/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
## Maintainers
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
6+
| time_stopping | @shun-lin | [email protected] |
67
| tqdm_progress_bar | @shun-lin | [email protected] |
78

89
## Contents
910
| Submodule | Callback | Reference |
1011
|:----------------------- |:-------------------|:---------------|
12+
| time_stopping | TimeStopping | N/A |
1113
| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ |
1214

13-
1415
## Contribution Guidelines
1516
#### Standard API
1617
In order to conform with the current API standard, all callbacks

tensorflow_addons/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
21+
from tensorflow_addons.callbacks.time_stopping import TimeStopping
22+
from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Callback that stops training when a specified amount of time has passed."""
16+
17+
from __future__ import absolute_import, division, print_function
18+
19+
import datetime
20+
import time
21+
22+
import tensorflow as tf
23+
from tensorflow.keras.callbacks import Callback
24+
25+
26+
@tf.keras.utils.register_keras_serializable(package='Addons')
27+
class TimeStopping(Callback):
28+
"""Stop training when a specified amount of time has passed.
29+
30+
Args:
31+
seconds: maximum amount of time before stopping.
32+
Defaults to 86400 (1 day).
33+
verbose: verbosity mode. Defaults to 0.
34+
"""
35+
36+
def __init__(self, seconds=86400, verbose=0):
37+
super(TimeStopping, self).__init__()
38+
39+
self.seconds = seconds
40+
self.verbose = verbose
41+
42+
def on_train_begin(self, logs=None):
43+
self.stopping_time = time.time() + self.seconds
44+
45+
def on_epoch_end(self, epoch, logs={}):
46+
if time.time() >= self.stopping_time:
47+
self.model.stop_training = True
48+
self.stopped_epoch = epoch
49+
50+
def on_train_end(self, logs=None):
51+
if self.verbose > 0:
52+
formatted_time = datetime.timedelta(seconds=self.seconds)
53+
msg = 'Timed stopping at epoch {} after training for {}'.format(
54+
self.stopped_epoch + 1, formatted_time)
55+
print(msg)
56+
57+
def get_config(self):
58+
config = {
59+
'seconds': self.seconds,
60+
'verbose': self.verbose,
61+
}
62+
63+
base_config = super(TimeStopping, self).get_config()
64+
return dict(list(base_config.items()) + list(config.items()))

0 commit comments

Comments
 (0)