Skip to content

Commit 8fafffa

Browse files
committed
revert workaround as it is already fixed in latest keras.
1 parent 98b6714 commit 8fafffa

File tree

2 files changed

+2
-29
lines changed

2 files changed

+2
-29
lines changed

metric.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import tensorflow as tf
21
from keras import backend as K
32

43
smooth = 1
@@ -38,32 +37,6 @@ def dice_loss_strict(y_true, y_pred):
3837
return -dice_strict(y_true, y_pred)
3938

4039

41-
def _to_tensor(x, dtype):
42-
x = tf.convert_to_tensor(x)
43-
if x.dtype != dtype:
44-
x = tf.cast(x, dtype)
45-
return x
46-
47-
48-
def _tf_bce(output, target, from_logits=False):
49-
"""Workaround for keras bug with latest tensorflow"""
50-
51-
# Note: tf.nn.softmax_cross_entropy_with_logits
52-
# expects logits, Keras expects probabilities.
53-
if not from_logits:
54-
# transform back to logits
55-
epsilon = _to_tensor(K.epsilon(), output.dtype.base_dtype)
56-
output = tf.clip_by_value(output, epsilon, 1 - epsilon)
57-
output = tf.log(output / (1 - output))
58-
return tf.nn.sigmoid_cross_entropy_with_logits(labels=output, logits=target)
59-
60-
61-
def bce(y_true, y_pred):
62-
# Workaround for shape bug.
63-
y_true.set_shape(y_pred.get_shape())
64-
return K.mean(_tf_bce(y_pred, y_true), axis=-1)
65-
66-
6740
# Sanity check loss functions..
6841
if __name__ == "'__main__":
6942
import numpy as np

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from keras.models import Model
1414

1515
from keras.optimizers import Adam
16-
from metric import dice_loss, dice, bce
16+
from metric import dice_loss, dice
1717
from data import DataManager
1818

1919

@@ -86,7 +86,7 @@ def build_model(optimizer=None):
8686

8787
model = Model(input=inputs, output=[conv10, aux])
8888
model.compile(optimizer=optimizer,
89-
loss={'main_output': dice_loss, 'aux_output': bce},
89+
loss={'main_output': dice_loss, 'aux_output': 'binary_crossentropy'},
9090
metrics={'main_output': dice, 'aux_output': 'acc'},
9191
loss_weights={'main_output': 1, 'aux_output': 0.5})
9292

0 commit comments

Comments
 (0)