@@ -84,14 +84,14 @@ def __init__(self, model_type="VGG2", **kwargs):
8484 dimensions = 128 ,
8585 num_classes = self .hparams ["num_classes" ])
8686
87- ## add deprecated type for backward compatability
87+ ## add deprecated type for backward compatibility
8888 elif model_type == "VGG1_old" :
8989 self .network = _VGG1 (in_channels = self .hparams ["num_in_channels" ],
9090 cfg = "B" ,
9191 dimensions = 128 ,
9292 num_classes = self .hparams ["num_classes" ])
9393
94- ## add deprecated type for backward compatability
94+ ## add deprecated type for backward compatibility
9595 elif model_type == "VGG2_old" :
9696 self .network = _VGG2 (in_channels = self .hparams ["num_in_channels" ],
9797 cfg = "B" ,
@@ -123,7 +123,7 @@ def training_step(self, batch, batch_idx):
123123 output_softmax = self .network (data )
124124 loss = F .nll_loss (output_softmax , label )
125125
126- #calculate accuracy
126+ #calculate accuracy
127127 probabilities = torch .exp (output_softmax )
128128 pred_labels = torch .argmax (probabilities , dim = 1 )
129129 acc = self .accuracy (pred_labels , label )
@@ -174,7 +174,7 @@ def __init__(self, model_type="VGG2_regression", **kwargs):
174174
175175 # Define the regression model
176176 if model_type == "VGG2_regression" :
177- self .network = VGG2_regression (in_channels = self .hparams ["num_in_channels" ], cfg = "B" )
177+ self .network = VGG2_regression (in_channels = self .hparams ["num_in_channels" ], cfg = "B" , cfg_MLP = "A" )
178178
179179 # Initialize metrics for regression model
180180 self .mse = torchmetrics .MeanSquaredError () # MSE metric for regression
@@ -200,9 +200,9 @@ def configure_optimizers(self):
200200 raise ValueError ("No optimizer specified in hparams" )
201201 return optimizer
202202
203- def training_step (self , batch , batch_idx ):
203+ def training_step (self , batch ):
204204 data , target = batch
205- output = self .network (data ) # Forward pass, only classification, no softmax
205+ output = self .network (data ) # Forward pass, only one output
206206 loss = F .mse_loss (output , target ) # L2 loss
207207
208208 # accuracy metrics for regression???
@@ -213,7 +213,7 @@ def training_step(self, batch, batch_idx):
213213
214214 return loss
215215
216- def validation_step (self , batch , batch_idx ):
216+ def validation_step (self , batch ):
217217 data , target = batch
218218 output = self .network (data )
219219 loss = F .mse_loss (output , target )
@@ -226,7 +226,7 @@ def validation_step(self, batch, batch_idx):
226226
227227 return loss
228228
229- def test_step (self , batch , batch_idx ):
229+ def test_step (self , batch ):
230230 data , target = batch
231231 output = self .network (data )
232232 loss = F .mse_loss (output , target )
0 commit comments