Skip to content

Commit 614c660

Browse files
committed
get more tests passing
1 parent 718a2d9 commit 614c660

File tree

11 files changed

+336
-339
lines changed

11 files changed

+336
-339
lines changed

src/Examples/AlexNet.cs

+23-23
Original file line numberDiff line numberDiff line change
@@ -56,35 +56,35 @@ private class Model : NN.Module
5656
public Model(int numClasses)
5757
{
5858
features = Sequential(
59-
Conv2D(3, 64, kernelSize: 3, stride: 2, padding: 1),
60-
Relu(inPlace: true),
61-
MaxPool2D(kernelSize: new long[] { 2 }),
62-
Conv2D(64, 192, kernelSize: 3, padding: 1),
63-
Relu(inPlace: true),
64-
MaxPool2D(kernelSize: new long[] { 2 }),
65-
Conv2D(192, 384, kernelSize: 3, padding: 1),
66-
Relu(inPlace: true),
67-
Conv2D(384, 256, kernelSize: 3, padding: 1),
68-
Relu(inPlace: true),
69-
Conv2D(256, 256, kernelSize: 3, padding: 1),
70-
Relu(inPlace: true),
71-
MaxPool2D(kernelSize: new long[] { 2 }));
59+
("c1", Conv2D(3, 64, kernelSize: 3, stride: 2, padding: 1)),
60+
("r1", Relu(inPlace: true)),
61+
("mp1", MaxPool2D(kernelSize: new long[] { 2 })),
62+
("c2", Conv2D(64, 192, kernelSize: 3, padding: 1)),
63+
("r2", Relu(inPlace: true)),
64+
("mp2", MaxPool2D(kernelSize: new long[] { 2 })),
65+
("c3", Conv2D(192, 384, kernelSize: 3, padding: 1)),
66+
("r3", Relu(inPlace: true)),
67+
("c4", Conv2D(384, 256, kernelSize: 3, padding: 1)),
68+
("r4", Relu(inPlace: true)),
69+
("c5", Conv2D(256, 256, kernelSize: 3, padding: 1)),
70+
("r5", Relu(inPlace: true)),
71+
("mp3", MaxPool2D(kernelSize: new long[] { 2 })));
7272

7373
avgPool = AdaptiveAvgPool2D(new long[] { 2, 2 });
7474

7575
classifier = Sequential(
76-
Dropout(),
77-
Linear(256 * 2 * 2, 4096),
78-
Relu(inPlace: true),
79-
Dropout(),
80-
Linear(4096, 4096),
81-
Relu(inPlace: true),
82-
Linear(4096, numClasses)
76+
("d1", Dropout()),
77+
("l1", Linear(256 * 2 * 2, 4096)),
78+
("r1", Relu(inPlace: true)),
79+
("d2", Dropout()),
80+
("l2", Linear(4096, 4096)),
81+
("r3", Relu(inPlace: true)),
82+
("l3", Linear(4096, numClasses))
8383
);
8484

85-
RegisterModule (features);
86-
RegisterModule (avgPool);
87-
RegisterModule(classifier);
85+
RegisterModule ("features", features);
86+
RegisterModule ("avg", avgPool);
87+
RegisterModule ("classify", classifier);
8888
}
8989

9090
public TorchTensor Forward(TorchTensor input)

src/Examples/MNIST.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ private class Model : NN.Module
5252

5353
public Model()
5454
{
55-
RegisterModule(conv1);
56-
RegisterModule(conv2);
57-
RegisterModule(fc1);
58-
RegisterModule(fc2);
55+
RegisterModule("conv1", conv1);
56+
RegisterModule("conv2", conv2);
57+
RegisterModule("lin1", fc1);
58+
RegisterModule("lin2", fc2);
5959
}
6060

6161
public TorchTensor Forward(TorchTensor input)

0 commit comments

Comments
 (0)