diff --git a/big_vision/pp/archive/autoaugment.py b/big_vision/pp/archive/autoaugment.py index f1af14e..4664f25 100644 --- a/big_vision/pp/archive/autoaugment.py +++ b/big_vision/pp/archive/autoaugment.py @@ -202,6 +202,8 @@ def color(image, factor): def contrast(image, factor): """Equivalent of PIL Contrast.""" + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] degenerate = tf.image.rgb_to_grayscale(image) # Cast before calling tf.histogram. degenerate = tf.cast(degenerate, tf.int32) @@ -210,7 +212,8 @@ def contrast(image, factor): # and create a constant image size of that value. Use that as the # blending degenerate target of the original image. hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) - mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + mean = tf.reduce_sum( + tf.cast(hist, tf.float32) * tf.linspace(0., 255., 256)) / float(image_height * image_width) degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) diff --git a/big_vision/pp/autoaugment.py b/big_vision/pp/autoaugment.py index 6cc45f1..e08a192 100644 --- a/big_vision/pp/autoaugment.py +++ b/big_vision/pp/autoaugment.py @@ -202,6 +202,8 @@ def color(image, factor): def contrast(image, factor): """Equivalent of PIL Contrast.""" + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] degenerate = tf.image.rgb_to_grayscale(image) # Cast before calling tf.histogram. degenerate = tf.cast(degenerate, tf.int32) @@ -210,7 +212,8 @@ def contrast(image, factor): # and create a constant image size of that value. Use that as the # blending degenerate target of the original image. hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) - mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + mean = tf.reduce_sum( + tf.cast(hist, tf.float32) * tf.linspace(0., 255., 256)) / float(image_height * image_width) degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))