Skip to content

Commit 74a79ea

Browse files
committed
RegressionModel
1 parent 840813b commit 74a79ea

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/sparcscore/ml/plmodels.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ def __init__(self, model_type="VGG2", **kwargs):
8484
dimensions=128,
8585
num_classes=self.hparams["num_classes"])
8686

87-
## add deprecated type for backward compatability
87+
## add deprecated type for backward compatibility
8888
elif model_type == "VGG1_old":
8989
self.network = _VGG1(in_channels=self.hparams["num_in_channels"],
9090
cfg = "B",
9191
dimensions=128,
9292
num_classes=self.hparams["num_classes"])
9393

94-
## add deprecated type for backward compatability
94+
## add deprecated type for backward compatibility
9595
elif model_type == "VGG2_old":
9696
self.network = _VGG2(in_channels=self.hparams["num_in_channels"],
9797
cfg = "B",
@@ -123,7 +123,7 @@ def training_step(self, batch, batch_idx):
123123
output_softmax = self.network(data)
124124
loss = F.nll_loss(output_softmax, label)
125125

126-
#calculate accuracy
126+
#calculate accuracy
127127
probabilities = torch.exp(output_softmax)
128128
pred_labels = torch.argmax(probabilities, dim=1)
129129
acc = self.accuracy(pred_labels, label)
@@ -174,7 +174,7 @@ def __init__(self, model_type="VGG2_regression", **kwargs):
174174

175175
# Define the regression model
176176
if model_type == "VGG2_regression":
177-
self.network = VGG2_regression(in_channels=self.hparams["num_in_channels"], cfg="B")
177+
self.network = VGG2_regression(in_channels=self.hparams["num_in_channels"], cfg="B", cfg_MLP="A")
178178

179179
# Initialize metrics for regression model
180180
self.mse = torchmetrics.MeanSquaredError() # MSE metric for regression
@@ -200,9 +200,9 @@ def configure_optimizers(self):
200200
raise ValueError("No optimizer specified in hparams")
201201
return optimizer
202202

203-
def training_step(self, batch, batch_idx):
203+
def training_step(self, batch):
204204
data, target = batch
205-
output = self.network(data) # Forward pass, only classification, no softmax
205+
output = self.network(data) # Forward pass, only one output
206206
loss = F.mse_loss(output, target) # L2 loss
207207

208208
# accuracy metrics for regression???
@@ -213,7 +213,7 @@ def training_step(self, batch, batch_idx):
213213

214214
return loss
215215

216-
def validation_step(self, batch, batch_idx):
216+
def validation_step(self, batch):
217217
data, target = batch
218218
output = self.network(data)
219219
loss = F.mse_loss(output, target)
@@ -226,7 +226,7 @@ def validation_step(self, batch, batch_idx):
226226

227227
return loss
228228

229-
def test_step(self, batch, batch_idx):
229+
def test_step(self, batch):
230230
data, target = batch
231231
output = self.network(data)
232232
loss = F.mse_loss(output, target)

0 commit comments

Comments
 (0)