Skip to content

Commit de9cb53

Browse files
committed
commit
1 parent 8abcee5 commit de9cb53

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

generate_model.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from models import (cnn, C3DNet, resnet, ResNetV2, ResNeXt, ResNeXtV2, WideResNet, PreActResNet,
55
EfficientNet, DenseNet, ShuffleNet, ShuffleNetV2, SqueezeNet, MobileNet, MobileNetV2)
66

7+
from opts import parse_opts
78

89

9-
def get_cnn_model(cnn_name, model_depth, n_classes, in_channels, sample_size=96):
10+
11+
def main(cnn_name, model_depth, n_classes, in_channels, sample_size):
1012

1113
# simple CNN
1214
if cnn_name == 'cnn':
@@ -184,14 +186,29 @@ def get_cnn_model(cnn_name, model_depth, n_classes, in_channels, sample_size=96)
184186
override_params={'num_classes': n_classes},
185187
in_channels=in_channels)
186188

187-
188189
if torch.cuda.is_available():
189190
model.cuda()
190191

191192
return model
192193

193194

195+
if __name__ == '__main__':
196+
197+
parser = argparse.ArgumentParser()
198+
parser.add_argument('--manual_seed', default=1234, type=int, help='Mannual seed')
199+
parser.add_argument('--cnn_name', default='ResNet', type=str, help='cnn model names')
200+
parser.add_argument('--model_depth', default=101, type=str, help='model depth (18|34|50|101|152|200)')
201+
parser.add_argument('--n_classes', default=2, type=str, help='model output classes')
202+
parser.add_argument('--in_channels', default=1, type=str, help='model input channels (1|3)')
203+
parser.add_argument('--sample_size', default=128, type=str, help='image size')
204+
args = parser.parse_args()
194205

206+
model = main(cnn_name=args.cnn_name,
207+
model_depth=args.model_depth,
208+
n_classes=args.n_classes,
209+
in_channels=args.in_channels,
210+
sample_size=args.sample_sizes
211+
)
195212

196213

197214

opts.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
4+
def parse_opts():
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument('--manual_seed', default=1234, type=int, help='Mannual seed')
8+
parser.add_argument('--cnn_name', default='ResNet', type=str, help='cnn model names')
9+
parser.add_argument('--model_depth', default=101, type=str, help='model depth (18|34|50|101|152|200)')
10+
parser.add_argument('--n_classes', default=2, type=str, help='model output classes')
11+
parser.add_argument('--in_channels', default=1, type=str, help='model input channels (1|3)')
12+
parser.add_argument('--sample_size', default=128, type=str, help='image size')
13+
14+
args = parser.parse_args()
15+
16+
return args

0 commit comments

Comments
 (0)