|
6 | 6 | "source": [
|
7 | 7 | "# PyTorch model (de)serialization\n",
|
8 | 8 | "\n",
|
| 9 | + "At the top level, serialization in PyTorch has two methods, `torch.save()` and `torch.load()`, implemented in [torch/serialization.py](https://github.com/pytorch/pytorch/blob/master/torch/serialization.py).\n", |
| 10 | + "\n", |
9 | 11 | "## Saving the model\n",
|
10 | 12 | "\n",
|
11 |
| - "In this example we will explore the serialization and deserialization of PyTorch model. We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from previous examples, augmented with `torch.save()` call at the end.\n", |
| 13 | + "Below we will explore the serialization and deserialization of PyTorch model.\n", |
| 14 | + "We'll use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from PyTorch examples, augmented with `torch.save()` call at the end.\n", |
12 | 15 | "\n",
|
13 | 16 | "We save the trained model like this:\n",
|
14 | 17 | "\n",
|
15 | 18 | "```python\n",
|
16 | 19 | "torch.save({\n",
|
17 |
| - " 'epoch': args.epochs, # == 10\n", |
| 20 | + " 'epoch': args.epochs, # == 10\n", |
18 | 21 | " 'model_state_dict': model.state_dict(),\n",
|
19 | 22 | " 'optimizer_state_dict': optimizer.state_dict()\n",
|
20 | 23 | "}, './mnist-model.pt')\n",
|
|
59 | 62 | "cell_type": "markdown",
|
60 | 63 | "metadata": {},
|
61 | 64 | "source": [
|
62 |
| - "So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [`pickle`](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look." |
| 65 | + "So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle). The most interesting part here are the model's and optimizer's parameters, as returned from [`torch.nn.Module.state_dict()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.state_dict) method. Let's take a closer look." |
63 | 66 | ]
|
64 | 67 | },
|
65 | 68 | {
|
|
140 | 143 | "cell_type": "markdown",
|
141 | 144 | "metadata": {},
|
142 | 145 | "source": [
|
143 |
| - "Remember, that after the model instantiation its parameters are initialized with random values, e.g." |
| 146 | + "Remember that after the model instantiation its parameters are initialized with random values, e.g." |
144 | 147 | ]
|
145 | 148 | },
|
146 | 149 | {
|
|
208 | 211 | ]
|
209 | 212 | },
|
210 | 213 | {
|
211 |
| - "cell_type": "code", |
212 |
| - "execution_count": null, |
| 214 | + "cell_type": "markdown", |
213 | 215 | "metadata": {},
|
214 |
| - "outputs": [], |
215 |
| - "source": [] |
| 216 | + "source": [ |
| 217 | + "## Serialization across devices\n", |
| 218 | + "\n", |
| 219 | + "PyTorch documentation has a [good tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices) on that." |
| 220 | + ] |
| 221 | + }, |
| 222 | + { |
| 223 | + "cell_type": "markdown", |
| 224 | + "metadata": {}, |
| 225 | + "source": [ |
| 226 | + "## Tensor serialization\n", |
| 227 | + "\n", |
| 228 | + "The model and optimizer serialization in PyTorch is built on the standard Python [pickle](https://docs.python.org/3.5/library/pickle.html) functionality - except for the tensor storage itself. That part is implemented in " |
| 229 | + ] |
216 | 230 | }
|
217 | 231 | ],
|
218 | 232 | "metadata": {
|
|
0 commit comments