|
161 | 161 | "try:\n",
|
162 | 162 | " from wurlitzer import sys_pipes\n",
|
163 | 163 | "except:\n",
|
164 |
| - " from colabtools.googlelog import CaptureLog as sys_pipes" |
| 164 | + " from colabtools.googlelog import CaptureLog as sys_pipes\n", |
| 165 | + "\n", |
| 166 | + "from IPython.core.magic import register_line_magic\n", |
| 167 | + "from IPython.display import Javascript" |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "code", |
| 172 | + "execution_count": null, |
| 173 | + "metadata": { |
| 174 | + "id": "2AhqJz3VmQM-" |
| 175 | + }, |
| 176 | + "outputs": [], |
| 177 | + "source": [ |
| 178 | + "#@title View results with a max cell height.\n", |
| 179 | + "\n", |
| 180 | + "\n", |
| 181 | + "# Some of the model training logs can cover the full\n", |
| 182 | + "# screen if not compressed to a smaller viewport.\n", |
| 183 | + "# This magic allows setting a max height for a cell.\n", |
| 184 | + "@register_line_magic\n", |
| 185 | + "def set_cell_height(size):\n", |
| 186 | + " display(\n", |
| 187 | + " Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n", |
| 188 | + " str(size) + \"})\"))" |
165 | 189 | ]
|
166 | 190 | },
|
167 | 191 | {
|
|
350 | 374 | },
|
351 | 375 | "outputs": [],
|
352 | 376 | "source": [
|
| 377 | + "%set_cell_height 300\n", |
| 378 | + "\n", |
353 | 379 | "# Specify the model.\n",
|
354 | 380 | "model_1 = tfdf.keras.RandomForestModel()\n",
|
355 | 381 | "\n",
|
|
463 | 489 | "model_1.save(\"/tmp/my_saved_model\")"
|
464 | 490 | ]
|
465 | 491 | },
|
| 492 | + { |
| 493 | + "cell_type": "markdown", |
| 494 | + "metadata": { |
| 495 | + "id": "6-8R02_SXpbq" |
| 496 | + }, |
| 497 | + "source": [ |
| 498 | + "## Plot the model\n", |
| 499 | + "\n", |
| 500 | + "Plotting a decision tree and following the first branches helps learning about decision forests. In some cases, plotting a model can even be used for debugging.\n", |
| 501 | + "\n", |
| 502 | + "Because of the difference in the way they are trained, some models are more interresting to plan than others. Because of the noise injected during training and the depth of the trees, plotting Random Forest is less informative than plotting a CART or the first tree of a Gradient Boosted Tree.\n", |
| 503 | + "\n", |
| 504 | + "Never the less, let's plot the first tree of our Random Forest model:" |
| 505 | + ] |
| 506 | + }, |
| 507 | + { |
| 508 | + "cell_type": "code", |
| 509 | + "execution_count": null, |
| 510 | + "metadata": { |
| 511 | + "id": "KUIxf8N6Yjl0" |
| 512 | + }, |
| 513 | + "outputs": [], |
| 514 | + "source": [ |
| 515 | + "tfdf.model_plotter.plot_model_in_colab(model_1, tree_idx=0, max_depth=3)" |
| 516 | + ] |
| 517 | + }, |
| 518 | + { |
| 519 | + "cell_type": "markdown", |
| 520 | + "metadata": { |
| 521 | + "id": "cPcL_hDnY7Zy" |
| 522 | + }, |
| 523 | + "source": [ |
| 524 | + "The root node on the left contains the first condition (`bill_depth_mm \u003e= 16.55`), number of examples (240) and label distribution (the red-blue-green bar).\n", |
| 525 | + "\n", |
| 526 | + "Examples that evaluates true to `bill_depth_mm \u003e= 16.55` are branched to the green path. The other ones are branched to the red path.\n", |
| 527 | + "\n", |
| 528 | + "The deeper the node, the more `pure` they become i.e. the label distribution is biased toward a subset of classes. \n", |
| 529 | + "\n", |
| 530 | + "**Note:** Over the mouse on top of the plot for details." |
| 531 | + ] |
| 532 | + }, |
466 | 533 | {
|
467 | 534 | "cell_type": "markdown",
|
468 | 535 | "metadata": {
|
|
498 | 565 | },
|
499 | 566 | "outputs": [],
|
500 | 567 | "source": [
|
| 568 | + "%set_cell_height 300\n", |
501 | 569 | "model_1.summary()"
|
502 | 570 | ]
|
503 | 571 | },
|
|
597 | 665 | },
|
598 | 666 | "outputs": [],
|
599 | 667 | "source": [
|
| 668 | + "%set_cell_height 150\n", |
600 | 669 | "model_1.make_inspector().training_logs()"
|
601 | 670 | ]
|
602 | 671 | },
|
|
716 | 785 | "id": "xmzvuI78voD4"
|
717 | 786 | },
|
718 | 787 | "source": [
|
719 |
| - "The description of the learning algorithms and their hyper-parameters are also available in the [API reference](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/RandomForestModel) and builtin help:\n", |
720 |
| - "\n", |
721 |
| - "```\n", |
| 788 | + "The description of the learning algorithms and their hyper-parameters are also available in the [API reference](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf) and builtin help:" |
| 789 | + ] |
| 790 | + }, |
| 791 | + { |
| 792 | + "cell_type": "code", |
| 793 | + "execution_count": null, |
| 794 | + "metadata": { |
| 795 | + "id": "2hONToBav4DE" |
| 796 | + }, |
| 797 | + "outputs": [], |
| 798 | + "source": [ |
722 | 799 | "# help works anywhere.\n",
|
723 | 800 | "help(tfdf.keras.RandomForestModel)\n",
|
724 | 801 | "\n",
|
725 | 802 | "# ? only works in ipython or notebooks, it usually opens on a separate panel.\n",
|
726 |
| - "tfdf.keras.RandomForestModel?\n", |
727 |
| - "```" |
| 803 | + "tfdf.keras.RandomForestModel?" |
728 | 804 | ]
|
729 | 805 | },
|
730 | 806 | {
|
|
814 | 890 | },
|
815 | 891 | "outputs": [],
|
816 | 892 | "source": [
|
| 893 | + "%set_cell_height 300\n", |
| 894 | + "\n", |
817 | 895 | "feature_1 = tfdf.keras.FeatureUsage(name=\"year\", semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)\n",
|
818 | 896 | "feature_2 = tfdf.keras.FeatureUsage(name=\"bill_length_mm\")\n",
|
819 | 897 | "feature_3 = tfdf.keras.FeatureUsage(name=\"sex\")\n",
|
|
971 | 1049 | },
|
972 | 1050 | "outputs": [],
|
973 | 1051 | "source": [
|
| 1052 | + "%set_cell_height 300\n", |
| 1053 | + "\n", |
974 | 1054 | "body_mass_g = tf.keras.layers.Input(shape=(1,), name=\"body_mass_g\")\n",
|
975 | 1055 | "body_mass_kg = body_mass_g / 1000.0\n",
|
976 | 1056 | "\n",
|
|
1086 | 1166 | },
|
1087 | 1167 | "outputs": [],
|
1088 | 1168 | "source": [
|
| 1169 | + "%set_cell_height 300\n", |
| 1170 | + "\n", |
1089 | 1171 | "# Configure the model.\n",
|
1090 | 1172 | "model_7 = tfdf.keras.RandomForestModel(task = tfdf.keras.Task.REGRESSION)\n",
|
1091 | 1173 | "\n",
|
|
1161 | 1243 | },
|
1162 | 1244 | "outputs": [],
|
1163 | 1245 | "source": [
|
| 1246 | + "%set_cell_height 200\n", |
| 1247 | + "\n", |
1164 | 1248 | "archive_path = tf.keras.utils.get_file(\"letor.zip\",\n",
|
1165 | 1249 | " \"https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\",\n",
|
1166 | 1250 | " extract=True)\n",
|
|
1266 | 1350 | },
|
1267 | 1351 | "outputs": [],
|
1268 | 1352 | "source": [
|
| 1353 | + "%set_cell_height 400\n", |
| 1354 | + "\n", |
1269 | 1355 | "model_8 = tfdf.keras.GradientBoostedTreesModel(\n",
|
1270 | 1356 | " task=tfdf.keras.Task.RANKING,\n",
|
1271 | 1357 | " ranking_group=\"group\",\n",
|
|
1299 | 1385 | },
|
1300 | 1386 | "outputs": [],
|
1301 | 1387 | "source": [
|
| 1388 | + "%set_cell_height 400\n", |
| 1389 | + "\n", |
1302 | 1390 | "model_8.summary()"
|
1303 | 1391 | ]
|
1304 | 1392 | }
|
|
0 commit comments