|
17 | 17 | " 'epoch': args.epochs, # == 10\n", |
18 | 18 | " 'model_state_dict': model.state_dict(),\n", |
19 | 19 | " 'optimizer_state_dict': optimizer.state_dict()\n", |
20 | | - "}, './mnist-model.tar')\n", |
21 | | - "```\n", |
22 | | - "\n", |
23 | | - "(`.tar` extension for saved models is just a PyTorch convention)" |
| 20 | + "}, './mnist-model.pt')\n", |
| 21 | + "```" |
24 | 22 | ] |
25 | 23 | }, |
26 | 24 | { |
|
34 | 32 | }, |
35 | 33 | { |
36 | 34 | "cell_type": "code", |
37 | | - "execution_count": 4, |
| 35 | + "execution_count": 1, |
38 | 36 | "metadata": {}, |
39 | 37 | "outputs": [ |
40 | 38 | { |
41 | | - "data": { |
42 | | - "text/plain": [ |
43 | | - "dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])" |
44 | | - ] |
45 | | - }, |
46 | | - "execution_count": 4, |
47 | | - "metadata": {}, |
48 | | - "output_type": "execute_result" |
| 39 | + "name": "stdout", |
| 40 | + "output_type": "stream", |
| 41 | + "text": [ |
| 42 | + "<class 'dict'>\n", |
| 43 | + "dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])\n", |
| 44 | + "epoch = 10\n" |
| 45 | + ] |
49 | 46 | } |
50 | 47 | ], |
51 | 48 | "source": [ |
52 | 49 | "import torch\n", |
53 | 50 | "\n", |
54 | | - "saved_params = torch.load('./mnist-model.tar')\n", |
| 51 | + "model_state = torch.load('./mnist-model.pt')\n", |
55 | 52 | "\n", |
56 | | - "saved_params.keys()" |
| 53 | + "print(type(model_state))\n", |
| 54 | + "print(model_state.keys())\n", |
| 55 | + "print('epoch =', model_state['epoch'])" |
| 56 | + ] |
| 57 | + }, |
| 58 | + { |
| 59 | + "cell_type": "markdown", |
| 60 | + "metadata": {}, |
| 61 | + "source": [ |
| 62 | + "So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle)." |
57 | 63 | ] |
58 | 64 | }, |
59 | 65 | { |
|
0 commit comments