1
- {
2
- "cells" : [
3
- {
4
- "cell_type" : " code" ,
5
- "execution_count" : null ,
6
- "metadata" : {},
7
- "outputs" : [],
8
- "source" : [
9
- " import torch\n " ,
10
- " from torch.utils.data import DataLoader\n " ,
11
- " from torch.utils.data.dataset import Dataset\n " ,
12
- " from torchvision import datasets, transforms"
13
- ]
14
- },
15
- {
16
- "cell_type" : " markdown" ,
17
- "metadata" : {},
18
- "source" : [
19
- " # Squares"
20
- ]
21
- },
22
- {
23
- "cell_type" : " code" ,
24
- "execution_count" : null ,
25
- "metadata" : {},
26
- "outputs" : [],
27
- "source" : [
28
- " class SquareDataset(Dataset):\n " ,
29
- " def __init__(self, size):\n " ,
30
- " self.size = size\n " ,
31
- " self.X = torch.randint(255, (size, 9), dtype=torch.float)\n " ,
32
- " \n " ,
33
- " real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n " ,
34
- " [0,0,0,1,1,1,0,0,0],\n " ,
35
- " [0,0,0,0,0,0,1,1,1]], \n " ,
36
- " dtype=torch.float)\n " ,
37
- " \n " ,
38
- " y = torch.argmax(self.X.mm(real_w.t()), 1)\n " ,
39
- " \n " ,
40
- " self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n " ,
41
- " .scatter_(1, y.view(-1, 1), 1)\n " ,
42
- " \n " ,
43
- " def __getitem__(self, index):\n " ,
44
- " return (self.X[index], self.Y[index])\n " ,
45
- " \n " ,
46
- " def __len__(self):\n " ,
47
- " return self.size"
48
- ]
49
- },
50
- {
51
- "cell_type" : " code" ,
52
- "execution_count" : null ,
53
- "metadata" : {},
54
- "outputs" : [],
55
- "source" : [
56
- " squares = SquareDataset(256)\n " ,
57
- " print(squares[34])\n " ,
58
- " print(squares[254])\n " ,
59
- " print(squares[25])"
60
- ]
61
- },
62
- {
63
- "cell_type" : " code" ,
64
- "execution_count" : null ,
65
- "metadata" : {
66
- "scrolled" : false
67
- },
68
- "outputs" : [],
69
- "source" : [
70
- " dataloader = DataLoader(squares, batch_size=5)\n " ,
71
- " \n " ,
72
- " for batch, (X, Y) in enumerate(dataloader):\n " ,
73
- " print(X, '\\ n\\ n', Y)\n " ,
74
- " break"
75
- ]
76
- },
77
- {
78
- "cell_type" : " markdown" ,
79
- "metadata" : {},
80
- "source" : [
81
- " # Digits"
82
- ]
83
- },
84
- {
85
- "cell_type" : " code" ,
86
- "execution_count" : null ,
87
- "metadata" : {},
88
- "outputs" : [],
89
- "source" : [
90
- " digits = datasets.MNIST('data', train=True, download=True,\n " ,
91
- " transform=transforms.Compose([\n " ,
92
- " transforms.ToTensor(),\n " ,
93
- " transforms.Lambda(lambda x: x.view(28*28))\n " ,
94
- " ]),\n " ,
95
- " target_transform=transforms.Compose([\n " ,
96
- " transforms.Lambda(lambda y: \n " ,
97
- " torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n " ,
98
- " ])\n " ,
99
- " )"
100
- ]
101
- },
102
- {
103
- "cell_type" : " code" ,
104
- "execution_count" : null ,
105
- "metadata" : {},
106
- "outputs" : [],
107
- "source" : [
108
- " dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n " ,
109
- " \n " ,
110
- " for batch, (X, Y) in enumerate(dataloader):\n " ,
111
- " print(X, '\\ n\\ n', Y)\n " ,
112
- " break"
113
- ]
114
- },
115
- {
116
- "cell_type" : " code" ,
117
- "execution_count" : null ,
118
- "metadata" : {},
119
- "outputs" : [],
120
- "source" : []
121
- }
122
- ],
123
- "metadata" : {
124
- "kernelspec" : {
125
- "display_name" : " Python 3" ,
126
- "language" : " python" ,
127
- "name" : " python3"
128
- },
129
- "language_info" : {
130
- "codemirror_mode" : {
131
- "name" : " ipython" ,
132
- "version" : 3
133
- },
134
- "file_extension" : " .py" ,
135
- "mimetype" : " text/x-python" ,
136
- "name" : " python" ,
137
- "nbconvert_exporter" : " python" ,
138
- "pygments_lexer" : " ipython3" ,
139
- "version" : " 3.7.3"
140
- },
141
- "varInspector" : {
142
- "cols" : {
143
- "lenName" : 16 ,
144
- "lenType" : 16 ,
145
- "lenVar" : 40
146
- },
147
- "kernels_config" : {
148
- "python" : {
149
- "delete_cmd_postfix" : " " ,
150
- "delete_cmd_prefix" : " del " ,
151
- "library" : " var_list.py" ,
152
- "varRefreshCmd" : " print(var_dic_list())"
153
- },
154
- "r" : {
155
- "delete_cmd_postfix" : " ) " ,
156
- "delete_cmd_prefix" : " rm(" ,
157
- "library" : " var_list.r" ,
158
- "varRefreshCmd" : " cat(var_dic_list()) "
159
- }
160
- },
161
- "types_to_exclude" : [
162
- " module" ,
163
- " function" ,
164
- " builtin_function_or_method" ,
165
- " instance" ,
166
- " _Feature"
167
- ],
168
- "window_display" : false
169
- }
170
- },
171
- "nbformat" : 4 ,
172
- "nbformat_minor" : 2
173
- }
1
+ {"cells":[{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":"import torch\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import datasets, transforms"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Dataset"},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":"class SquareDataset(Dataset):\n def __init__(self, size):\n self.size = size\n self.X = torch.randint(255, (size, 9), dtype=torch.float)\n\n real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n [0,0,0,1,1,1,0,0,0],\n [0,0,0,0,0,0,1,1,1]], \n dtype=torch.float)\n\n y = torch.argmax(self.X.mm(real_w.t()), 1)\n \n self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n .scatter_(1, y.view(-1, 1), 1)\n\n def __getitem__(self, index):\n return (self.X[index], self.Y[index])\n\n def __len__(self):\n return self.size"},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":"squares = SquareDataset(256)\nprint(squares[34])\nprint(squares[254])\nprint(squares[25])"},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(squares, batch_size=5)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Digits\nTransforms!"},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":"digits = datasets.MNIST('data', train=True, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Lambda(lambda x: x.view(28*28))\n ]),\n target_transform=transforms.Compose([\n transforms.Lambda(lambda y: \n torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n ])\n )"},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"name":"python","codemirror_mode":{"name":"ipython","version":3}},"orig_nbformat":2,"file_extension":".py","mimetype":"text/x-python","name":"python","npconvert_exporter":"python","pygments_lexer":"ipython3","version":3}}
0 commit comments