diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 9a2b755..8d22035 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -84,14 +84,14 @@ def __init__(self, model_type="VGG2", **kwargs): dimensions=128, num_classes=self.hparams["num_classes"]) - ## add deprecated type for backward compatability + ## add deprecated type for backward compatibility elif model_type == "VGG1_old": self.network = _VGG1(in_channels=self.hparams["num_in_channels"], cfg = "B", dimensions=128, num_classes=self.hparams["num_classes"]) - ## add deprecated type for backward compatability + ## add deprecated type for backward compatibility elif model_type == "VGG2_old": self.network = _VGG2(in_channels=self.hparams["num_in_channels"], cfg = "B", @@ -123,7 +123,7 @@ def training_step(self, batch, batch_idx): output_softmax = self.network(data) loss = F.nll_loss(output_softmax, label) - #calculate accuracy + #calculate accuracy probabilities = torch.exp(output_softmax) pred_labels = torch.argmax(probabilities, dim=1) acc = self.accuracy(pred_labels, label) @@ -174,7 +174,7 @@ def __init__(self, model_type="VGG2_regression", **kwargs): # Define the regression model if model_type == "VGG2_regression": - self.network = VGG2_regression(in_channels=self.hparams["num_in_channels"], cfg="B") + self.network = VGG2_regression(in_channels=self.hparams["num_in_channels"], cfg="B", cfg_MLP="A") # Initialize metrics for regression model self.mse = torchmetrics.MeanSquaredError() # MSE metric for regression @@ -200,9 +200,9 @@ def configure_optimizers(self): raise ValueError("No optimizer specified in hparams") return optimizer - def training_step(self, batch, batch_idx): + def training_step(self, batch): data, target = batch - output = self.network(data) # Forward pass, only classification, no softmax + output = self.network(data) # Forward pass, only one output loss = F.mse_loss(output, target) # L2 loss # accuracy metrics for regression??? @@ -213,7 +213,7 @@ def training_step(self, batch, batch_idx): return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch): data, target = batch output = self.network(data) loss = F.mse_loss(output, target) @@ -226,7 +226,7 @@ def validation_step(self, batch, batch_idx): return loss - def test_step(self, batch, batch_idx): + def test_step(self, batch): data, target = batch output = self.network(data) loss = F.mse_loss(output, target)