@@ -293,6 +293,13 @@ def overlay(self):
293
293
"""The current overlay """
294
294
return self ._overlay
295
295
296
+ @overlay .setter
297
+ def overlay (self , img ):
298
+ if img is None :
299
+ self ._remove_overlay ()
300
+ else :
301
+ self .set_overlay (img )
302
+
296
303
@property
297
304
def threshold (self ):
298
305
"""The current data display threshold """
@@ -382,16 +389,7 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
382
389
383
390
# we already have a plotted overlay
384
391
if self ._overlay is not None :
385
- # remove all images + cross hair lines
386
- for nn , im in enumerate (self ._overlay ._ims ):
387
- im .remove ()
388
- for line in self ._overlay ._crosshairs [nn ].values ():
389
- line .remove ()
390
- # remove the fourth axis, if it was created for the overlay
391
- if (self ._overlay .n_volumes > 1 and len (self ._overlay ._axes ) > 3
392
- and self .n_volumes == 1 ):
393
- a = self ._axes .pop (- 1 )
394
- a .remove ()
392
+ self ._remove_overlay ()
395
393
396
394
axes = self ._axes
397
395
o_n_volumes = int (np .prod (data .shape [3 :]))
@@ -401,6 +399,9 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
401
399
# 4D underlay, 3D overlay
402
400
elif o_n_volumes < self .n_volumes and o_n_volumes == 1 :
403
401
axes = axes [:- 1 ]
402
+ # 4D underlay, 4D overlay
403
+ elif o_n_volumes > 1 and self .n_volumes > 1 :
404
+ raise TypeError ('Cannot set 4D overlay on top of 4D underlay' )
404
405
405
406
# mask array for provided threshold
406
407
self ._overlay = self .__class__ (data , affine = affine , axes = axes )
@@ -416,6 +417,21 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
416
417
cross ['vert' ].set_visible (False )
417
418
self ._overlay ._draw ()
418
419
420
+ def _remove_overlay (self ):
421
+ """ Removes current overlay image + associated axes """
422
+ # remove all images + cross hair lines
423
+ for nn , im in enumerate (self ._overlay ._ims ):
424
+ im .remove ()
425
+ for line in self ._overlay ._crosshairs [nn ].values ():
426
+ line .remove ()
427
+ # remove the fourth axis, if it was created for the overlay
428
+ if (self ._overlay .n_volumes > 1 and len (self ._overlay ._axes ) > 3
429
+ and self .n_volumes == 1 ):
430
+ a = self ._axes .pop (- 1 )
431
+ a .remove ()
432
+
433
+ self ._overlay = None
434
+
419
435
def link_to (self , other ):
420
436
"""Link positional changes between two canvases
421
437
0 commit comments