|
4 | 4 | from models import (cnn, C3DNet, resnet, ResNetV2, ResNeXt, ResNeXtV2, WideResNet, PreActResNet,
|
5 | 5 | EfficientNet, DenseNet, ShuffleNet, ShuffleNetV2, SqueezeNet, MobileNet, MobileNetV2)
|
6 | 6 |
|
| 7 | +from opts import parse_opts |
7 | 8 |
|
8 | 9 |
|
9 |
| -def get_cnn_model(cnn_name, model_depth, n_classes, in_channels, sample_size=96): |
| 10 | + |
| 11 | +def main(cnn_name, model_depth, n_classes, in_channels, sample_size): |
10 | 12 |
|
11 | 13 | # simple CNN
|
12 | 14 | if cnn_name == 'cnn':
|
@@ -184,14 +186,29 @@ def get_cnn_model(cnn_name, model_depth, n_classes, in_channels, sample_size=96)
|
184 | 186 | override_params={'num_classes': n_classes},
|
185 | 187 | in_channels=in_channels)
|
186 | 188 |
|
187 |
| - |
188 | 189 | if torch.cuda.is_available():
|
189 | 190 | model.cuda()
|
190 | 191 |
|
191 | 192 | return model
|
192 | 193 |
|
193 | 194 |
|
| 195 | +if __name__ == '__main__': |
| 196 | + |
| 197 | + parser = argparse.ArgumentParser() |
| 198 | + parser.add_argument('--manual_seed', default=1234, type=int, help='Mannual seed') |
| 199 | + parser.add_argument('--cnn_name', default='ResNet', type=str, help='cnn model names') |
| 200 | + parser.add_argument('--model_depth', default=101, type=str, help='model depth (18|34|50|101|152|200)') |
| 201 | + parser.add_argument('--n_classes', default=2, type=str, help='model output classes') |
| 202 | + parser.add_argument('--in_channels', default=1, type=str, help='model input channels (1|3)') |
| 203 | + parser.add_argument('--sample_size', default=128, type=str, help='image size') |
| 204 | + args = parser.parse_args() |
194 | 205 |
|
| 206 | + model = main(cnn_name=args.cnn_name, |
| 207 | + model_depth=args.model_depth, |
| 208 | + n_classes=args.n_classes, |
| 209 | + in_channels=args.in_channels, |
| 210 | + sample_size=args.sample_sizes |
| 211 | + ) |
195 | 212 |
|
196 | 213 |
|
197 | 214 |
|
0 commit comments