1
- {"cells":[{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":"import torch\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import datasets, transforms"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Dataset"},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":"class SquareDataset(Dataset):\n def __init__(self, size):\n self.size = size\n self.X = torch.randint(255, (size, 9), dtype=torch.float)\n\n real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n [0,0,0,1,1,1,0,0,0],\n [0,0,0,0,0,0,1,1,1]], \n dtype=torch.float)\n\n y = torch.argmax(self.X.mm(real_w.t()), 1)\n \n self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n .scatter_(1, y.view(-1, 1), 1)\n\n def __getitem__(self, index):\n return (self.X[index], self.Y[index])\n\n def __len__(self):\n return self.size"},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":"squares = SquareDataset(256)\nprint(squares[34])\nprint(squares[254])\nprint(squares[25])"},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(squares, batch_size=5)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Digits\nTransforms!"},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":"digits = datasets.MNIST('data', train=True, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Lambda(lambda x: x.view(28*28))\n ]),\n target_transform=transforms.Compose([\n transforms.Lambda(lambda y: \n torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n ])\n )"},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"name":"python","codemirror_mode":{"name":"ipython","version":3}},"orig_nbformat":2,"file_extension":".py","mimetype":"text/x-python","name":"python","npconvert_exporter":"python","pygments_lexer":"ipython3","version":3}}
1
+ {
2
+ "cells" : [
3
+ {
4
+ "cell_type" : " code" ,
5
+ "execution_count" : 1 ,
6
+ "metadata" : {},
7
+ "outputs" : [],
8
+ "source" : [
9
+ " import torch\n " ,
10
+ " from torch.utils.data import DataLoader\n " ,
11
+ " from torch.utils.data.dataset import Dataset\n " ,
12
+ " from torchvision import datasets, transforms"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type" : " markdown" ,
17
+ "metadata" : {},
18
+ "source" : [
19
+ " # Dataset"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type" : " code" ,
24
+ "execution_count" : 2 ,
25
+ "metadata" : {},
26
+ "outputs" : [],
27
+ "source" : [
28
+ " class SquareDataset(Dataset):\n " ,
29
+ " def __init__(self, size):\n " ,
30
+ " self.size = size\n " ,
31
+ " self.X = torch.randint(255, (size, 9), dtype=torch.float)\n " ,
32
+ " \n " ,
33
+ " real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n " ,
34
+ " [0,0,0,1,1,1,0,0,0],\n " ,
35
+ " [0,0,0,0,0,0,1,1,1]], \n " ,
36
+ " dtype=torch.float)\n " ,
37
+ " \n " ,
38
+ " y = torch.argmax(self.X.mm(real_w.t()), 1)\n " ,
39
+ " \n " ,
40
+ " self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n " ,
41
+ " .scatter_(1, y.view(-1, 1), 1)\n " ,
42
+ " \n " ,
43
+ " def __getitem__(self, index):\n " ,
44
+ " return (self.X[index], self.Y[index])\n " ,
45
+ " \n " ,
46
+ " def __len__(self):\n " ,
47
+ " return self.size"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type" : " code" ,
52
+ "execution_count" : 3 ,
53
+ "metadata" : {},
54
+ "outputs" : [
55
+ {
56
+ "name" : " stdout" ,
57
+ "output_type" : " stream" ,
58
+ "text" : [
59
+ " (tensor([ 54., 182., 47., 142., 200., 197., 220., 215., 33.]), tensor([0., 1., 0.]))\n " ,
60
+ " (tensor([198., 171., 26., 140., 28., 9., 205., 48., 113.]), tensor([1., 0., 0.]))\n " ,
61
+ " (tensor([ 64., 7., 167., 4., 9., 160., 169., 113., 214.]), tensor([0., 0., 1.]))\n "
62
+ ]
63
+ }
64
+ ],
65
+ "source" : [
66
+ " squares = SquareDataset(256)\n " ,
67
+ " print(squares[34])\n " ,
68
+ " print(squares[254])\n " ,
69
+ " print(squares[25])"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type" : " code" ,
74
+ "execution_count" : 4 ,
75
+ "metadata" : {},
76
+ "outputs" : [
77
+ {
78
+ "name" : " stdout" ,
79
+ "output_type" : " stream" ,
80
+ "text" : [
81
+ " tensor([[152., 127., 155., 219., 81., 140., 112., 77., 102.],\n " ,
82
+ " [ 77., 58., 228., 164., 229., 155., 111., 223., 141.],\n " ,
83
+ " [106., 250., 87., 62., 105., 254., 0., 210., 136.],\n " ,
84
+ " [190., 108., 134., 204., 145., 251., 146., 171., 99.],\n " ,
85
+ " [ 88., 36., 190., 108., 122., 4., 231., 22., 70.]]) \n " ,
86
+ " \n " ,
87
+ " tensor([[0., 1., 0.],\n " ,
88
+ " [0., 1., 0.],\n " ,
89
+ " [1., 0., 0.],\n " ,
90
+ " [0., 1., 0.],\n " ,
91
+ " [0., 0., 1.]])\n "
92
+ ]
93
+ }
94
+ ],
95
+ "source" : [
96
+ " dataloader = DataLoader(squares, batch_size=5)\n " ,
97
+ " \n " ,
98
+ " for batch, (X, Y) in enumerate(dataloader):\n " ,
99
+ " print(X, '\\ n\\ n', Y)\n " ,
100
+ " break"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type" : " markdown" ,
105
+ "metadata" : {},
106
+ "source" : [
107
+ " # Digits\n " ,
108
+ " Transforms!"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type" : " code" ,
113
+ "execution_count" : 5 ,
114
+ "metadata" : {},
115
+ "outputs" : [],
116
+ "source" : [
117
+ " digits = datasets.MNIST('data', train=True, download=True,\n " ,
118
+ " transform=transforms.Compose([\n " ,
119
+ " transforms.ToTensor(),\n " ,
120
+ " transforms.Lambda(lambda x: x.view(28*28))\n " ,
121
+ " ]),\n " ,
122
+ " target_transform=transforms.Compose([\n " ,
123
+ " transforms.Lambda(lambda y: \n " ,
124
+ " torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n " ,
125
+ " ])\n " ,
126
+ " )"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type" : " code" ,
131
+ "execution_count" : 6 ,
132
+ "metadata" : {},
133
+ "outputs" : [
134
+ {
135
+ "name" : " stdout" ,
136
+ "output_type" : " stream" ,
137
+ "text" : [
138
+ " tensor([[0., 0., 0., ..., 0., 0., 0.],\n " ,
139
+ " [0., 0., 0., ..., 0., 0., 0.],\n " ,
140
+ " [0., 0., 0., ..., 0., 0., 0.],\n " ,
141
+ " ...,\n " ,
142
+ " [0., 0., 0., ..., 0., 0., 0.],\n " ,
143
+ " [0., 0., 0., ..., 0., 0., 0.],\n " ,
144
+ " [0., 0., 0., ..., 0., 0., 0.]]) \n " ,
145
+ " \n " ,
146
+ " tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n " ,
147
+ " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n " ,
148
+ " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n " ,
149
+ " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n " ,
150
+ " [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n " ,
151
+ " [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n " ,
152
+ " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n " ,
153
+ " [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n " ,
154
+ " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n " ,
155
+ " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])\n "
156
+ ]
157
+ }
158
+ ],
159
+ "source" : [
160
+ " dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n " ,
161
+ " \n " ,
162
+ " for batch, (X, Y) in enumerate(dataloader):\n " ,
163
+ " print(X, '\\ n\\ n', Y)\n " ,
164
+ " break"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type" : " code" ,
169
+ "execution_count" : null ,
170
+ "metadata" : {},
171
+ "outputs" : [],
172
+ "source" : []
173
+ }
174
+ ],
175
+ "metadata" : {
176
+ "file_extension" : " .py" ,
177
+ "kernelspec" : {
178
+ "display_name" : " Python 3" ,
179
+ "language" : " python" ,
180
+ "name" : " python3"
181
+ },
182
+ "language_info" : {
183
+ "codemirror_mode" : {
184
+ "name" : " ipython" ,
185
+ "version" : 3
186
+ },
187
+ "file_extension" : " .py" ,
188
+ "mimetype" : " text/x-python" ,
189
+ "name" : " python" ,
190
+ "nbconvert_exporter" : " python" ,
191
+ "pygments_lexer" : " ipython3" ,
192
+ "version" : " 3.6.10"
193
+ },
194
+ "mimetype" : " text/x-python" ,
195
+ "name" : " python" ,
196
+ "npconvert_exporter" : " python" ,
197
+ "pygments_lexer" : " ipython3" ,
198
+ "version" : 3
199
+ },
200
+ "nbformat" : 4 ,
201
+ "nbformat_minor" : 2
202
+ }
0 commit comments