From 2cc530b7b59cd5feca49aadc267dbd48ed37e1cc Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Wed, 20 Apr 2022 19:16:04 -0700 Subject: [PATCH] Add pre-defined cohorts in responsibleaidashboard-diabetes-decision-making.ipynb (#1252) * [WIP] Add pre-defined cohorts in responsibleaidashboard-diabetes-decision-making.ipynb Signed-off-by: Gaurav Gupta * Fix the e2e test for notebook Signed-off-by: Gaurav Gupta * Make _cohort.py module a public module (#1253) * Make _cohort.py a public module Signed-off-by: Gaurav Gupta * Add missing file Signed-off-by: Gaurav Gupta * Fix cohort namespace Signed-off-by: Gaurav Gupta * minor fix to url for responsibleai package in setup.py (#1260) * Counterfactual Chart: Correct target description according to task_type (#1261) * Counterfactual Chart: Correct target description according to task_type Signed-off-by: Gaurav Gupta * Change function name Signed-off-by: Gaurav Gupta * Make Range lowercase Signed-off-by: Gaurav Gupta * fix whitespace in values of adult census income dataset (#1263) * Add what-If scatter chart from highchart lib (#1262) * add whatIf scatter chart * widget test * what if local importance bar chart * fix * widget * fix tooltip * refactor * test * test * add a builddebug yarn command to build UX locally which can be debugged in browser e2e (#1265) * allow rai text insights to work with RAI dashboard (#1269) * remove duplicate code in explanation dashboard (#1266) * Individual causal style responsive (#1268) * add whatIf scatter chart * widget test * what if local importance bar chart * fix * widget * fix tooltip * refactor * test * test * Causal Style * Allow duplicating cohorts multiple times (#1274) * allow duplicating a cohort more than once * lintfix * Disable column header highlighting on hover in IndividualFeatureImportanceView (#1272) * disable column highlight on hover * lintfix * Rename new cohorts from "Unsaved" to "Temporary cohort" (#1273) * rename Unsaved to Temporary cohort * localize temp cohort * Counterfactual style refactor (#1275) * style refactor * test * test * test * fix * Don't change cursor on hover over cohort name * Fix (#1281) * fix cohort info styling (#1277) * fix readme link to fairness and interpretability example notebook (#1282) * add new RAI Utils package for common utilities shared across RAI packages (#1280) * Add ICE chart (#1283) * Fix * ice chart * ic * test * test * update docstring for explanation dashboard in regards to min number of rows (#1271) * make builds more reliable by adding retry logic to urlretrieve calls in notebooks (#1218) * upgrade pytest to 7.0.1, remove mock and updgrade pytest-mock to 3.6.1 (#1287) * remove deprecated codecov parameter (#1293) * Fix min/max special case in cohort filter creation with "in the range of" (#1279) * fix logic in the case that min or max are zero * lintfix * Rename 'Dashboard navigation' to 'Dashboard configuration' (#1291) * Rename 'Dashboard navigation' to 'Dashboard configuration' Signed-off-by: Gaurav Gupta * Notebook change Signed-off-by: Gaurav Gupta * Add raiutils to PR template (#1290) * fix heatmap bug (#1297) * Make "save and switch" work from cohort settings (#1276) * make save and switch work * fix naming * lintfix * adjustment according to Ilya's comment * lintfix * add retry logic to codecov step and only upload results for one python version (#1298) * add github action to release raiutils to pypi (#1294) * Add highchart for Dataset Explorer (#1286) * test * style * click * fix test * fix test * test * test * test * test * Update requirements-linting.txt to add flake8-pytest-style (#1296) * Fix sort abs (#1299) * Rename "base cohort" to "global cohort" (#1278) * change base cohort to global cohort * fix spelling * lintfix * fix codecov comment not appearing on PRs (#1302) * take absolute value of error calculation for regression scenario (#1301) * Limit individual feature importance selection to up to 5 (#1305) * update feature importance string * limit selection to up to 5 * add group count * remove message bar, show info icon instead * update e2e locator * fix E2E failure on feature importance * add ariaLabel for expand collapse button * add renderOnNewLayer props * Add error message for counterfactual panel (#1310) * add error message for counterfactual * update error message in camel case to fix build error * Add to_json() and from_json() methods to Cohort class (#1300) * Add to_json() and from_json() methods to Cohort class Signed-off-by: Gaurav Gupta * Address code review comments Signed-off-by: Gaurav Gupta * Fix linting Signed-off-by: Gaurav Gupta * Add a highchart heatmap helper class (#1307) * add highchart heatmap helper class * add erroneously deleted line back * Fix cohort setting string (#1304) * Fix string * remove none * name * test * Fix all data style (#1303) * Add a feature flag for the new model overview experience (#1306) * add feature flag for new model overview experience and turn it off by default * remove useless constructor * Clean up charts code (#1313) * clean up chart code * remove arg * Expand the counterfactual flyout to cover the full page (#1315) Signed-off-by: Gaurav Gupta * Bump minimist from 1.2.5 to 1.2.6 (#1292) * Bump minimist from 1.2.5 to 1.2.6 Bumps [minimist](https://github.com/substack/minimist) from 1.2.5 to 1.2.6. - [Release notes](https://github.com/substack/minimist/releases) - [Commits](https://github.com/substack/minimist/compare/1.2.5...1.2.6) --- updated-dependencies: - dependency-name: minimist dependency-type: indirect ... Signed-off-by: dependabot[bot] * minimist ^1.2.6 Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: xuke444 <40614413+xuke444@users.noreply.github.com> Co-authored-by: Roman Lutz * fix random node download failures by upgrading to latest github action with retry logic (#1317) * Add dataset cohort table to new ModelOverview experience (#1314) * add only dataset cohort table, update wrapping code * lintfix * lintfix * build error fix * Add installation instructions for raiwidgets to README (#1320) * refactor RAIInsights into RAIInsightsBase class for basic functionality (#1284) * Fix what if counterfactual header and description text misaligned (#1316) * align * e2e * add clear temporary cohort button to error analysis (#1322) * Raise UserConfigValidationException in case no model but valid model serializer (#1325) Signed-off-by: Gaurav Gupta * Add test case for handling different types in causal (#1321) Signed-off-by: Gaurav Gupta * show shift to an empty cohort in tree view as an empty node (#1318) * Bug fixing (#1326) * Move chart description up and remove scroll bar * Change string * Add box outlier for dataset explorer (#1323) * add outlier for dataset explorer * name * update string when no datapoint selected (#1331) * Fix Big empty space for featureImportance chart (#1328) * legend * removed invalid test case * constant * Disable save as new cohort button if nothing is selected in error tree (#1327) * Add disaggregated analysis table/heatmap (#1332) * disaggregated analysis changes only * lintfix * Change warning message to user exception for model type and task type mismatch (#1330) * Change warning message to user exception for model type and task type mismatch Signed-off-by: Gaurav Gupta * Fix flake8 errors Signed-off-by: Gaurav Gupta * Change the counterfactual text color from black to grey (#1337) Signed-off-by: Gaurav Gupta * Limit each component description width up to 750px for readability (#1336) * limit description width up to 750px * export maxWidth from a common place * block empty cohort creation in RAI Dashboard (#1335) * Add warning message in cohort editor for invalid input value; Update 'Shift cohort' to 'Switch cohort' (#1339) * add error message for invalid value * update shift cohort to switch cohort * Rename counterfactual style files to confirm with *.styles.ts (#1338) Signed-off-by: Gaurav Gupta * Add disaggregated analysis table to Model Overview (#1341) * pull in changes for disaggregated analysis * add styles file * add textured NaN cells * module import for textured cells and grid y axis * lintfix * use combobox for dropdown rather than dropdown * lintfix * Rename causal style files to confirm with *.styles.ts (#1342) Signed-off-by: Gaurav Gupta * update responsibleai to interpret-community 0.25.0 (#1343) * All component title and descriptions should be aligned (#1346) * update Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * Remove 5K limit blurb from local explanations tab (#1347) Signed-off-by: Gaurav Gupta * Sort features by default in counterfactual flyout (#1312) * Sort features by default in counterfactual flyout Signed-off-by: Gaurav Gupta * Fix failing tests Signed-off-by: Gaurav Gupta * attempt to fix test Signed-off-by: Gaurav Gupta * Remove check Signed-off-by: Gaurav Gupta * Bump moment from 2.28.0 to 2.29.2 (#1333) Bumps [moment](https://github.com/moment/moment) from 2.28.0 to 2.29.2. - [Release notes](https://github.com/moment/moment/releases) - [Changelog](https://github.com/moment/moment/blob/develop/CHANGELOG.md) - [Commits](https://github.com/moment/moment/compare/2.28.0...2.29.2) --- updated-dependencies: - dependency-name: moment dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Roman Lutz * Counterfactual flyout top section need to be moved to left & Error analysis move side content to align with description text (#1350) * update Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * move the root all data statistics to ErrorReport and ErrorAnalysisData (#1344) * update error analysis documentation to clarify the error tree splits on errors even when other metrics are selected (#1349) Co-authored-by: Vinutha Karanth * update erroranalysis to 0.2.1 and remove some duplicate dependencies (#1334) * disable turbo checking for large amount of data (#1351) Signed-off-by: Ke Xu * force re-render when chart type changes (#1354) Signed-off-by: Ke Xu * move the root all data statistics to ErrorReport and ErrorAnalysisData (#1352) * Rename output column name in the counterfactual flyout (#1353) Signed-off-by: Gaurav Gupta * Show column chart for categorical feature in data explorer (#1355) * Show column chart for categorical feature in data explorer * address comments * update fluentui (#1356) Signed-off-by: Ke Xu * update code owner (#1308) * update code owner Signed-off-by: Ke Xu * remove dup Signed-off-by: Ke Xu Co-authored-by: Roman Lutz * update version to match studio (#1357) Signed-off-by: Ke Xu * alignment fixes (#1359) * Add charts for new model overview experience (#1348) * rename OverallTable to DisaggregatedMetricsTable and move to core-ui * Copy the ModelPerformanceTab into model-assessment and rename to ModelOverview * reference OverallTable again in fairness widget * refactor core chart component out into core-ui * refactor out core chart component into core-ui * lintfix * undo tsconfig.lib.json change * fix chartAndConfigsId in tests * lintfix * add table for cohort metrics and add dropdown metric selector, add new metrics * add new metrics * undo unwanted changes * fix casing * add superscript 2 for r-squared * update tests to reflect new metrics * lintfix * add feature flag * fix mae * fix mae calculation * first version of new model overview table * get probability distribution box plot to work * add feature flag for new model overview experience and turn it off by default * add highchart heatmap helper class * remove custom styling * add erroneously deleted line back * remove useless constructor * modularize model overview * show outliers in box plot, fix positioning * remove showmetricsummary * refactor heatmap code into a common class * add featureDropdownRef to allow focusing * add only dataset cohort table, update wrapping code * lintfix * lintfix * build error fix * add chart config flyout (in progress) * add chart config flyout (in progress) * address feedback, use finalized color * adjust feature selection to disable options if limit is reached, add axis config buttons * select all via dropdown * lintfix * refactor box plot calculations and rendering * add style file * textured NaN cells, grid labels on y axis * standardize box plots to use fences * fix merge issues * unify box plot tooltip formatting code, fix bar chart sizing issue * small fixes * rearrange feature dropdown * lintfix * remove commented out code * remove box plot tooltip customization * lintfix * add a few unit tests * unit tests for smaller utilities, localization fixes, consistent flyout flow with confirm/cancel buttons * lintfix * fix chart config flyout update * fix test case * rename files for lint * file rename for lint * release rai-core-flask 0.3.0 (#1361) * upgrade python version used with flask CI to fix segfault error (#1363) * release raiwidgets and responsibleai v0.18.0 (#1360) * fix two bugs (#1364) * Add pre-built cohort into adult census notebook (#1243) * [WIP] Add pre-built cohort into adult census notebook Signed-off-by: Gaurav Gupta * erroranalysis version bump in raiwidgets to 0.1.31 (#1245) * Make cohrtData empty list in case no pre-bdefined cohorts are injected (#1247) Signed-off-by: Gaurav Gupta * Simplify the train pipeline responsibleaidashboard-census-classification-model-debugging.ipynb (#1195) * Simplify the train pipeline responsibleaidashboard-census-classification-model-debugging.ipynb Signed-off-by: Gaurav Gupta * Address code review comments * Update notebooks/responsibleaidashboard/responsibleaidashboard-census-classification-model-debugging.ipynb Co-authored-by: Roman Lutz Co-authored-by: Roman Lutz Signed-off-by: Gaurav Gupta * Add regression test for pre-defined cohorts in raiwidgets (#1249) Signed-off-by: Gaurav Gupta * color (#1248) * Add feature importance box & bar chart (#1241) * refactor * build * build * temp * temp * temp * temp * box * cache * e2e * e2e * fix * e2e fix * e2e * fix e2e * widget * widget * fix * widget * e2e * e2e * e2e * test * test * PreBuilt cohorts UX changes (#1242) * Intial SDK implementation cohorts Signed-off-by: Gaurav Gupta * Add basic validationf for cohorts Signed-off-by: Gaurav Gupta * Add serialized version of cohort config to ResponsibleAiDashboard Signed-off-by: Gaurav Gupta * Add more tests cohorts Signed-off-by: Gaurav Gupta * fix broken builds due to pip upgrade which broke pip-tools (#1185) * refactor matrix filter and area state to be private static (#1179) * Change variable name Signed-off-by: Gaurav Gupta * Add more cohort filters Signed-off-by: Gaurav Gupta * Add cohort data to dashboard e2e Signed-off-by: Gaurav Gupta * Add more cohorts filters Signed-off-by: Gaurav Gupta * Document various data validation for cohorts Signed-off-by: Gaurav Gupta * Add new interfaces for pre-built cohort Signed-off-by: Gaurav Gupta * Add more cohort filters Signed-off-by: Gaurav Gupta * Add prebuilt cohort walking logic in UI and add more data validation scenarios Signed-off-by: Gaurav Gupta * Add basic data validation checks Signed-off-by: Gaurav Gupta * Add logic to translate the Index cohort filter Signed-off-by: Gaurav Gupta * Remove commented out code Signed-off-by: Gaurav Gupta * Add SDK validations for Index based cohort filter Signed-off-by: Gaurav Gupta * Add code for validating classification outcome Signed-off-by: Gaurav Gupta * Add error filter validations and add tests Signed-off-by: Gaurav Gupta * Add fake cohorts for regression dataset Signed-off-by: Gaurav Gupta * Add fake cohorts for multi-class classification dataset Signed-off-by: Gaurav Gupta * Add handling of regression filter Signed-off-by: Gaurav Gupta * Add support for classification outcome in UI Signed-off-by: Gaurav Gupta * Add validations for Predicted Y and True Y cohort filters Signed-off-by: Gaurav Gupta * Add UI code to handle prediced Y and true Y for pre-built cohort filters Signed-off-by: Gaurav Gupta * Add cohort validation with test data to raiwidgets Signed-off-by: Gaurav Gupta * Add tests for validating Predicted/True Y cohorts Signed-off-by: Gaurav Gupta * Add UI support for TrueY/PredictedY for classification Signed-off-by: Gaurav Gupta * Rename cohort_filter_list to cohort_list Signed-off-by: Gaurav Gupta * Rename UI varibles to match SDK Signed-off-by: Gaurav Gupta * Fix duplicate cohort name Signed-off-by: Gaurav Gupta * Add SDK cohorts to notebook Signed-off-by: Gaurav Gupta * Add dataset validations and add categorical features Signed-off-by: Gaurav Gupta * Add validations for categorical_features Signed-off-by: Gaurav Gupta * Fix sorted imports Signed-off-by: Gaurav Gupta * Add code for translating categorical values Signed-off-by: Gaurav Gupta * Move cohort processing to a separate file Signed-off-by: Gaurav Gupta * Fix code review comments Signed-off-by: Gaurav Gupta * Refactor cohort translated function into different small functions Signed-off-by: Gaurav Gupta * Change to lowercase for outcome Signed-off-by: Gaurav Gupta * Fix code review comments Signed-off-by: Gaurav Gupta * Refactor cohort_list validations and converge pytest common functions into fixtures Signed-off-by: Gaurav Gupta * Add conftest into raiwidgets tests Signed-off-by: Gaurav Gupta * Add validations for cohort list Signed-off-by: Gaurav Gupta * Add cohortData test Signed-off-by: Gaurav Gupta * Fix sorted imports Signed-off-by: Gaurav Gupta * isort fix Signed-off-by: Gaurav Gupta * Add UI unit tests for cohort translation Signed-off-by: Gaurav Gupta * Add more checks in UI uni test Signed-off-by: Gaurav Gupta * Add UI tests for regression cohorts Signed-off-by: Gaurav Gupta * REmove notebook change Signed-off-by: Gaurav Gupta * Fix typescript build Signed-off-by: Gaurav Gupta * Change cohort filter values so that cohort filters non-zero points Signed-off-by: Gaurav Gupta * Fix for empty cohort list Signed-off-by: Gaurav Gupta * Simplify the train pipeline responsibleaidashboard-census-classification-model-debugging.ipynb (#1195) * Simplify the train pipeline responsibleaidashboard-census-classification-model-debugging.ipynb Signed-off-by: Gaurav Gupta * Address code review comments * Update notebooks/responsibleaidashboard/responsibleaidashboard-census-classification-model-debugging.ipynb Co-authored-by: Roman Lutz Co-authored-by: Roman Lutz * Propagate error strings instead of raising exceptions Signed-off-by: Gaurav Gupta * Fix code issues Signed-off-by: Gaurav Gupta * Fix code review comments Signed-off-by: Gaurav Gupta * Fix code review comments Signed-off-by: Gaurav Gupta Co-authored-by: Ilya Matiach Co-authored-by: Roman Lutz * Make _cohort.py module a public module (#1253) * Make _cohort.py a public module Signed-off-by: Gaurav Gupta * Add missing file Signed-off-by: Gaurav Gupta * fix notebook build failures due to pywinpty dependency release failing in python 3.6 (#1257) * fix notebook build failures due to pywinpty dependency release failing in python 3.6 * build pywinpty from conda instead * add lowerbound * fixup * fixup * Add supported models and data types to README.md responsibleai (#1259) Signed-off-by: Gaurav Gupta * make getting-started notebook a markdown file showing APIs (#1223) * refactor tabs out of RAI dashboard into a separate component (#1256) * Add individual causal scatter chart (#1258) * temp * refactor * test * style fix * comment * minor fix to url for responsibleai package in setup.py (#1260) * Fix UX e2e tests and address code review comments Signed-off-by: Gaurav Gupta * Fix eslint Signed-off-by: Gaurav Gupta * Address review comments Signed-off-by: Gaurav Gupta * Reset the number of samples in test dataset Signed-off-by: Gaurav Gupta Co-authored-by: Ilya Matiach Co-authored-by: Roman Lutz Co-authored-by: Bo Zhang <71688188+zhb000@users.noreply.github.com> Signed-off-by: Gaurav Gupta * Change cohort name Signed-off-by: Gaurav Gupta Co-authored-by: Ilya Matiach Co-authored-by: Bo Zhang <71688188+zhb000@users.noreply.github.com> Co-authored-by: Roman Lutz Co-authored-by: tongy-msft <91754176+tongyu-microsoft@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: xuke444 <40614413+xuke444@users.noreply.github.com> Co-authored-by: Vinutha Karanth --- .../modelAssessmentDatasets.ts | 9 +- ...aidashboard-diabetes-decision-making.ipynb | 123 +++++++++++++++++- 2 files changed, 129 insertions(+), 3 deletions(-) diff --git a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts index c6548aeb25..d62e13b7de 100644 --- a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts +++ b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts @@ -123,7 +123,14 @@ const modelAssessmentDatasets = { "s6" ], modelStatisticsData: { - cohortDropDownValues: ["All data"], + cohortDropDownValues: [ + "All data", + "Cohort Index", + "Cohort Predicted Y", + "Cohort True Y", + "Cohort Regression Error", + "Cohort Age and BMI" + ], defaultXAxis: "Error", defaultXAxisPanelValue: "Error", defaultYAxis: "Cohort", diff --git a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb index 960585f6e9..bf2df40fab 100644 --- a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb +++ b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "75018d5c", "metadata": {}, "source": [ "# Plan real-world action using counterfactual example analysis and causal analysis" @@ -9,6 +10,7 @@ }, { "cell_type": "markdown", + "id": "d4939847", "metadata": {}, "source": [ "This notebook demonstrates the use of the Responsible AI Toolbox to make decisions from diabetes progression data. It walks through the API calls necessary to create a widget with causal inferencing insights, then guides a visual analysis of the data." @@ -16,6 +18,7 @@ }, { "cell_type": "markdown", + "id": "231caa35", "metadata": {}, "source": [ "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", @@ -28,6 +31,7 @@ }, { "cell_type": "markdown", + "id": "8cfa82d1", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" @@ -35,6 +39,7 @@ }, { "cell_type": "markdown", + "id": "789b30d1", "metadata": {}, "source": [ "The following section examines the code necessary to create the dataset. It then generates insights using the `responsibleai` API that can be visually analyzed." @@ -42,6 +47,7 @@ }, { "cell_type": "markdown", + "id": "3e43e464", "metadata": {}, "source": [ "### Train a Model\n", @@ -51,6 +57,7 @@ { "cell_type": "code", "execution_count": null, + "id": "a670ba8c", "metadata": {}, "outputs": [], "source": [ @@ -64,6 +71,7 @@ }, { "cell_type": "markdown", + "id": "a4f53194", "metadata": {}, "source": [ "First, load the diabetes dataset and specify the different types of features. Then, clean it and put it into a DataFrame with named columns." @@ -72,6 +80,7 @@ { "cell_type": "code", "execution_count": null, + "id": "479ad4f8", "metadata": {}, "outputs": [], "source": [ @@ -83,6 +92,7 @@ }, { "cell_type": "markdown", + "id": "c7cdd8ae", "metadata": {}, "source": [ "After loading and cleaning the data, split the datapoints into training and test sets. Assemble separate datasets for the training and test data." @@ -91,6 +101,7 @@ { "cell_type": "code", "execution_count": null, + "id": "4e02d132", "metadata": {}, "outputs": [], "source": [ @@ -105,6 +116,7 @@ }, { "cell_type": "markdown", + "id": "59853607", "metadata": {}, "source": [ "Train a nearest-neighbors classifier on the training data." @@ -113,6 +125,7 @@ { "cell_type": "code", "execution_count": null, + "id": "6612038f", "metadata": {}, "outputs": [], "source": [ @@ -122,6 +135,7 @@ }, { "cell_type": "markdown", + "id": "29805164", "metadata": {}, "source": [ "### Create Model and Data Insights" @@ -130,6 +144,7 @@ { "cell_type": "code", "execution_count": null, + "id": "c65f788f", "metadata": {}, "outputs": [], "source": [ @@ -139,6 +154,7 @@ }, { "cell_type": "markdown", + "id": "400de1d9", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", @@ -149,6 +165,7 @@ { "cell_type": "code", "execution_count": null, + "id": "d965f769", "metadata": {}, "outputs": [], "source": [ @@ -158,6 +175,7 @@ }, { "cell_type": "markdown", + "id": "38fbbe06", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on decision-making." @@ -166,6 +184,7 @@ { "cell_type": "code", "execution_count": null, + "id": "24567d8d", "metadata": {}, "outputs": [], "source": [ @@ -178,6 +197,7 @@ }, { "cell_type": "markdown", + "id": "571b2235", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." @@ -186,6 +206,7 @@ { "cell_type": "code", "execution_count": null, + "id": "a7dec636", "metadata": {}, "outputs": [], "source": [ @@ -194,6 +215,81 @@ }, { "cell_type": "markdown", + "id": "0ad206fd", + "metadata": {}, + "source": [ + "Compose some cohorts which can be injected into the `ResponsibleAIDashboard`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7a039b34", + "metadata": {}, + "outputs": [], + "source": [ + "from raiwidgets.cohort import Cohort, CohortFilter, CohortFilterMethods\n", + "\n", + "# Cohort on age and bmi features in the dataset\n", + "cohort_filter_age = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[40],\n", + " column='age')\n", + "cohort_filter_bmi = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[0],\n", + " column='bmi')\n", + " \n", + "user_cohort_age_and_bmi= Cohort(name='Cohort Age and BMI')\n", + "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_age)\n", + "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_bmi)\n", + "\n", + "# Cohort on index\n", + "cohort_filter_index = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[20],\n", + " column='Index')\n", + "\n", + "user_cohort_index = Cohort(name='Cohort Index')\n", + "user_cohort_index.add_cohort_filter(cohort_filter_index)\n", + "\n", + "# Cohort on predicted y values\n", + "cohort_filter_predicted_y = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[165.0],\n", + " column='Predicted Y')\n", + "\n", + "user_cohort_predicted_y = Cohort(name='Cohort Predicted Y')\n", + "user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y)\n", + "\n", + "# Cohort on true y values\n", + "cohort_filter_true_y = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[45.0],\n", + " column='True Y')\n", + "\n", + "user_cohort_true_y = Cohort(name='Cohort True Y')\n", + "user_cohort_true_y.add_cohort_filter(cohort_filter_true_y)\n", + "\n", + "# Cohort on true y values\n", + "cohort_filter_regression_error = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[20.0],\n", + " column='Error')\n", + "\n", + "user_cohort_regression_error = Cohort(name='Cohort Regression Error')\n", + "user_cohort_regression_error.add_cohort_filter(cohort_filter_regression_error)\n", + "\n", + "cohort_list = [user_cohort_age_and_bmi,\n", + " user_cohort_index,\n", + " user_cohort_predicted_y,\n", + " user_cohort_true_y,\n", + " user_cohort_regression_error]" + ] + }, + { + "cell_type": "markdown", + "id": "54a43b5c", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." @@ -202,14 +298,16 @@ { "cell_type": "code", "execution_count": null, + "id": "ad84c884", "metadata": {}, "outputs": [], "source": [ - "ResponsibleAIDashboard(rai_insights)" + "ResponsibleAIDashboard(rai_insights, cohort_list=cohort_list)" ] }, { "cell_type": "markdown", + "id": "fb2ab57e", "metadata": {}, "source": [ "## Take Real-World Action" @@ -217,6 +315,7 @@ }, { "cell_type": "markdown", + "id": "84325421", "metadata": {}, "source": [ "### What-If Counterfactuals Analysis" @@ -224,6 +323,7 @@ }, { "cell_type": "markdown", + "id": "d292d247", "metadata": {}, "source": [ "Let's imagine that the diabetes progression scores predicted by the model are used to determine medical insurance rates. If the score is greater than 120, there is a higher rate. Patient 43's model score of 268.08 results in this increased rate, and they want to know how they should change their health to get a lower rate prediction from the model (leading to lower insurance price).\n", @@ -234,6 +334,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "d459156b", "metadata": {}, "source": [ "![What-If Counterfactuals component with datapoint 43 selected on the scatter plot with axes \"Predicted Y\" and \"Index\"](./img/regression-decision-making-1.png)" @@ -241,6 +342,7 @@ }, { "cell_type": "markdown", + "id": "d7b86696", "metadata": {}, "source": [ "What can Patient 43 do to create the desired change? The top ranked features bar plot shows that `bmi` and `s5` are the best to perturb to bring the model score within 120." @@ -249,6 +351,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "b16d1a6c", "metadata": {}, "source": [ "![Top-ranked features (descending) for datapoint 43 to perturb to reduce model prediction below 120: bmi, s5, s4, s3, age, bp, sex, s1, s2, s6](./img/regression-decision-making-2.png)" @@ -256,6 +359,7 @@ }, { "cell_type": "markdown", + "id": "709c3019", "metadata": {}, "source": [ "Let's see how that can be achieved. Change `bmi` to -0.04 and `s5` to -0.042 and see what the result is." @@ -264,6 +368,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "5faa62ea", "metadata": {}, "source": [ "![Counterfactual creation panel. BMI has been changed to -0.04 and s5 has been changed to -0.042](./img/regression-decision-making-3.png)" @@ -271,6 +376,7 @@ }, { "cell_type": "markdown", + "id": "a9f67339", "metadata": {}, "source": [ "As we can see, the model's prediction has dropped to 131.22. Thus, Patient 43 should work on reducing their [body mass index and serum triglycerides level](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) to bring the model score under the insurance threshold." @@ -279,6 +385,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "22d445d7", "metadata": {}, "source": [ "![Counterfactual of datapoint 43 selected on the counterfactuals scatter plot with axes \"Predicted Y\" and \"Index\". Predicted Y is 115.4](./img/regression-decision-making-4.png)" @@ -286,6 +393,7 @@ }, { "cell_type": "markdown", + "id": "b4f78fd8", "metadata": {}, "source": [ "Note that this result does not mean that reducing `bmi` and `s5` *causes* the diabetes progression score to go down. It simply decreases the model prediction. To investigate causal relationships, continue reading:" @@ -293,6 +401,7 @@ }, { "cell_type": "markdown", + "id": "b134cdb5", "metadata": {}, "source": [ "### Causal Analysis" @@ -300,6 +409,7 @@ }, { "cell_type": "markdown", + "id": "da76466d", "metadata": {}, "source": [ "Now suppose that a doctor wishes to know how to reduce the progression of diabetes in her patients. This can be explored in the Causal Inference component of the Responsible AI Toolbox.\n", @@ -310,6 +420,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "90b838d8", "metadata": {}, "source": [ "![Overall causal analysis table](./img/regression-decision-making-5.png)\n", @@ -318,6 +429,7 @@ }, { "cell_type": "markdown", + "id": "f6078481", "metadata": {}, "source": [ "Let's revisit Patient 43. Instead of simply reducing the model score, they've decided to focus on actually improving their health to manage their diabetes better. In the \"Individual causal what-if\" tab, it shows that decreasing his/her bmi to 0.05 reduces diabetes progression from 242 to 237.982." @@ -326,6 +438,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "93105414", "metadata": {}, "source": [ "![individual causal analysis table](./img/regression-decision-making-7.png)" @@ -333,6 +446,7 @@ }, { "cell_type": "markdown", + "id": "a6fa7384", "metadata": {}, "source": [ "To put that into a formal intervention policy, switch to the \"Treatment policy\" tab. This view helps build policies for future interventions. You can identify what parts of your sample experience the largest responses to changes in causal features, or treatments, and construct rules to define which future populations should be targeted for particular interventions." @@ -341,6 +455,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "d1af0772", "metadata": {}, "source": [ "![treatment_policy](./img/regression-decision-making-8.png)" @@ -348,6 +463,7 @@ }, { "cell_type": "markdown", + "id": "ac8025e4", "metadata": {}, "source": [ "Is that change the best overall treatment for them? Let's investigate different policies. Going back to the \"Treatment policy\" tab, we see that going with the above intervention of s2 feature outperforms perturbing that with a \"always increase\" intervention." @@ -356,6 +472,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "ce677d35", "metadata": {}, "source": [ "![image.png](./img/regression-decision-making-9.png)" @@ -363,6 +480,7 @@ }, { "cell_type": "markdown", + "id": "3355ea1c", "metadata": {}, "source": [ "Finally, you can see a list demonstrating which datapoints (patients) in the current data sample have the largest causal response to the selected treatment (s2 feature change), based on all features included in the estimated causal model." @@ -371,6 +489,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "3cb02322", "metadata": {}, "source": [ "![causal-table](./img/regression-decision-making-10.png)" @@ -393,7 +512,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.11" } }, "nbformat": 4,