1
- {
2
- "cells" : [
3
- {
4
- "cell_type" : " code" ,
5
- "execution_count" : null ,
6
- "metadata" : {},
7
- "outputs" : [],
8
- "source" : [
9
- " import torch\n " ,
10
- " import torch.nn as nn\n " ,
11
- " import torch.nn.functional as F\n " ,
12
- " from utils.draw import draw_digits\n " ,
13
- " from torch.utils.data import DataLoader\n " ,
14
- " from torch.optim.lr_scheduler import StepLR\n " ,
15
- " from torchvision import datasets, transforms"
16
- ]
17
- },
18
- {
19
- "cell_type" : " code" ,
20
- "execution_count" : null ,
21
- "metadata" : {},
22
- "outputs" : [],
23
- "source" : [
24
- " class CNN(nn.Module):\n " ,
25
- " def __init__(self):\n " ,
26
- " super(CNN, self).__init__()\n " ,
27
- " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n " ,
28
- " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n " ,
29
- " self.conv2_drop = nn.Dropout2d()\n " ,
30
- " self.fc1 = nn.Linear(320, 50)\n " ,
31
- " self.fc2 = nn.Linear(50, 10)\n " ,
32
- " \n " ,
33
- " def forward(self, x):\n " ,
34
- " x = x.view(-1, 1, 28, 28)\n " ,
35
- " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n " ,
36
- " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n " ,
37
- " x = x.view(-1, 320)\n " ,
38
- " x = F.relu(self.fc1(x))\n " ,
39
- " x = F.dropout(x, training=self.training)\n " ,
40
- " x = self.fc2(x)\n " ,
41
- " return F.softmax(x, dim=1)"
42
- ]
43
- },
44
- {
45
- "cell_type" : " code" ,
46
- "execution_count" : null ,
47
- "metadata" : {},
48
- "outputs" : [],
49
- "source" : [
50
- " digits = datasets.MNIST('data', download=True,\n " ,
51
- " transform=transforms.Compose([\n " ,
52
- " transforms.ToTensor(),\n " ,
53
- " transforms.Lambda(lambda x: x.reshape(28*28))\n " ,
54
- " ]),\n " ,
55
- " target_transform=transforms.Compose([\n " ,
56
- " transforms.Lambda(lambda y: \n " ,
57
- " torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n " ,
58
- " ])\n " ,
59
- " )"
60
- ]
61
- },
62
- {
63
- "cell_type" : " code" ,
64
- "execution_count" : null ,
65
- "metadata" : {},
66
- "outputs" : [],
67
- "source" : [
68
- " draw_digits(digits)"
69
- ]
70
- },
71
- {
72
- "cell_type" : " code" ,
73
- "execution_count" : null ,
74
- "metadata" : {},
75
- "outputs" : [],
76
- "source" : [
77
- " torch.cuda.is_available()"
78
- ]
79
- },
80
- {
81
- "cell_type" : " code" ,
82
- "execution_count" : null ,
83
- "metadata" : {},
84
- "outputs" : [],
85
- "source" : [
86
- " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n " ,
87
- " \n " ,
88
- " # Use the nn package to define our model and loss function.\n " ,
89
- " model = CNN()\n " ,
90
- " model = model.to(device)\n " ,
91
- " \n " ,
92
- " cost = torch.nn.BCELoss()\n " ,
93
- " \n " ,
94
- " # optimizer which Tensors it should update.\n " ,
95
- " learning_rate = 1e-3\n " ,
96
- " optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n " ,
97
- " scheduler = StepLR(optimizer, 3)\n " ,
98
- " \n " ,
99
- " # dataset!\n " ,
100
- " dataloader = DataLoader(digits, batch_size=64, num_workers=0, pin_memory=True)\n " ,
101
- " \n " ,
102
- " epochs = 10"
103
- ]
104
- },
105
- {
106
- "cell_type" : " code" ,
107
- "execution_count" : null ,
108
- "metadata" : {
109
- "scrolled" : false
110
- },
111
- "outputs" : [],
112
- "source" : [
113
- " for t in range(epochs):\n " ,
114
- " print('\\ nepoch {}, (lr: {:>.1e})'.format(t, scheduler.get_lr()[0]))\n " ,
115
- " print('-------------------------------')\n " ,
116
- " for batch, (X, Y) in enumerate(dataloader):\n " ,
117
- " X, Y = X.to(device), Y.to(device)\n " ,
118
- " optimizer.zero_grad()\n " ,
119
- " pred = model(X)\n " ,
120
- " loss = cost(pred, Y)\n " ,
121
- " loss.backward()\n " ,
122
- " optimizer.step()\n " ,
123
- " \n " ,
124
- " if batch % 100 == 0:\n " ,
125
- " print('loss: {:>10f} [{:>5d}/{:>5d}]'.format(loss.item(), batch * len(X), len(dataloader.dataset)))\n " ,
126
- " \n " ,
127
- " # step scheduler\n " ,
128
- " scheduler.step()"
129
- ]
130
- },
131
- {
132
- "cell_type" : " markdown" ,
133
- "metadata" : {},
134
- "source" : [
135
- " # Does it work???"
136
- ]
137
- },
138
- {
139
- "cell_type" : " code" ,
140
- "execution_count" : null ,
141
- "metadata" : {},
142
- "outputs" : [],
143
- "source" : [
144
- " test_data = datasets.MNIST('data', train=False, download=True,\n " ,
145
- " transform=transforms.Compose([\n " ,
146
- " transforms.ToTensor(),\n " ,
147
- " transforms.Lambda(lambda x: x.reshape(28*28))\n " ,
148
- " ]),\n " ,
149
- " target_transform=transforms.Compose([\n " ,
150
- " transforms.Lambda(lambda y: \n " ,
151
- " torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n " ,
152
- " ])\n " ,
153
- " )\n " ,
154
- " test_loader = DataLoader(digits, batch_size=64, num_workers=0, pin_memory=True)"
155
- ]
156
- },
157
- {
158
- "cell_type" : " code" ,
159
- "execution_count" : null ,
160
- "metadata" : {},
161
- "outputs" : [],
162
- "source" : [
163
- " model.eval()\n " ,
164
- " test_loss = 0\n " ,
165
- " correct = 0\n " ,
166
- " with torch.no_grad():\n " ,
167
- " for batch, (X, Y) in enumerate(test_loader):\n " ,
168
- " X, Y = X.to(device), Y.to(device)\n " ,
169
- " pred = model(X)\n " ,
170
- " \n " ,
171
- " test_loss += cost(pred, Y).item()\n " ,
172
- " correct += (pred.argmax(1) == Y.argmax(1)).type(torch.float).sum().item()\n " ,
173
- " \n " ,
174
- " test_loss /= len(dataloader.dataset)\n " ,
175
- " correct /= len(dataloader.dataset)\n " ,
176
- " print('Test Error:')\n " ,
177
- " print('acc: {:>0.1f}%, avg loss: {:>8f}'.format(100*correct, test_loss))"
178
- ]
179
- },
180
- {
181
- "cell_type" : " code" ,
182
- "execution_count" : null ,
183
- "metadata" : {},
184
- "outputs" : [],
185
- "source" : [
186
- " len(test_data)"
187
- ]
188
- },
189
- {
190
- "cell_type" : " markdown" ,
191
- "metadata" : {},
192
- "source" : [
193
- " # Saving Things!"
194
- ]
195
- },
196
- {
197
- "cell_type" : " code" ,
198
- "execution_count" : null ,
199
- "metadata" : {},
200
- "outputs" : [],
201
- "source" : [
202
- " import torch.onnx as onnx\n " ,
203
- " \n " ,
204
- " # create dummy variable to traverse graph\n " ,
205
- " x = torch.randint(255, (1, 28*28), dtype=torch.float).to(device) / 255\n " ,
206
- " onnx.export(model, x, 'superfile.onnx')\n " ,
207
- " print('Saved onnx model to \" superfile.onnx\" ')"
208
- ]
209
- },
210
- {
211
- "cell_type" : " code" ,
212
- "execution_count" : null ,
213
- "metadata" : {},
214
- "outputs" : [],
215
- "source" : []
216
- }
217
- ],
218
- "metadata" : {
219
- "kernelspec" : {
220
- "display_name" : " Python 3" ,
221
- "language" : " python" ,
222
- "name" : " python3"
223
- },
224
- "language_info" : {
225
- "codemirror_mode" : {
226
- "name" : " ipython" ,
227
- "version" : 3
228
- },
229
- "file_extension" : " .py" ,
230
- "mimetype" : " text/x-python" ,
231
- "name" : " python" ,
232
- "nbconvert_exporter" : " python" ,
233
- "pygments_lexer" : " ipython3" ,
234
- "version" : " 3.7.3"
235
- },
236
- "varInspector" : {
237
- "cols" : {
238
- "lenName" : 16 ,
239
- "lenType" : 16 ,
240
- "lenVar" : 40
241
- },
242
- "kernels_config" : {
243
- "python" : {
244
- "delete_cmd_postfix" : " " ,
245
- "delete_cmd_prefix" : " del " ,
246
- "library" : " var_list.py" ,
247
- "varRefreshCmd" : " print(var_dic_list())"
248
- },
249
- "r" : {
250
- "delete_cmd_postfix" : " ) " ,
251
- "delete_cmd_prefix" : " rm(" ,
252
- "library" : " var_list.r" ,
253
- "varRefreshCmd" : " cat(var_dic_list()) "
254
- }
255
- },
256
- "types_to_exclude" : [
257
- " module" ,
258
- " function" ,
259
- " builtin_function_or_method" ,
260
- " instance" ,
261
- " _Feature"
262
- ],
263
- "window_display" : false
264
- }
265
- },
266
- "nbformat" : 4 ,
267
- "nbformat_minor" : 2
268
- }
1
+ {"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.draw import draw_digits\nfrom torch.utils.data import DataLoader\nfrom torch.optim.lr_scheduler import StepLR\nfrom torchvision import datasets, transforms"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"class CNN(nn.Module):\n def __init__(self):\n super(CNN, self).__init__()\n self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n self.conv2_drop = nn.Dropout2d()\n self.fc1 = nn.Linear(320, 50)\n self.fc2 = nn.Linear(50, 10)\n\n def forward(self, x):\n x = x.view(-1, 1, 28, 28)\n x = F.relu(F.max_pool2d(self.conv1(x), 2))\n x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n x = x.view(-1, 320)\n x = F.relu(self.fc1(x))\n x = F.dropout(x, training=self.training)\n x = self.fc2(x)\n return F.softmax(x, dim=1)"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"digits = datasets.MNIST('data', download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Lambda(lambda x: x.reshape(28*28))\n ]),\n target_transform=transforms.Compose([\n transforms.Lambda(lambda y: \n torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n ])\n )"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"draw_digits(digits)"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"torch.cuda.is_available()"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# Use the nn package to define our model and loss function.\nmodel = CNN()\nmodel = model.to(device)\n\ncost = torch.nn.BCELoss()\n\n# optimizer which Tensors it should update.\nlearning_rate = 1e-3\noptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\nscheduler = StepLR(optimizer, 3)\n\n# dataset!\ndataloader = DataLoader(digits, batch_size=64, num_workers=0, pin_memory=True)\n\nepochs = 10"},{"cell_type":"code","execution_count":null,"metadata":{"scrolled":false},"outputs":[],"source":"for t in range(epochs):\n print('\\nepoch {}, (lr: {:>.1e})'.format(t, scheduler.get_lr()[0]))\n print('-------------------------------')\n for batch, (X, Y) in enumerate(dataloader):\n X, Y = X.to(device), Y.to(device)\n optimizer.zero_grad()\n pred = model(X)\n loss = cost(pred, Y)\n loss.backward()\n optimizer.step()\n \n if batch % 100 == 0:\n print('loss: {:>10f} [{:>5d}/{:>5d}]'.format(loss.item(), batch * len(X), len(dataloader.dataset)))\n \n # step scheduler\n scheduler.step()"},{"cell_type":"markdown","metadata":{},"source":["# Does it work???"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"test_data = datasets.MNIST('data', train=False, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Lambda(lambda x: x.reshape(28*28))\n ]),\n target_transform=transforms.Compose([\n transforms.Lambda(lambda y: \n torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n ])\n )\ntest_loader = DataLoader(digits, batch_size=64, num_workers=0, pin_memory=True)"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"model.eval()\ntest_loss = 0\ncorrect = 0\nwith torch.no_grad():\n for batch, (X, Y) in enumerate(test_loader):\n X, Y = X.to(device), Y.to(device)\n pred = model(X)\n\n test_loss += cost(pred, Y).item()\n correct += (pred.argmax(1) == Y.argmax(1)).type(torch.float).sum().item()\n\ntest_loss /= len(dataloader.dataset)\ncorrect /= len(dataloader.dataset)\nprint('Test Error:')\nprint('acc: {:>0.1f}%, avg loss: {:>8f}'.format(100*correct, test_loss))"},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"len(test_data)"},{"cell_type":"markdown","metadata":{},"source":["# Saving Things!"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":"import torch.onnx as onnx\n\n# create dummy variable to traverse graph\nx = torch.randint(255, (1, 28*28), dtype=torch.float).to(device) / 255\nonnx.export(model, x, 'superfile.onnx')\nprint('Saved onnx model to \"superfile.onnx\"')"}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"name":"python","codemirror_mode":{"name":"ipython","version":3}},"orig_nbformat":2,"file_extension":".py","mimetype":"text/x-python","name":"python","npconvert_exporter":"python","pygments_lexer":"ipython3","version":3}}
0 commit comments