@@ -43,6 +43,7 @@ def main(_argv):
43
43
bbox_tensors = []
44
44
for i , fm in enumerate (feature_maps ):
45
45
bbox_tensor = decode_train (fm , NUM_CLASS , STRIDES , ANCHORS , i )
46
+ bbox_tensors .append (fm )
46
47
bbox_tensors .append (bbox_tensor )
47
48
model = tf .keras .Model (input_layer , bbox_tensors )
48
49
else :
@@ -51,13 +52,15 @@ def main(_argv):
51
52
bbox_tensors = []
52
53
for i , fm in enumerate (feature_maps ):
53
54
bbox_tensor = decode_train (fm , NUM_CLASS , STRIDES , ANCHORS , i )
55
+ bbox_tensors .append (fm )
54
56
bbox_tensors .append (bbox_tensor )
55
57
model = tf .keras .Model (input_layer , bbox_tensors )
56
58
elif FLAGS .model == 'yolov4' :
57
59
feature_maps = YOLOv4 (input_layer , NUM_CLASS )
58
60
bbox_tensors = []
59
61
for i , fm in enumerate (feature_maps ):
60
62
bbox_tensor = decode_train (fm , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
63
+ bbox_tensors .append (fm )
61
64
bbox_tensors .append (bbox_tensor )
62
65
model = tf .keras .Model (input_layer , bbox_tensors )
63
66
@@ -89,8 +92,7 @@ def train_step(image_data, target):
89
92
# optimizing process
90
93
for i in range (3 ):
91
94
conv , pred = pred_result [i * 2 ], pred_result [i * 2 + 1 ]
92
- # loss_items = compute_loss(pred, conv, target[i][0], target[i][1], STRIDES=STRIDES, NUM_CLASS=NUM_CLASS, IOU_LOSS_THRESH=IOU_LOSS_THRESH, i=i)
93
- loss_items = compute_loss (pred , conv , * target [i ], i )
95
+ loss_items = compute_loss (pred , conv , target [i ][0 ], target [i ][1 ], STRIDES = STRIDES , NUM_CLASS = NUM_CLASS , IOU_LOSS_THRESH = IOU_LOSS_THRESH , i = i )
94
96
giou_loss += loss_items [0 ]
95
97
conf_loss += loss_items [1 ]
96
98
prob_loss += loss_items [2 ]
@@ -129,7 +131,7 @@ def test_step(image_data, target):
129
131
# optimizing process
130
132
for i in range (3 ):
131
133
conv , pred = pred_result [i * 2 ], pred_result [i * 2 + 1 ]
132
- loss_items = compute_loss (pred , conv , * target [i ], STRIDES = STRIDES , NUM_CLASS = NUM_CLASS , IOU_LOSS_THRESH = IOU_LOSS_THRESH , i = i )
134
+ loss_items = compute_loss (pred , conv , target [i ][ 0 ], target [ i ][ 1 ], STRIDES = STRIDES , NUM_CLASS = NUM_CLASS , IOU_LOSS_THRESH = IOU_LOSS_THRESH , i = i )
133
135
giou_loss += loss_items [0 ]
134
136
conf_loss += loss_items [1 ]
135
137
prob_loss += loss_items [2 ]
0 commit comments