Skip to content

Commit a487af0

Browse files
authored
Add argmax and argmax2d activations
1 parent b77cf5a commit a487af0

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

segmentation_models_pytorch/base/modules.py

+14
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def forward(self, x):
6464
return x * self.cSE(x) + x * self.sSE(x)
6565

6666

67+
class ArgMax(nn.Module):
68+
69+
def __init__(self, dim=None):
70+
super().__init__()
71+
self.dim = dim
72+
73+
def forward(self, x):
74+
return torch.argmax(x, dim=dim)
75+
76+
6777
class Activation(nn.Module):
6878

6979
def __init__(self, name, **params):
@@ -80,6 +90,10 @@ def __init__(self, name, **params):
8090
self.activation = nn.Softmax(**params)
8191
elif name == 'logsoftmax':
8292
self.activation = nn.LogSoftmax(**params)
93+
elif name == 'argmax':
94+
self.activation = ArgMax(**params)
95+
elif name == 'argmax2d':
96+
self.activation = ArgMax(dim=1, **params)
8397
elif callable(name):
8498
self.activation = name(**params)
8599
else:

0 commit comments

Comments
 (0)