@@ -45,10 +45,10 @@ static void Main(string[] args)
45
45
46
46
private class Model : NN . Module
47
47
{
48
- private NN . Module conv1 = Conv2D ( 1 , 10 , 5 ) ;
49
- private NN . Module conv2 = Conv2D ( 10 , 20 , 5 ) ;
50
- private NN . Module fc1 = Linear ( 320 , 50 ) ;
51
- private NN . Module fc2 = Linear ( 50 , 10 ) ;
48
+ private NN . Conv2D conv1 = Conv2D ( 1 , 10 , 5 ) ;
49
+ private NN . Conv2D conv2 = Conv2D ( 10 , 20 , 5 ) ;
50
+ private NN . Linear fc1 = Linear ( 320 , 50 ) ;
51
+ private NN . Linear fc2 = Linear ( 50 , 10 ) ;
52
52
53
53
public Model ( )
54
54
{
@@ -58,22 +58,22 @@ public Model()
58
58
RegisterModule ( fc2 ) ;
59
59
}
60
60
61
- public override TorchTensor Forward ( TorchTensor input )
61
+ public TorchTensor Forward ( TorchTensor input )
62
62
{
63
63
using ( var l11 = conv1 . Forward ( input ) )
64
- using ( var l12 = MaxPool2D ( l11 , kernelSize : new long [ ] { 2 } ) )
64
+ using ( var l12 = MaxPool2D ( l11 , kernelSize : new long [ ] { 2 } ) )
65
65
using ( var l13 = Relu ( l12 ) )
66
66
67
67
using ( var l21 = conv2 . Forward ( l13 ) )
68
- using ( var l22 = FeatureDropout ( l21 ) )
69
- using ( var l23 = MaxPool2D ( l22 , kernelSize : new long [ ] { 2 } ) )
68
+ using ( var l22 = FeatureAlphaDropout ( l21 ) )
69
+ using ( var l23 = MaxPool2D ( l22 , kernelSize : new long [ ] { 2 } ) )
70
70
using ( var l24 = Relu ( l23 ) )
71
71
72
72
using ( var x = l24 . View ( new long [ ] { - 1 , 320 } ) )
73
73
74
74
using ( var l31 = fc1 . Forward ( x ) )
75
75
using ( var l32 = Relu ( l31 ) )
76
- using ( var l33 = Dropout ( l32 , IsTraining ( ) ) )
76
+ using ( var l33 = Dropout ( l32 ) )
77
77
78
78
using ( var l41 = fc2 . Forward ( l33 ) )
79
79
@@ -82,7 +82,7 @@ public override TorchTensor Forward(TorchTensor input)
82
82
}
83
83
84
84
private static void Train (
85
- NN . Module model ,
85
+ Model model ,
86
86
NN . Optimizer optimizer ,
87
87
Loss loss ,
88
88
IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
@@ -119,7 +119,7 @@ private static void Train(
119
119
}
120
120
121
121
private static void Test (
122
- NN . Module model ,
122
+ Model model ,
123
123
Loss loss ,
124
124
IEnumerable < ( TorchTensor , TorchTensor ) > dataLoader ,
125
125
long size )
0 commit comments