Skip to content

Commit b49d29c

Browse files
committed
implementing 2d dice loss instead of binary cross-entropy loss.
1 parent 1fce760 commit b49d29c

File tree

1 file changed

+2
-19
lines changed

1 file changed

+2
-19
lines changed

tutorial/utils.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,6 @@ def get_loaders(
109109
log = logging.getLogger(__name__)
110110

111111

112-
class DiceLoss(nn.Module):
113-
def __init__(self, weight=None, size_average=True):
114-
super(DiceLoss, self).__init__()
115-
116-
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, smooth=1e-4):
117-
# comment out if your model contains a sigmoid or equivalent activation layer
118-
inputs = torch.sigmoid(inputs)
119-
120-
# flatten label and prediction tensors
121-
inputs = inputs.view(-1)
122-
targets = targets.view(-1)
123-
124-
intersection = (inputs * targets).sum()
125-
dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
126-
127-
return 1 - dice
128-
129-
130112
class DiceLoss2D(nn.Module):
131113
"""Originally implemented by Cong Gao."""
132114

@@ -170,4 +152,5 @@ def forward(self, inputs, target):
170152
avg_dices = torch.sum(dices, dim=1) / num_classes
171153

172154
# compute average over the batch
173-
return torch.mean(avg_dices)
155+
return torch.mean(avg_dices)
156+

0 commit comments

Comments
 (0)