Skip to content

Commit cb3c9db

Browse files
authored
Include option for saving plots as files (#43)
* Add option for saving plots to file * Rename parameter and add annotation
1 parent 8afc95d commit cb3c9db

File tree

1 file changed

+52
-30
lines changed

1 file changed

+52
-30
lines changed

CLUEstering/CLUEstering.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -883,17 +883,19 @@ def cluster_centroids(self) -> np.ndarray:
883883

884884
return centroids
885885

886-
def input_plotter(self, plot_title: str = '', title_size: float = 16,
887-
x_label: str = 'x', y_label: str = 'y', z_label: str = 'z',
888-
label_size: float = 16, pt_size: float = 1, pt_colour: str = 'b',
889-
grid: bool = True, grid_style: str = '--', grid_size: float = 0.2,
890-
x_ticks=None, y_ticks=None, z_ticks=None,
886+
def input_plotter(self, filepath: Union[str, None] = None, plot_title: str = '',
887+
title_size: float = 16, x_label: str = 'x', y_label: str = 'y',
888+
z_label: str = 'z', label_size: float = 16, pt_size: float = 1,
889+
pt_colour: str = 'b', grid: bool = True, grid_style: str = '--',
890+
grid_size: float = 0.2, x_ticks=None, y_ticks=None, z_ticks=None,
891891
**kwargs) -> None:
892892
"""
893893
Plots the points in input.
894894
895895
Parameters
896896
----------
897+
filepath : string, optional
898+
The path to the file where the plot should be saved.
897899
plot_title : string, optional
898900
Title of the plot.
899901
title_size : float, optional
@@ -958,7 +960,10 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
958960
if y_ticks is not None:
959961
plt.yticks(y_ticks)
960962

961-
plt.show()
963+
if filepath is not None:
964+
plt.savefig(filepath)
965+
else:
966+
plt.show()
962967
elif self.clust_data.n_dim == 2:
963968
plt.scatter(self.coords[0],
964969
self.coords[1],
@@ -982,7 +987,10 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
982987
if y_ticks is not None:
983988
plt.yticks(y_ticks)
984989

985-
plt.show()
990+
if filepath is not None:
991+
plt.savefig(filepath)
992+
else:
993+
plt.show()
986994
else:
987995
fig = plt.figure()
988996
ax_ = fig.add_subplot(projection='3d')
@@ -1012,14 +1020,17 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
10121020
if z_ticks is not None:
10131021
ax_.set_zticks(z_ticks)
10141022

1015-
plt.show()
1016-
1017-
def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
1018-
x_label: str = 'x', y_label: str = 'y', z_label: str = 'z',
1019-
label_size: float = 16, outl_size: float = 10, pt_size: float = 10,
1020-
seed_size: float = 25, grid: bool = True, grid_style: str = '--',
1021-
grid_size: float = 0.2, x_ticks=None, y_ticks=None, z_ticks=None,
1022-
**kwargs) -> None:
1023+
if filepath is not None:
1024+
plt.savefig(filepath)
1025+
else:
1026+
plt.show()
1027+
1028+
def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '',
1029+
title_size: float = 16, x_label: str = 'x', y_label: str = 'y',
1030+
z_label: str = 'z', label_size: float = 16, outl_size: float = 10,
1031+
pt_size: float = 10, seed_size: float = 25, grid: bool = True,
1032+
grid_style: str = '--', grid_size: float = 0.2, x_ticks=None,
1033+
y_ticks=None, z_ticks=None, **kwargs) -> None:
10231034
"""
10241035
Plots the clusters with a different colour for every cluster.
10251036
@@ -1028,30 +1039,32 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
10281039
10291040
Parameters
10301041
----------
1042+
filepath : string, optional
1043+
The path to the file where the plot should be saved.
10311044
plot_title : string, optional
1032-
Title of the plot
1045+
Title of the plot.
10331046
title_size : float, optional
1034-
Size of the plot title
1047+
Size of the plot title.
10351048
x_label : string, optional
1036-
Label on x-axis
1049+
Label on x-axis.
10371050
y_label : string, optional
1038-
Label on y-axis
1051+
Label on y-axis.
10391052
z_label : string, optional
1040-
Label on z-axis
1053+
Label on z-axis.
10411054
label_size : int, optional
1042-
Fontsize of the axis labels
1055+
Fontsize of the axis labels.
10431056
outl_size : int, optional
1044-
Size of the outliers in the plot
1057+
Size of the outliers in the plot.
10451058
pt_size : int, optional
1046-
Size of the points in the plot
1059+
Size of the points in the plot.
10471060
seed_size : int, optional
1048-
Size of the seeds in the plot
1061+
Size of the seeds in the plot.
10491062
grid : bool, optional
1050-
f true displays grids in the plot
1063+
f true displays grids in the plot.
10511064
grid_style : string, optional
1052-
Style of the grid
1065+
Style of the grid.
10531066
grid_size : float, optional
1054-
Linewidth of the plot grid
1067+
Linewidth of the plot grid.
10551068
x_ticks : list, optional
10561069
List of ticks for the x axis.
10571070
y_ticks : list, optional
@@ -1105,7 +1118,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11051118
if y_ticks is not None:
11061119
plt.yticks(y_ticks)
11071120

1108-
plt.show()
1121+
if filepath is not None:
1122+
plt.savefig(filepath)
1123+
else:
1124+
plt.show()
11091125
elif self.clust_data.n_dim == 2:
11101126
data = {'x0': self.coords[0],
11111127
'x1': self.coords[1],
@@ -1140,7 +1156,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11401156
if y_ticks is not None:
11411157
plt.yticks(y_ticks)
11421158

1143-
plt.show()
1159+
if filepath is not None:
1160+
plt.savefig(filepath)
1161+
else:
1162+
plt.show()
11441163
else:
11451164
data = {'x0': self.coords[0],
11461165
'x1': self.coords[1],
@@ -1182,7 +1201,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11821201
if z_ticks is not None:
11831202
ax_.set_zticks(z_ticks)
11841203

1185-
plt.show()
1204+
if filepath is not None:
1205+
plt.savefig(filepath)
1206+
else:
1207+
plt.show()
11861208

11871209
def to_csv(self, output_folder: str, file_name: str) -> None:
11881210
"""

0 commit comments

Comments
 (0)