diff --git a/src/grid_strategy/_abc.py b/src/grid_strategy/_abc.py index 796cca5..e570e8b 100644 --- a/src/grid_strategy/_abc.py +++ b/src/grid_strategy/_abc.py @@ -18,7 +18,7 @@ class GridStrategy(metaclass=ABCMeta): def __init__(self, alignment="center"): self.alignment = alignment - def get_grid(self, n): + def get_grid(self, n, figure=None): """ Return a list of axes designed according to the strategy. Grid arrangements are tuples with the same length as the number of rows, @@ -28,32 +28,35 @@ def get_grid(self, n): x x x x x where each x would be a subplot. + + If `figure` is None, creates a new figure. """ grid_arrangement = self.get_grid_arrangement(n) - return self.get_gridspec(grid_arrangement) + return self.get_gridspec(grid_arrangement, figure) @classmethod @abstractmethod def get_grid_arrangement(cls, n): # pragma: nocover pass - def get_gridspec(self, grid_arrangement): + def get_gridspec(self, grid_arrangement, figure=None): + if figure is None: + figure = plt.figure(constrained_layout=True) + nrows = len(grid_arrangement) ncols = max(grid_arrangement) # If it has justified alignment, will not be the same as the other alignments if self.alignment == "justified": - return self._justified(nrows, grid_arrangement) + return self._justified(nrows, grid_arrangement, figure) else: - return self._ragged(nrows, ncols, grid_arrangement) + return self._ragged(nrows, ncols, grid_arrangement, figure) - def _justified(self, nrows, grid_arrangement): + def _justified(self, nrows, grid_arrangement, figure): ax_specs = [] num_small_cols = np.lcm.reduce(grid_arrangement) - gs = gridspec.GridSpec( - nrows, num_small_cols, figure=plt.figure(constrained_layout=True) - ) + gs = gridspec.GridSpec(nrows, num_small_cols, figure=figure) for r, row_cols in enumerate(grid_arrangement): skip = num_small_cols // row_cols for col in range(row_cols): @@ -63,15 +66,13 @@ def _justified(self, nrows, grid_arrangement): ax_specs.append(gs[r, s:e]) return ax_specs - def _ragged(self, nrows, ncols, grid_arrangement): + def _ragged(self, nrows, ncols, grid_arrangement, figure): if len(set(grid_arrangement)) > 1: col_width = 2 else: col_width = 1 - gs = gridspec.GridSpec( - nrows, ncols * col_width, figure=plt.figure(constrained_layout=True) - ) + gs = gridspec.GridSpec(nrows, ncols * col_width, figure=figure) ax_specs = [] for r, row_cols in enumerate(grid_arrangement): diff --git a/tests/test_grids.py b/tests/test_grids.py index f28a80a..27fc839 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -1,5 +1,5 @@ import pytest -from unittest import mock +from unittest.mock import Mock, patch, sentinel from grid_strategy.strategies import SquareStrategy @@ -17,29 +17,18 @@ def __eq__(self, other): return self.rows == other.rows and self.cols == other.cols -class GridSpecMock: - def __init__(self, nrows, ncols, *args, **kwargs): - self._nrows_ = nrows - self._ncols_ = ncols - - self._args_ = args - self._kwargs_ = kwargs - - def __getitem__(self, key_tup): - return SpecValue(*key_tup, self) - - @pytest.fixture def gridspec_mock(): - class Figure: - pass + with patch("grid_strategy._abc.gridspec.GridSpec", new_callable=Mock) as g: + g.return_value.__getitem__ = lambda self, key_tup: SpecValue(*key_tup, self) + yield g - def figure(*args, **kwargs): - return Figure() - with mock.patch(f"grid_strategy._abc.gridspec.GridSpec", new=GridSpecMock) as g: - with mock.patch(f"grid_strategy._abc.plt.figure", new=figure): - yield g +@pytest.fixture +def plt_figure_mock(): + with patch("grid_strategy._abc.plt.figure", new_callable=Mock) as f: + f.return_value = sentinel.new_figure + yield f @pytest.mark.parametrize( @@ -68,10 +57,25 @@ def figure(*args, **kwargs): ("left", 2, [(0, slice(0, 1)), (0, slice(1, 2))]), ], ) -def test_square_spec(gridspec_mock, align, n, exp_specs): +@pytest.mark.parametrize("figure_passed", [True, False]) +def test_square_spec( + gridspec_mock, plt_figure_mock, align, n, exp_specs, figure_passed +): + if figure_passed: + user_figure = sentinel.user_figure + else: + user_figure = None + ss = SquareStrategy(align) + act = ss.get_grid(n, figure=user_figure) - act = ss.get_grid(n) exp = [SpecValue(*spec) for spec in exp_specs] - assert act == exp + + args, kwargs = gridspec_mock.call_args + if figure_passed: + plt_figure_mock.assert_not_called() + assert kwargs["figure"] is sentinel.user_figure + else: + plt_figure_mock.assert_called_once_with(constrained_layout=True) + assert kwargs["figure"] is sentinel.new_figure