4
4
import numpy as np
5
5
from dataset import UCLSegmentation
6
6
7
+ import logging
8
+ from torch import nn
9
+ log = logging .getLogger (__name__ )
10
+
7
11
8
12
def save_checkpoint (state , filename = 'my_checkpoint.pth.tar' ):
9
13
print ('=> Saving checkpoint' )
@@ -95,3 +99,75 @@ def get_loaders(
95
99
)
96
100
97
101
return train_loader , val_loader
102
+
103
+
104
+ import logging
105
+
106
+ import torch
107
+ from torch import nn
108
+
109
+ log = logging .getLogger (__name__ )
110
+
111
+
112
+ class DiceLoss (nn .Module ):
113
+ def __init__ (self , weight = None , size_average = True ):
114
+ super (DiceLoss , self ).__init__ ()
115
+
116
+ def forward (self , inputs : torch .Tensor , targets : torch .Tensor , smooth = 1e-4 ):
117
+ # comment out if your model contains a sigmoid or equivalent activation layer
118
+ inputs = torch .sigmoid (inputs )
119
+
120
+ # flatten label and prediction tensors
121
+ inputs = inputs .view (- 1 )
122
+ targets = targets .view (- 1 )
123
+
124
+ intersection = (inputs * targets ).sum ()
125
+ dice = (2.0 * intersection + smooth ) / (inputs .sum () + targets .sum () + smooth )
126
+
127
+ return 1 - dice
128
+
129
+
130
+ class DiceLoss2D (nn .Module ):
131
+ """Originally implemented by Cong Gao."""
132
+
133
+ def __init__ (self , skip_bg = False ):
134
+ super (DiceLoss2D , self ).__init__ ()
135
+
136
+ self .skip_bg = skip_bg
137
+
138
+ def forward (self , inputs , target ):
139
+ # Add this to numerator and denominator to avoid divide by zero when nothing is segmented
140
+ # and ground truth is also empty (denominator term).
141
+ # Also allow a Dice of 1 (-1) for this case (both terms).
142
+ eps = 1.0e-4
143
+
144
+ if self .skip_bg :
145
+ # numerator of Dice, for each class except class 0 (background)
146
+ numerators = 2 * torch .sum (target [:, 1 :] * inputs [:, 1 :], dim = (2 , 3 )) + eps
147
+
148
+ # denominator of Dice, for each class except class 0 (background)
149
+ denominators = (
150
+ torch .sum (target [:, 1 :] * target [:, 1 :, :, :], dim = (2 , 3 ))
151
+ + torch .sum (inputs [:, 1 :] * inputs [:, 1 :], dim = (2 , 3 ))
152
+ + eps
153
+ )
154
+
155
+ # minus one to exclude the background class
156
+ num_classes = inputs .shape [1 ] - 1
157
+ else :
158
+ # numerator of Dice, for each class
159
+ numerators = 2 * torch .sum (target * inputs , dim = (2 , 3 )) + eps
160
+
161
+ # denominator of Dice, for each class
162
+ denominators = torch .sum (target * target , dim = (2 , 3 )) + torch .sum (inputs * inputs , dim = (2 , 3 )) + eps
163
+
164
+ num_classes = inputs .shape [1 ]
165
+
166
+ # Dice coefficients for each image in the batch, for each class
167
+ dices = 1 - (numerators / denominators )
168
+
169
+ # compute average Dice score for each image in the batch
170
+ avg_dices = torch .sum (dices , dim = 1 ) / num_classes
171
+
172
+ # compute average over the batch
173
+ return torch .mean (avg_dices )
0 commit comments