File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function , division
3
+
4
+ import torch
5
+ import torch .nn as nn
6
+
7
+ class CombinedLoss (nn .Module ):
8
+ def __init__ (self , params , loss_dict ):
9
+ super (CombinedLoss , self ).__init__ ()
10
+ loss_names = params ['loss_type' ]
11
+ self .loss_weight = params ['loss_weight' ]
12
+ assert (len (loss_names ) == len (self .loss_weight ))
13
+ self .loss_list = []
14
+ for loss_name in loss_names :
15
+ if (loss_name in loss_dict ):
16
+ one_loss = loss_dict [loss_name ](params )
17
+ self .loss_list .append (one_loss )
18
+ else :
19
+ raise ValueError ("{0:} is not defined, or has not been added to the \
20
+ loss dictionary" .format (loss_name ))
21
+
22
+ def forward (self , loss_input_dict ):
23
+ loss_value = 0.0
24
+ for i in range (len (self .loss_list )):
25
+ loss_value = self .loss_weight [i ] + self .loss_list [i ](loss_input_dict )
26
+ return loss_value
You can’t perform that action at this time.
0 commit comments