Skip to content

Commit b35a2b0

Browse files
declare T_flat directly
1 parent e69b179 commit b35a2b0

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

cellpose/dynamics.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=t
3434
Returns:
3535
np.ndarray: Generated flows.
3636
"""
37-
if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
38-
T = torch.zeros(shape, dtype=torch.float, device=device)
39-
else:
40-
T = torch.zeros(shape, dtype=torch.double, device=device)
41-
37+
38+
dtype = torch.float32 if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps" else torch.float64
39+
T_flat = torch.zeros(np.prod(shape), dtype=dtype, device=device)
40+
4241
ndim = len(shape)
4342
Ly, Lx = shape[-2:]
4443
# speed up with flattened inds
@@ -51,8 +50,7 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=t
5150
flat_meds = (meds[:, 0] * (Ly * Lx) + meds[:, 1] * Lx + meds[:, 2]).long()
5251

5352
flat_center = flat_neighbors[0]
54-
T_flat = T.view(-1)
55-
53+
5654
nneigh = flat_neighbors.shape[0]
5755
for i in range(n_iter):
5856
T_flat[flat_meds] += 1

0 commit comments

Comments
 (0)