Skip to content

Commit 94ed43d

Browse files
Merge pull request #1008 from MouseLand/rev3
Rev3
2 parents dc3848d + 5711b85 commit 94ed43d

File tree

11 files changed

+738
-426
lines changed

11 files changed

+738
-426
lines changed

cellpose/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
234234
# slices from padding
235235
# slc = [slice(0, self.nclasses) for n in range(imgs.ndim)] # changed from imgs.shape[n]+1 for first slice size
236236
slc = [slice(0, imgs.shape[n] + 1) for n in range(imgs.ndim)]
237-
slc[-3] = slice(0, 3)
237+
slc[-3] = slice(0, net.nout)
238238
slc[-2] = slice(ysub[0], ysub[-1] + 1)
239239
slc[-1] = slice(xsub[0], xsub[-1] + 1)
240240
slc = tuple(slc)
@@ -286,7 +286,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
286286
yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32)
287287
styles = []
288288
if ny * nx > batch_size:
289-
ziterator = trange(Lz, file=tqdm_out)
289+
ziterator = (trange(Lz, file=tqdm_out, mininterval=30)
290+
if Lz > 1 else range(Lz))
290291
for i in ziterator:
291292
yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize,
292293
tile_overlap=tile_overlap)
@@ -297,7 +298,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
297298
ntiles = ny * nx
298299
nimgs = max(2, int(np.round(batch_size / ntiles)))
299300
niter = int(np.ceil(Lz / nimgs))
300-
ziterator = trange(niter, file=tqdm_out)
301+
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
302+
if Lz > 1 else range(niter))
301303
for k in ziterator:
302304
IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32)
303305
for i in range(min(Lz - k * nimgs, nimgs)):

cellpose/denoise.py

Lines changed: 204 additions & 116 deletions
Large diffs are not rendered by default.

cellpose/gui/gui.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -733,32 +733,31 @@ def make_buttons(self):
733733
self.l0.addWidget(self.denoiseBox, b, 0, 1, 9)
734734

735735
b0 = 0
736-
self.denoiseBoxG.addWidget(QLabel("mode:"), b0, 0, 1, 3)
737-
736+
738737
# DENOISING
739738
self.DenoiseButtons = []
740739
nett = [
741-
"filter image (settings below)",
742740
"clear restore/filter",
741+
"filter image (settings below)",
743742
"denoise (please set cell diameter first)",
744743
"deblur (please set cell diameter first)",
745744
"upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)",
745+
"one-click model trained to denoise+deblur+upsample (please set cell diameter first)"
746746
]
747-
self.denoise_text = ["filter", "none", "denoise", "deblur", "upsample"]
747+
self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"]
748748
self.restore = None
749749
self.ratio = 1.
750-
jj = 3
750+
jj = 0
751+
w = 3
751752
for j in range(len(self.denoise_text)):
752753
self.DenoiseButtons.append(
753754
guiparts.DenoiseButton(self, self.denoise_text[j]))
754-
w = 3
755755
self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w)
756-
jj += w
757756
self.DenoiseButtons[-1].setFixedWidth(75)
758757
self.DenoiseButtons[-1].setToolTip(nett[j])
759758
self.DenoiseButtons[-1].setFont(self.medfont)
760-
b0 += 1 if j == 1 else 0
761-
jj = 0 if j == 1 else jj
759+
b0 += 1 if j%2==1 else 0
760+
jj = 0 if j%2==1 else jj + w
762761

763762
# b0+=1
764763
self.save_norm = QCheckBox("save restored/filtered image")
@@ -767,22 +766,23 @@ def make_buttons(self):
767766
self.save_norm.setChecked(True)
768767
# self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8)
769768

770-
b0 += 1
771-
label = QLabel("Cellpose3 model type:")
769+
b0 -= 3
770+
label = QLabel("restore-dataset:")
772771
label.setToolTip(
773-
"choose model type and click [denoise], [deblur], or [upsample]")
772+
"choose dataset and click [denoise], [deblur], [upsample], or [one-click]")
774773
label.setFont(self.medfont)
775-
self.denoiseBoxG.addWidget(label, b0, 0, 1, 4)
774+
self.denoiseBoxG.addWidget(label, b0, 6, 1, 3)
776775

776+
b0 += 1
777777
self.DenoiseChoose = QComboBox()
778778
self.DenoiseChoose.setFont(self.medfont)
779-
self.DenoiseChoose.addItems(["one-click", "nuclei"])
780-
self.DenoiseChoose.setFixedWidth(100)
779+
self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"])
780+
self.DenoiseChoose.setFixedWidth(85)
781781
tipstr = "choose model type and click [denoise], [deblur], or [upsample]"
782782
self.DenoiseChoose.setToolTip(tipstr)
783-
self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 5, 1, 4)
783+
self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3)
784784

785-
b0 += 1
785+
b0 += 2
786786
# FILTERING
787787
self.filtBox = QCollapsible("custom filter settings")
788788
self.filtBox._toggle_btn.setFont(self.medfont)
@@ -1019,7 +1019,7 @@ def enable_buttons(self):
10191019
for i in range(len(self.DenoiseButtons)):
10201020
self.DenoiseButtons[i].setEnabled(True)
10211021
if self.load_3D:
1022-
self.DenoiseButtons[-1].setEnabled(False)
1022+
self.DenoiseButtons[-2].setEnabled(False)
10231023
self.ModelButtonB.setEnabled(True)
10241024
self.SizeButton.setEnabled(True)
10251025
self.newmodel.setEnabled(True)
@@ -2213,7 +2213,7 @@ def compute_restore(self):
22132213
self.DenoiseChoose.setCurrentIndex(1)
22142214
if "upsample" in self.restore:
22152215
i = self.DenoiseChoose.currentIndex()
2216-
diam_up = 30. if i == 0 else 17.
2216+
diam_up = 30. if i==0 or i==1 else 17.
22172217
print(diam_up, self.ratio)
22182218
self.Diameter.setText(str(diam_up / self.ratio))
22192219
self.compute_denoise_model(model_type=model_type)
@@ -2264,16 +2264,16 @@ def compute_denoise_model(self, model_type=None):
22642264
self.progress.setValue(0)
22652265
try:
22662266
tic = time.time()
2267-
nstr = "cyto3" if self.DenoiseChoose.currentText(
2268-
) == "one-click" else "nuclei"
2269-
print(model_type)
2267+
nstr = self.DenoiseChoose.currentText()
2268+
nstr.replace("-", "")
22702269
self.clear_restore()
22712270
model_name = model_type + "_" + nstr
2271+
print(model_name)
22722272
# denoising model
22732273
self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(),
22742274
model_type=model_name)
22752275
self.progress.setValue(10)
2276-
diam_up = 30. if "cyto3" in model_name else 17.
2276+
diam_up = 30. if "cyto" in model_name else 17.
22772277

22782278
# params
22792279
channels = self.get_channels()

cellpose/gui/gui3d.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def avg3d(C):
3838
"""
3939
Ly, Lx = C.shape
4040
# pad T by 2
41-
T = np.zeros((Ly + 2, Lx + 2), np.float32)
42-
M = np.zeros((Ly, Lx), np.float32)
41+
T = np.zeros((Ly + 2, Lx + 2), "float32")
42+
M = np.zeros((Ly, Lx), "float32")
4343
T[1:-1, 1:-1] = C.copy()
4444
y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
4545
indexing="ij")
@@ -244,7 +244,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
244244
vc = stroke[iz, 2]
245245
if iz.sum() > 0:
246246
# get points inside drawn points
247-
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
247+
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
248248
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
249249
axis=-1)[:, np.newaxis, :]
250250
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
@@ -265,7 +265,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
265265
elif ioverlap.sum() > 0:
266266
ar, ac = ar[~ioverlap], ac[~ioverlap]
267267
# compute outline of new mask
268-
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
268+
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
269269
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
270270
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
271271
cv2.CHAIN_APPROX_NONE)
@@ -282,7 +282,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
282282
pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
283283

284284
mall = mall[:, pix[0].min():pix[0].max() + 1,
285-
pix[1].min():pix[1].max() + 1].astype(np.float32)
285+
pix[1].min():pix[1].max() + 1].astype("float32")
286286
ymin, xmin = pix[0].min(), pix[1].min()
287287
if len(zdraw) > 1:
288288
mall, zfill = interpZ(mall, zdraw - zmin)
@@ -422,15 +422,15 @@ def update_ortho(self):
422422
for j in range(2):
423423
if j == 0:
424424
if self.view == 0:
425-
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2)
425+
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
426426
else:
427427
image = self.stack_filtered[zmin:zmax, :,
428-
x].transpose(1, 0, 2)
428+
x].transpose(1, 0, 2).copy()
429429
else:
430430
image = self.stack[
431431
zmin:zmax,
432-
y, :] if self.view == 0 else self.stack_filtered[zmin:zmax,
433-
y, :]
432+
y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
433+
y, :].copy()
434434
if self.nchan == 1:
435435
# show single channel
436436
image = image[..., 0]
@@ -458,28 +458,30 @@ def update_ortho(self):
458458
self.imgOrtho[j].setLevels(
459459
self.saturation[0][self.currentZ])
460460
elif self.color == 4:
461-
image = image.astype(np.float32).mean(axis=-1).astype(np.uint8)
461+
if image.ndim > 2:
462+
image = image.astype("float32").mean(axis=2).astype("uint8")
462463
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
463464
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
464465
elif self.color == 5:
465-
image = image.astype(np.float32).mean(axis=-1).astype(np.uint8)
466+
if image.ndim > 2:
467+
image = image.astype("float32").mean(axis=2).astype("uint8")
466468
self.imgOrtho[j].setImage(image, autoLevels=False,
467469
lut=self.cmap[0])
468470
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
469471
self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
470472
self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
471473

472474
else:
473-
image = np.zeros((10, 10), np.uint8)
475+
image = np.zeros((10, 10), "uint8")
474476
self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
475477
self.imgOrtho[0].setLevels([0.0, 255.0])
476478
self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
477479
self.imgOrtho[1].setLevels([0.0, 255.0])
478480

479481
zrange = zmax - zmin
480482
self.layer_ortho = [
481-
np.zeros((self.Ly, zrange, 4), np.uint8),
482-
np.zeros((zrange, self.Lx, 4), np.uint8)
483+
np.zeros((self.Ly, zrange, 4), "uint8"),
484+
np.zeros((zrange, self.Lx, 4), "uint8")
483485
]
484486
if self.masksOn:
485487
for j in range(2):
@@ -488,7 +490,7 @@ def update_ortho(self):
488490
else:
489491
cp = self.cellpix[zmin:zmax, y]
490492
self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
491-
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype(np.uint8)
493+
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
492494
if self.selected > 0:
493495
self.layer_ortho[j][cp == self.selected] = np.array(
494496
[255, 255, 255, self.opacity])
@@ -499,7 +501,7 @@ def update_ortho(self):
499501
op = self.outpix[zmin:zmax, :, x].T
500502
else:
501503
op = self.outpix[zmin:zmax, y]
502-
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype(np.uint8)
504+
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
503505

504506
for j in range(2):
505507
self.layerOrtho[j].setImage(self.layer_ortho[j])

cellpose/models.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"lowhigh": None,
3636
"percentile": None,
3737
"normalize": True,
38-
"norm3D": False,
38+
"norm3D": True,
3939
"sharpen_radius": 0,
4040
"smooth_radius": 0,
4141
"tile_norm_blocksize": 0,
@@ -263,7 +263,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
263263
if (pretrained_model and not Path(pretrained_model).exists() and
264264
np.any([pretrained_model == s for s in all_models])):
265265
model_type = pretrained_model
266-
266+
267267
# check if model_type is builtin or custom user model saved in .cellpose/models
268268
if model_type is not None and np.any([model_type == s for s in all_models]):
269269
if np.any([model_type == s for s in MODEL_NAMES]):
@@ -286,6 +286,10 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
286286
models_logger.warning(
287287
"pretrained_model path does not exist, using default model")
288288
use_default = True
289+
elif pretrained_model:
290+
if pretrained_model[-13:] == "nucleitorch_0":
291+
builtin = True
292+
self.diam_mean = 17.
289293

290294
builtin = True if use_default else builtin
291295
self.pretrained_model = model_path(
@@ -503,37 +507,18 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
503507
del yf
504508
else:
505509
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
506-
iterator = trange(nimg, file=tqdm_out,
507-
mininterval=30) if nimg > 1 else range(nimg)
508-
styles = np.zeros((nimg, self.nbase[-1]), np.float32)
510+
img = np.asarray(x)
511+
if do_normalization:
512+
img = transforms.normalize_img(img, **normalize_params)
513+
if rescale != 1.0:
514+
img = transforms.resize_image(img, rsz=rescale)
515+
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
516+
tile=tile, tile_overlap=tile_overlap)
509517
if resample:
510-
dP = np.zeros((2, nimg, shape[1], shape[2]), np.float32)
511-
cellprob = np.zeros((nimg, shape[1], shape[2]), np.float32)
512-
else:
513-
dP = np.zeros(
514-
(2, nimg, int(shape[1] * rescale), int(shape[2] * rescale)),
515-
np.float32)
516-
cellprob = np.zeros(
517-
(nimg, int(shape[1] * rescale), int(shape[2] * rescale)),
518-
np.float32)
519-
for i in iterator:
520-
img = np.asarray(x[i])
521-
if do_normalization:
522-
img = transforms.normalize_img(img, **normalize_params)
523-
if rescale != 1.0:
524-
img = transforms.resize_image(img, rsz=rescale)
525-
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
526-
tile=tile, tile_overlap=tile_overlap)
527-
if resample:
528-
yf = transforms.resize_image(yf, shape[1], shape[2])
529-
530-
cellprob[i] = yf[:, :, 2]
531-
dP[:, i] = yf[:, :, :2].transpose((2, 0, 1))
532-
if self.nclasses == 4:
533-
if i == 0:
534-
bd = np.zeros_like(cellprob)
535-
bd[i] = yf[:, :, 3]
536-
styles[i][:len(style)] = style
518+
yf = transforms.resize_image(yf, shape[1], shape[2])
519+
dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy()
520+
cellprob = yf[..., 2]
521+
styles = style
537522
del yf, style
538523
styles = styles.squeeze()
539524

cellpose/resnet_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class CPnet(nn.Module):
199199
def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True,
200200
diam_mean=30.):
201201
super().__init__()
202+
self.nchan = nbase[0]
202203
self.nbase = nbase
203204
self.nout = nout
204205
self.sz = sz

cellpose/train.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
331331
test_probs=None, load_files=True, batch_size=8, learning_rate=0.005,
332332
n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None,
333333
channel_axis=None, rgb=False, normalize=True, compute_flows=False,
334-
save_path=None, save_every=100, nimg_per_epoch=None,
334+
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
335335
nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224,
336336
min_train_masks=5, model_name=None):
337337
"""
@@ -362,6 +362,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
362362
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
363363
save_path (str, optional): String - where to save the trained model. Defaults to None.
364364
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
365+
save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
365366
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
366367
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
367368
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
@@ -444,10 +445,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
444445
t0 = time.time()
445446
model_name = f"cellpose_{t0}" if model_name is None else model_name
446447
save_path = Path.cwd() if save_path is None else Path(save_path)
447-
model_path = save_path / "models" / model_name
448+
filename = save_path / "models" / model_name
448449
(save_path / "models").mkdir(exist_ok=True)
449450

450-
train_logger.info(f">>> saving model to {model_path}")
451+
train_logger.info(f">>> saving model to {filename}")
451452

452453
lavg, nsum = 0, 0
453454
for iepoch in range(n_epochs):
@@ -518,15 +519,21 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
518519
lavgt /= len(rperm)
519520
lavg /= nsum
520521
train_logger.info(
521-
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.4f}, time {time.time()-t0:.2f}s"
522+
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
522523
)
523524
lavg, nsum = 0, 0
524525

525-
if iepoch > 0 and iepoch % save_every == 0:
526-
net.save_model(model_path)
527-
net.save_model(model_path)
526+
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
527+
if save_each and iepoch != n_epochs - 1: #separate files as model progresses
528+
filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
529+
else:
530+
filename0 = filename
531+
train_logger.info(f"saving network parameters to {filename0}")
532+
net.save_model(filename0)
533+
534+
net.save_model(filename)
528535

529-
return model_path
536+
return filename
530537

531538

532539
def train_size(net, pretrained_model, train_data=None, train_labels=None,

0 commit comments

Comments
 (0)