Skip to content

Add an option to pass an existing figure to get_grid #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/grid_strategy/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GridStrategy(metaclass=ABCMeta):
def __init__(self, alignment="center"):
self.alignment = alignment

def get_grid(self, n):
def get_grid(self, n, figure=None):
"""
Return a list of axes designed according to the strategy.
Grid arrangements are tuples with the same length as the number of rows,
Expand All @@ -28,32 +28,34 @@ def get_grid(self, n):
x x x
x x
where each x would be a subplot.

If `figure` is None, creates a new figure.
"""

grid_arrangement = self.get_grid_arrangement(n)
return self.get_gridspec(grid_arrangement)
return self.get_gridspec(grid_arrangement, figure)

@classmethod
@abstractmethod
def get_grid_arrangement(cls, n): # pragma: nocover
pass

def get_gridspec(self, grid_arrangement):
def get_gridspec(self, grid_arrangement, figure=None):
nrows = len(grid_arrangement)
ncols = max(grid_arrangement)

# If it has justified alignment, will not be the same as the other alignments
if self.alignment == "justified":
return self._justified(nrows, grid_arrangement)
return self._justified(nrows, grid_arrangement, figure)
else:
return self._ragged(nrows, ncols, grid_arrangement)
return self._ragged(nrows, ncols, grid_arrangement, figure)

def _justified(self, nrows, grid_arrangement):
def _justified(self, nrows, grid_arrangement, figure=None):
ax_specs = []
num_small_cols = np.lcm.reduce(grid_arrangement)
gs = gridspec.GridSpec(
nrows, num_small_cols, figure=plt.figure(constrained_layout=True)
)
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, num_small_cols, figure=figure)
for r, row_cols in enumerate(grid_arrangement):
skip = num_small_cols // row_cols
for col in range(row_cols):
Expand All @@ -63,15 +65,15 @@ def _justified(self, nrows, grid_arrangement):
ax_specs.append(gs[r, s:e])
return ax_specs

def _ragged(self, nrows, ncols, grid_arrangement):
def _ragged(self, nrows, ncols, grid_arrangement, figure=None):
if len(set(grid_arrangement)) > 1:
col_width = 2
else:
col_width = 1

gs = gridspec.GridSpec(
nrows, ncols * col_width, figure=plt.figure(constrained_layout=True)
)
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, ncols * col_width, figure=figure)

ax_specs = []
for r, row_cols in enumerate(grid_arrangement):
Expand Down