Skip to content

Commit da245fd

Browse files
Merge pull request #1447 from MouseLand/dynamics_speedup
speed up w/ flattened inds instead of tuple (claude suggestion) and #…
2 parents f878e3e + b35a2b0 commit da245fd

1 file changed

Lines changed: 85 additions & 53 deletions

File tree

cellpose/dynamics.py

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import torch
2121
import torch.nn.functional as F
2222

23-
def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
24-
device=torch.device("cpu")):
23+
def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=torch.device("cpu")):
2524
"""Runs diffusion on GPU to generate flows for training images or quality control.
2625
2726
Args:
@@ -30,41 +29,51 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
3029
isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
3130
shape (tuple): Shape of the tensor.
3231
n_iter (int, optional): Number of iterations. Defaults to 200.
33-
device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
32+
device (torch.device, optional): Device to run on. Defaults to torch.device("cpu").
3433
3534
Returns:
36-
torch.Tensor: Generated flows.
37-
35+
np.ndarray: Generated flows.
3836
"""
39-
if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
40-
T = torch.zeros(shape, dtype=torch.float, device=device)
37+
38+
dtype = torch.float32 if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps" else torch.float64
39+
T_flat = torch.zeros(np.prod(shape), dtype=dtype, device=device)
40+
41+
ndim = len(shape)
42+
Ly, Lx = shape[-2:]
43+
# speed up with flattened inds
44+
if ndim == 2:
45+
Ly, Lx = shape
46+
flat_neighbors = (neighbors[0] * Lx + neighbors[1]).long()
47+
flat_meds = (meds[:, 0] * Lx + meds[:, 1]).long()
4148
else:
42-
T = torch.zeros(shape, dtype=torch.double, device=device)
43-
49+
flat_neighbors = (neighbors[0] * (Ly * Lx) + neighbors[1] * Lx + neighbors[2]).long()
50+
flat_meds = (meds[:, 0] * (Ly * Lx) + meds[:, 1] * Lx + meds[:, 2]).long()
51+
52+
flat_center = flat_neighbors[0]
53+
54+
nneigh = flat_neighbors.shape[0]
4455
for i in range(n_iter):
45-
T[tuple(meds.T)] += 1
46-
Tneigh = T[tuple(neighbors)]
47-
Tneigh *= isneighbor
48-
T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
49-
del meds, isneighbor, Tneigh
50-
51-
if T.ndim == 2:
52-
grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
53-
del neighbors
56+
T_flat[flat_meds] += 1
57+
Tneigh = T_flat[flat_neighbors]
58+
T_flat[flat_center] = (Tneigh * isneighbor).sum(dim=0) / nneigh
59+
del flat_meds, neighbors, meds, isneighbor, Tneigh
60+
61+
if ndim == 2:
62+
grads = T_flat[flat_neighbors[[2, 1, 4, 3]]]
5463
dy = grads[0] - grads[1]
5564
dx = grads[2] - grads[3]
5665
del grads
57-
mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
66+
mu = np.stack((dy.cpu().numpy(), dx.cpu().numpy()), axis=0)
5867
else:
59-
grads = T[tuple(neighbors[:, 1:])]
60-
del neighbors
68+
grads = T_flat[flat_neighbors[1:]]
6169
dz = grads[0] - grads[1]
6270
dy = grads[2] - grads[3]
6371
dx = grads[4] - grads[5]
6472
del grads
65-
mu_torch = np.stack(
66-
(dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
67-
return mu_torch
73+
mu = np.stack((dz.cpu().numpy(), dy.cpu().numpy(), dx.cpu().numpy()), axis=0)
74+
75+
return mu
76+
6877

6978
def center_of_mass(mask):
7079
yi, xi = np.nonzero(mask)
@@ -78,6 +87,7 @@ def center_of_mass(mask):
7887

7988
return ymean, xmean
8089

90+
8191
def get_centers(masks, slices):
8292
centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))]
8393
centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start])
@@ -105,7 +115,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
105115
meds_p are cell centers.
106116
"""
107117
if device is None:
108-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
118+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
109119

110120
if masks.max() > 0:
111121
Ly0, Lx0 = masks.shape
@@ -120,6 +130,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
120130
y = y.int()
121131
x = x.int()
122132
neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device)
133+
123134
yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
124135
for i in range(9):
125136
neighbors[0, i] = y + yxi[0][i]
@@ -139,6 +150,7 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
139150
### run diffusion
140151
n_iter = 2 * ext.max() if niter is None else niter
141152
mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
153+
142154
device=device)
143155
mu = mu.astype("float64")
144156

@@ -151,22 +163,24 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
151163
else:
152164
# no masks, return empty flows
153165
mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
154-
return mu0
166+
slices = None
167+
168+
return mu0, slices
155169

156-
def masks_to_flows_gpu_3d(masks, device=None, niter=None):
170+
def masks_to_flows_gpu_3d(masks, device=torch.device('cpu'), niter=None):
157171
"""Convert masks to flows using diffusion from center pixel.
158172
159173
Args:
160174
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
161-
device (torch.device, optional): The device to run the computation on. Defaults to None.
175+
device (torch.device, optional): The device to run the computation on. Defaults to torch.device('cpu').
162176
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
163177
164178
Returns:
165179
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
166180
167181
"""
168182
if device is None:
169-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
183+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
170184

171185
Lz0, Ly0, Lx0 = masks.shape
172186
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
@@ -185,7 +199,7 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
185199
# get mask centers
186200
slices = find_objects(masks)
187201

188-
centers = np.zeros((masks.max(), 3), "int")
202+
centers = torch.zeros((masks.max(), 3), dtype=torch.int, device=device)
189203
for i, si in enumerate(slices):
190204
if si is not None:
191205
sz, sy, sx = si
@@ -205,7 +219,7 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
205219
centers[i, 2] = xmed + sx.start
206220

207221
# get neighbor validator (not all neighbors are in same mask)
208-
neighbor_masks = masks_padded[tuple(neighbors)]
222+
neighbor_masks = masks_padded[neighbors[0], neighbors[1], neighbors[2]]
209223
isneighbor = neighbor_masks == neighbor_masks[0]
210224
ext = np.array(
211225
[[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
@@ -258,7 +272,7 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non
258272
iterator = trange if nimg > 1 else range
259273
for n in iterator(nimg):
260274
labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
261-
vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
275+
vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)[0]
262276

263277
# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
264278
flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
@@ -299,13 +313,11 @@ def flow_error(maski, dP_net, device=None):
299313
return
300314

301315
# flows predicted from estimated masks
302-
dP_masks = masks_to_flows_gpu(maski, device=device)
303-
# difference between predicted flows vs mask flows
304-
flow_errors = np.zeros(maski.max())
305-
for i in range(dP_masks.shape[0]):
306-
flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
307-
index=np.arange(1,
308-
maski.max() + 1))
316+
dP_masks, slices = masks_to_flows_gpu(maski, device=device)
317+
318+
# assign flow error to each mask, as mean squared error between predicted flows and flows from masks
319+
err = ((dP_masks - dP_net / 5.)**2).sum(axis=0)
320+
flow_errors = np.array([err[slc[0], slc[1]][maski[slc[0], slc[1]] == (j+1)].mean() for j, slc in enumerate(slices)])
309321

310322
return flow_errors, dP_masks
311323

@@ -357,7 +369,8 @@ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
357369
for t in range(niter):
358370
dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
359371
for k in range(ndim): #clamp the final pixel locations
360-
pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
372+
pt[..., k] += dPt[:, k]
373+
torch.clamp_(pt[..., k], -1., 1.)
361374

362375
#undo the normalization from before, reverse order of operations
363376
pt += 1
@@ -444,10 +457,12 @@ def mem_info():
444457

445458
merrors, _ = flow_error(masks, flows, device0)
446459
badi = 1 + (merrors > threshold).nonzero()[0]
447-
masks[np.isin(masks, badi)] = 0
460+
fastremap.mask(masks, badi, in_place=True)
461+
fastremap.renumber(masks, in_place=True)
448462
return masks
449463

450464

465+
451466
def max_pool1d(h, kernel_size=5, axis=1, out=None):
452467
""" memory efficient max_pool thanks to Mark Kittisopikul
453468
@@ -536,16 +551,27 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
536551
dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
537552
return np.zeros(shape0, dtype="uint16")
538553

539-
npts = h1[tuple(seeds1.T)]
554+
npts = h1[seeds1[:,0], seeds1[:,1]] if ndim == 2 else h1[seeds1[:,0], seeds1[:,1], seeds1[:,2]]
540555
isort1 = npts.argsort()
541556
seeds1 = seeds1[isort1]
542557

543558
n_seeds = len(seeds1)
544-
h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
545-
for k in range(n_seeds):
546-
slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
547-
h_slc[k] = h1[slc]
559+
# speed up with flattened inds
560+
offset_t = torch.arange(-5, 6, device=seeds1.device)
561+
inds_t = torch.meshgrid(ndim * [offset_t], indexing="ij")
562+
if ndim == 2:
563+
flat_inds = (inds_t[0] * shape[1] + inds_t[1]).flatten()
564+
flat_inds = flat_inds + (seeds1[:,0] * shape[1] + seeds1[:,1])[:,None]
565+
else:
566+
flat_inds = (inds_t[0] * shape[1] * shape[2] + inds_t[1] * shape[2] + inds_t[2]).flatten()
567+
flat_inds = flat_inds + (seeds1[:,0] * shape[1] * shape[2] + seeds1[:,1] * shape[2] + seeds1[:,2])[:,None]
568+
569+
h1 = h1.view(-1)
570+
h_slc = h1[flat_inds]
571+
h_slc = h_slc.reshape(n_seeds, *[11]*ndim)
572+
548573
del h1
574+
549575
seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
550576
if ndim==2:
551577
seed_masks[:,5,5] = 1
@@ -557,16 +583,22 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
557583
seed_masks = max_pool_nd(seed_masks, kernel_size=3)
558584
seed_masks *= h_slc > 2
559585
del h_slc
560-
seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
561-
for k in range(n_seeds)]
562-
del seed_masks
563586

587+
588+
# speed up from issue #1435 from weiyusheng
564589
dtype = torch.int32 if n_seeds < 2**16 else torch.int64
565-
M1 = torch.zeros(shape, dtype=dtype, device=device)
566-
for k in range(n_seeds):
567-
M1[seeds_new[k]] = 1 + k
590+
M1 = torch.zeros(np.prod(shape), device=device, dtype=dtype)
591+
ipix = torch.nonzero(seed_masks).to(dtype)
592+
mask_idx = ipix[:, 0]
593+
mask_pos = ipix[:, 1:] + seeds1[mask_idx] - 5
594+
if ndim == 2:
595+
flat_inds = mask_pos[:, 0] * shape[1] + mask_pos[:, 1]
596+
else:
597+
flat_inds = mask_pos[:, 0] * shape[1] * shape[2] + mask_pos[:, 1] * shape[2] + mask_pos[:, 2]
598+
M1.scatter_reduce_(0, flat_inds, mask_idx + 1, reduce="amax", include_self=False)
599+
M1 = M1.reshape(shape)
568600

569-
M1 = M1[tuple(pt)]
601+
M1 = M1[pt[0], pt[1]] if ndim == 2 else M1[pt[0], pt[1], pt[2]]
570602
M1 = M1.cpu().numpy()
571603

572604
dtype = "uint16" if n_seeds < 2**16 else "uint32"

0 commit comments

Comments
 (0)