Skip to content

Commit a3843bb

Browse files
committed
boxed module handles
1 parent 614c660 commit a3843bb

18 files changed

+445
-264
lines changed

.editorconfig

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ max_line_length = 120
1515
indent_style = space
1616
indent_size = 2
1717

18+
# C++
19+
[*.{cpp,h,hpp,c}]
20+
indent_style = space
21+
indent_size = 4
22+
1823
# XML config files
1924
[*.{config,nuspec,resx}]
2025
indent_style = space

src/Examples/AlexNet.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ static void Main(string[] args)
4747
}
4848
}
4949

50-
private class Model : NN.Module
50+
private class Model : CustomModule
5151
{
52-
private readonly NN.Sequential features;
53-
private readonly NN.AdaptiveAvgPool2D avgPool;
54-
private readonly NN.Sequential classifier;
52+
private readonly Sequential features;
53+
private readonly AdaptiveAvgPool2D avgPool;
54+
private readonly Sequential classifier;
5555

56-
public Model(int numClasses)
56+
public Model(int numClasses)
5757
{
5858
features = Sequential(
5959
("c1", Conv2D(3, 64, kernelSize: 3, stride: 2, padding: 1)),
@@ -87,7 +87,7 @@ public Model(int numClasses)
8787
RegisterModule ("classify", classifier);
8888
}
8989

90-
public TorchTensor Forward(TorchTensor input)
90+
public override TorchTensor Forward(TorchTensor input)
9191
{
9292
using (var f = features.Forward(input))
9393
using (var avg = avgPool.Forward(f))

src/Examples/MNIST.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ static void Main(string[] args)
4343
}
4444
}
4545

46-
private class Model : NN.Module
46+
private class Model : CustomModule
4747
{
48-
private NN.Conv2D conv1 = Conv2D(1, 10, 5);
49-
private NN.Conv2D conv2 = Conv2D(10, 20, 5);
50-
private NN.Linear fc1 = Linear(320, 50);
51-
private NN.Linear fc2 = Linear(50, 10);
48+
private Conv2D conv1 = Conv2D(1, 10, 5);
49+
private Conv2D conv2 = Conv2D(10, 20, 5);
50+
private Linear fc1 = Linear(320, 50);
51+
private Linear fc2 = Linear(50, 10);
5252

5353
public Model()
5454
{
@@ -58,7 +58,7 @@ public Model()
5858
RegisterModule("lin2", fc2);
5959
}
6060

61-
public TorchTensor Forward(TorchTensor input)
61+
public override TorchTensor Forward(TorchTensor input)
6262
{
6363
using (var l11 = conv1.Forward(input))
6464
using (var l12 = MaxPool2D (l11, kernelSize: new long[]{ 2 }))

0 commit comments

Comments
 (0)