Skip to content

Commit fabfeee

Browse files
committed
Wrote the state_dict tutorial up to the load_state_dict() part. Did not go into details of tensor serialization in C++ etc.
1 parent a6658bb commit fabfeee

File tree

1 file changed

+136
-2
lines changed

1 file changed

+136
-2
lines changed

load/state_dict.ipynb

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"source": [
2828
"## Loading PyTorch model\n",
2929
"\n",
30-
"Here we'll assume that we already have the file with the saved MNIST model with all default hyperparameters and trained for 10 epochs. Let's load it."
30+
"Here we'll assume that we already have the file with the saved MNIST model with all default hyperparameters and trained for 10 epochs. Loading it is simple:"
3131
]
3232
},
3333
{
@@ -59,7 +59,141 @@
5959
"cell_type": "markdown",
6060
"metadata": {},
6161
"source": [
62-
"So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle)."
62+
"So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [`pickle`](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look."
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 2,
68+
"metadata": {},
69+
"outputs": [
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"model_params: <class 'collections.OrderedDict'> \n",
75+
"\n",
76+
"conv1.weight: <class 'torch.Tensor'> torch.Size([10, 1, 5, 5])\n",
77+
" conv1.bias: <class 'torch.Tensor'> torch.Size([10])\n",
78+
"conv2.weight: <class 'torch.Tensor'> torch.Size([20, 10, 5, 5])\n",
79+
" conv2.bias: <class 'torch.Tensor'> torch.Size([20])\n",
80+
" fc1.weight: <class 'torch.Tensor'> torch.Size([50, 320])\n",
81+
" fc1.bias: <class 'torch.Tensor'> torch.Size([50])\n",
82+
" fc2.weight: <class 'torch.Tensor'> torch.Size([10, 50])\n",
83+
" fc2.bias: <class 'torch.Tensor'> torch.Size([10])\n",
84+
"\n",
85+
" conv1.bias: tensor([ 0.0272, -0.0762, -0.0617, 0.0235, 0.1745, 0.0320, 0.0871, 0.0674,\n",
86+
" -0.0222, -0.0541])\n"
87+
]
88+
}
89+
],
90+
"source": [
91+
"model_params = model_state['model_state_dict']\n",
92+
"print(\"model_params:\", type(model_params), \"\\n\")\n",
93+
"\n",
94+
"for (key, val) in model_params.items():\n",
95+
" print(\"%12s: %s %s\" % (key, type(val), val.size()))\n",
96+
" \n",
97+
"print(\"\\n%12s: %s\" % (\"conv1.bias\", model_params['conv1.bias']))"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"metadata": {},
103+
"source": [
104+
"That is, `.state_dict()` produces an `OrderedDict` of tensors, and uses for keys names of the variables and their parameters.\n",
105+
"\n",
106+
"Now we need to populate the actual model's parameters (on CUDA or CPU) with that data. For that, we have to use the method [`torch.nn.Module.load_state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.load_state_dict). Unfortunately, it won't recreate the model's topology for us. We have to use the code from [MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py) to build it explicitly:"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 3,
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"import torch.nn as nn\n",
116+
"import torch.nn.functional as F\n",
117+
"\n",
118+
"class Net(nn.Module):\n",
119+
" def __init__(self):\n",
120+
" super(Net, self).__init__()\n",
121+
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
122+
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
123+
" self.conv2_drop = nn.Dropout2d()\n",
124+
" self.fc1 = nn.Linear(320, 50)\n",
125+
" self.fc2 = nn.Linear(50, 10)\n",
126+
"\n",
127+
" def forward(self, x):\n",
128+
" x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
129+
" x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
130+
" x = x.view(-1, 320)\n",
131+
" x = F.relu(self.fc1(x))\n",
132+
" x = F.dropout(x, training=self.training)\n",
133+
" x = self.fc2(x)\n",
134+
" return F.log_softmax(x, dim=1)\n",
135+
"\n",
136+
"model = Net()"
137+
]
138+
},
139+
{
140+
"cell_type": "markdown",
141+
"metadata": {},
142+
"source": [
143+
"Remember, that initially all parameters are initialized with random values, e.g."
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": 4,
149+
"metadata": {},
150+
"outputs": [
151+
{
152+
"data": {
153+
"text/plain": [
154+
"Parameter containing:\n",
155+
"tensor([ 0.1578, 0.1650, -0.1272, -0.1976, 0.0318, -0.1246, -0.0474, -0.0620,\n",
156+
" 0.1829, -0.1198], requires_grad=True)"
157+
]
158+
},
159+
"execution_count": 4,
160+
"metadata": {},
161+
"output_type": "execute_result"
162+
}
163+
],
164+
"source": [
165+
"model.conv1.bias"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {},
171+
"source": [
172+
"Now we can populate them with data from the file:"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 5,
178+
"metadata": {},
179+
"outputs": [
180+
{
181+
"data": {
182+
"text/plain": [
183+
"Parameter containing:\n",
184+
"tensor([ 0.0272, -0.0762, -0.0617, 0.0235, 0.1745, 0.0320, 0.0871, 0.0674,\n",
185+
" -0.0222, -0.0541], requires_grad=True)"
186+
]
187+
},
188+
"execution_count": 5,
189+
"metadata": {},
190+
"output_type": "execute_result"
191+
}
192+
],
193+
"source": [
194+
"model.load_state_dict(model_params)\n",
195+
"\n",
196+
"model.conv1.bias"
63197
]
64198
},
65199
{

0 commit comments

Comments
 (0)