@@ -86,7 +86,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}
8686 return data
8787
8888def _get_batch (inds , data = None , labels = None , files = None , labels_files = None ,
89- normalize_params = {"normalize" : False }):
89+ channel_axis = None , normalize_params = {"normalize" : False }):
9090 """
9191 Get a batch of images and labels.
9292
@@ -96,6 +96,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
9696 labels (list or None): List of label data. If None, labels will be loaded from files.
9797 files (list or None): List of file paths for images.
9898 labels_files (list or None): List of file paths for labels.
99+ channel_axis (int or None): Axis of channel dimension.
99100 normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).
100101
101102 Returns:
@@ -104,7 +105,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
104105 if data is None :
105106 lbls = None
106107 imgs = [io .imread (files [i ]) for i in inds ]
107- imgs = _reshape_norm (imgs , normalize_params = normalize_params )
108+ imgs = _reshape_norm (imgs , channel_axis = channel_axis , normalize_params = normalize_params )
108109 if labels_files is not None :
109110 lbls = [io .imread (labels_files [i ])[1 :] for i in inds ]
110111 else :
0 commit comments