Skip to content

Commit be30ee8

Browse files
committed
Minor fixes to notebooks
1 parent 95b4121 commit be30ee8

File tree

3 files changed

+102
-109
lines changed

3 files changed

+102
-109
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ Several example Jupyter notebooks are available.
2121
- [TSLime](examples/TSLime.ipynb)[^1]
2222
- [TSSaliency](examples/TSSaliency.ipynb)[^1]
2323
- [Time-series Individual Conditional Expectation (TSICE)](examples/TSICE.ipynb)[^1]
24+
- Full examples
25+
- [Energy load forecasting](examples/EnergyLoadForecasting.ipynb)[^1]
26+
- [Engine fault detection](examples/EngineFaultDetection.ipynb)[^1]
2427

2528
### Fairness metrics
2629

examples/EnergyLoadForecasting.ipynb

Lines changed: 45 additions & 54 deletions
Large diffs are not rendered by default.

examples/EngineFaultDetection.ipynb

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
"cell_type": "markdown",
55
"source": [
66
"# Engine Fault Detection\n",
7+
"\n",
8+
"> These examples are originally from the [AI Explainability 360 Toolkit for Time-Series and Industrial Use Cases](https://ai-library-examples.github.io/aix4industries/books/intro.html#) book.\n",
9+
"\n",
10+
"\n",
711
"Diagnostics is essential for assets and equiment in industrial domain. Specifically, fault detection is required to avoid asset failures in the future. In this demo, we cover a use case on fault detection in an automotive subsystem using [FordA engine noise data](https://timeseriesclassification.com/description.php?Dataset=FordA).\n",
812
"\n",
913
"The FordA dataset is used in IEEE World Congress on Computational Intelligence (WCCI) competition in 2008. Each case/instance has 500 measurements of engine noise and the target is to diagnose whether a certain symptom (defect) exists or not in the automotive subsystem. Train and test data sets were collected in typical operating conditions, with minimal noise contamination. A classification model can be trained to detect fault for each test instance.\n",
1014
"\n",
11-
"For this demo, we use the transformer model trained in [Keras time series examples](https://keras.io/examples/timeseries/timeseries_transformer_classification/). The model can be downloaded from [Huggingface Hub](https://huggingface.co/keras-io/timeseries_transformer_classification). We use aix360's [Time Series Saliency Explainer (TSSaliencyExplainer)](https://github.com/Trusted-AI/AIX360/blob/master/aix360/algorithms/tssaliency/tssaliency.py) to explain how each time point or temporal segments influenced the Keras transformer model's prediction for a test instance. Further, we use aix360's [Time Series Local Interpretable Model-agnostic Explainer (TSLime)](https://github.com/Trusted-AI/AIX360/blob/master/aix360/algorithms/tslime/tslime.py) to explain how much (magnitude) each time point in the test instance influenced the Keras transformer model's prediction.\n",
15+
"For this demo, we use the transformer model trained in [Keras time series examples](https://keras.io/examples/timeseries/timeseries_transformer_classification/). The model can be downloaded from [Huggingface Hub](https://huggingface.co/keras-io/timeseries_transformer_classification). TrustyAI will use AIX360's [Time Series Saliency Explainer (TSSaliencyExplainer)](https://github.com/Trusted-AI/AIX360/blob/master/aix360/algorithms/tssaliency/tssaliency.py) implementation to explain how each time point or temporal segments influenced the Keras transformer model's prediction for a test instance. Further, we use [Time Series Local Interpretable Model-agnostic Explainer (TSLime)](https://github.com/Trusted-AI/AIX360/blob/master/aix360/algorithms/tslime/tslime.py) (also provided by AIX360) to explain how much (magnitude) each time point in the test instance influenced the Keras transformer model's prediction.\n",
1216
"\n",
1317
"\n",
1418
"For more algorithmic details on TSSaliency and TSLime, you can refer to [Time Series Saliency Explanation](efd:references:tssaliency) and [Time Series Local Interpretable Model-Agnostic Explanation](efd:references:tslime) sections respectively.\n",
1519
"\n",
16-
"To start this hands on demo, skip to [Instructions](efd:references:instructions).\n",
17-
"\n",
18-
"\n",
1920
"#### Time Series Saliency Explanation\n",
2021
"\n",
2122
"Time Series Saliency (TSSaliency) Explainer is a model agnostic saliency explainer for time series associate tasks. The saliency supports univariate and multivariate use cases. It explains temporal importance of different variates on the model prediction. TSSaliency incorporates an integrated gradient method for saliency estimation and, therefore, provides reasonable values for functions that are continuous, and differentiable almost everywhere. It may be ill-suited to certain types of ensemble models such as Random Forests. The saliency measure involves the notion of a base value. For example, the base value can be the constant signal with average value. The saliency measure is computed by integrating the model sensitivity over a trajectory from the base value to the time series signal. The TSSaliency explainer provides variate wise contributions to model prediction at a temporal resolution.\n",
@@ -33,7 +34,7 @@
3334
},
3435
{
3536
"cell_type": "code",
36-
"execution_count": 5,
37+
"execution_count": 4,
3738
"outputs": [],
3839
"source": [
3940
"import os\n",
@@ -48,8 +49,8 @@
4849
"import numpy as np\n",
4950
"import pandas as pd\n",
5051
"import matplotlib.pyplot as plt\n",
52+
"plt.style.use('../styles/material_rh.mplstyle')\n",
5153
"import tensorflow as tf\n",
52-
"from tensorflow import keras\n",
5354
"from sklearn.metrics import accuracy_score\n",
5455
"from huggingface_hub import from_pretrained_keras\n",
5556
"from trustyai.utils.extras.timeseries import tsFrame\n",
@@ -68,8 +69,8 @@
6869
"metadata": {
6970
"collapsed": false,
7071
"ExecuteTime": {
71-
"end_time": "2023-08-14T10:08:37.070741Z",
72-
"start_time": "2023-08-14T10:08:37.064694Z"
72+
"end_time": "2023-08-15T08:47:52.576236Z",
73+
"start_time": "2023-08-15T08:47:52.549992Z"
7374
}
7475
},
7576
"id": "f369d48f44e942e9"
@@ -90,7 +91,7 @@
9091
},
9192
{
9293
"cell_type": "code",
93-
"execution_count": 6,
94+
"execution_count": 5,
9495
"outputs": [],
9596
"source": [
9697
"input_length = 500\n",
@@ -99,8 +100,8 @@
99100
"metadata": {
100101
"collapsed": false,
101102
"ExecuteTime": {
102-
"end_time": "2023-08-14T10:08:41.778102Z",
103-
"start_time": "2023-08-14T10:08:40.963421Z"
103+
"end_time": "2023-08-15T08:47:56.600095Z",
104+
"start_time": "2023-08-15T08:47:55.885642Z"
104105
}
105106
},
106107
"id": "7bf0efcfb0c5e7ce"
@@ -111,7 +112,7 @@
111112
"If the above dataset url does not work, follow the below steps. Otherwise, skip to [Plot the dataset](efd:references:plot_dataset).\n",
112113
"\n",
113114
"- Install aeon library using `python -m pip install git+https://github.com/aeon-toolkit/aeon.git` to download the data using [aeon APIs](https://www.aeon-toolkit.org/en/latest/api_reference/auto_generated/aeon.datasets.load_classification.html#aeon.datasets.load_classification).\n",
114-
"- Paste the below code snippet into a cell in `engine_fault_detection.ipynb` in Jupyter lab and run the cell to download the data using `aeon` apis.\n"
115+
"- Uncomment the code below and run it to download the data using `aeon` apis.\n"
115116
],
116117
"metadata": {
117118
"collapsed": false
@@ -123,15 +124,15 @@
123124
"execution_count": null,
124125
"outputs": [],
125126
"source": [
126-
"from aeon.datasets import load_classification\n",
127-
"from sklearn.model_selection import train_test_split\n",
128-
"\n",
129-
"X, y, meta = load_classification(name=\"ArrowHead\")\n",
130-
"y = y.astype(int)\n",
131-
"y[y == -1] = 0\n",
132-
"x_train, x_test, y_train, y_test = train_test_split(\n",
133-
" X, y, test_size=0.3, random_state=22\n",
134-
" )"
127+
"# from aeon.datasets import load_classification\n",
128+
"# from sklearn.model_selection import train_test_split\n",
129+
"# \n",
130+
"# X, y, meta = load_classification(name=\"ArrowHead\")\n",
131+
"# y = y.astype(int)\n",
132+
"# y[y == -1] = 0\n",
133+
"# x_train, x_test, y_train, y_test = train_test_split(\n",
134+
"# X, y, test_size=0.3, random_state=22\n",
135+
"# )"
135136
],
136137
"metadata": {
137138
"collapsed": false
@@ -150,7 +151,7 @@
150151
},
151152
{
152153
"cell_type": "code",
153-
"execution_count": 8,
154+
"execution_count": 6,
154155
"outputs": [
155156
{
156157
"data": {
@@ -167,8 +168,8 @@
167168
"metadata": {
168169
"collapsed": false,
169170
"ExecuteTime": {
170-
"end_time": "2023-08-14T10:09:43.450290Z",
171-
"start_time": "2023-08-14T10:09:43.182595Z"
171+
"end_time": "2023-08-15T08:47:59.998174Z",
172+
"start_time": "2023-08-15T08:47:59.749653Z"
172173
}
173174
},
174175
"id": "df5ef0290fd35d11"
@@ -187,15 +188,15 @@
187188
},
188189
{
189190
"cell_type": "code",
190-
"execution_count": 9,
191+
"execution_count": 7,
191192
"outputs": [
192193
{
193194
"data": {
194195
"text/plain": "Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s]",
195196
"application/vnd.jupyter.widget-view+json": {
196197
"version_major": 2,
197198
"version_minor": 0,
198-
"model_id": "50f19527187e4f13935a3ffb646796c2"
199+
"model_id": "379f454dd25046278efe9d054e4a247d"
199200
}
200201
},
201202
"metadata": {},
@@ -344,8 +345,8 @@
344345
"metadata": {
345346
"collapsed": false,
346347
"ExecuteTime": {
347-
"end_time": "2023-08-14T10:10:21.171780Z",
348-
"start_time": "2023-08-14T10:10:09.126608Z"
348+
"end_time": "2023-08-15T08:48:10.402492Z",
349+
"start_time": "2023-08-15T08:48:05.897678Z"
349350
}
350351
},
351352
"id": "1e6f1dd548961a55"
@@ -362,7 +363,7 @@
362363
},
363364
{
364365
"cell_type": "code",
365-
"execution_count": 10,
366+
"execution_count": 8,
366367
"outputs": [
367368
{
368369
"name": "stdout",
@@ -387,8 +388,8 @@
387388
"metadata": {
388389
"collapsed": false,
389390
"ExecuteTime": {
390-
"end_time": "2023-08-14T10:10:56.257383Z",
391-
"start_time": "2023-08-14T10:10:52.955081Z"
391+
"end_time": "2023-08-15T08:48:15.415344Z",
392+
"start_time": "2023-08-15T08:48:14.018416Z"
392393
}
393394
},
394395
"id": "d188c9214ebd8e49"
@@ -406,7 +407,7 @@
406407
},
407408
{
408409
"cell_type": "code",
409-
"execution_count": 11,
410+
"execution_count": 9,
410411
"outputs": [],
411412
"source": [
412413
"tssaliency_explainer = TSSaliencyExplainer(model= functools.partial(binary_model.predict_proba, verbose = 0),\n",
@@ -419,8 +420,8 @@
419420
"metadata": {
420421
"collapsed": false,
421422
"ExecuteTime": {
422-
"end_time": "2023-08-14T10:11:28.362096Z",
423-
"start_time": "2023-08-14T10:11:28.347961Z"
423+
"end_time": "2023-08-15T08:48:38.905247Z",
424+
"start_time": "2023-08-15T08:48:38.887494Z"
424425
}
425426
},
426427
"id": "8b120733283852f7"
@@ -439,7 +440,7 @@
439440
},
440441
{
441442
"cell_type": "code",
442-
"execution_count": 13,
443+
"execution_count": 10,
443444
"outputs": [],
444445
"source": [
445446
"indx = 3\n",
@@ -453,18 +454,16 @@
453454
"metadata": {
454455
"collapsed": false,
455456
"ExecuteTime": {
456-
"end_time": "2023-08-14T10:12:12.464510Z",
457-
"start_time": "2023-08-14T10:12:08.116250Z"
457+
"end_time": "2023-08-15T08:48:45.490158Z",
458+
"start_time": "2023-08-15T08:48:41.346711Z"
458459
}
459460
},
460461
"id": "76625d814030e612"
461462
},
462463
{
463464
"cell_type": "markdown",
464465
"source": [
465-
"By Proposition 1 in \"section 3 Mukund Sundararajan et al. Axiomatic Attribution for Deep Networks\", sum of saliency scores (Integrated Gradient) and model prediction delta (`f(instance) - f(base)`) between the input instance and the base value should be zero. As the value is closer to zero, quality of explanation is high.\n",
466-
"\n",
467-
"Paste the below code snippet into a cell in `engine_fault_detection.ipynb` in Jupyter lab and run the cell."
466+
"By Proposition 1 in \"section 3 Mukund Sundararajan et al. Axiomatic Attribution for Deep Networks\", sum of saliency scores (Integrated Gradient) and model prediction delta (`f(instance) - f(base)`) between the input instance and the base value should be zero. As the value is closer to zero, quality of explanation is high."
468467
],
469468
"metadata": {
470469
"collapsed": false
@@ -473,7 +472,7 @@
473472
},
474473
{
475474
"cell_type": "code",
476-
"execution_count": 15,
475+
"execution_count": 11,
477476
"outputs": [
478477
{
479478
"name": "stdout",
@@ -499,8 +498,8 @@
499498
"metadata": {
500499
"collapsed": false,
501500
"ExecuteTime": {
502-
"end_time": "2023-08-14T10:13:03.282051Z",
503-
"start_time": "2023-08-14T10:13:03.279073Z"
501+
"end_time": "2023-08-15T08:48:53.575158Z",
502+
"start_time": "2023-08-15T08:48:53.569791Z"
504503
}
505504
},
506505
"id": "7dd63d0ae66b7ba1"
@@ -517,7 +516,7 @@
517516
},
518517
{
519518
"cell_type": "code",
520-
"execution_count": 16,
519+
"execution_count": 12,
521520
"outputs": [
522521
{
523522
"data": {
@@ -546,8 +545,8 @@
546545
"metadata": {
547546
"collapsed": false,
548547
"ExecuteTime": {
549-
"end_time": "2023-08-14T10:14:20.378459Z",
550-
"start_time": "2023-08-14T10:14:19.260418Z"
548+
"end_time": "2023-08-15T08:48:59.080207Z",
549+
"start_time": "2023-08-15T08:48:58.293923Z"
551550
}
552551
},
553552
"id": "2afb1fb85261595f"
@@ -568,7 +567,7 @@
568567
},
569568
{
570569
"cell_type": "code",
571-
"execution_count": 17,
570+
"execution_count": 13,
572571
"outputs": [],
573572
"source": [
574573
"from trustyai.explainers.extras.tslime import TSLimeExplainer\n",
@@ -589,8 +588,8 @@
589588
"metadata": {
590589
"collapsed": false,
591590
"ExecuteTime": {
592-
"end_time": "2023-08-14T10:16:39.345384Z",
593-
"start_time": "2023-08-14T10:16:38.149681Z"
591+
"end_time": "2023-08-15T08:49:07.430220Z",
592+
"start_time": "2023-08-15T08:49:07.284781Z"
594593
}
595594
},
596595
"id": "c7c21fa6a40c825b"
@@ -609,16 +608,16 @@
609608
},
610609
{
611610
"cell_type": "code",
612-
"execution_count": 18,
611+
"execution_count": 14,
613612
"outputs": [],
614613
"source": [
615614
"tslime_explanation = tslime_explainer.explain(ts_instance)"
616615
],
617616
"metadata": {
618617
"collapsed": false,
619618
"ExecuteTime": {
620-
"end_time": "2023-08-14T10:17:06.319310Z",
621-
"start_time": "2023-08-14T10:17:05.006496Z"
619+
"end_time": "2023-08-15T08:49:12.947385Z",
620+
"start_time": "2023-08-15T08:49:12.033660Z"
622621
}
623622
},
624623
"id": "cda568983ba7ed4f"
@@ -637,7 +636,7 @@
637636
},
638637
{
639638
"cell_type": "code",
640-
"execution_count": 19,
639+
"execution_count": 15,
641640
"outputs": [
642641
{
643642
"data": {
@@ -654,8 +653,8 @@
654653
"metadata": {
655654
"collapsed": false,
656655
"ExecuteTime": {
657-
"end_time": "2023-08-14T10:17:58.672815Z",
658-
"start_time": "2023-08-14T10:17:58.217571Z"
656+
"end_time": "2023-08-15T08:49:18.165139Z",
657+
"start_time": "2023-08-15T08:49:17.710159Z"
659658
}
660659
},
661660
"id": "7796e79a7b98206a"

0 commit comments

Comments
 (0)