Skip to content

Commit 81b8f58

Browse files
committed
updated computational graph visualization
1 parent 2f338aa commit 81b8f58

11 files changed

+2509
-152
lines changed

data.ipynb

+202-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,202 @@
1-
{"cells":[{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":"import torch\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.dataset import Dataset\nfrom torchvision import datasets, transforms"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Dataset"},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":"class SquareDataset(Dataset):\n def __init__(self, size):\n self.size = size\n self.X = torch.randint(255, (size, 9), dtype=torch.float)\n\n real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n [0,0,0,1,1,1,0,0,0],\n [0,0,0,0,0,0,1,1,1]], \n dtype=torch.float)\n\n y = torch.argmax(self.X.mm(real_w.t()), 1)\n \n self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n .scatter_(1, y.view(-1, 1), 1)\n\n def __getitem__(self, index):\n return (self.X[index], self.Y[index])\n\n def __len__(self):\n return self.size"},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":"squares = SquareDataset(256)\nprint(squares[34])\nprint(squares[254])\nprint(squares[25])"},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(squares, batch_size=5)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"},{"cell_type":"markdown","metadata":{},"outputs":[],"source":"# Digits\nTransforms!"},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":"digits = datasets.MNIST('data', train=True, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Lambda(lambda x: x.view(28*28))\n ]),\n target_transform=transforms.Compose([\n transforms.Lambda(lambda y: \n torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n ])\n )"},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":"dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n\nfor batch, (X, Y) in enumerate(dataloader):\n print(X, '\\n\\n', Y)\n break"}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"name":"python","codemirror_mode":{"name":"ipython","version":3}},"orig_nbformat":2,"file_extension":".py","mimetype":"text/x-python","name":"python","npconvert_exporter":"python","pygments_lexer":"ipython3","version":3}}
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"from torch.utils.data import DataLoader\n",
11+
"from torch.utils.data.dataset import Dataset\n",
12+
"from torchvision import datasets, transforms"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"# Dataset"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 2,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"class SquareDataset(Dataset):\n",
29+
" def __init__(self, size):\n",
30+
" self.size = size\n",
31+
" self.X = torch.randint(255, (size, 9), dtype=torch.float)\n",
32+
"\n",
33+
" real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],\n",
34+
" [0,0,0,1,1,1,0,0,0],\n",
35+
" [0,0,0,0,0,0,1,1,1]], \n",
36+
" dtype=torch.float)\n",
37+
"\n",
38+
" y = torch.argmax(self.X.mm(real_w.t()), 1)\n",
39+
" \n",
40+
" self.Y = torch.zeros(size, 3, dtype=torch.float) \\\n",
41+
" .scatter_(1, y.view(-1, 1), 1)\n",
42+
"\n",
43+
" def __getitem__(self, index):\n",
44+
" return (self.X[index], self.Y[index])\n",
45+
"\n",
46+
" def __len__(self):\n",
47+
" return self.size"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": 3,
53+
"metadata": {},
54+
"outputs": [
55+
{
56+
"name": "stdout",
57+
"output_type": "stream",
58+
"text": [
59+
"(tensor([ 54., 182., 47., 142., 200., 197., 220., 215., 33.]), tensor([0., 1., 0.]))\n",
60+
"(tensor([198., 171., 26., 140., 28., 9., 205., 48., 113.]), tensor([1., 0., 0.]))\n",
61+
"(tensor([ 64., 7., 167., 4., 9., 160., 169., 113., 214.]), tensor([0., 0., 1.]))\n"
62+
]
63+
}
64+
],
65+
"source": [
66+
"squares = SquareDataset(256)\n",
67+
"print(squares[34])\n",
68+
"print(squares[254])\n",
69+
"print(squares[25])"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": 4,
75+
"metadata": {},
76+
"outputs": [
77+
{
78+
"name": "stdout",
79+
"output_type": "stream",
80+
"text": [
81+
"tensor([[152., 127., 155., 219., 81., 140., 112., 77., 102.],\n",
82+
" [ 77., 58., 228., 164., 229., 155., 111., 223., 141.],\n",
83+
" [106., 250., 87., 62., 105., 254., 0., 210., 136.],\n",
84+
" [190., 108., 134., 204., 145., 251., 146., 171., 99.],\n",
85+
" [ 88., 36., 190., 108., 122., 4., 231., 22., 70.]]) \n",
86+
"\n",
87+
" tensor([[0., 1., 0.],\n",
88+
" [0., 1., 0.],\n",
89+
" [1., 0., 0.],\n",
90+
" [0., 1., 0.],\n",
91+
" [0., 0., 1.]])\n"
92+
]
93+
}
94+
],
95+
"source": [
96+
"dataloader = DataLoader(squares, batch_size=5)\n",
97+
"\n",
98+
"for batch, (X, Y) in enumerate(dataloader):\n",
99+
" print(X, '\\n\\n', Y)\n",
100+
" break"
101+
]
102+
},
103+
{
104+
"cell_type": "markdown",
105+
"metadata": {},
106+
"source": [
107+
"# Digits\n",
108+
"Transforms!"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": 5,
114+
"metadata": {},
115+
"outputs": [],
116+
"source": [
117+
"digits = datasets.MNIST('data', train=True, download=True,\n",
118+
" transform=transforms.Compose([\n",
119+
" transforms.ToTensor(),\n",
120+
" transforms.Lambda(lambda x: x.view(28*28))\n",
121+
" ]),\n",
122+
" target_transform=transforms.Compose([\n",
123+
" transforms.Lambda(lambda y: \n",
124+
" torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))\n",
125+
" ])\n",
126+
" )"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": 6,
132+
"metadata": {},
133+
"outputs": [
134+
{
135+
"name": "stdout",
136+
"output_type": "stream",
137+
"text": [
138+
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
139+
" [0., 0., 0., ..., 0., 0., 0.],\n",
140+
" [0., 0., 0., ..., 0., 0., 0.],\n",
141+
" ...,\n",
142+
" [0., 0., 0., ..., 0., 0., 0.],\n",
143+
" [0., 0., 0., ..., 0., 0., 0.],\n",
144+
" [0., 0., 0., ..., 0., 0., 0.]]) \n",
145+
"\n",
146+
" tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
147+
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
148+
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
149+
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
150+
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n",
151+
" [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
152+
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
153+
" [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
154+
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
155+
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])\n"
156+
]
157+
}
158+
],
159+
"source": [
160+
"dataloader = DataLoader(digits, batch_size=10, shuffle=True)\n",
161+
"\n",
162+
"for batch, (X, Y) in enumerate(dataloader):\n",
163+
" print(X, '\\n\\n', Y)\n",
164+
" break"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"metadata": {},
171+
"outputs": [],
172+
"source": []
173+
}
174+
],
175+
"metadata": {
176+
"file_extension": ".py",
177+
"kernelspec": {
178+
"display_name": "Python 3",
179+
"language": "python",
180+
"name": "python3"
181+
},
182+
"language_info": {
183+
"codemirror_mode": {
184+
"name": "ipython",
185+
"version": 3
186+
},
187+
"file_extension": ".py",
188+
"mimetype": "text/x-python",
189+
"name": "python",
190+
"nbconvert_exporter": "python",
191+
"pygments_lexer": "ipython3",
192+
"version": "3.6.10"
193+
},
194+
"mimetype": "text/x-python",
195+
"name": "python",
196+
"npconvert_exporter": "python",
197+
"pygments_lexer": "ipython3",
198+
"version": 3
199+
},
200+
"nbformat": 4,
201+
"nbformat_minor": 2
202+
}

digits.ipynb

+430-1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)