Skip to content

Commit 840813b

Browse files
committed
small changes to the VGG regression model
1 parent 72ee7d6 commit 840813b

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/sparcscore/ml/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def make_layers(self, cfg, in_channels, batch_norm = True):
5555
i +=1
5656
return nn.Sequential(OrderedDict(layers))
5757

58-
def make_layers_MLP(self, cfg_MLP, cfg, single_output = False):
58+
def make_layers_MLP(self, cfg_MLP, cfg, regression = False):
5959
"""
6060
Create sequential models layers according to the chosen configuration provided in
6161
cfg for the MLP.
@@ -77,13 +77,13 @@ def make_layers_MLP(self, cfg_MLP, cfg, single_output = False):
7777
for out_features in cfg_MLP:
7878
if out_features == "M":
7979
layers += [(f"MLP_relu{i}", nn.ReLU(True)), (f"MLP_dropout{i}", nn.Dropout())]
80-
i+=1
80+
i += 1
8181
else:
8282
linear = (f"MLP_linear{i}", nn.Linear(in_features, out_features))
8383
layers += [linear]
8484
in_features = out_features
8585

86-
if single_output:
86+
if regression: # if regression is True, make the final layer a single output
8787
linear = (f"MLP_linear_final", nn.Linear(in_features, 1))
8888
layers += [linear]
8989
else:
@@ -112,7 +112,7 @@ class VGG1(VGGBase):
112112
Instance of VGGBase with the model architecture 1.
113113
"""
114114
def __init__(self,
115-
cfg = "B",
115+
cfg = "B", # default configuration
116116
cfg_MLP = "A",
117117
dimensions = 196,
118118
in_channels = 1,
@@ -167,17 +167,15 @@ class VGG2_regression(VGGBase):
167167
"""
168168
def __init__(self,
169169
cfg = "B",
170-
cfg_MLP = "B",
171-
dimensions = 196,
172-
in_channels = 1,
173-
num_classes = 2,
170+
cfg_MLP = "A",
171+
in_channels = 1,
174172
):
175173

176174
super(VGG2_regression, self).__init__()
177175

178176
self.norm = nn.BatchNorm2d(in_channels)
179177
self.features = self.make_layers(self.cfgs[cfg], in_channels)
180-
self.classifier = self.make_layers_MLP(self.cfgs_MLP[cfg_MLP], self.cfgs[cfg], single_output = True)
178+
self.classifier = self.make_layers_MLP(self.cfgs_MLP[cfg_MLP], self.cfgs[cfg], regression=True) # regression is set to True to make the final layer a single output
181179

182180
def vgg(cfg, in_channels, **kwargs):
183181
model = VGG2_regression(self.make_layers(self.cfgs[cfg], in_channels), **kwargs)
@@ -186,7 +184,9 @@ def vgg(cfg, in_channels, **kwargs):
186184
def forward(self, x):
187185
x = self.norm(x)
188186
x = self.features(x)
187+
189188
x = torch.flatten(x, 1)
189+
190190
x = self.classifier(x)
191191
return x
192192

0 commit comments

Comments
 (0)