@@ -84,14 +84,14 @@ def __init__(self, model_type="VGG2", **kwargs):
84
84
dimensions = 128 ,
85
85
num_classes = self .hparams ["num_classes" ])
86
86
87
- ## add deprecated type for backward compatability
87
+ ## add deprecated type for backward compatibility
88
88
elif model_type == "VGG1_old" :
89
89
self .network = _VGG1 (in_channels = self .hparams ["num_in_channels" ],
90
90
cfg = "B" ,
91
91
dimensions = 128 ,
92
92
num_classes = self .hparams ["num_classes" ])
93
93
94
- ## add deprecated type for backward compatability
94
+ ## add deprecated type for backward compatibility
95
95
elif model_type == "VGG2_old" :
96
96
self .network = _VGG2 (in_channels = self .hparams ["num_in_channels" ],
97
97
cfg = "B" ,
@@ -123,7 +123,7 @@ def training_step(self, batch, batch_idx):
123
123
output_softmax = self .network (data )
124
124
loss = F .nll_loss (output_softmax , label )
125
125
126
- #calculate accuracy
126
+ #calculate accuracy
127
127
probabilities = torch .exp (output_softmax )
128
128
pred_labels = torch .argmax (probabilities , dim = 1 )
129
129
acc = self .accuracy (pred_labels , label )
@@ -174,7 +174,7 @@ def __init__(self, model_type="VGG2_regression", **kwargs):
174
174
175
175
# Define the regression model
176
176
if model_type == "VGG2_regression" :
177
- self .network = VGG2_regression (in_channels = self .hparams ["num_in_channels" ], cfg = "B" )
177
+ self .network = VGG2_regression (in_channels = self .hparams ["num_in_channels" ], cfg = "B" , cfg_MLP = "A" )
178
178
179
179
# Initialize metrics for regression model
180
180
self .mse = torchmetrics .MeanSquaredError () # MSE metric for regression
@@ -200,9 +200,9 @@ def configure_optimizers(self):
200
200
raise ValueError ("No optimizer specified in hparams" )
201
201
return optimizer
202
202
203
- def training_step (self , batch , batch_idx ):
203
+ def training_step (self , batch ):
204
204
data , target = batch
205
- output = self .network (data ) # Forward pass, only classification, no softmax
205
+ output = self .network (data ) # Forward pass, only one output
206
206
loss = F .mse_loss (output , target ) # L2 loss
207
207
208
208
# accuracy metrics for regression???
@@ -213,7 +213,7 @@ def training_step(self, batch, batch_idx):
213
213
214
214
return loss
215
215
216
- def validation_step (self , batch , batch_idx ):
216
+ def validation_step (self , batch ):
217
217
data , target = batch
218
218
output = self .network (data )
219
219
loss = F .mse_loss (output , target )
@@ -226,7 +226,7 @@ def validation_step(self, batch, batch_idx):
226
226
227
227
return loss
228
228
229
- def test_step (self , batch , batch_idx ):
229
+ def test_step (self , batch ):
230
230
data , target = batch
231
231
output = self .network (data )
232
232
loss = F .mse_loss (output , target )
0 commit comments