Skip to content

Commit 718a2d9

Browse files
committed
simplify NN modules to avoid storing information in .NET
1 parent b267466 commit 718a2d9

21 files changed

+484
-559
lines changed

src/Examples/AlexNet.cs

+11-10
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ static void Main(string[] args)
4949

5050
private class Model : NN.Module
5151
{
52-
private readonly NN.Module features;
53-
private readonly NN.Module avgPool;
54-
private readonly NN.Module classifier;
52+
private readonly NN.Sequential features;
53+
private readonly NN.AdaptiveAvgPool2D avgPool;
54+
private readonly NN.Sequential classifier;
5555

5656
public Model(int numClasses)
5757
{
@@ -70,23 +70,24 @@ public Model(int numClasses)
7070
Relu(inPlace: true),
7171
MaxPool2D(kernelSize: new long[] { 2 }));
7272

73-
avgPool = AdaptiveAvgPool2D(2, 2);
73+
avgPool = AdaptiveAvgPool2D(new long[] { 2, 2 });
7474

7575
classifier = Sequential(
76-
Dropout(IsTraining()),
76+
Dropout(),
7777
Linear(256 * 2 * 2, 4096),
7878
Relu(inPlace: true),
79-
Dropout(IsTraining()),
79+
Dropout(),
8080
Linear(4096, 4096),
8181
Relu(inPlace: true),
8282
Linear(4096, numClasses)
8383
);
8484

85-
RegisterModule(features);
85+
RegisterModule (features);
86+
RegisterModule (avgPool);
8687
RegisterModule(classifier);
8788
}
8889

89-
public override TorchTensor Forward(TorchTensor input)
90+
public TorchTensor Forward(TorchTensor input)
9091
{
9192
using (var f = features.Forward(input))
9293
using (var avg = avgPool.Forward(f))
@@ -97,7 +98,7 @@ public override TorchTensor Forward(TorchTensor input)
9798
}
9899

99100
private static void Train(
100-
NN.Module model,
101+
Model model,
101102
NN.Optimizer optimizer,
102103
Loss loss,
103104
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
@@ -140,7 +141,7 @@ private static void Train(
140141
}
141142

142143
private static void Test(
143-
NN.Module model,
144+
Model model,
144145
Loss loss,
145146
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
146147
long size)

src/Examples/MNIST.cs

+11-11
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ static void Main(string[] args)
4545

4646
private class Model : NN.Module
4747
{
48-
private NN.Module conv1 = Conv2D(1, 10, 5);
49-
private NN.Module conv2 = Conv2D(10, 20, 5);
50-
private NN.Module fc1 = Linear(320, 50);
51-
private NN.Module fc2 = Linear(50, 10);
48+
private NN.Conv2D conv1 = Conv2D(1, 10, 5);
49+
private NN.Conv2D conv2 = Conv2D(10, 20, 5);
50+
private NN.Linear fc1 = Linear(320, 50);
51+
private NN.Linear fc2 = Linear(50, 10);
5252

5353
public Model()
5454
{
@@ -58,22 +58,22 @@ public Model()
5858
RegisterModule(fc2);
5959
}
6060

61-
public override TorchTensor Forward(TorchTensor input)
61+
public TorchTensor Forward(TorchTensor input)
6262
{
6363
using (var l11 = conv1.Forward(input))
64-
using (var l12 = MaxPool2D(l11, kernelSize: new long[]{ 2 }))
64+
using (var l12 = MaxPool2D (l11, kernelSize: new long[]{ 2 }))
6565
using (var l13 = Relu(l12))
6666

6767
using (var l21 = conv2.Forward(l13))
68-
using (var l22 = FeatureDropout(l21))
69-
using (var l23 = MaxPool2D(l22, kernelSize: new long[] { 2 }))
68+
using (var l22 = FeatureAlphaDropout(l21))
69+
using (var l23 = MaxPool2D (l22, kernelSize: new long[] { 2 }))
7070
using (var l24 = Relu(l23))
7171

7272
using (var x = l24.View(new long[] { -1, 320 }))
7373

7474
using (var l31 = fc1.Forward(x))
7575
using (var l32 = Relu(l31))
76-
using (var l33 = Dropout(l32, IsTraining()))
76+
using (var l33 = Dropout(l32))
7777

7878
using (var l41 = fc2.Forward(l33))
7979

@@ -82,7 +82,7 @@ public override TorchTensor Forward(TorchTensor input)
8282
}
8383

8484
private static void Train(
85-
NN.Module model,
85+
Model model,
8686
NN.Optimizer optimizer,
8787
Loss loss,
8888
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
@@ -119,7 +119,7 @@ private static void Train(
119119
}
120120

121121
private static void Test(
122-
NN.Module model,
122+
Model model,
123123
Loss loss,
124124
IEnumerable<(TorchTensor, TorchTensor)> dataLoader,
125125
long size)

0 commit comments

Comments
 (0)