-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpatches_extraction.py
305 lines (249 loc) · 11.7 KB
/
patches_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from typing import Union, List
from pathlib import Path
import configparser
import tables
from PIL import Image
import numpy as np
from numpy import ndarray
from sklearn.feature_extraction.image import extract_patches_2d, extract_patches
class Extractor:
"""
A generic extractor class with size-specific settings
and unified interface for extracting patches from images.
Can be used to extract patches from both images and masks.
"""
def __init__(self,
resize: float = 0.125,
mirror_pad_size: int = 128,
patch_size: int = 256,
stride_size: int = 64,
normalize_mask: bool = False,
config_section_name: str = None):
"""
Configure the extractor with specific size arguments;
Parameters
----------
resize : float
resize factor of original slide images, e.g.,
in order to get a desired 5x magnification patches
from 40x magnification slides, the resize factor should
be 0.125, since (0.125) * 40 = 5.
mirror_pad_size : int
size of padding regions in front of or behind the first two axis
mirror/reflecting padding is used here
patch_size : int
height and width of a patch;
for now only square patches are extracted
stride_size : int
stride used in patches extraction
normalize_mask : bool
reshape the mask to add the third channel and rescale our labels to
continuous integers, i.e, from 0 to len(unique_labels)-1.
config_section_name : str
section name in config file "extractor_param.ini";
if not None (default value), it will be used to initialize the extractor,
ignoring all the above input arguments
"""
if config_section_name is None:
self.resize = resize
self.mirror_pad_size = mirror_pad_size
self.patch_size = patch_size
self.stride_size = stride_size
self.normalize_mask = normalize_mask
else:
config = configparser.ConfigParser()
config.read(Path(__file__).parent / "extractor_param.ini")
assert config_section_name in config, f"{config_section_name} is not a valid section name.\n" \
f"Valid sections: {config.sections()}"
section = config[config_section_name]
self.resize = section.getfloat("resize")
self.mirror_pad_size = section.getint("mirror_pad_size")
self.patch_size = section.getint("patch_size")
self.stride_size = section.getint("stride_size")
self.normalize_mask = section.getboolean("normalize_mask")
def extract_patches(self, img: Image, interp_method=Image.BICUBIC) -> ndarray:
"""
Interface for extracting patches from an image after resizing
Parameters
----------
img : PIL Image
image to extract patches from
interp_method : callable, PIL.Image.BICUBIC by default
a function specifies the interpolation method for resizing,
can only be chosen from PIL.Image.Filters.
See https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters
Returns
-------
img_patches: np.array with shape (ntile, patch_size, patch_size, 3)
extracted patches from input image, without any filtering;
filtering is delegated to the caller function
"""
resized_shape = tuple(map(lambda x: int(x*self.resize), img.size))
img = img.resize(size=resized_shape, resample=interp_method)
# apply mirror padding on front and end of first two axis
# to make sure the border pixels are all preserved
# TODO: make sure the mirror_pad_size is large enough to cover every original pixel
pad_front_end = (self.mirror_pad_size, self.mirror_pad_size)
img = np.pad(img,
pad_width=[pad_front_end, pad_front_end, (0, 0)],
mode="reflect")
# convert input image into overlapping tiles,
# size is ntiler * ntilec * 1 * patch_size x patch_size x 3
# TODO: fix the deprecation warnings
img_patches: np.ndarray = extract_patches(img,
(self.patch_size, self.patch_size, 3),
self.stride_size)
# reshape size to ntile * patch_size x patch_size x 3
img_patches = img_patches.reshape((-1, self.patch_size, self.patch_size, 3)) # type: ndarray
return img_patches
def extract_img_patches(img: Union[Path, str, np.ndarray],
extractor: Extractor):
"""
Helper function for extracting patches from raw images
with a configured extractor instance.
Filtering is applied to patches, only keeping patches
with tissues in it.
Parameters
----------
img : str or Path or numpy array or PIL Image
path to the image file OR
an image already loaded in memory as numpy array
extractor : Extractor
an instance of Extractor class with parameter specified;
caller function should use the same one with
the one used in the "extract_mask_patches" function
Returns
-------
img_patches: np.array
Extracted patches after applying filtering strategy
keep_indices: List[int]
indices for patches satisfying filtering criterion
"""
if isinstance(img, str) or isinstance(img, Path):
img_path = str(img)
img = Image.open(img_path)
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img_patches = extractor.extract_patches(img)
# only keep patches with tissue in it;
# indices for valid patches are returned
# TODO: determine a more robust way of tissue detection;
# currently, keep all patches whose mean(R_channel) <= 220
keep_indices = filter(lambda patch_id: np.mean(img_patches[patch_id, ..., 0]) <= 220,
range(len(img_patches)))
keep_indices = list(keep_indices)
return img_patches, keep_indices
def extract_mask_patches(mask_path: Union[Path, str],
extractor: Extractor):
"""
Helper function for extracting patches from masks
with a configured extractor instance.
Filtering is applied to patches, only keeping positive patches
with annotations or negative patches passing the random sampling
Parameters
----------
mask_path : Union[Path, str]
path to the mask file
extractor: Extractor
an instance of Extractor class with parameter specified;
caller function should use the same one with
the one used in the "extract_image_patches" function
Returns
-------
"""
# make sure that input mask has three channels
# and labels for every class is represented as a uint8 label
mask_path = str(mask_path)
mask = np.array(Image.open(mask_path), dtype=np.uint8) # load as uint8 labels
unique_labels = np.sort(np.unique(mask))
if extractor.normalize_mask:
# rescale labels
if np.max(unique_labels) != len(unique_labels) - 1:
label_mappings = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}
for old_label in label_mappings:
new_label = label_mappings[old_label]
if old_label != new_label:
mask[mask == old_label] = new_label
# reshape mask to 3D
if len(mask.shape) == 2:
mask = np.repeat(mask[..., np.newaxis], 3, axis=2)
assert mask.dtype == np.uint8, \
f"Only supports uint8 labels, but the current dtype for mask {mask_path} is {mask.dtype}"
assert len(mask.shape) == 3 and mask.shape[-1] == 3, \
f"Mask must also have three channels. Current mask {mask_path} has shape {mask.shape}."
assert mask.max() == len(unique_labels) - 1, \
f"Labels are not continuous integer from 0 to {len(unique_labels) - 1}; Max: {mask.max()}"
# want to use nearest;
# otherwise resizing may cause non-existing classes
# to be produced via interpolation (e.g., ".25")
mask_patches = extractor.extract_patches(Image.fromarray(mask), interp_method=Image.NEAREST)
# keep all positive patches and sample negative patches;
# TODO: determine a more intuitive sampling strategy;
# currently, a normal distribution is used, selecting N(0,1) > 0.8
random_flags = np.random.rand(len(mask_patches))
keep_indices = filter(lambda patch_id: np.any(mask_patches[patch_id, ..., 0])
or random_flags[patch_id] > 0.8,
range(len(mask_patches)))
keep_indices = list(keep_indices)
return mask_patches, keep_indices
def crop_and_save_patches_to_hdf5(hdf5_dataset_fname, images, masks, extractor: Extractor):
"""
Crop images and masks and save all extracted patches to a hdf5 file.
The resulting hdf5 file has three main columns, namely "src_image_fname",
"img" and "mask". Can load each of them by using the syntax of
<h5_file>.root.<col_name>[index, ...], and it will return a numpy array.
Parameters
----------
hdf5_dataset_fname : str or Path
target location to create and write the hdf5 file
images : List[str] or List[Path]
list of paths to image files
masks : List[str] or List[Path]
list of paths to mask files;
note that the mask can contain multiple classes,
every unique non-zero value stands for a unique class.
Currently the mapping is {tubules:1, artery:2, glomerulus:3, arteriole:4}
extractor : Extractor
instance of Extractor class, containing extraction parameters
Returns
-------
None, but with the side effect of writing a hdf5 file to the target location
"""
img_dtype = tables.UInt8Atom()
filename_dtype = tables.StringAtom(itemsize=255)
img_shape = (extractor.patch_size, extractor.patch_size, 3)
mask_shape = (extractor.patch_size, extractor.patch_size) # mask is just a 2D matrix
with tables.open_file(hdf5_dataset_fname, mode='w') as hdf5_file:
# use blosc compression
filters = tables.Filters(complevel=1, complib='blosc')
# filenames, images, masks are saved as three separate
# earray in the hdf5 file tree
src_img_fnames = hdf5_file.create_earray(
hdf5_file.root,
name="src_image_fname",
atom=filename_dtype,
shape=(0, ))
img_array = hdf5_file.create_earray(
hdf5_file.root,
name="img",
atom=img_dtype,
shape=(0, *img_shape),
chunkshape=(1, *img_shape),
filters=filters)
mask_array = hdf5_file.create_earray(
hdf5_file.root,
name="mask",
atom=img_dtype,
shape=(0, *mask_shape),
chunkshape=(1, *mask_shape),
filters=filters)
for img_path, mask_path in zip(images, masks):
# append newly created patches for every pair image and mask
# and flush them incrementally to the hdf5 file
img_patches, img_keep_indices = extract_img_patches(img_path, extractor)
mask_patches, mask_keep_indices = extract_mask_patches(mask_path, extractor)
# take intersection of both indices for images and masks
keep_indices = np.intersect1d(img_keep_indices, mask_keep_indices)
img_array.append(img_patches[keep_indices, ...])
mask_array.append(mask_patches[keep_indices, ..., 0].squeeze())
src_img_fnames.append([img_path] * len(img_array))