Skip to content

Commit 5cbbca7

Browse files
committed
combination of loss
1 parent 83b6cf6 commit 5cbbca7

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

pymic/loss/seg/combined.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
class CombinedLoss(nn.Module):
8+
def __init__(self, params, loss_dict):
9+
super(CombinedLoss, self).__init__()
10+
loss_names = params['loss_type']
11+
self.loss_weight = params['loss_weight']
12+
assert (len(loss_names) == len(self.loss_weight))
13+
self.loss_list = []
14+
for loss_name in loss_names:
15+
if(loss_name in loss_dict):
16+
one_loss = loss_dict[loss_name](params)
17+
self.loss_list.append(one_loss)
18+
else:
19+
raise ValueError("{0:} is not defined, or has not been added to the \
20+
loss dictionary".format(loss_name))
21+
22+
def forward(self, loss_input_dict):
23+
loss_value = 0.0
24+
for i in range(len(self.loss_list)):
25+
loss_value = self.loss_weight[i] + self.loss_list[i](loss_input_dict)
26+
return loss_value

0 commit comments

Comments
 (0)