File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -105,8 +105,8 @@ def triplet_semihard_loss(
105
105
Returns:
106
106
triplet_loss: float scalar with dtype of `y_pred`.
107
107
"""
108
-
109
- labels , embeddings = y_true , y_pred
108
+ labels = tf . convert_to_tensor ( y_true , name = "labels" )
109
+ embeddings = tf . convert_to_tensor ( y_pred , name = "embeddings" )
110
110
111
111
convert_to_float32 = (
112
112
embeddings .dtype == tf .dtypes .float16 or embeddings .dtype == tf .dtypes .bfloat16
@@ -242,7 +242,8 @@ def triplet_hard_loss(
242
242
Returns:
243
243
triplet_loss: float scalar with dtype of `y_pred`.
244
244
"""
245
- labels , embeddings = y_true , y_pred
245
+ labels = tf .convert_to_tensor (y_true , name = "labels" )
246
+ embeddings = tf .convert_to_tensor (y_pred , name = "embeddings" )
246
247
247
248
convert_to_float32 = (
248
249
embeddings .dtype == tf .dtypes .float16 or embeddings .dtype == tf .dtypes .bfloat16
You can’t perform that action at this time.
0 commit comments