Skip to content

Commit 6f23757

Browse files
committed
Updated cvt_gray and typos
1 parent 0baf5d6 commit 6f23757

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

transforming/augment.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,19 @@ def cvt_gray(
9292
"""
9393
img = Image.fromarray(sample_input.images.numpy())
9494
if random.uniform(0, 1) < probability:
95+
new_img = Image.new(
96+
img.mode,
97+
(
98+
img.size[0],
99+
img.size[1],
100+
),
101+
(0, 0, 0),
102+
)
95103
img = ImageOps.grayscale(img)
96-
sample_output.images.append(img)
104+
new_img.paste(img, (0, 0))
105+
sample_output.images.append(new_img)
106+
else:
107+
sample_output.images.append(img)
97108
sample_output.labels.append(sample_input.labels.numpy())
98109
return sample_output
99110

@@ -158,7 +169,7 @@ def cvt_crop(
158169
Args:
159170
sample_input: input dataset passed to generate output dataset.
160171
sample_output: output dataset which will contain transforms of input dataset.
161-
crop_locations: tuple (start_x,start_y,end_x,end_y) to determine region for crop. Defaults to -1 which causes a centre crop.
172+
crop_locations: tuple (start_x,start_y,end_x,end_y) to determine region for crop. Defaults to None which causes a centre crop.
162173
probability: probability to randomly apply transformation. Defaults to 0.5.
163174
164175
Returns:
@@ -191,7 +202,7 @@ def cvt_resize(
191202
Args:
192203
sample_input: input dataset passed to generate output dataset.
193204
sample_output: output dataset which will contain transforms of input dataset.
194-
resize_size: tuple (width,height) to determine dimensions for resize. Defaults to -1 which prevents resizing.
205+
resize_size: tuple (width,height) to determine dimensions for resize. Defaults to None which prevents resizing.
195206
probability: probability to randomly apply transformation. Defaults to 0.5.
196207
197208
Returns:
@@ -298,7 +309,9 @@ def cvt_padding(
298309
pad_color,
299310
)
300311
new_img.paste(img, (pad_size[0], pad_size[1]))
301-
sample_output.images.append(new_img)
312+
sample_output.images.append(new_img)
313+
else:
314+
sample_output.images.append(img)
302315
sample_output.labels.append(sample_input.labels.numpy())
303316
return sample_output
304317

@@ -325,13 +338,15 @@ def cvt_padding(
325338
args = parser.parse_args()
326339

327340
ds_input = hub.load(args.input_path)
328-
ds_output = hub.like(args.output_path, ds_input)
341+
ds_output = hub.like(args.output_path, ds_input, overwrite=True)
329342
pipeline = hub.compose(
330343
[
331-
cvt_horizontal_flip(probability=0.4),
344+
cvt_horizontal_flip(probability=0.5),
332345
cvt_crop(probability=0.8),
333346
cvt_gray(probability=0.7),
334-
cvt_padding(pad_size=(10, 10, 10, 10), bg_color=(0, 0, 0), probability=0.5),
347+
cvt_padding(
348+
pad_size=(10, 10, 10, 10), pad_color=(0, 0, 0), probability=0.4
349+
),
335350
cvt_resize(resize_size=(100, 80), probability=0.6),
336351
]
337352
)

0 commit comments

Comments
 (0)