|
302 | 302 | "name": "stdout",
|
303 | 303 | "output_type": "stream",
|
304 | 304 | "text": [
|
305 |
| - "Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.405838\n", |
306 |
| - "Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.206041\n", |
307 |
| - "Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.144166\n" |
| 305 | + "Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.271775\n", |
| 306 | + "warning: Embedding dir exists, did you set global_step for add_embedding()?\n", |
| 307 | + "Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.175213\n", |
| 308 | + "warning: Embedding dir exists, did you set global_step for add_embedding()?\n", |
| 309 | + "Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.115128\n", |
| 310 | + "warning: Embedding dir exists, did you set global_step for add_embedding()?\n" |
308 | 311 | ]
|
309 | 312 | }
|
310 | 313 | ],
|
|
382 | 385 | }
|
383 | 386 | ],
|
384 | 387 | "source": [
|
385 |
| - "vgg16 = models.vgg16() # 这里下载预训练好的模型\n", |
| 388 | + "vgg16 = models.vgg16(pretrained=True) # 这里下载预训练好的模型\n", |
386 | 389 | "print(vgg16) # 打印一下这个模型"
|
387 | 390 | ]
|
388 | 391 | },
|
|
401 | 404 | "source": [
|
402 | 405 | "transform_2 = transforms.Compose([\n",
|
403 | 406 | " transforms.Resize(224), \n",
|
404 |
| - " transforms.CenterCrop(224),\n", |
| 407 | + " transforms.CenterCrop((224,224)),\n", |
405 | 408 | " transforms.ToTensor(),\n",
|
406 |
| - " # convert RGB to BGR\n", |
407 |
| - " # from <https://github.com/mrzhu-cool/pix2pix-pytorch/blob/master/util.py>\n", |
408 |
| - " transforms.Lambda(lambda x: torch.index_select(x, 0, torch.LongTensor([2, 1, 0]))),\n", |
409 |
| - " transforms.Lambda(lambda x: x*255),\n", |
410 |
| - " transforms.Normalize(mean = [103.939, 116.779, 123.68],\n", |
411 |
| - " std = [ 1, 1, 1 ]),\n", |
| 409 | + " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", |
| 410 | + " std=[0.229, 0.224, 0.225])\n", |
412 | 411 | "])"
|
413 | 412 | ]
|
414 | 413 | },
|
|
453 | 452 | "metadata": {},
|
454 | 453 | "outputs": [
|
455 | 454 | {
|
456 |
| - "name": "stdout", |
457 |
| - "output_type": "stream", |
458 |
| - "text": [ |
459 |
| - "(1, 1000) 931\n" |
460 |
| - ] |
| 455 | + "data": { |
| 456 | + "text/plain": [ |
| 457 | + "287" |
| 458 | + ] |
| 459 | + }, |
| 460 | + "execution_count": 14, |
| 461 | + "metadata": {}, |
| 462 | + "output_type": "execute_result" |
461 | 463 | }
|
462 | 464 | ],
|
463 | 465 | "source": [
|
464 |
| - "raw_score = vgg16(vgg16_input)\n", |
465 |
| - "raw_score_numpy = raw_score.data.numpy()\n", |
466 |
| - "print(raw_score_numpy.shape, np.argmax(raw_score_numpy.ravel()))" |
| 466 | + "out = vgg16(vgg16_input)\n", |
| 467 | + "_, preds = torch.max(out.data, 1)\n", |
| 468 | + "label=preds.numpy()[0]\n", |
| 469 | + "label" |
467 | 470 | ]
|
468 | 471 | },
|
469 | 472 | {
|
|
487 | 490 | "cell_type": "markdown",
|
488 | 491 | "metadata": {},
|
489 | 492 | "source": [
|
490 |
| - "打开tensorboard找到graphs 看看效果吧" |
| 493 | + "打开tensorboard找到graphs就可以看到vgg模型具体的架构了" |
491 | 494 | ]
|
492 | 495 | },
|
493 | 496 | {
|
|
0 commit comments