Skip to content

Commit a6658bb

Browse files
committed
rename the saved mnist model to .pt
1 parent e7fd534 commit a6658bb

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed
File renamed without changes.

load/state_dict.ipynb

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
" 'epoch': args.epochs, # == 10\n",
1818
" 'model_state_dict': model.state_dict(),\n",
1919
" 'optimizer_state_dict': optimizer.state_dict()\n",
20-
"}, './mnist-model.tar')\n",
21-
"```\n",
22-
"\n",
23-
"(`.tar` extension for saved models is just a PyTorch convention)"
20+
"}, './mnist-model.pt')\n",
21+
"```"
2422
]
2523
},
2624
{
@@ -34,26 +32,34 @@
3432
},
3533
{
3634
"cell_type": "code",
37-
"execution_count": 4,
35+
"execution_count": 1,
3836
"metadata": {},
3937
"outputs": [
4038
{
41-
"data": {
42-
"text/plain": [
43-
"dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])"
44-
]
45-
},
46-
"execution_count": 4,
47-
"metadata": {},
48-
"output_type": "execute_result"
39+
"name": "stdout",
40+
"output_type": "stream",
41+
"text": [
42+
"<class 'dict'>\n",
43+
"dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])\n",
44+
"epoch = 10\n"
45+
]
4946
}
5047
],
5148
"source": [
5249
"import torch\n",
5350
"\n",
54-
"saved_params = torch.load('./mnist-model.tar')\n",
51+
"model_state = torch.load('./mnist-model.pt')\n",
5552
"\n",
56-
"saved_params.keys()"
53+
"print(type(model_state))\n",
54+
"print(model_state.keys())\n",
55+
"print('epoch =', model_state['epoch'])"
56+
]
57+
},
58+
{
59+
"cell_type": "markdown",
60+
"metadata": {},
61+
"source": [
62+
"So the `torch.load()` function just reads back the dictionary that was passed to `torch.save()`, and for basic Python types it is not different from Python standard [pickle](https://docs.python.org/3.5/library/pickle.html) module (in fact, it *is* a pickle)."
5763
]
5864
},
5965
{

0 commit comments

Comments
 (0)