From 58f5f92fa271ae2b5ecff0d53bbfb07ccbe02bf9 Mon Sep 17 00:00:00 2001 From: JefJ <26346574+jejon@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:03:12 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Fix=20dimension=20indexing=20in?= =?UTF-8?q?=20SoftmaxND=20+=20add=20unit=20tests=20for=20SoftmaxND=20and?= =?UTF-8?q?=20LogSoftmaxND?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/landmarker/models/utils.py | 7 +++++-- tests/test_models.py | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/landmarker/models/utils.py b/src/landmarker/models/utils.py index 34f4071..37cb1ed 100644 --- a/src/landmarker/models/utils.py +++ b/src/landmarker/models/utils.py @@ -5,10 +5,13 @@ class SoftmaxND(nn.Module): def __init__(self, spatial_dims): super().__init__() - self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -2) + self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -1) def forward(self, x): - out = torch.exp(x - torch.max(x, dim=self.dim, keepdim=True)[0]) + max_val = x + for d in self.dim: + max_val, _ = torch.max(max_val, dim=d, keepdim=True) + out = torch.exp(x - max_val) return out / torch.sum(out, dim=self.dim, keepdim=True) diff --git a/tests/test_models.py b/tests/test_models.py index 217ece4..4bcb9b4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,6 +15,7 @@ ProbSpatialConfigurationNet, SpatialConfigurationNet, ) +from landmarker.models.utils import LogSoftmaxND, SoftmaxND def test_original_spatial_configuration_net(): @@ -267,3 +268,37 @@ def test_coord_conv_layer_coord_channels_range(): # check that the output values are within the range [-1, 1] assert out.shape == torch.Size([2, 5, 64, 64]) assert (-1 <= out[:, 3:]).all() and (out[:, 3:] <= 1).all() + + +def test_softmax_nd(): + """Test the SoftmaxND class.""" + # Test for 2D case + softmax_2d = SoftmaxND(spatial_dims=2) + x = torch.randn(1, 3, 4, 4) + output = softmax_2d(x) + assert output.shape == x.shape + assert torch.allclose(torch.sum(output, dim=(-2, -1)), torch.ones(1, 3)) + + # Test for 3D case + softmax_3d = SoftmaxND(spatial_dims=3) + x = torch.randn(1, 3, 4, 4, 4) + output = softmax_3d(x) + assert output.shape == x.shape + assert torch.allclose(torch.sum(output, dim=(-3, -2, -1)), torch.ones(1, 3)) + + +def test_log_softmax_nd(): + """Test the LogSoftmaxND class.""" + # Test for 2D case + log_softmax_2d = LogSoftmaxND(spatial_dims=2) + x = torch.randn(1, 3, 4, 4) + output = log_softmax_2d(x) + assert output.shape == x.shape + assert torch.allclose(torch.sum(torch.exp(output), dim=(-2, -1)), torch.ones(1, 3)) + + # Test for 3D case + log_softmax_3d = LogSoftmaxND(spatial_dims=3) + x = torch.randn(1, 3, 4, 4, 4) + output = log_softmax_3d(x) + assert output.shape == x.shape + assert torch.allclose(torch.sum(torch.exp(output), dim=(-3, -2, -1)), torch.ones(1, 3))