Skip to content

Commit 27c3e8e

Browse files
achoumcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 375931435
1 parent 48faf05 commit 27c3e8e

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
## 0.1.5 - ????
44

5+
### Features
6+
7+
```
8+
- Raise an error of the number of classes is greater than 100 (can be disabled).
9+
- Raise an error if the model's task does not match the `pd_dataframe_to_tf_dataset`'s task.
10+
```
11+
512
### Bug fix
613

714
- Fix failure when input feature contains commas.

tensorflow_decision_forests/keras/core.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,16 @@ def fit(self,
739739
All other fields are filled as usual for `Keras.Mode.fit()`.
740740
"""
741741

742+
# If the dataset was created with "pd_dataframe_to_tf_dataset", ensure that
743+
# the task is correctly set.
744+
if hasattr(x, "_tfdf_task"):
745+
dataset_task = getattr(x, "_tfdf_task")
746+
if dataset_task != self._task:
747+
raise ValueError(
748+
f"The model's `task` attribute ({Task.Name(self._task)}) does "
749+
"not match the `task` attribute passed to "
750+
f"`pd_dataframe_to_tf_dataset` ({Task.Name(dataset_task)}).")
751+
742752
# Call "compile" if the user forgot to do so.
743753
if not self._is_compiled:
744754
self.compile()
@@ -1005,7 +1015,8 @@ def _batch_size(inputs: Union[tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor:
10051015
def pd_dataframe_to_tf_dataset(
10061016
dataframe,
10071017
label: Optional[str] = None,
1008-
task: Optional[TaskType] = Task.CLASSIFICATION) -> tf.data.Dataset:
1018+
task: Optional[TaskType] = Task.CLASSIFICATION,
1019+
max_num_classes: Optional[int] = 100) -> tf.data.Dataset:
10091020
"""Converts a Panda Dataframe into a TF Dataset.
10101021
10111022
Details:
@@ -1025,6 +1036,10 @@ def pd_dataframe_to_tf_dataset(
10251036
dataframe: Pandas dataframe containing a training or evaluation dataset.
10261037
label: Name of the label column.
10271038
task: Target task of the dataset.
1039+
max_num_classes: Maximum number of classes for a classification task. A high
1040+
number of unique value / classes might indicate that the problem is a
1041+
regression or a ranking instead of a classification. Set to None to
1042+
disable checking the number of classes.
10281043
10291044
Returns:
10301045
A TensorFlow Dataset.
@@ -1035,6 +1050,14 @@ def pd_dataframe_to_tf_dataset(
10351050
if task == Task.CLASSIFICATION and label is not None:
10361051
classification_classes = dataframe[label].unique().tolist()
10371052
classification_classes.sort()
1053+
if len(classification_classes) > max_num_classes:
1054+
raise ValueError(
1055+
f"The number of unique classes ({len(classification_classes)}) "
1056+
f"exceeds max_num_classes ({max_num_classes}). A high number of "
1057+
"unique value / classes might indicate that the problem is a "
1058+
"regression or a ranking instead of a classification. If this "
1059+
"problem is effectively a classification problem, increase "
1060+
"`max_num_classes`.")
10381061
dataframe[label] = dataframe[label].map(classification_classes.index)
10391062

10401063
# Make sure tha missing values for string columns are not represented as
@@ -1050,7 +1073,10 @@ def pd_dataframe_to_tf_dataset(
10501073
tf_dataset = tf.data.Dataset.from_tensor_slices(dict(dataframe))
10511074

10521075
# The batch size does not impact the training of TF-DF.
1053-
return tf_dataset.batch(64)
1076+
tf_dataset = tf_dataset.batch(64)
1077+
1078+
setattr(tf_dataset, "_tfdf_task", task)
1079+
return tf_dataset
10541080

10551081

10561082
def yggdrasil_model_to_keras_model(src_path: str, dst_path: str):

tensorflow_decision_forests/keras/keras_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,20 @@ def test_feature_with_comma(self):
11981198
dataset = pd.DataFrame({"a,b": [0, 1, 2], "label": [0, 1, 2]})
11991199
model.fit(keras.pd_dataframe_to_tf_dataset(dataset, label="label"))
12001200

1201+
def test_error_too_much_classes(self):
1202+
dataframe = pd.DataFrame({"x": list(range(10)), "label": list(range(10))})
1203+
with self.assertRaises(ValueError):
1204+
keras.pd_dataframe_to_tf_dataset(
1205+
dataframe, label="label", max_num_classes=5)
1206+
1207+
def test_error_non_matching_task(self):
1208+
dataframe = pd.DataFrame({"x": list(range(10)), "label": list(range(10))})
1209+
dataset = keras.pd_dataframe_to_tf_dataset(
1210+
dataframe, label="label", task=keras.Task.CLASSIFICATION)
1211+
model = keras.GradientBoostedTreesModel(task=keras.Task.REGRESSION)
1212+
with self.assertRaises(ValueError):
1213+
model.fit(dataset)
1214+
12011215

12021216
if __name__ == "__main__":
12031217
tf.test.main()

0 commit comments

Comments
 (0)