2020import torch
2121import torch .nn .functional as F
2222
23- def _extend_centers_gpu (neighbors , meds , isneighbor , shape , n_iter = 200 ,
24- device = torch .device ("cpu" )):
23+ def _extend_centers_gpu (neighbors , meds , isneighbor , shape , n_iter = 200 , device = torch .device ("cpu" )):
2524 """Runs diffusion on GPU to generate flows for training images or quality control.
2625
2726 Args:
@@ -30,41 +29,51 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
3029 isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
3130 shape (tuple): Shape of the tensor.
3231 n_iter (int, optional): Number of iterations. Defaults to 200.
33- device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
32+ device (torch.device, optional): Device to run on. Defaults to torch.device("cpu").
3433
3534 Returns:
36- torch.Tensor: Generated flows.
37-
35+ np.ndarray: Generated flows.
3836 """
39- if torch .prod (torch .tensor (shape )) > 4e7 or device .type == "mps" :
40- T = torch .zeros (shape , dtype = torch .float , device = device )
37+
38+ dtype = torch .float32 if torch .prod (torch .tensor (shape )) > 4e7 or device .type == "mps" else torch .float64
39+ T_flat = torch .zeros (np .prod (shape ), dtype = dtype , device = device )
40+
41+ ndim = len (shape )
42+ Ly , Lx = shape [- 2 :]
43+ # speed up with flattened inds
44+ if ndim == 2 :
45+ Ly , Lx = shape
46+ flat_neighbors = (neighbors [0 ] * Lx + neighbors [1 ]).long ()
47+ flat_meds = (meds [:, 0 ] * Lx + meds [:, 1 ]).long ()
4148 else :
42- T = torch .zeros (shape , dtype = torch .double , device = device )
43-
49+ flat_neighbors = (neighbors [0 ] * (Ly * Lx ) + neighbors [1 ] * Lx + neighbors [2 ]).long ()
50+ flat_meds = (meds [:, 0 ] * (Ly * Lx ) + meds [:, 1 ] * Lx + meds [:, 2 ]).long ()
51+
52+ flat_center = flat_neighbors [0 ]
53+
54+ nneigh = flat_neighbors .shape [0 ]
4455 for i in range (n_iter ):
45- T [tuple (meds .T )] += 1
46- Tneigh = T [tuple (neighbors )]
47- Tneigh *= isneighbor
48- T [tuple (neighbors [:, 0 ])] = Tneigh .mean (axis = 0 )
49- del meds , isneighbor , Tneigh
50-
51- if T .ndim == 2 :
52- grads = T [neighbors [0 , [2 , 1 , 4 , 3 ]], neighbors [1 , [2 , 1 , 4 , 3 ]]]
53- del neighbors
56+ T_flat [flat_meds ] += 1
57+ Tneigh = T_flat [flat_neighbors ]
58+ T_flat [flat_center ] = (Tneigh * isneighbor ).sum (dim = 0 ) / nneigh
59+ del flat_meds , neighbors , meds , isneighbor , Tneigh
60+
61+ if ndim == 2 :
62+ grads = T_flat [flat_neighbors [[2 , 1 , 4 , 3 ]]]
5463 dy = grads [0 ] - grads [1 ]
5564 dx = grads [2 ] - grads [3 ]
5665 del grads
57- mu_torch = np .stack ((dy .cpu ().squeeze ( 0 ), dx .cpu ().squeeze ( 0 )), axis = - 2 )
66+ mu = np .stack ((dy .cpu ().numpy ( ), dx .cpu ().numpy ( )), axis = 0 )
5867 else :
59- grads = T [tuple (neighbors [:, 1 :])]
60- del neighbors
68+ grads = T_flat [flat_neighbors [1 :]]
6169 dz = grads [0 ] - grads [1 ]
6270 dy = grads [2 ] - grads [3 ]
6371 dx = grads [4 ] - grads [5 ]
6472 del grads
65- mu_torch = np .stack (
66- (dz .cpu ().squeeze (0 ), dy .cpu ().squeeze (0 ), dx .cpu ().squeeze (0 )), axis = - 2 )
67- return mu_torch
73+ mu = np .stack ((dz .cpu ().numpy (), dy .cpu ().numpy (), dx .cpu ().numpy ()), axis = 0 )
74+
75+ return mu
76+
6877
6978def center_of_mass (mask ):
7079 yi , xi = np .nonzero (mask )
@@ -78,6 +87,7 @@ def center_of_mass(mask):
7887
7988 return ymean , xmean
8089
90+
8191def get_centers (masks , slices ):
8292 centers = [center_of_mass (masks [slices [i ]]== (i + 1 )) for i in range (len (slices ))]
8393 centers = np .array ([np .array ([centers [i ][0 ] + slices [i ][0 ].start , centers [i ][1 ] + slices [i ][1 ].start ])
@@ -105,7 +115,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
105115 meds_p are cell centers.
106116 """
107117 if device is None :
108- device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('mps' ) if torch .backends .mps .is_available () else None
118+ device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('mps' ) if torch .backends .mps .is_available () else torch . device ( 'cpu' )
109119
110120 if masks .max () > 0 :
111121 Ly0 , Lx0 = masks .shape
@@ -120,6 +130,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
120130 y = y .int ()
121131 x = x .int ()
122132 neighbors = torch .zeros ((2 , 9 , y .shape [0 ]), dtype = torch .int , device = device )
133+
123134 yxi = [[0 , - 1 , 1 , 0 , 0 , - 1 , - 1 , 1 , 1 ], [0 , 0 , 0 , - 1 , 1 , - 1 , 1 , - 1 , 1 ]]
124135 for i in range (9 ):
125136 neighbors [0 , i ] = y + yxi [0 ][i ]
@@ -139,6 +150,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
139150 ### run diffusion
140151 n_iter = 2 * ext .max () if niter is None else niter
141152 mu = _extend_centers_gpu (neighbors , meds_p , isneighbor , shape , n_iter = n_iter ,
153+
142154 device = device )
143155 mu = mu .astype ("float64" )
144156
@@ -151,22 +163,24 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
151163 else :
152164 # no masks, return empty flows
153165 mu0 = np .zeros ((2 , masks .shape [0 ], masks .shape [1 ]))
154- return mu0
166+ slices = None
167+
168+ return mu0 , slices
155169
156- def masks_to_flows_gpu_3d (masks , device = None , niter = None ):
170+ def masks_to_flows_gpu_3d (masks , device = torch . device ( 'cpu' ) , niter = None ):
157171 """Convert masks to flows using diffusion from center pixel.
158172
159173 Args:
160174 masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
161- device (torch.device, optional): The device to run the computation on. Defaults to None .
175+ device (torch.device, optional): The device to run the computation on. Defaults to torch.device('cpu') .
162176 niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
163177
164178 Returns:
165179 np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
166180
167181 """
168182 if device is None :
169- device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('mps' ) if torch .backends .mps .is_available () else None
183+ device = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('mps' ) if torch .backends .mps .is_available () else torch . device ( 'cpu' )
170184
171185 Lz0 , Ly0 , Lx0 = masks .shape
172186 Lz , Ly , Lx = Lz0 + 2 , Ly0 + 2 , Lx0 + 2
@@ -185,7 +199,7 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
185199 # get mask centers
186200 slices = find_objects (masks )
187201
188- centers = np .zeros ((masks .max (), 3 ), " int" )
202+ centers = torch .zeros ((masks .max (), 3 ), dtype = torch . int , device = device )
189203 for i , si in enumerate (slices ):
190204 if si is not None :
191205 sz , sy , sx = si
@@ -205,7 +219,7 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
205219 centers [i , 2 ] = xmed + sx .start
206220
207221 # get neighbor validator (not all neighbors are in same mask)
208- neighbor_masks = masks_padded [tuple ( neighbors ) ]
222+ neighbor_masks = masks_padded [neighbors [ 0 ], neighbors [ 1 ], neighbors [ 2 ] ]
209223 isneighbor = neighbor_masks == neighbor_masks [0 ]
210224 ext = np .array (
211225 [[sz .stop - sz .start + 1 , sy .stop - sy .start + 1 , sx .stop - sx .start + 1 ]
@@ -258,7 +272,7 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non
258272 iterator = trange if nimg > 1 else range
259273 for n in iterator (nimg ):
260274 labels [n ][0 ] = fastremap .renumber (labels [n ][0 ], in_place = True )[0 ]
261- vecn = masks_to_flows_gpu (labels [n ][0 ].astype (int ), device = device , niter = niter )
275+ vecn = masks_to_flows_gpu (labels [n ][0 ].astype (int ), device = device , niter = niter )[ 0 ]
262276
263277 # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
264278 flow = np .concatenate ((labels [n ], labels [n ] > 0.5 , vecn ),
@@ -299,13 +313,11 @@ def flow_error(maski, dP_net, device=None):
299313 return
300314
301315 # flows predicted from estimated masks
302- dP_masks = masks_to_flows_gpu (maski , device = device )
303- # difference between predicted flows vs mask flows
304- flow_errors = np .zeros (maski .max ())
305- for i in range (dP_masks .shape [0 ]):
306- flow_errors += mean ((dP_masks [i ] - dP_net [i ] / 5. )** 2 , maski ,
307- index = np .arange (1 ,
308- maski .max () + 1 ))
316+ dP_masks , slices = masks_to_flows_gpu (maski , device = device )
317+
318+ # assign flow error to each mask, as mean squared error between predicted flows and flows from masks
319+ err = ((dP_masks - dP_net / 5. )** 2 ).sum (axis = 0 )
320+ flow_errors = np .array ([err [slc [0 ], slc [1 ]][maski [slc [0 ], slc [1 ]] == (j + 1 )].mean () for j , slc in enumerate (slices )])
309321
310322 return flow_errors , dP_masks
311323
@@ -357,7 +369,8 @@ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
357369 for t in range (niter ):
358370 dPt = torch .nn .functional .grid_sample (im , pt , align_corners = False )
359371 for k in range (ndim ): #clamp the final pixel locations
360- pt [..., k ] = torch .clamp (pt [..., k ] + dPt [:, k ], - 1. , 1. )
372+ pt [..., k ] += dPt [:, k ]
373+ torch .clamp_ (pt [..., k ], - 1. , 1. )
361374
362375 #undo the normalization from before, reverse order of operations
363376 pt += 1
@@ -444,10 +457,12 @@ def mem_info():
444457
445458 merrors , _ = flow_error (masks , flows , device0 )
446459 badi = 1 + (merrors > threshold ).nonzero ()[0 ]
447- masks [np .isin (masks , badi )] = 0
460+ fastremap .mask (masks , badi , in_place = True )
461+ fastremap .renumber (masks , in_place = True )
448462 return masks
449463
450464
465+
451466def max_pool1d (h , kernel_size = 5 , axis = 1 , out = None ):
452467 """ memory efficient max_pool thanks to Mark Kittisopikul
453468
@@ -536,16 +551,27 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
536551 dynamics_logger .warning ("no seeds found in get_masks_torch - no masks found." )
537552 return np .zeros (shape0 , dtype = "uint16" )
538553
539- npts = h1 [tuple ( seeds1 . T ) ]
554+ npts = h1 [seeds1 [:, 0 ], seeds1 [:, 1 ]] if ndim == 2 else h1 [ seeds1 [:, 0 ], seeds1 [:, 1 ], seeds1 [:, 2 ] ]
540555 isort1 = npts .argsort ()
541556 seeds1 = seeds1 [isort1 ]
542557
543558 n_seeds = len (seeds1 )
544- h_slc = torch .zeros ((n_seeds , * [11 ]* ndim ), device = seeds1 .device )
545- for k in range (n_seeds ):
546- slc = tuple ([slice (seeds1 [k ][j ]- 5 , seeds1 [k ][j ]+ 6 ) for j in range (ndim )])
547- h_slc [k ] = h1 [slc ]
559+ # speed up with flattened inds
560+ offset_t = torch .arange (- 5 , 6 , device = seeds1 .device )
561+ inds_t = torch .meshgrid (ndim * [offset_t ], indexing = "ij" )
562+ if ndim == 2 :
563+ flat_inds = (inds_t [0 ] * shape [1 ] + inds_t [1 ]).flatten ()
564+ flat_inds = flat_inds + (seeds1 [:,0 ] * shape [1 ] + seeds1 [:,1 ])[:,None ]
565+ else :
566+ flat_inds = (inds_t [0 ] * shape [1 ] * shape [2 ] + inds_t [1 ] * shape [2 ] + inds_t [2 ]).flatten ()
567+ flat_inds = flat_inds + (seeds1 [:,0 ] * shape [1 ] * shape [2 ] + seeds1 [:,1 ] * shape [2 ] + seeds1 [:,2 ])[:,None ]
568+
569+ h1 = h1 .view (- 1 )
570+ h_slc = h1 [flat_inds ]
571+ h_slc = h_slc .reshape (n_seeds , * [11 ]* ndim )
572+
548573 del h1
574+
549575 seed_masks = torch .zeros ((n_seeds , * [11 ]* ndim ), device = seeds1 .device )
550576 if ndim == 2 :
551577 seed_masks [:,5 ,5 ] = 1
@@ -557,16 +583,22 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
557583 seed_masks = max_pool_nd (seed_masks , kernel_size = 3 )
558584 seed_masks *= h_slc > 2
559585 del h_slc
560- seeds_new = [tuple ((torch .nonzero (seed_masks [k ]) + seeds1 [k ] - 5 ).T )
561- for k in range (n_seeds )]
562- del seed_masks
563586
587+
588+ # speed up from issue #1435 from weiyusheng
564589 dtype = torch .int32 if n_seeds < 2 ** 16 else torch .int64
565- M1 = torch .zeros (shape , dtype = dtype , device = device )
566- for k in range (n_seeds ):
567- M1 [seeds_new [k ]] = 1 + k
590+ M1 = torch .zeros (np .prod (shape ), device = device , dtype = dtype )
591+ ipix = torch .nonzero (seed_masks ).to (dtype )
592+ mask_idx = ipix [:, 0 ]
593+ mask_pos = ipix [:, 1 :] + seeds1 [mask_idx ] - 5
594+ if ndim == 2 :
595+ flat_inds = mask_pos [:, 0 ] * shape [1 ] + mask_pos [:, 1 ]
596+ else :
597+ flat_inds = mask_pos [:, 0 ] * shape [1 ] * shape [2 ] + mask_pos [:, 1 ] * shape [2 ] + mask_pos [:, 2 ]
598+ M1 .scatter_reduce_ (0 , flat_inds , mask_idx + 1 , reduce = "amax" , include_self = False )
599+ M1 = M1 .reshape (shape )
568600
569- M1 = M1 [tuple ( pt ) ]
601+ M1 = M1 [pt [ 0 ], pt [ 1 ]] if ndim == 2 else M1 [ pt [ 0 ], pt [ 1 ], pt [ 2 ] ]
570602 M1 = M1 .cpu ().numpy ()
571603
572604 dtype = "uint16" if n_seeds < 2 ** 16 else "uint32"
0 commit comments