Skip to content

Commit aaa6ba2

Browse files
arvndscopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 375688815
1 parent 6f33d91 commit aaa6ba2

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

tensorflow_decision_forests/keras/core.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,6 @@ def train_step(self, data):
588588
logging.info("Collect training examples.\nFeatures: %s\nLabel: %s",
589589
train_x, train_y)
590590

591-
if len(train_y.shape) != 1:
592-
raise ValueError(
593-
"Expecting label of rank 1. Got {} instead.".format(train_y))
594-
595591
if self._preprocessing is not None:
596592
train_x = self._preprocessing(train_x)
597593
if self._verbose:
@@ -621,6 +617,16 @@ def train_step(self, data):
621617
f"The training label tensor is expected to be a tensor. Got {train_y}"
622618
" instead.")
623619

620+
if len(train_y.shape) != 1:
621+
if self._verbose:
622+
logging.info("Squeezing labels to [batch_size] from [batch_size, 1].")
623+
train_y = tf.squeeze(train_y, axis=1)
624+
625+
if len(train_y.shape) != 1:
626+
raise ValueError(
627+
"Labels can either be passed in as [batch_size, 1] or [batch_size]. "
628+
"Invalid shape %s." % train_y.shape)
629+
624630
# List the input features and their semantics.
625631
assert self._semantics is None, "The model is already trained"
626632
self._semantics = tf_core.infer_semantic(
@@ -750,6 +756,7 @@ def fit(self,
750756
# end of the epoch. This may fail in case any of the `on_train_batch_*`
751757
# callbacks calls `evaluate()` before the end of the 1st epoch.
752758
self._train_on_evaluate = True
759+
753760
try:
754761
history = super(CoreModel, self).fit(
755762
x=x, y=y, epochs=1, callbacks=callbacks, **kwargs)

tensorflow_decision_forests/keras/keras_test.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
import shutil
2424
import subprocess
25-
from typing import List, Tuple, Any, Optional
25+
from typing import List, Tuple, Any, Optional, Type
2626

2727
from absl import flags
2828
from absl import logging
@@ -830,7 +830,9 @@ def _synthetic_train_and_test(
830830
test_numerical: Optional[bool] = False,
831831
test_multidimensional_numerical: Optional[bool] = False,
832832
test_categorical: Optional[bool] = False,
833-
test_categorical_set: Optional[bool] = False):
833+
test_categorical_set: Optional[bool] = False,
834+
label_shape: Optional[int] = None,
835+
fit_raises: Optional[Type[Exception]] = None):
834836
"""Trains a model on a synthetic dataset."""
835837

836838
train_path = os.path.join(self.get_temp_dir(), "train.rio.gz")
@@ -868,12 +870,13 @@ def _synthetic_train_and_test(
868870
popen.wait()
869871

870872
feature_spec = {}
873+
label_shape = [label_shape] if label_shape else []
871874
if task == keras.Task.CLASSIFICATION:
872-
feature_spec["LABEL"] = tf.io.FixedLenFeature([], tf.int64)
875+
feature_spec["LABEL"] = tf.io.FixedLenFeature(label_shape, tf.int64)
873876
elif task == keras.Task.REGRESSION:
874-
feature_spec["LABEL"] = tf.io.FixedLenFeature([], tf.float32)
877+
feature_spec["LABEL"] = tf.io.FixedLenFeature(label_shape, tf.float32)
875878
elif task == keras.Task.RANKING:
876-
feature_spec["LABEL"] = tf.io.FixedLenFeature([], tf.float32)
879+
feature_spec["LABEL"] = tf.io.FixedLenFeature(label_shape, tf.float32)
877880
feature_spec["GROUP"] = tf.io.FixedLenFeature([], tf.string)
878881
else:
879882
assert False
@@ -964,8 +967,16 @@ def on_epoch_end(self, epoch, logs=None):
964967
self.evaluation = model.evaluate(test_dataset)
965968

966969
callback = _TestEvalCallback()
967-
history = model.fit(train_dataset, validation_data=test_dataset,
968-
callbacks=[callback])
970+
history = None
971+
if fit_raises is not None:
972+
with self.assertRaises(fit_raises):
973+
model.fit(
974+
train_dataset, validation_data=test_dataset, callbacks=[callback])
975+
else:
976+
history = model.fit(
977+
train_dataset, validation_data=test_dataset, callbacks=[callback])
978+
if history is None:
979+
return
969980
model.summary()
970981

971982
train_evaluation = model.evaluate(train_dataset)
@@ -991,6 +1002,23 @@ def test_synthetic_classification_numerical(self):
9911002
self._synthetic_train_and_test(
9921003
keras.Task.CLASSIFICATION, 0.8, 0.72, test_numerical=True)
9931004

1005+
def test_synthetic_classification_squeeze_label(self):
1006+
self._synthetic_train_and_test(
1007+
keras.Task.CLASSIFICATION,
1008+
0.8,
1009+
0.72,
1010+
test_numerical=True,
1011+
label_shape=1)
1012+
1013+
def test_synthetic_classification_squeeze_label_invalid_shape(self):
1014+
self._synthetic_train_and_test(
1015+
keras.Task.CLASSIFICATION,
1016+
0.8,
1017+
0.72,
1018+
test_numerical=True,
1019+
label_shape=2,
1020+
fit_raises=ValueError)
1021+
9941022
def test_synthetic_classification_categorical(self):
9951023
self._synthetic_train_and_test(
9961024
keras.Task.CLASSIFICATION, 0.95, 0.70, test_categorical=True)

0 commit comments

Comments
 (0)