Skip to content

Commit 75ce434

Browse files
authored
Convert inputs to tensor (#2108)
1 parent 3279000 commit 75ce434

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorflow_addons/losses/triplet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def triplet_semihard_loss(
105105
Returns:
106106
triplet_loss: float scalar with dtype of `y_pred`.
107107
"""
108-
109-
labels, embeddings = y_true, y_pred
108+
labels = tf.convert_to_tensor(y_true, name="labels")
109+
embeddings = tf.convert_to_tensor(y_pred, name="embeddings")
110110

111111
convert_to_float32 = (
112112
embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16
@@ -242,7 +242,8 @@ def triplet_hard_loss(
242242
Returns:
243243
triplet_loss: float scalar with dtype of `y_pred`.
244244
"""
245-
labels, embeddings = y_true, y_pred
245+
labels = tf.convert_to_tensor(y_true, name="labels")
246+
embeddings = tf.convert_to_tensor(y_pred, name="embeddings")
246247

247248
convert_to_float32 = (
248249
embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16

0 commit comments

Comments
 (0)