|
1 | 1 | { |
2 | 2 | "cells": [ |
3 | | - { |
4 | | - "cell_type": "code", |
5 | | - "execution_count": 1, |
6 | | - "metadata": {}, |
7 | | - "outputs": [], |
8 | | - "source": [ |
9 | | - "import pandas as pd\n", |
10 | | - "import matplotlib.pyplot as plt" |
11 | | - ] |
12 | | - }, |
13 | 3 | { |
14 | 4 | "cell_type": "code", |
15 | 5 | "execution_count": 2, |
16 | 6 | "metadata": {}, |
17 | 7 | "outputs": [], |
18 | 8 | "source": [ |
19 | | - "from sklearn.metrics import accuracy_score\n", |
20 | | - "from sklearn.metrics import roc_auc_score, roc_curve" |
21 | | - ] |
22 | | - }, |
23 | | - { |
24 | | - "cell_type": "code", |
25 | | - "execution_count": 7, |
26 | | - "metadata": {}, |
27 | | - "outputs": [], |
28 | | - "source": [ |
29 | | - "import numpy as np" |
30 | | - ] |
31 | | - }, |
32 | | - { |
33 | | - "cell_type": "code", |
34 | | - "execution_count": 18, |
35 | | - "metadata": {}, |
36 | | - "outputs": [], |
37 | | - "source": [ |
| 9 | + "import pandas as pd\n", |
| 10 | + "import matplotlib.pyplot as plt\n", |
| 11 | + "import numpy as np\n", |
| 12 | + "from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, accuracy_score\n", |
38 | 13 | "import seaborn as sns\n", |
39 | 14 | "sns.set()" |
40 | 15 | ] |
41 | 16 | }, |
42 | 17 | { |
43 | 18 | "cell_type": "code", |
44 | | - "execution_count": 19, |
| 19 | + "execution_count": 3, |
45 | 20 | "metadata": {}, |
46 | 21 | "outputs": [], |
47 | 22 | "source": [ |
|
53 | 28 | }, |
54 | 29 | { |
55 | 30 | "cell_type": "code", |
56 | | - "execution_count": 21, |
| 31 | + "execution_count": 5, |
57 | 32 | "metadata": {}, |
58 | 33 | "outputs": [ |
59 | 34 | { |
|
71 | 46 | " val_pred = pd.read_csv('../../logs/{}/val_pred.csv'.format(model)).iloc[:, 1]\n", |
72 | 47 | " auc = roc_auc_score(train_y, val_pred)\n", |
73 | 48 | " acc = accuracy_score(train_y, np.round(val_pred))\n", |
| 49 | + " cm = confusion_matrix(train_y, np.round(val_pred))\n", |
| 50 | + " sns.heatmap(cm, annot=True, cmap='Blues')\n", |
| 51 | + " plt.savefig(model + 'confusion_matrix.png')\n", |
| 52 | + " plt.clf()\n", |
74 | 53 | " fpr, tpr, thresholds = roc_curve(train_y, val_pred)\n", |
75 | 54 | " label='auc = {:.3f} acc={:.3f}'.format(auc, acc)\n", |
76 | 55 | " plt.plot(fpr, tpr, label=label)\n", |
|
0 commit comments