Skip to content
This repository was archived by the owner on Nov 18, 2024. It is now read-only.

Commit cf614ef

Browse files
committed
update check_pred
1 parent 40e80b8 commit cf614ef

5 files changed

+10
-31
lines changed

Diff for: other/check_pred/lstmconfusion_matrix.png

7.25 KB
Loading

Diff for: other/check_pred/metrics.ipynb

+10-31
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,22 @@
11
{
22
"cells": [
3-
{
4-
"cell_type": "code",
5-
"execution_count": 1,
6-
"metadata": {},
7-
"outputs": [],
8-
"source": [
9-
"import pandas as pd\n",
10-
"import matplotlib.pyplot as plt"
11-
]
12-
},
133
{
144
"cell_type": "code",
155
"execution_count": 2,
166
"metadata": {},
177
"outputs": [],
188
"source": [
19-
"from sklearn.metrics import accuracy_score\n",
20-
"from sklearn.metrics import roc_auc_score, roc_curve"
21-
]
22-
},
23-
{
24-
"cell_type": "code",
25-
"execution_count": 7,
26-
"metadata": {},
27-
"outputs": [],
28-
"source": [
29-
"import numpy as np"
30-
]
31-
},
32-
{
33-
"cell_type": "code",
34-
"execution_count": 18,
35-
"metadata": {},
36-
"outputs": [],
37-
"source": [
9+
"import pandas as pd\n",
10+
"import matplotlib.pyplot as plt\n",
11+
"import numpy as np\n",
12+
"from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, accuracy_score\n",
3813
"import seaborn as sns\n",
3914
"sns.set()"
4015
]
4116
},
4217
{
4318
"cell_type": "code",
44-
"execution_count": 19,
19+
"execution_count": 3,
4520
"metadata": {},
4621
"outputs": [],
4722
"source": [
@@ -53,7 +28,7 @@
5328
},
5429
{
5530
"cell_type": "code",
56-
"execution_count": 21,
31+
"execution_count": 5,
5732
"metadata": {},
5833
"outputs": [
5934
{
@@ -71,6 +46,10 @@
7146
" val_pred = pd.read_csv('../../logs/{}/val_pred.csv'.format(model)).iloc[:, 1]\n",
7247
" auc = roc_auc_score(train_y, val_pred)\n",
7348
" acc = accuracy_score(train_y, np.round(val_pred))\n",
49+
" cm = confusion_matrix(train_y, np.round(val_pred))\n",
50+
" sns.heatmap(cm, annot=True, cmap='Blues')\n",
51+
" plt.savefig(model + 'confusion_matrix.png')\n",
52+
" plt.clf()\n",
7453
" fpr, tpr, thresholds = roc_curve(train_y, val_pred)\n",
7554
" label='auc = {:.3f} acc={:.3f}'.format(auc, acc)\n",
7655
" plt.plot(fpr, tpr, label=label)\n",

Diff for: other/check_pred/resnet_1confusion_matrix.png

7.33 KB
Loading

Diff for: other/check_pred/resnet_2confusion_matrix.png

7.19 KB
Loading

Diff for: other/check_pred/wavenetconfusion_matrix.png

7.24 KB
Loading

0 commit comments

Comments
 (0)