Skip to content

Commit 48faf05

Browse files
achoumcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 375732248
1 parent aaa6ba2 commit 48faf05

File tree

4 files changed

+34
-15
lines changed

4 files changed

+34
-15
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## 0.1.5 - ????
4+
5+
### Bug fix
6+
7+
- Fix failure when input feature contains commas.
8+
9+
310
## 0.1.4 - 2021-05-21
411

512
### Features

tensorflow_decision_forests/keras/core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ class AdvancedArguments(NamedTuple):
234234
yggdrasil_training_config: Yggdrasil Decision Forests training
235235
configuration. Expose a few extra hyper-parameters.
236236
yggdrasil_deployment_config: Configuration of the computing resources used
237-
to train the model e.g. number of threads. Does not impact the model
238-
quality.
237+
to train the model e.g. number of threads. Does not impact the model
238+
quality.
239239
"""
240240

241241
infer_prediction_signature: Optional[bool] = True
@@ -368,7 +368,7 @@ def __init__(self,
368368

369369
if self._temp_directory is None:
370370
self._temp_directory = tempfile.mkdtemp()
371-
logging.info("Using %s are temporary training directory",
371+
logging.info("Using %s as temporary training directory",
372372
self._temp_directory)
373373

374374
if (self._task == Task.RANKING) != (ranking_group is not None):
@@ -745,7 +745,7 @@ def fit(self,
745745

746746
if "epochs" in kwargs:
747747
if kwargs["epochs"] != 1:
748-
raise ValueError("all decision forests algorithms train with only 1 "+
748+
raise ValueError("all decision forests algorithms train with only 1 " +
749749
"epoch, epochs={} given".format(kwargs["epochs"]))
750750
del kwargs["epochs"] # Not needed since we force it to 1 below.
751751

@@ -774,11 +774,10 @@ def evaluate(self, *args, **kwargs):
774774
775775
Args:
776776
*args: Passed to `keras.Model.evaluate`.
777-
**kwargs: Passed to `keras.Model.evaluate`.
778-
779-
Scalar test loss (if the model has a single output and no metrics) or list
780-
of scalars (if the model has multiple outputs and/or metrics). See details
781-
in `keras.Model.evaluate`.
777+
**kwargs: Passed to `keras.Model.evaluate`. Scalar test loss (if the
778+
model has a single output and no metrics) or list of scalars (if the
779+
model has multiple outputs and/or metrics). See details in
780+
`keras.Model.evaluate`.
782781
"""
783782
if self._train_on_evaluate:
784783
if not self._is_trained.numpy():

tensorflow_decision_forests/keras/keras_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def build_model(signature: Signature, dataset: Dataset, **args) -> models.Model:
382382
return model
383383

384384

385-
class TFDFInKerasTest(parameterized.TestCase, tf.test.TestCase):
385+
class TFDFTest(parameterized.TestCase, tf.test.TestCase):
386386

387387
def _check_adult_model(self,
388388
model,
@@ -984,10 +984,10 @@ def on_epoch_end(self, epoch, logs=None):
984984
test_evaluation = model.evaluate(test_dataset)
985985
logging.info("Test evaluation: %s", test_evaluation)
986986
val_evaluation = [history.history[key][0] for key in val_keys]
987-
logging.info("Validation evaluation in training "
988-
"(validation_data=test_dataset): %s", val_evaluation)
989-
logging.info("Callback evaluation (test_dataset): %s",
990-
callback.evaluation)
987+
logging.info(
988+
"Validation evaluation in training "
989+
"(validation_data=test_dataset): %s", val_evaluation)
990+
logging.info("Callback evaluation (test_dataset): %s", callback.evaluation)
991991

992992
# The training evaluation is capped by the ratio of missing value (5%).
993993
if compare is not None:
@@ -1193,6 +1193,11 @@ def processor(x):
11931193
def test_get_all_models(self):
11941194
print(keras.get_all_models())
11951195

1196+
def test_feature_with_comma(self):
1197+
model = keras.GradientBoostedTreesModel()
1198+
dataset = pd.DataFrame({"a,b": [0, 1, 2], "label": [0, 1, 2]})
1199+
model.fit(keras.pd_dataframe_to_tf_dataset(dataset, label="label"))
1200+
11961201

11971202
if __name__ == "__main__":
11981203
tf.test.main()

tensorflow_decision_forests/tensorflow/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,15 @@ def train(input_ids: List[str],
508508
def _input_key_to_id(model_id: str, key: str) -> str:
509509
"""Gets the name of the feature accumulator resource."""
510510

511-
return model_id + "_" + key
511+
# Escape the commas that are used to separate the column resource id.
512+
# Those IDs have not impact to the final model, but they should be unique and
513+
# not contain commas.
514+
#
515+
# Turn the character '|' into an escape symbol.
516+
input_id = model_id + "_" + key.replace("|", "||").replace(",", "|c")
517+
if "," in input_id:
518+
raise ValueError(f"Internal error: Found comma in input_id {input_id}")
519+
return input_id
512520

513521

514522
def combine_tensors_and_semantics(

0 commit comments

Comments
 (0)