From 27c3e8ee6e4b23a24ca2d6f1c9369215c4332e1c Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Wed, 26 May 2021 06:47:19 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 375931435 --- CHANGELOG.md | 7 +++++ tensorflow_decision_forests/keras/core.py | 30 +++++++++++++++++-- .../keras/keras_test.py | 14 +++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f9ca93d..7bda3679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ ## 0.1.5 - ???? +### Features + +``` +- Raise an error of the number of classes is greater than 100 (can be disabled). +- Raise an error if the model's task does not match the `pd_dataframe_to_tf_dataset`'s task. +``` + ### Bug fix - Fix failure when input feature contains commas. diff --git a/tensorflow_decision_forests/keras/core.py b/tensorflow_decision_forests/keras/core.py index 161a1a27..bf2a9962 100644 --- a/tensorflow_decision_forests/keras/core.py +++ b/tensorflow_decision_forests/keras/core.py @@ -739,6 +739,16 @@ def fit(self, All other fields are filled as usual for `Keras.Mode.fit()`. """ + # If the dataset was created with "pd_dataframe_to_tf_dataset", ensure that + # the task is correctly set. + if hasattr(x, "_tfdf_task"): + dataset_task = getattr(x, "_tfdf_task") + if dataset_task != self._task: + raise ValueError( + f"The model's `task` attribute ({Task.Name(self._task)}) does " + "not match the `task` attribute passed to " + f"`pd_dataframe_to_tf_dataset` ({Task.Name(dataset_task)}).") + # Call "compile" if the user forgot to do so. if not self._is_compiled: self.compile() @@ -1005,7 +1015,8 @@ def _batch_size(inputs: Union[tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor: def pd_dataframe_to_tf_dataset( dataframe, label: Optional[str] = None, - task: Optional[TaskType] = Task.CLASSIFICATION) -> tf.data.Dataset: + task: Optional[TaskType] = Task.CLASSIFICATION, + max_num_classes: Optional[int] = 100) -> tf.data.Dataset: """Converts a Panda Dataframe into a TF Dataset. Details: @@ -1025,6 +1036,10 @@ def pd_dataframe_to_tf_dataset( dataframe: Pandas dataframe containing a training or evaluation dataset. label: Name of the label column. task: Target task of the dataset. + max_num_classes: Maximum number of classes for a classification task. A high + number of unique value / classes might indicate that the problem is a + regression or a ranking instead of a classification. Set to None to + disable checking the number of classes. Returns: A TensorFlow Dataset. @@ -1035,6 +1050,14 @@ def pd_dataframe_to_tf_dataset( if task == Task.CLASSIFICATION and label is not None: classification_classes = dataframe[label].unique().tolist() classification_classes.sort() + if len(classification_classes) > max_num_classes: + raise ValueError( + f"The number of unique classes ({len(classification_classes)}) " + f"exceeds max_num_classes ({max_num_classes}). A high number of " + "unique value / classes might indicate that the problem is a " + "regression or a ranking instead of a classification. If this " + "problem is effectively a classification problem, increase " + "`max_num_classes`.") dataframe[label] = dataframe[label].map(classification_classes.index) # Make sure tha missing values for string columns are not represented as @@ -1050,7 +1073,10 @@ def pd_dataframe_to_tf_dataset( tf_dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe)) # The batch size does not impact the training of TF-DF. - return tf_dataset.batch(64) + tf_dataset = tf_dataset.batch(64) + + setattr(tf_dataset, "_tfdf_task", task) + return tf_dataset def yggdrasil_model_to_keras_model(src_path: str, dst_path: str): diff --git a/tensorflow_decision_forests/keras/keras_test.py b/tensorflow_decision_forests/keras/keras_test.py index ca51c227..03c6956e 100644 --- a/tensorflow_decision_forests/keras/keras_test.py +++ b/tensorflow_decision_forests/keras/keras_test.py @@ -1198,6 +1198,20 @@ def test_feature_with_comma(self): dataset = pd.DataFrame({"a,b": [0, 1, 2], "label": [0, 1, 2]}) model.fit(keras.pd_dataframe_to_tf_dataset(dataset, label="label")) + def test_error_too_much_classes(self): + dataframe = pd.DataFrame({"x": list(range(10)), "label": list(range(10))}) + with self.assertRaises(ValueError): + keras.pd_dataframe_to_tf_dataset( + dataframe, label="label", max_num_classes=5) + + def test_error_non_matching_task(self): + dataframe = pd.DataFrame({"x": list(range(10)), "label": list(range(10))}) + dataset = keras.pd_dataframe_to_tf_dataset( + dataframe, label="label", task=keras.Task.CLASSIFICATION) + model = keras.GradientBoostedTreesModel(task=keras.Task.REGRESSION) + with self.assertRaises(ValueError): + model.fit(dataset) + if __name__ == "__main__": tf.test.main()