-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscatter_plot.py
65 lines (48 loc) · 1.8 KB
/
scatter_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Scatter plot of 2 courses
"""
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from data_describer import HogwartsDataDescriber
def scatter_plot(plot: plt,
df: HogwartsDataDescriber,
course1: str,
course2: str):
"""
Scatter plot for 2 courses
:param plot: matplotlib.axes._subplots.AxesSubplot
:param df: HogwartsDataDescriber
:param course1: course 1 name
:param course2: course 2 name
:return: None
"""
for house, color in zip(df.houses, df.colors):
# choose course marks of students belonging to the house
x = df[course1][df['Hogwarts House'] == house]
y = df[course2][df['Hogwarts House'] == house]
plot.scatter(x, y, color=color, alpha=0.5)
def show_scatter_plot(csv_path: str, course1: str, course2: str):
# obtaining data for plotting
df = HogwartsDataDescriber.read_csv(csv_path)
_, ax = plt.subplots()
scatter_plot(ax, df, course1, course2)
ax.set_xlabel(course1)
ax.set_ylabel(course2)
ax.legend(df.houses)
plt.show()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--data_path',
type=str,
default='../data/dataset_train.csv',
help='Path to dataset_train.csv file')
parser.add_argument('--course1',
type=str,
default='Astronomy',
help='Name of the course for x axis')
parser.add_argument('--course2',
type=str,
default='Defense Against the Dark Arts',
help='Name of the course for y axis')
args = parser.parse_args()
show_scatter_plot(args.data_path, args.course1, args.course2)