@@ -883,17 +883,19 @@ def cluster_centroids(self) -> np.ndarray:
883883
884884 return centroids
885885
886- def input_plotter (self , plot_title : str = '' , title_size : float = 16 ,
887- x_label : str = 'x' , y_label : str = 'y ' , z_label : str = 'z ' ,
888- label_size : float = 16 , pt_size : float = 1 , pt_colour : str = 'b' ,
889- grid : bool = True , grid_style : str = '--' , grid_size : float = 0.2 ,
890- x_ticks = None , y_ticks = None , z_ticks = None ,
886+ def input_plotter (self , filepath : Union [ str , None ] = None , plot_title : str = '' ,
887+ title_size : float = 16 , x_label : str = 'x ' , y_label : str = 'y ' ,
888+ z_label : str = 'z' , label_size : float = 16 , pt_size : float = 1 ,
889+ pt_colour : str = 'b' , grid : bool = True , grid_style : str = '--' ,
890+ grid_size : float = 0.2 , x_ticks = None , y_ticks = None , z_ticks = None ,
891891 ** kwargs ) -> None :
892892 """
893893 Plots the points in input.
894894
895895 Parameters
896896 ----------
897+ filepath : string, optional
898+ The path to the file where the plot should be saved.
897899 plot_title : string, optional
898900 Title of the plot.
899901 title_size : float, optional
@@ -958,7 +960,10 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
958960 if y_ticks is not None :
959961 plt .yticks (y_ticks )
960962
961- plt .show ()
963+ if filepath is not None :
964+ plt .savefig (filepath )
965+ else :
966+ plt .show ()
962967 elif self .clust_data .n_dim == 2 :
963968 plt .scatter (self .coords [0 ],
964969 self .coords [1 ],
@@ -982,7 +987,10 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
982987 if y_ticks is not None :
983988 plt .yticks (y_ticks )
984989
985- plt .show ()
990+ if filepath is not None :
991+ plt .savefig (filepath )
992+ else :
993+ plt .show ()
986994 else :
987995 fig = plt .figure ()
988996 ax_ = fig .add_subplot (projection = '3d' )
@@ -1012,14 +1020,17 @@ def input_plotter(self, plot_title: str = '', title_size: float = 16,
10121020 if z_ticks is not None :
10131021 ax_ .set_zticks (z_ticks )
10141022
1015- plt .show ()
1016-
1017- def cluster_plotter (self , plot_title : str = '' , title_size : float = 16 ,
1018- x_label : str = 'x' , y_label : str = 'y' , z_label : str = 'z' ,
1019- label_size : float = 16 , outl_size : float = 10 , pt_size : float = 10 ,
1020- seed_size : float = 25 , grid : bool = True , grid_style : str = '--' ,
1021- grid_size : float = 0.2 , x_ticks = None , y_ticks = None , z_ticks = None ,
1022- ** kwargs ) -> None :
1023+ if filepath is not None :
1024+ plt .savefig (filepath )
1025+ else :
1026+ plt .show ()
1027+
1028+ def cluster_plotter (self , filepath : Union [str , None ] = None , plot_title : str = '' ,
1029+ title_size : float = 16 , x_label : str = 'x' , y_label : str = 'y' ,
1030+ z_label : str = 'z' , label_size : float = 16 , outl_size : float = 10 ,
1031+ pt_size : float = 10 , seed_size : float = 25 , grid : bool = True ,
1032+ grid_style : str = '--' , grid_size : float = 0.2 , x_ticks = None ,
1033+ y_ticks = None , z_ticks = None , ** kwargs ) -> None :
10231034 """
10241035 Plots the clusters with a different colour for every cluster.
10251036
@@ -1028,30 +1039,32 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
10281039
10291040 Parameters
10301041 ----------
1042+ filepath : string, optional
1043+ The path to the file where the plot should be saved.
10311044 plot_title : string, optional
1032- Title of the plot
1045+ Title of the plot.
10331046 title_size : float, optional
1034- Size of the plot title
1047+ Size of the plot title.
10351048 x_label : string, optional
1036- Label on x-axis
1049+ Label on x-axis.
10371050 y_label : string, optional
1038- Label on y-axis
1051+ Label on y-axis.
10391052 z_label : string, optional
1040- Label on z-axis
1053+ Label on z-axis.
10411054 label_size : int, optional
1042- Fontsize of the axis labels
1055+ Fontsize of the axis labels.
10431056 outl_size : int, optional
1044- Size of the outliers in the plot
1057+ Size of the outliers in the plot.
10451058 pt_size : int, optional
1046- Size of the points in the plot
1059+ Size of the points in the plot.
10471060 seed_size : int, optional
1048- Size of the seeds in the plot
1061+ Size of the seeds in the plot.
10491062 grid : bool, optional
1050- f true displays grids in the plot
1063+ f true displays grids in the plot.
10511064 grid_style : string, optional
1052- Style of the grid
1065+ Style of the grid.
10531066 grid_size : float, optional
1054- Linewidth of the plot grid
1067+ Linewidth of the plot grid.
10551068 x_ticks : list, optional
10561069 List of ticks for the x axis.
10571070 y_ticks : list, optional
@@ -1105,7 +1118,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11051118 if y_ticks is not None :
11061119 plt .yticks (y_ticks )
11071120
1108- plt .show ()
1121+ if filepath is not None :
1122+ plt .savefig (filepath )
1123+ else :
1124+ plt .show ()
11091125 elif self .clust_data .n_dim == 2 :
11101126 data = {'x0' : self .coords [0 ],
11111127 'x1' : self .coords [1 ],
@@ -1140,7 +1156,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11401156 if y_ticks is not None :
11411157 plt .yticks (y_ticks )
11421158
1143- plt .show ()
1159+ if filepath is not None :
1160+ plt .savefig (filepath )
1161+ else :
1162+ plt .show ()
11441163 else :
11451164 data = {'x0' : self .coords [0 ],
11461165 'x1' : self .coords [1 ],
@@ -1182,7 +1201,10 @@ def cluster_plotter(self, plot_title: str = '', title_size: float = 16,
11821201 if z_ticks is not None :
11831202 ax_ .set_zticks (z_ticks )
11841203
1185- plt .show ()
1204+ if filepath is not None :
1205+ plt .savefig (filepath )
1206+ else :
1207+ plt .show ()
11861208
11871209 def to_csv (self , output_folder : str , file_name : str ) -> None :
11881210 """
0 commit comments