Skip to content

Commit 37abd82

Browse files
authored
Merge pull request WZMIAOMIAO#540 from WZMIAOMIAO/dev
update comments
2 parents 87a31d3 + 3cd608f commit 37abd82

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

pytorch_classification/Test4_googlenet/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
116116

117117
self.branch3 = nn.Sequential(
118118
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
119+
# 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
120+
# Please see https://github.com/pytorch/vision/issues/906 for details.
119121
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
120122
)
121123

pytorch_classification/Test4_googlenet/train.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,20 @@ def main():
6060
# test_data_iter = iter(validate_loader)
6161
# test_image, test_label = test_data_iter.next()
6262

63+
net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
64+
# 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
65+
# 官方的模型中使用了bn层以及改了一些参数,不能混用
66+
# import torchvision
6367
# net = torchvision.models.googlenet(num_classes=5)
6468
# model_dict = net.state_dict()
69+
# # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
6570
# pretrain_model = torch.load("googlenet.pth")
6671
# del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
6772
# "aux2.fc2.weight", "aux2.fc2.bias",
6873
# "fc.weight", "fc.bias"]
6974
# pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
7075
# model_dict.update(pretrain_dict)
7176
# net.load_state_dict(model_dict)
72-
net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
7377
net.to(device)
7478
loss_function = nn.CrossEntropyLoss()
7579
optimizer = optim.Adam(net.parameters(), lr=0.0003)

0 commit comments

Comments
 (0)