Skip to content

Commit 69ac603

Browse files
committed
state_dict checkpoint before tensor storage serialization part
1 parent 6869565 commit 69ac603

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

load/state_dict.ipynb

+22-8
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
"source": [
77
"# PyTorch model (de)serialization\n",
88
"\n",
9+
"At the top level, serialization in PyTorch has two methods, `torch.save()` and `torch.load()`, implemented in [torch/serialization.py](https://github.com/pytorch/pytorch/blob/master/torch/serialization.py).\n",
10+
"\n",
911
"## Saving the model\n",
1012
"\n",
11-
"In this example we will explore the serialization and deserialization of PyTorch model. We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from previous examples, augmented with `torch.save()` call at the end.\n",
13+
"Below we will explore the serialization and deserialization of PyTorch model.\n",
14+
"We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from PyTorch examples, augmented with `torch.save()` call at the end.\n",
1215
"\n",
1316
"We save the trained model like this:\n",
1417
"\n",
1518
"```python\n",
1619
"torch.save({\n",
17-
" 'epoch': args.epochs, # == 10\n",
20+
" 'epoch': args.epochs, # == 10\n",
1821
" 'model_state_dict': model.state_dict(),\n",
1922
" 'optimizer_state_dict': optimizer.state_dict()\n",
2023
"}, './mnist-model.pt')\n",
@@ -59,7 +62,7 @@
5962
"cell_type": "markdown",
6063
"metadata": {},
6164
"source": [
62-
"So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [`pickle`](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look."
65+
"So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look."
6366
]
6467
},
6568
{
@@ -140,7 +143,7 @@
140143
"cell_type": "markdown",
141144
"metadata": {},
142145
"source": [
143-
"Remember, that after the model instantiation its parameters are initialized with random values, e.g."
146+
"Remember that after the model instantiation its parameters are initialized with random values, e.g."
144147
]
145148
},
146149
{
@@ -208,11 +211,22 @@
208211
]
209212
},
210213
{
211-
"cell_type": "code",
212-
"execution_count": null,
214+
"cell_type": "markdown",
213215
"metadata": {},
214-
"outputs": [],
215-
"source": []
216+
"source": [
217+
"## Serialization across devices\n",
218+
"\n",
219+
"PyTorch documentation has a [good tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices) on that."
220+
]
221+
},
222+
{
223+
"cell_type": "markdown",
224+
"metadata": {},
225+
"source": [
226+
"## Tensor serialization\n",
227+
"\n",
228+
"The model and optimizer serialization in PyTorch is built on the standard Python [pickle](https://docs.python.org/3.5/library/pickle.html) functionality - except for the tensor storage itself. That part is implemented in "
229+
]
216230
}
217231
],
218232
"metadata": {

0 commit comments

Comments
 (0)