Skip to content

Commit 8b84556

Browse files
committed
修正:一个严重的错误
1 parent 6fff3a7 commit 8b84556

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

chapter4/4.2.2-tensorboardx.ipynb

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,12 @@
302302
"name": "stdout",
303303
"output_type": "stream",
304304
"text": [
305-
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.405838\n",
306-
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.206041\n",
307-
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.144166\n"
305+
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.271775\n",
306+
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
307+
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.175213\n",
308+
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
309+
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.115128\n",
310+
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n"
308311
]
309312
}
310313
],
@@ -382,7 +385,7 @@
382385
}
383386
],
384387
"source": [
385-
"vgg16 = models.vgg16() # 这里下载预训练好的模型\n",
388+
"vgg16 = models.vgg16(pretrained=True) # 这里下载预训练好的模型\n",
386389
"print(vgg16) # 打印一下这个模型"
387390
]
388391
},
@@ -401,14 +404,10 @@
401404
"source": [
402405
"transform_2 = transforms.Compose([\n",
403406
" transforms.Resize(224), \n",
404-
" transforms.CenterCrop(224),\n",
407+
" transforms.CenterCrop((224,224)),\n",
405408
" transforms.ToTensor(),\n",
406-
" # convert RGB to BGR\n",
407-
" # from <https://github.com/mrzhu-cool/pix2pix-pytorch/blob/master/util.py>\n",
408-
" transforms.Lambda(lambda x: torch.index_select(x, 0, torch.LongTensor([2, 1, 0]))),\n",
409-
" transforms.Lambda(lambda x: x*255),\n",
410-
" transforms.Normalize(mean = [103.939, 116.779, 123.68],\n",
411-
" std = [ 1, 1, 1 ]),\n",
409+
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
410+
" std=[0.229, 0.224, 0.225])\n",
412411
"])"
413412
]
414413
},
@@ -453,17 +452,21 @@
453452
"metadata": {},
454453
"outputs": [
455454
{
456-
"name": "stdout",
457-
"output_type": "stream",
458-
"text": [
459-
"(1, 1000) 931\n"
460-
]
455+
"data": {
456+
"text/plain": [
457+
"287"
458+
]
459+
},
460+
"execution_count": 14,
461+
"metadata": {},
462+
"output_type": "execute_result"
461463
}
462464
],
463465
"source": [
464-
"raw_score = vgg16(vgg16_input)\n",
465-
"raw_score_numpy = raw_score.data.numpy()\n",
466-
"print(raw_score_numpy.shape, np.argmax(raw_score_numpy.ravel()))"
466+
"out = vgg16(vgg16_input)\n",
467+
"_, preds = torch.max(out.data, 1)\n",
468+
"label=preds.numpy()[0]\n",
469+
"label"
467470
]
468471
},
469472
{
@@ -487,7 +490,7 @@
487490
"cell_type": "markdown",
488491
"metadata": {},
489492
"source": [
490-
"打开tensorboard找到graphs 看看效果吧"
493+
"打开tensorboard找到graphs就可以看到vgg模型具体的架构了"
491494
]
492495
},
493496
{

0 commit comments

Comments
 (0)