Skip to content

Commit 3681f87

Browse files
committed
Internal change
PiperOrigin-RevId: 374448559
1 parent 6094ec0 commit 3681f87

File tree

10 files changed

+300
-44
lines changed

10 files changed

+300
-44
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## 0.1.2 - 2021-05-18
4+
5+
### Features
6+
7+
- Inference engines: QuickScorer Extended and Pred
8+
39
## 0.1.1 - 2021-05-17
410

511
### Features

configure/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from setuptools.command.install import install
2121
from setuptools.dist import Distribution
2222

23-
_VERSION = "0.1.1"
23+
_VERSION = "0.1.2"
2424

2525
with open("README.md", "r", encoding="utf-8") as fh:
2626
long_description = fh.read()

documentation/known_issues.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ TensorFlow and Keras is new, and some issues are expected -- we are trying to
66
fix them as quickly as possible.
77

88
See also the
9-
[known issues of Yggdrasil Decision Forests](https://github.com/google/yggdrasil-decision-forests/documentation/known_issues.md).
9+
[known issues of Yggdrasil Decision Forests](https://github.com/google/yggdrasil-decision-forests/documentation/known_issues.md)
10+
and the [migration guide](migration.md) for behavior that is different from
11+
other algorithms.
1012

1113
## Windows Pip package is not available
1214

@@ -33,3 +35,51 @@ an error complaining about tensor shape.
3335

3436
- *Solution #2:* Wrapps your preprocessing function into another function that
3537
[squeeze](https://www.tensorflow.org/api_docs/python/tf/squeeze) its inputs.
38+
39+
## No support for TF distribution strategies.
40+
41+
TF-DF does not yet support distribution strategies or datasets that do not fit
42+
in memory. This is because the classical decision forest training algorithms
43+
already implemented require the entire dataset to be available in memory.
44+
45+
**Workaround**
46+
47+
* Downsample your dataset. A rule of thumb is that TF-DF training
48+
uses 4 bytes per input dimension, so a dataset with 100 million examples and 10
49+
numerical/categorical features would be 4 GB in memory.
50+
51+
* Train a manual ensemble on slices of the dataset, i.e. train N models on N
52+
slices of data, and average the predictions.
53+
54+
## No support for training callbacks.
55+
56+
Training callbacks will not get the expected metrics passed to on_epoch_end
57+
since TF-DF algorithms are trained for only one epoch, and the validation
58+
data is evaluated before the model is trained. Evaluation callbacks are
59+
supported.
60+
61+
**Workaround**
62+
63+
By design TF-DF algorithms train for only one epoch, so callbacks that override
64+
on_epoch_end can be instantiated and called manually with the metrics from
65+
model.evaluate(). Specifically:
66+
67+
```diff {.bad}
68+
- cb = tf.keras.callbacks.Callback()
69+
- model.fit(train_ds, validation_data=val_ds, callbacks=[cb])
70+
```
71+
72+
```diff {.good}
73+
+ model.fit(train_ds)
74+
+ cb.on_epoch_end(epoch=1, logs=model.evaluate(val_ds, ...))
75+
```
76+
77+
## No support for GPU / TPU.
78+
79+
TF-DF does not support GPU or TPU training. Compiling with AVX instructions,
80+
however, may speed up serving.
81+
82+
## No support for [model_to_estimator](https://www.tensorflow.org/api_docs/python/tf/keras/estimator/model_to_estimator)
83+
84+
TF-DF does not implement the APIs required to convert a trained/untrained model
85+
to the estimator format.

documentation/tutorials/advanced_colab.ipynb

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,31 @@
116116
"try:\n",
117117
" from wurlitzer import sys_pipes\n",
118118
"except:\n",
119-
" from colabtools.googlelog import CaptureLog as sys_pipes"
119+
" from colabtools.googlelog import CaptureLog as sys_pipes\n",
120+
"\n",
121+
"from IPython.core.magic import register_line_magic\n",
122+
"from IPython.display import Javascript"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {
129+
"id": "XAWSjWrQmVE0"
130+
},
131+
"outputs": [],
132+
"source": [
133+
"#@title View results with a max cell height.\n",
134+
"\n",
135+
"\n",
136+
"# Some of the model training logs can cover the full\n",
137+
"# screen if not compressed to a smaller viewport.\n",
138+
"# This magic allows setting a max height for a cell.\n",
139+
"@register_line_magic\n",
140+
"def set_cell_height(size):\n",
141+
" display(\n",
142+
" Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n",
143+
" str(size) + \"})\"))"
120144
]
121145
},
122146
{
@@ -179,6 +203,8 @@
179203
},
180204
"outputs": [],
181205
"source": [
206+
"%set_cell_height 300\n",
207+
"\n",
182208
"model.summary()"
183209
]
184210
},
@@ -191,6 +217,30 @@
191217
"Remark the multiple variable importances with name `MEAN_DECREASE_IN_*`."
192218
]
193219
},
220+
{
221+
"cell_type": "markdown",
222+
"metadata": {
223+
"id": "xTwmx8A0c4TU"
224+
},
225+
"source": [
226+
"## Plotting the model\n",
227+
"\n",
228+
"Next, we plot our model.\n",
229+
"\n",
230+
"A Random Forest is a large model (this model has 300 trees and ~5k nodes; see the summary above). Therefore, we will only plot the first tree, and limit the nodes to depth 3."
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": null,
236+
"metadata": {
237+
"id": "ZRTrXDz_dIAQ"
238+
},
239+
"outputs": [],
240+
"source": [
241+
"tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)"
242+
]
243+
},
194244
{
195245
"cell_type": "markdown",
196246
"metadata": {
@@ -590,12 +640,33 @@
590640
"inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)\n",
591641
"print(\"Input features:\", inspector.features())"
592642
]
643+
},
644+
{
645+
"cell_type": "markdown",
646+
"metadata": {
647+
"id": "muW1hgmotx8J"
648+
},
649+
"source": [
650+
"And of course, you can plot the model :)"
651+
]
652+
},
653+
{
654+
"cell_type": "code",
655+
"execution_count": null,
656+
"metadata": {
657+
"id": "bqahDVg3t1xM"
658+
},
659+
"outputs": [],
660+
"source": [
661+
"tfdf.model_plotter.plot_model_in_colab(manual_model)"
662+
]
593663
}
594664
],
595665
"metadata": {
596666
"colab": {
597667
"collapsed_sections": [],
598668
"name": "advanced_colab.ipynb",
669+
"provenance": [],
599670
"toc_visible": true
600671
},
601672
"kernelspec": {

documentation/tutorials/beginner_colab.ipynb

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,31 @@
161161
"try:\n",
162162
" from wurlitzer import sys_pipes\n",
163163
"except:\n",
164-
" from colabtools.googlelog import CaptureLog as sys_pipes"
164+
" from colabtools.googlelog import CaptureLog as sys_pipes\n",
165+
"\n",
166+
"from IPython.core.magic import register_line_magic\n",
167+
"from IPython.display import Javascript"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": null,
173+
"metadata": {
174+
"id": "2AhqJz3VmQM-"
175+
},
176+
"outputs": [],
177+
"source": [
178+
"#@title View results with a max cell height.\n",
179+
"\n",
180+
"\n",
181+
"# Some of the model training logs can cover the full\n",
182+
"# screen if not compressed to a smaller viewport.\n",
183+
"# This magic allows setting a max height for a cell.\n",
184+
"@register_line_magic\n",
185+
"def set_cell_height(size):\n",
186+
" display(\n",
187+
" Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n",
188+
" str(size) + \"})\"))"
165189
]
166190
},
167191
{
@@ -350,6 +374,8 @@
350374
},
351375
"outputs": [],
352376
"source": [
377+
"%set_cell_height 300\n",
378+
"\n",
353379
"# Specify the model.\n",
354380
"model_1 = tfdf.keras.RandomForestModel()\n",
355381
"\n",
@@ -463,6 +489,47 @@
463489
"model_1.save(\"/tmp/my_saved_model\")"
464490
]
465491
},
492+
{
493+
"cell_type": "markdown",
494+
"metadata": {
495+
"id": "6-8R02_SXpbq"
496+
},
497+
"source": [
498+
"## Plot the model\n",
499+
"\n",
500+
"Plotting a decision tree and following the first branches helps learning about decision forests. In some cases, plotting a model can even be used for debugging.\n",
501+
"\n",
502+
"Because of the difference in the way they are trained, some models are more interresting to plan than others. Because of the noise injected during training and the depth of the trees, plotting Random Forest is less informative than plotting a CART or the first tree of a Gradient Boosted Tree.\n",
503+
"\n",
504+
"Never the less, let's plot the first tree of our Random Forest model:"
505+
]
506+
},
507+
{
508+
"cell_type": "code",
509+
"execution_count": null,
510+
"metadata": {
511+
"id": "KUIxf8N6Yjl0"
512+
},
513+
"outputs": [],
514+
"source": [
515+
"tfdf.model_plotter.plot_model_in_colab(model_1, tree_idx=0, max_depth=3)"
516+
]
517+
},
518+
{
519+
"cell_type": "markdown",
520+
"metadata": {
521+
"id": "cPcL_hDnY7Zy"
522+
},
523+
"source": [
524+
"The root node on the left contains the first condition (`bill_depth_mm \u003e= 16.55`), number of examples (240) and label distribution (the red-blue-green bar).\n",
525+
"\n",
526+
"Examples that evaluates true to `bill_depth_mm \u003e= 16.55` are branched to the green path. The other ones are branched to the red path.\n",
527+
"\n",
528+
"The deeper the node, the more `pure` they become i.e. the label distribution is biased toward a subset of classes. \n",
529+
"\n",
530+
"**Note:** Over the mouse on top of the plot for details."
531+
]
532+
},
466533
{
467534
"cell_type": "markdown",
468535
"metadata": {
@@ -498,6 +565,7 @@
498565
},
499566
"outputs": [],
500567
"source": [
568+
"%set_cell_height 300\n",
501569
"model_1.summary()"
502570
]
503571
},
@@ -597,6 +665,7 @@
597665
},
598666
"outputs": [],
599667
"source": [
668+
"%set_cell_height 150\n",
600669
"model_1.make_inspector().training_logs()"
601670
]
602671
},
@@ -716,15 +785,22 @@
716785
"id": "xmzvuI78voD4"
717786
},
718787
"source": [
719-
"The description of the learning algorithms and their hyper-parameters are also available in the [API reference](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/RandomForestModel) and builtin help:\n",
720-
"\n",
721-
"```\n",
788+
"The description of the learning algorithms and their hyper-parameters are also available in the [API reference](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf) and builtin help:"
789+
]
790+
},
791+
{
792+
"cell_type": "code",
793+
"execution_count": null,
794+
"metadata": {
795+
"id": "2hONToBav4DE"
796+
},
797+
"outputs": [],
798+
"source": [
722799
"# help works anywhere.\n",
723800
"help(tfdf.keras.RandomForestModel)\n",
724801
"\n",
725802
"# ? only works in ipython or notebooks, it usually opens on a separate panel.\n",
726-
"tfdf.keras.RandomForestModel?\n",
727-
"```"
803+
"tfdf.keras.RandomForestModel?"
728804
]
729805
},
730806
{
@@ -814,6 +890,8 @@
814890
},
815891
"outputs": [],
816892
"source": [
893+
"%set_cell_height 300\n",
894+
"\n",
817895
"feature_1 = tfdf.keras.FeatureUsage(name=\"year\", semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)\n",
818896
"feature_2 = tfdf.keras.FeatureUsage(name=\"bill_length_mm\")\n",
819897
"feature_3 = tfdf.keras.FeatureUsage(name=\"sex\")\n",
@@ -971,6 +1049,8 @@
9711049
},
9721050
"outputs": [],
9731051
"source": [
1052+
"%set_cell_height 300\n",
1053+
"\n",
9741054
"body_mass_g = tf.keras.layers.Input(shape=(1,), name=\"body_mass_g\")\n",
9751055
"body_mass_kg = body_mass_g / 1000.0\n",
9761056
"\n",
@@ -1086,6 +1166,8 @@
10861166
},
10871167
"outputs": [],
10881168
"source": [
1169+
"%set_cell_height 300\n",
1170+
"\n",
10891171
"# Configure the model.\n",
10901172
"model_7 = tfdf.keras.RandomForestModel(task = tfdf.keras.Task.REGRESSION)\n",
10911173
"\n",
@@ -1161,6 +1243,8 @@
11611243
},
11621244
"outputs": [],
11631245
"source": [
1246+
"%set_cell_height 200\n",
1247+
"\n",
11641248
"archive_path = tf.keras.utils.get_file(\"letor.zip\",\n",
11651249
" \"https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\",\n",
11661250
" extract=True)\n",
@@ -1266,6 +1350,8 @@
12661350
},
12671351
"outputs": [],
12681352
"source": [
1353+
"%set_cell_height 400\n",
1354+
"\n",
12691355
"model_8 = tfdf.keras.GradientBoostedTreesModel(\n",
12701356
" task=tfdf.keras.Task.RANKING,\n",
12711357
" ranking_group=\"group\",\n",
@@ -1299,6 +1385,8 @@
12991385
},
13001386
"outputs": [],
13011387
"source": [
1388+
"%set_cell_height 400\n",
1389+
"\n",
13021390
"model_8.summary()"
13031391
]
13041392
}

0 commit comments

Comments
 (0)