|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "# High-level Gluon Example" |
| 7 | + "# MXNet/Gluon CNN example" |
8 | 8 | ]
|
9 | 9 | },
|
10 | 10 | {
|
|
71 | 71 | "outputs": [],
|
72 | 72 | "source": [
|
73 | 73 | "def SymbolModule(n_classes=N_CLASSES):\n",
|
74 |
| - " sym = gluon.nn.Sequential()\n", |
| 74 | + " sym = gluon.nn.HybridSequential()\n", |
75 | 75 | " with sym.name_scope():\n",
|
76 | 76 | " sym.add(gluon.nn.Conv2D(channels=50, kernel_size=3, padding=1, activation='relu'))\n",
|
77 | 77 | " sym.add(gluon.nn.Conv2D(channels=50, kernel_size=3, padding=1))\n",
|
|
121 | 121 | "Preparing test set...\n",
|
122 | 122 | "(50000, 3, 32, 32) (10000, 3, 32, 32) (50000,) (10000,)\n",
|
123 | 123 | "float32 float32 int32 int32\n",
|
124 |
| - "CPU times: user 630 ms, sys: 588 ms, total: 1.22 s\n", |
125 |
| - "Wall time: 1.22 s\n" |
| 124 | + "CPU times: user 708 ms, sys: 589 ms, total: 1.3 s\n", |
| 125 | + "Wall time: 1.29 s\n" |
126 | 126 | ]
|
127 | 127 | }
|
128 | 128 | ],
|
|
143 | 143 | "name": "stdout",
|
144 | 144 | "output_type": "stream",
|
145 | 145 | "text": [
|
146 |
| - "CPU times: user 321 ms, sys: 392 ms, total: 713 ms\n", |
147 |
| - "Wall time: 876 ms\n" |
| 146 | + "CPU times: user 345 ms, sys: 421 ms, total: 766 ms\n", |
| 147 | + "Wall time: 768 ms\n" |
148 | 148 | ]
|
149 | 149 | }
|
150 | 150 | ],
|
|
164 | 164 | "name": "stdout",
|
165 | 165 | "output_type": "stream",
|
166 | 166 | "text": [
|
167 |
| - "CPU times: user 203 µs, sys: 128 µs, total: 331 µs\n", |
168 |
| - "Wall time: 337 µs\n" |
| 167 | + "CPU times: user 683 µs, sys: 444 µs, total: 1.13 ms\n", |
| 168 | + "Wall time: 406 µs\n" |
169 | 169 | ]
|
170 | 170 | }
|
171 | 171 | ],
|
|
178 | 178 | "cell_type": "code",
|
179 | 179 | "execution_count": 9,
|
180 | 180 | "metadata": {},
|
| 181 | + "outputs": [], |
| 182 | + "source": [ |
| 183 | + "train_loss = nd.zeros(1, ctx=ctx)" |
| 184 | + ] |
| 185 | + }, |
| 186 | + { |
| 187 | + "cell_type": "code", |
| 188 | + "execution_count": 10, |
| 189 | + "metadata": {}, |
| 190 | + "outputs": [], |
| 191 | + "source": [ |
| 192 | + "train_loss += nd.ones(1, ctx=ctx)" |
| 193 | + ] |
| 194 | + }, |
| 195 | + { |
| 196 | + "cell_type": "code", |
| 197 | + "execution_count": 11, |
| 198 | + "metadata": {}, |
181 | 199 | "outputs": [
|
182 | 200 | {
|
183 | 201 | "name": "stdout",
|
184 | 202 | "output_type": "stream",
|
185 | 203 | "text": [
|
186 |
| - "Epoch 0: loss: 1.8405\n", |
187 |
| - "Epoch 1: loss: 1.3773\n", |
188 |
| - "Epoch 2: loss: 1.1577\n", |
189 |
| - "Epoch 3: loss: 0.9811\n", |
190 |
| - "Epoch 4: loss: 0.8450\n", |
191 |
| - "Epoch 5: loss: 0.7354\n", |
192 |
| - "Epoch 6: loss: 0.6391\n", |
193 |
| - "Epoch 7: loss: 0.5559\n", |
194 |
| - "Epoch 8: loss: 0.4810\n", |
195 |
| - "Epoch 9: loss: 0.4157\n", |
196 |
| - "CPU times: user 1min 18s, sys: 15.3 s, total: 1min 34s\n", |
197 |
| - "Wall time: 1min 2s\n" |
| 204 | + "Epoch 0: loss: 1.8314\n", |
| 205 | + "Epoch 1: loss: 1.3397\n", |
| 206 | + "Epoch 2: loss: 1.1221\n", |
| 207 | + "Epoch 3: loss: 0.9576\n", |
| 208 | + "Epoch 4: loss: 0.8261\n", |
| 209 | + "Epoch 5: loss: 0.7215\n", |
| 210 | + "Epoch 6: loss: 0.6226\n", |
| 211 | + "Epoch 7: loss: 0.5389\n", |
| 212 | + "Epoch 8: loss: 0.4729\n", |
| 213 | + "Epoch 9: loss: 0.4072\n", |
| 214 | + "CPU times: user 1min 5s, sys: 18 s, total: 1min 23s\n", |
| 215 | + "Wall time: 56.6 s\n" |
198 | 216 | ]
|
199 | 217 | }
|
200 | 218 | ],
|
201 | 219 | "source": [
|
202 | 220 | "%%time\n",
|
203 |
| - "# Main training loop: 62s\n", |
| 221 | + "sym.hybridize()\n", |
204 | 222 | "for j in range(EPOCHS):\n",
|
205 |
| - " train_loss = 0.0\n", |
| 223 | + " train_loss = nd.zeros(1, ctx=ctx)\n", |
206 | 224 | " for data, target in yield_mb(x_train, y_train, BATCHSIZE, shuffle=True):\n",
|
207 | 225 | " # Get samples\n",
|
208 | 226 | " data = nd.array(data).as_in_context(ctx)\n",
|
|
215 | 233 | " # Back-prop\n",
|
216 | 234 | " loss.backward()\n",
|
217 | 235 | " trainer.step(data.shape[0])\n",
|
218 |
| - " train_loss += nd.sum(loss).asscalar()\n", |
219 |
| - " # Log\n", |
220 |
| - " print('Epoch %3d: loss: %5.4f'%(j, train_loss/len(x_train)))" |
| 236 | + " train_loss += nd.sum(loss)\n", |
| 237 | + " # Log \n", |
| 238 | + " # Waiting for the operations on the \n", |
| 239 | + " nd.waitall()\n", |
| 240 | + " print('Epoch %3d: loss: %5.4f'%(j, train_loss.asscalar()/len(x_train)))" |
221 | 241 | ]
|
222 | 242 | },
|
223 | 243 | {
|
224 | 244 | "cell_type": "code",
|
225 |
| - "execution_count": 10, |
| 245 | + "execution_count": 12, |
226 | 246 | "metadata": {},
|
227 | 247 | "outputs": [
|
228 | 248 | {
|
229 | 249 | "name": "stdout",
|
230 | 250 | "output_type": "stream",
|
231 | 251 | "text": [
|
232 |
| - "CPU times: user 627 ms, sys: 73.1 ms, total: 700 ms\n", |
233 |
| - "Wall time: 453 ms\n" |
| 252 | + "CPU times: user 382 ms, sys: 115 ms, total: 496 ms\n", |
| 253 | + "Wall time: 429 ms\n" |
234 | 254 | ]
|
235 | 255 | }
|
236 | 256 | ],
|
|
254 | 274 | },
|
255 | 275 | {
|
256 | 276 | "cell_type": "code",
|
257 |
| - "execution_count": 11, |
| 277 | + "execution_count": 13, |
258 | 278 | "metadata": {},
|
259 | 279 | "outputs": [
|
260 | 280 | {
|
261 | 281 | "name": "stdout",
|
262 | 282 | "output_type": "stream",
|
263 | 283 | "text": [
|
264 |
| - "Accuracy: 0.7661258012820513\n" |
| 284 | + "Accuracy: 0.7675280448717948\n" |
265 | 285 | ]
|
266 | 286 | }
|
267 | 287 | ],
|
|
273 | 293 | "metadata": {
|
274 | 294 | "anaconda-cloud": {},
|
275 | 295 | "kernelspec": {
|
276 |
| - "display_name": "Python 3", |
| 296 | + "display_name": "Python [default]", |
277 | 297 | "language": "python",
|
278 | 298 | "name": "python3"
|
279 | 299 | },
|
|
0 commit comments