Skip to content

Commit 5ad2bde

Browse files
committed
accuracy metric fixed
1 parent 1d27056 commit 5ad2bde

18 files changed

+3354
-1865
lines changed

analysis/Demo Analysis.ipynb

Lines changed: 83 additions & 75 deletions
Large diffs are not rendered by default.

analysis/Untitled.ipynb

Lines changed: 399 additions & 0 deletions
Large diffs are not rendered by default.

analysis/Untitled1.ipynb

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 18,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 26,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"a = np.array([1, 2, 2]) \n",
19+
"b = np.array([2, 2, 2])"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 29,
25+
"metadata": {},
26+
"outputs": [
27+
{
28+
"data": {
29+
"text/plain": [
30+
"0.5"
31+
]
32+
},
33+
"execution_count": 29,
34+
"metadata": {},
35+
"output_type": "execute_result"
36+
}
37+
],
38+
"source": [
39+
"def equal_weight_acc(truth, predictions): \n",
40+
" indices = [truth == x for x in set(truth)]\n",
41+
" return np.mean([(truth[i] == predictions[i]).mean() for i in indices])\n",
42+
"\n",
43+
"equal_weight_acc(a, b)"
44+
]
45+
}
46+
],
47+
"metadata": {
48+
"kernelspec": {
49+
"display_name": "Python (jordan_env)",
50+
"language": "python",
51+
"name": "jordan_env"
52+
},
53+
"language_info": {
54+
"codemirror_mode": {
55+
"name": "ipython",
56+
"version": 3
57+
},
58+
"file_extension": ".py",
59+
"mimetype": "text/x-python",
60+
"name": "python",
61+
"nbconvert_exporter": "python",
62+
"pygments_lexer": "ipython3",
63+
"version": "3.7.4"
64+
}
65+
},
66+
"nbformat": 4,
67+
"nbformat_minor": 2
68+
}

analysis/graphs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def plot_boxplots(df, save = False):
9292
plt.figure(figsize = (10, 5))
9393
sns.set_context("poster")
9494
sns.barplot(data = df, x = "penalty", y = "train_acc", capsize=.2)
95-
plt.xlabel("Gating Loss Penalty")
95+
plt.xlabel("Loss Parameter Φ")
9696
plt.ylabel("Training Accuracy")
9797
plt.ylim([0.8, 1.0])
9898
plt.savefig("graphs/train_boxplot.svg")
9999

100100
plt.figure(figsize = (10, 5))
101101
sns.barplot(data = df, x = "penalty", y = "final_acc", capsize=.2)
102-
plt.xlabel("Gating Loss Penalty")
102+
plt.xlabel("Loss Parameter Φ")
103103
plt.ylim([0.8, 1.0])
104104
plt.ylabel("Validation Accuracy")
105105
plt.savefig("graphs/val_boxplot.svg")

0 commit comments

Comments
 (0)