diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 3809c16..e3cb158 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -205,7 +205,7 @@ def configure_optimizers(self): def training_step(self, batch): data, target = batch print("target training ", target) - target = target.unsqueeze(0) # Add dimension for regression + target = target.unsqueeze(1) # Add dimension for regression output = self.network(data) # Forward pass, only one output loss = F.mse_loss(output, target) # L2 loss @@ -219,7 +219,7 @@ def training_step(self, batch): def validation_step(self, batch): data, target = batch - target = target.unsqueeze(0) + target = target.unsqueeze(1) print("target val ", target) output = self.network(data)