Skip to content

[BUG] CP3 incorrect UNet architecture when training with --all_channels  #1427

@neaamdawood

Description

@neaamdawood

Describe the bug

I currently have cellpose==3.0.7 installed in a cellpose3 .venv environment (cp3). I want to train a model with this specific version. I have a python script that I run in VS Code where I want to train a 3 channel model with the following parameters:

pretrained_model = r"C:\Users\nd\Downloads\pt_model_3ch"
traindir = r"C:\Users\nd\Desktop\data\segmentations_3ch\train"
testdir = r"C:\Users\nd\Desktop\data\segmentations_3ch\test"

cmd = (
    f'python -m cellpose '
    f'--train '
    f'--dir "{traindir}" '
    f'--test_dir "{testdir}" '
    f'--pretrained_model "{pretrained_model}" '
     f'--all_channels '
    f'--diam_mean 15 '
    f'--verbose '
    f'--no_norm '
    f'--save_every 25 '
    f'--mask_filter _seg.npy '
    f'--model_name_out model_3ch '
    f'--learning_rate 0.01 '
    f'--n_epochs 100 '
)

When I run the following python script in my cellpose3 environment, I get the following runtime error:

Error:
Traceback (most recent call last):
  File "C:\Users\nd\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\nd\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\__main__.py", line 308, in <module>
    main()
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\__main__.py", line 261, in main
    cpmodel_path = train.train_seg(
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\train.py", line 430, in train_seg
    y = net(X)[0]
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\resnet_torch.py", line 261, in forward
    T0 = self.downsample(data)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\resnet_torch.py", line 77, in forward
    return forward_call(*args, **kwargs)
    return forward_call(*args, **kwargs)
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\resnet_torch.py", line 77, in forward
    xd.append(self.down[n](y))
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\cellpose\resnet_torch.py", line 50, in forward
    x = self.proj(x) + self.conv[1](self.conv[0](x))
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\container.py", line 253, in forward
    input = module(input)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1779, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\module.py", line 1790, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\modules\batchnorm.py", line 194, in forward
    return F.batch_norm(
  File "C:\Users\nd\cp3\lib\site-packages\torch\nn\functional.py", line 2846, in batch_norm
    return torch.batch_norm(
RuntimeError: running_mean should contain 608 elements not 3

Is it possible to make a change to the cellpose source code that fixes this bug and release a version of cellpose that doesn't have this error when attempting to train a 3ch model in version 3?

Metadata

Metadata

Assignees

No one assigned

    Labels

    CP3Cellpose3 related issuesbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions