We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 7438604 + 1b295b5 commit 1902b25Copy full SHA for 1902b25
pytorch_object_detection/faster_rcnn/train_res50_fpn.py
@@ -17,7 +17,7 @@ def create_model(num_classes, load_pretrain_weights=True):
17
# 如果GPU显存很大可以设置比较大的batch_size就可以将norm_layer设置为普通的BatchNorm2d
18
# trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
19
# resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pth
20
- backbone = resnet50_fpn_backbone(pretrain_path="resnet50.pth",
+ backbone = resnet50_fpn_backbone(pretrain_path="./backbone/resnet50.pth",
21
norm_layer=torch.nn.BatchNorm2d,
22
trainable_layers=3)
23
# 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数
0 commit comments