@@ -56,35 +56,35 @@ private class Model : NN.Module
56
56
public Model ( int numClasses )
57
57
{
58
58
features = Sequential (
59
- Conv2D ( 3 , 64 , kernelSize : 3 , stride : 2 , padding : 1 ) ,
60
- Relu ( inPlace : true ) ,
61
- MaxPool2D ( kernelSize : new long [ ] { 2 } ) ,
62
- Conv2D ( 64 , 192 , kernelSize : 3 , padding : 1 ) ,
63
- Relu ( inPlace : true ) ,
64
- MaxPool2D ( kernelSize : new long [ ] { 2 } ) ,
65
- Conv2D ( 192 , 384 , kernelSize : 3 , padding : 1 ) ,
66
- Relu ( inPlace : true ) ,
67
- Conv2D ( 384 , 256 , kernelSize : 3 , padding : 1 ) ,
68
- Relu ( inPlace : true ) ,
69
- Conv2D ( 256 , 256 , kernelSize : 3 , padding : 1 ) ,
70
- Relu ( inPlace : true ) ,
71
- MaxPool2D ( kernelSize : new long [ ] { 2 } ) ) ;
59
+ ( "c1" , Conv2D ( 3 , 64 , kernelSize : 3 , stride : 2 , padding : 1 ) ) ,
60
+ ( "r1" , Relu ( inPlace : true ) ) ,
61
+ ( "mp1" , MaxPool2D ( kernelSize : new long [ ] { 2 } ) ) ,
62
+ ( "c2" , Conv2D ( 64 , 192 , kernelSize : 3 , padding : 1 ) ) ,
63
+ ( "r2" , Relu ( inPlace : true ) ) ,
64
+ ( "mp2" , MaxPool2D ( kernelSize : new long [ ] { 2 } ) ) ,
65
+ ( "c3" , Conv2D ( 192 , 384 , kernelSize : 3 , padding : 1 ) ) ,
66
+ ( "r3" , Relu ( inPlace : true ) ) ,
67
+ ( "c4" , Conv2D ( 384 , 256 , kernelSize : 3 , padding : 1 ) ) ,
68
+ ( "r4" , Relu ( inPlace : true ) ) ,
69
+ ( "c5" , Conv2D ( 256 , 256 , kernelSize : 3 , padding : 1 ) ) ,
70
+ ( "r5" , Relu ( inPlace : true ) ) ,
71
+ ( "mp3" , MaxPool2D ( kernelSize : new long [ ] { 2 } ) ) ) ;
72
72
73
73
avgPool = AdaptiveAvgPool2D ( new long [ ] { 2 , 2 } ) ;
74
74
75
75
classifier = Sequential (
76
- Dropout ( ) ,
77
- Linear ( 256 * 2 * 2 , 4096 ) ,
78
- Relu ( inPlace : true ) ,
79
- Dropout ( ) ,
80
- Linear ( 4096 , 4096 ) ,
81
- Relu ( inPlace : true ) ,
82
- Linear ( 4096 , numClasses )
76
+ ( "d1" , Dropout ( ) ) ,
77
+ ( "l1" , Linear ( 256 * 2 * 2 , 4096 ) ) ,
78
+ ( "r1" , Relu ( inPlace : true ) ) ,
79
+ ( "d2" , Dropout ( ) ) ,
80
+ ( "l2" , Linear ( 4096 , 4096 ) ) ,
81
+ ( "r3" , Relu ( inPlace : true ) ) ,
82
+ ( "l3" , Linear ( 4096 , numClasses ) )
83
83
) ;
84
84
85
- RegisterModule ( features ) ;
86
- RegisterModule ( avgPool ) ;
87
- RegisterModule ( classifier ) ;
85
+ RegisterModule ( "features" , features ) ;
86
+ RegisterModule ( "avg" , avgPool ) ;
87
+ RegisterModule ( "classify" , classifier ) ;
88
88
}
89
89
90
90
public TorchTensor Forward ( TorchTensor input )
0 commit comments