Skip to content

Commit 8be81e8

Browse files
committed
support channel_axis in _get_batch
1 parent 6d23968 commit 8be81e8

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

cellpose/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}
8686
return data
8787

8888
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
89-
normalize_params={"normalize": False}):
89+
channel_axis=None, normalize_params={"normalize": False}):
9090
"""
9191
Get a batch of images and labels.
9292
@@ -96,6 +96,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
9696
labels (list or None): List of label data. If None, labels will be loaded from files.
9797
files (list or None): List of file paths for images.
9898
labels_files (list or None): List of file paths for labels.
99+
channel_axis (int or None): Axis of channel dimension.
99100
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
100101
101102
Returns:
@@ -104,7 +105,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
104105
if data is None:
105106
lbls = None
106107
imgs = [io.imread(files[i]) for i in inds]
107-
imgs = _reshape_norm(imgs, normalize_params=normalize_params)
108+
imgs = _reshape_norm(imgs, channel_axis=channel_axis, normalize_params=normalize_params)
108109
if labels_files is not None:
109110
lbls = [io.imread(labels_files[i])[1:] for i in inds]
110111
else:

0 commit comments

Comments
 (0)