|
1 | 1 | {
|
2 | 2 | "cells": [
|
3 |
| - { |
4 |
| - "cell_type": "code", |
5 |
| - "execution_count": 1, |
6 |
| - "metadata": {}, |
7 |
| - "outputs": [], |
8 |
| - "source": [ |
9 |
| - "import pandas as pd\n", |
10 |
| - "import matplotlib.pyplot as plt" |
11 |
| - ] |
12 |
| - }, |
13 | 3 | {
|
14 | 4 | "cell_type": "code",
|
15 | 5 | "execution_count": 2,
|
16 | 6 | "metadata": {},
|
17 | 7 | "outputs": [],
|
18 | 8 | "source": [
|
19 |
| - "from sklearn.metrics import accuracy_score\n", |
20 |
| - "from sklearn.metrics import roc_auc_score, roc_curve" |
21 |
| - ] |
22 |
| - }, |
23 |
| - { |
24 |
| - "cell_type": "code", |
25 |
| - "execution_count": 7, |
26 |
| - "metadata": {}, |
27 |
| - "outputs": [], |
28 |
| - "source": [ |
29 |
| - "import numpy as np" |
30 |
| - ] |
31 |
| - }, |
32 |
| - { |
33 |
| - "cell_type": "code", |
34 |
| - "execution_count": 18, |
35 |
| - "metadata": {}, |
36 |
| - "outputs": [], |
37 |
| - "source": [ |
| 9 | + "import pandas as pd\n", |
| 10 | + "import matplotlib.pyplot as plt\n", |
| 11 | + "import numpy as np\n", |
| 12 | + "from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, accuracy_score\n", |
38 | 13 | "import seaborn as sns\n",
|
39 | 14 | "sns.set()"
|
40 | 15 | ]
|
41 | 16 | },
|
42 | 17 | {
|
43 | 18 | "cell_type": "code",
|
44 |
| - "execution_count": 19, |
| 19 | + "execution_count": 3, |
45 | 20 | "metadata": {},
|
46 | 21 | "outputs": [],
|
47 | 22 | "source": [
|
|
53 | 28 | },
|
54 | 29 | {
|
55 | 30 | "cell_type": "code",
|
56 |
| - "execution_count": 21, |
| 31 | + "execution_count": 5, |
57 | 32 | "metadata": {},
|
58 | 33 | "outputs": [
|
59 | 34 | {
|
|
71 | 46 | " val_pred = pd.read_csv('../../logs/{}/val_pred.csv'.format(model)).iloc[:, 1]\n",
|
72 | 47 | " auc = roc_auc_score(train_y, val_pred)\n",
|
73 | 48 | " acc = accuracy_score(train_y, np.round(val_pred))\n",
|
| 49 | + " cm = confusion_matrix(train_y, np.round(val_pred))\n", |
| 50 | + " sns.heatmap(cm, annot=True, cmap='Blues')\n", |
| 51 | + " plt.savefig(model + 'confusion_matrix.png')\n", |
| 52 | + " plt.clf()\n", |
74 | 53 | " fpr, tpr, thresholds = roc_curve(train_y, val_pred)\n",
|
75 | 54 | " label='auc = {:.3f} acc={:.3f}'.format(auc, acc)\n",
|
76 | 55 | " plt.plot(fpr, tpr, label=label)\n",
|
|
0 commit comments