diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index a602d62d8f..ad2c5ab2cf 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -589,9 +589,14 @@ def __getitem__(self, idx): "slicing image array data with `img.dataobj[slice]` or " "`img.get_fdata()[slice]`") - def orthoview(self): + def orthoview(self, overlay=None, **kwargs): """Plot the image using OrthoSlicer3D + Parameters + ---------- + overlay : ``spatialimage`` instance + Image to be plotted as overlay. Default: None + Returns ------- viewer : instance of OrthoSlicer3D @@ -603,8 +608,12 @@ def orthoview(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.dataobj, self.affine, - title=self.get_filename()) + ortho = OrthoSlicer3D(self.dataobj, self.affine, + title=self.get_filename()) + if overlay is not None: + ortho.set_overlay(overlay, **kwargs) + + return ortho def as_reoriented(self, ornt): """Apply an orientation change and return a new image diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 5d48665780..ba404529a0 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -69,6 +69,9 @@ def __init__(self, data, affine=None, axes=None, title=None): self._title = title self._closed = False self._cross = True + self._overlay = None + self._threshold = None + self._alpha = 1 data = np.asanyarray(data) if data.ndim < 3: @@ -285,6 +288,150 @@ def clim(self, clim): self._clim = tuple(clim) self.draw() + @property + def overlay(self): + """The current overlay """ + return self._overlay + + @overlay.setter + def overlay(self, img): + if img is None: + self._remove_overlay() + else: + self.set_overlay(img) + + @property + def threshold(self): + """The current data display threshold """ + return self._threshold + + @threshold.setter + def threshold(self, threshold): + # mask data array + if threshold is not None: + self._data = np.ma.masked_array(np.asarray(self._data), + np.asarray(self._data) <= threshold) + self._threshold = float(threshold) + else: + self._data = np.asarray(self._data) + self._threshold = threshold + + # update current volume data w/masked array and re-draw everything + if self._data.ndim > 3: + self._current_vol_data = self._data[..., self._data_idx[3]] + else: + self._current_vol_data = self._data + self._set_position(None, None, None, notify=False) + + @property + def alpha(self): + """ The current alpha (transparency) value """ + return self._alpha + + @alpha.setter + def alpha(self, alpha): + alpha = float(alpha) + if alpha > 1 or alpha < 0: + raise ValueError('alpha must be between 0 and 1') + for im in self._ims: + im.set_alpha(alpha) + self._alpha = alpha + self.draw() + + def set_overlay(self, data, affine=None, threshold=None, cmap='viridis', + alpha=0.7): + """ Sets `data` as overlay for currently plotted image + + Parameters + ---------- + data : array-like + The data that will be overlayed on the slicer. Should have 3+ + dimensions. + affine : array-like or None, optional + Affine transform for the provided data. This is used to determine + how the data should be sliced for plotting into the sagittal, + coronal, and axial view axes. If this does not match the currently + plotted slicer the provided data will be resampled. + threshold : float or None, optional + Threshold for overlay data; values below this threshold will not + be displayed. Default: None + cmap : str, optional + The Colormap instance or registered colormap name used to map + scalar data to colors. Default: 'viridis' + alpha : [0, 1] float, optional + Set the alpha value used for blending. Default: 0.7 + """ + if affine is None: + try: # did we get an image? + affine = data.affine + data = data.dataobj + except AttributeError: + pass + + # check that we have sufficient information to match the overlays + if affine is None and data.shape[:3] != self._data.shape[:3]: + raise ValueError('Provided `data` do not match shape of ' + 'underlay and no `affine` matrix was ' + 'provided. Please provide an `affine` matrix ' + 'or resample first three dims of `data` to {}' + .format(self._data.shape[:3])) + + # we need to resample the provided data to the already-plotted data + if not np.allclose(affine, self._affine): + from .processing import resample_from_to + from .nifti1 import Nifti1Image + target_shape = self._data.shape[:3] + data.shape[3:] + # we can't just use SpatialImage because we need an image type + # where the spatial axes are _always_ first + data = resample_from_to(Nifti1Image(data, affine), + (target_shape, self._affine)).dataobj + affine = self._affine + + # we already have a plotted overlay + if self._overlay is not None: + self._remove_overlay() + + axes = self._axes + o_n_volumes = int(np.prod(data.shape[3:])) + # 3D underlay, 4D overlay + if o_n_volumes > self.n_volumes and self.n_volumes == 1: + axes += [axes[0].figure.add_subplot(224)] + # 4D underlay, 3D overlay + elif o_n_volumes < self.n_volumes and o_n_volumes == 1: + axes = axes[:-1] + # 4D underlay, 4D overlay + elif o_n_volumes > 1 and self.n_volumes > 1: + raise TypeError('Cannot set 4D overlay on top of 4D underlay') + + # mask array for provided threshold + self._overlay = self.__class__(data, affine=affine, axes=axes) + self._overlay.threshold = threshold + + # set transparency and new cmap + self._overlay.cmap = cmap + self._overlay.alpha = alpha + + # no double cross-hairs (they get confused when we have linked orthos) + for cross in self._overlay._crosshairs: + cross['horiz'].set_visible(False) + cross['vert'].set_visible(False) + self._overlay._draw() + + def _remove_overlay(self): + """ Removes current overlay image + associated axes """ + # remove all images + cross hair lines + for nn, im in enumerate(self._overlay._ims): + im.remove() + for line in self._overlay._crosshairs[nn].values(): + line.remove() + # remove the fourth axis, if it was created for the overlay + if (self._overlay.n_volumes > 1 and len(self._overlay._axes) > 3 + and self.n_volumes == 1): + a = self._axes.pop(-1) + a.remove() + + self._overlay = None + def link_to(self, other): """Link positional changes between two canvases @@ -412,7 +559,7 @@ def _set_position(self, x, y, z, notify=True): idx = [slice(None)] * len(self._axes) for ii in range(3): idx[self._order[ii]] = self._data_idx[ii] - vdata = self._data[tuple(idx)].ravel() + vdata = np.asarray(self._data[tuple(idx)].ravel()) vdata = np.concatenate((vdata, [vdata[-1]])) self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5) self._volume_ax_objs['step'].set_ydata(vdata)