Skip to content

Commit 1fce760

Browse files
committed
implementing 2d dice loss instead of binary cross-entropy loss.
1 parent 0716e8c commit 1fce760

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tutorial/utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import numpy as np
55
from dataset import UCLSegmentation
66

7+
import logging
8+
from torch import nn
9+
log = logging.getLogger(__name__)
10+
711

812
def save_checkpoint(state, filename='my_checkpoint.pth.tar'):
913
print('=> Saving checkpoint')
@@ -95,3 +99,75 @@ def get_loaders(
9599
)
96100

97101
return train_loader, val_loader
102+
103+
104+
import logging
105+
106+
import torch
107+
from torch import nn
108+
109+
log = logging.getLogger(__name__)
110+
111+
112+
class DiceLoss(nn.Module):
113+
def __init__(self, weight=None, size_average=True):
114+
super(DiceLoss, self).__init__()
115+
116+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, smooth=1e-4):
117+
# comment out if your model contains a sigmoid or equivalent activation layer
118+
inputs = torch.sigmoid(inputs)
119+
120+
# flatten label and prediction tensors
121+
inputs = inputs.view(-1)
122+
targets = targets.view(-1)
123+
124+
intersection = (inputs * targets).sum()
125+
dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
126+
127+
return 1 - dice
128+
129+
130+
class DiceLoss2D(nn.Module):
131+
"""Originally implemented by Cong Gao."""
132+
133+
def __init__(self, skip_bg=False):
134+
super(DiceLoss2D, self).__init__()
135+
136+
self.skip_bg = skip_bg
137+
138+
def forward(self, inputs, target):
139+
# Add this to numerator and denominator to avoid divide by zero when nothing is segmented
140+
# and ground truth is also empty (denominator term).
141+
# Also allow a Dice of 1 (-1) for this case (both terms).
142+
eps = 1.0e-4
143+
144+
if self.skip_bg:
145+
# numerator of Dice, for each class except class 0 (background)
146+
numerators = 2 * torch.sum(target[:, 1:] * inputs[:, 1:], dim=(2, 3)) + eps
147+
148+
# denominator of Dice, for each class except class 0 (background)
149+
denominators = (
150+
torch.sum(target[:, 1:] * target[:, 1:, :, :], dim=(2, 3))
151+
+ torch.sum(inputs[:, 1:] * inputs[:, 1:], dim=(2, 3))
152+
+ eps
153+
)
154+
155+
# minus one to exclude the background class
156+
num_classes = inputs.shape[1] - 1
157+
else:
158+
# numerator of Dice, for each class
159+
numerators = 2 * torch.sum(target * inputs, dim=(2, 3)) + eps
160+
161+
# denominator of Dice, for each class
162+
denominators = torch.sum(target * target, dim=(2, 3)) + torch.sum(inputs * inputs, dim=(2, 3)) + eps
163+
164+
num_classes = inputs.shape[1]
165+
166+
# Dice coefficients for each image in the batch, for each class
167+
dices = 1 - (numerators / denominators)
168+
169+
# compute average Dice score for each image in the batch
170+
avg_dices = torch.sum(dices, dim=1) / num_classes
171+
172+
# compute average over the batch
173+
return torch.mean(avg_dices)

0 commit comments

Comments
 (0)