11import pathlib
2- from typing import List , Union
2+ from typing import Any , Callable , List , Optional , Union
33
44import matplotlib
55import matplotlib .pyplot as plt
1010from .load_data_ import axl_filename
1111from .result_set import ResultSet
1212
13- titleType = List [ str ]
13+ titleType = str
1414namesType = List [str ]
1515dataType = List [List [Union [int , float ]]]
1616
@@ -25,8 +25,11 @@ def _violinplot(
2525 self ,
2626 data : dataType ,
2727 names : namesType ,
28- title : titleType = None ,
29- ax : matplotlib .axes .SubplotBase = None ,
28+ title : Optional [titleType ] = None ,
29+ ax : Optional [matplotlib .axes .Axes ] = None ,
30+ get_figure : Callable [
31+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
32+ ] = lambda ax : ax .get_figure (),
3033 ) -> matplotlib .figure .Figure :
3134 """For making violinplots."""
3235
@@ -35,7 +38,11 @@ def _violinplot(
3538 else :
3639 ax = ax
3740
38- figure = ax .get_figure ()
41+ figure = get_figure (ax )
42+ if not isinstance (figure , matplotlib .figure .Figure ):
43+ raise RuntimeError (
44+ "get_figure unexpectedly returned a non-figure object"
45+ )
3946 width = max (self .num_players / 3 , 12 )
4047 height = width / 2
4148 spacing = 4
@@ -50,7 +57,7 @@ def _violinplot(
5057 )
5158 ax .set_xticks (positions )
5259 ax .set_xticklabels (names , rotation = 90 )
53- ax .set_xlim ([ 0 , spacing * (self .num_players + 1 )] )
60+ ax .set_xlim (( 0 , spacing * (self .num_players + 1 )) )
5461 ax .tick_params (axis = "both" , which = "both" , labelsize = 8 )
5562 if title :
5663 ax .set_title (title )
@@ -76,7 +83,9 @@ def _boxplot_xticks_labels(self):
7683 return [str (n ) for n in self .result_set .ranked_names ]
7784
7885 def boxplot (
79- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
86+ self ,
87+ title : Optional [titleType ] = None ,
88+ ax : Optional [matplotlib .axes .Axes ] = None ,
8089 ) -> matplotlib .figure .Figure :
8190 """For the specific mean score boxplot."""
8291 data = self ._boxplot_dataset
@@ -98,7 +107,9 @@ def _winplot_dataset(self):
98107 return wins , ranked_names
99108
100109 def winplot (
101- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
110+ self ,
111+ title : Optional [titleType ] = None ,
112+ ax : Optional [matplotlib .axes .Axes ] = None ,
102113 ) -> matplotlib .figure .Figure :
103114 """Plots the distributions for the number of wins for each strategy."""
104115
@@ -126,7 +137,9 @@ def _sdv_plot_dataset(self):
126137 return diffs , ranked_names
127138
128139 def sdvplot (
129- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
140+ self ,
141+ title : Optional [titleType ] = None ,
142+ ax : Optional [matplotlib .axes .Axes ] = None ,
130143 ) -> matplotlib .figure .Figure :
131144 """Score difference violin plots to visualize the distributions of how
132145 players attain their payoffs."""
@@ -143,7 +156,9 @@ def _lengthplot_dataset(self):
143156 ]
144157
145158 def lengthplot (
146- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
159+ self ,
160+ title : Optional [titleType ] = None ,
161+ ax : Optional [matplotlib .axes .Axes ] = None ,
147162 ) -> matplotlib .figure .Figure :
148163 """For the specific match length boxplot."""
149164 data = self ._lengthplot_dataset
@@ -174,9 +189,12 @@ def _payoff_heatmap(
174189 self ,
175190 data : dataType ,
176191 names : namesType ,
177- title : titleType = None ,
178- ax : matplotlib .axes .SubplotBase = None ,
192+ title : Optional [ titleType ] = None ,
193+ ax : Optional [ matplotlib .axes .Axes ] = None ,
179194 cmap : str = "viridis" ,
195+ get_figure : Callable [
196+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
197+ ] = lambda ax : ax .get_figure (),
180198 ) -> matplotlib .figure .Figure :
181199 """Generic heatmap plot"""
182200
@@ -185,7 +203,11 @@ def _payoff_heatmap(
185203 else :
186204 ax = ax
187205
188- figure = ax .get_figure ()
206+ figure = get_figure (ax )
207+ if not isinstance (figure , matplotlib .figure .Figure ):
208+ raise RuntimeError (
209+ "get_figure unexpectedly returned a non-figure object"
210+ )
189211 width = max (self .num_players / 4 , 12 )
190212 height = width
191213 figure .set_size_inches (width , height )
@@ -202,15 +224,19 @@ def _payoff_heatmap(
202224 return figure
203225
204226 def pdplot (
205- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
227+ self ,
228+ title : Optional [titleType ] = None ,
229+ ax : Optional [matplotlib .axes .Axes ] = None ,
206230 ) -> matplotlib .figure .Figure :
207231 """Payoff difference heatmap to visualize the distributions of how
208232 players attain their payoffs."""
209233 matrix , names = self ._pdplot_dataset
210234 return self ._payoff_heatmap (matrix , names , title = title , ax = ax )
211235
212236 def payoff (
213- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
237+ self ,
238+ title : Optional [titleType ] = None ,
239+ ax : Optional [matplotlib .axes .Axes ] = None ,
214240 ) -> matplotlib .figure .Figure :
215241 """Payoff heatmap to visualize the distributions of how
216242 players attain their payoffs."""
@@ -223,9 +249,12 @@ def payoff(
223249 def stackplot (
224250 self ,
225251 eco ,
226- title : titleType = None ,
252+ title : Optional [ titleType ] = None ,
227253 logscale : bool = True ,
228- ax : matplotlib .axes .SubplotBase = None ,
254+ ax : Optional [matplotlib .axes .Axes ] = None ,
255+ get_figure : Callable [
256+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
257+ ] = lambda ax : ax .get_figure (),
229258 ) -> matplotlib .figure .Figure :
230259
231260 populations = eco .population_sizes
@@ -235,7 +264,11 @@ def stackplot(
235264 else :
236265 ax = ax
237266
238- figure = ax .get_figure ()
267+ figure = get_figure (ax )
268+ if not isinstance (figure , matplotlib .figure .Figure ):
269+ raise RuntimeError (
270+ "get_figure unexpectedly returned a non-figure object"
271+ )
239272 turns = range (len (populations ))
240273 pops = [
241274 [populations [iturn ][ir ] for iturn in turns ]
@@ -247,7 +280,7 @@ def stackplot(
247280 ax .yaxis .set_label_position ("right" )
248281 ax .yaxis .labelpad = 25.0
249282
250- ax .set_ylim ([ 0.0 , 1.0 ] )
283+ ax .set_ylim (( 0.0 , 1.0 ) )
251284 ax .set_ylabel ("Relative population size" )
252285 ax .set_xlabel ("Turn" )
253286 if title is not None :
0 commit comments