Skip to content

Commit 3e2d68d

Browse files
committed
implemented 2d dice score and pixel normalization. Training loss is actually pretty good now
1 parent ead056e commit 3e2d68d

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tutorial/train.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.optim as optim
88
from model import UNET
99
from utils import (load_checkpoint, save_checkpoint, get_loaders, check_accuracy, save_predictions_as_imgs)
10+
from utils import DiceLoss2D
1011

1112
# Hyperparameters
1213
LEARNING_RATE = 1e-4
@@ -47,7 +48,8 @@ def train_fn(loader, model, optimizer, loss_fn, scaler):
4748
def main():
4849

4950
model = UNET(in_channels=3, out_channels=1).to(device=DEVICE)
50-
loss_fn = nn.BCEWithLogitsLoss() # LOSS FUNCTION DEFINED HERE. Perhaps change it to dice score
51+
#loss_fn = nn.BCEWithLogitsLoss() # LOSS FUNCTION DEFINED HERE. Perhaps change it to dice score
52+
loss_fn = DiceLoss2D()
5153
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
5254

5355
train_loader, val_loader = get_loaders(train_dir=TRAIN_IMG_DIR,

tutorial/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ class DiceLoss2D(nn.Module):
114114

115115
def __init__(self, skip_bg=False):
116116
super(DiceLoss2D, self).__init__()
117-
118117
self.skip_bg = skip_bg
119118

120119
def forward(self, inputs, target):

0 commit comments

Comments
 (0)