diff --git a/keras/api/_tf_keras/keras/applications/__init__.py b/keras/api/_tf_keras/keras/applications/__init__.py index 183b3ca66142..5280067a58d7 100644 --- a/keras/api/_tf_keras/keras/applications/__init__.py +++ b/keras/api/_tf_keras/keras/applications/__init__.py @@ -11,6 +11,7 @@ from keras.api.applications import imagenet_utils from keras.api.applications import inception_resnet_v2 from keras.api.applications import inception_v3 +from keras.api.applications import lpips from keras.api.applications import mobilenet from keras.api.applications import mobilenet_v2 from keras.api.applications import mobilenet_v3 @@ -46,6 +47,7 @@ from keras.src.applications.efficientnet_v2 import EfficientNetV2S from keras.src.applications.inception_resnet_v2 import InceptionResNetV2 from keras.src.applications.inception_v3 import InceptionV3 +from keras.src.applications.lpips import LPIPS from keras.src.applications.mobilenet import MobileNet from keras.src.applications.mobilenet_v2 import MobileNetV2 from keras.src.applications.mobilenet_v3 import MobileNetV3Large diff --git a/keras/api/_tf_keras/keras/applications/lpips/__init__.py b/keras/api/_tf_keras/keras/applications/lpips/__init__.py new file mode 100644 index 000000000000..96ae46ac93f4 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/lpips/__init__.py @@ -0,0 +1,8 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.lpips import LPIPS +from keras.src.applications.lpips import preprocess_input diff --git a/keras/api/applications/__init__.py b/keras/api/applications/__init__.py index 183b3ca66142..5280067a58d7 100644 --- a/keras/api/applications/__init__.py +++ b/keras/api/applications/__init__.py @@ -11,6 +11,7 @@ from keras.api.applications import imagenet_utils from keras.api.applications import inception_resnet_v2 from keras.api.applications import inception_v3 +from keras.api.applications import lpips from keras.api.applications import mobilenet from keras.api.applications import mobilenet_v2 from keras.api.applications import mobilenet_v3 @@ -46,6 +47,7 @@ from keras.src.applications.efficientnet_v2 import EfficientNetV2S from keras.src.applications.inception_resnet_v2 import InceptionResNetV2 from keras.src.applications.inception_v3 import InceptionV3 +from keras.src.applications.lpips import LPIPS from keras.src.applications.mobilenet import MobileNet from keras.src.applications.mobilenet_v2 import MobileNetV2 from keras.src.applications.mobilenet_v3 import MobileNetV3Large diff --git a/keras/api/applications/lpips/__init__.py b/keras/api/applications/lpips/__init__.py new file mode 100644 index 000000000000..96ae46ac93f4 --- /dev/null +++ b/keras/api/applications/lpips/__init__.py @@ -0,0 +1,8 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.lpips import LPIPS +from keras.src.applications.lpips import preprocess_input diff --git a/keras/src/applications/applications_test.py b/keras/src/applications/applications_test.py index 7ceb4dbd36b4..75fd75c3be18 100644 --- a/keras/src/applications/applications_test.py +++ b/keras/src/applications/applications_test.py @@ -12,6 +12,7 @@ from keras.src.applications import efficientnet_v2 from keras.src.applications import inception_resnet_v2 from keras.src.applications import inception_v3 +from keras.src.applications import lpips from keras.src.applications import mobilenet from keras.src.applications import mobilenet_v2 from keras.src.applications import mobilenet_v3 @@ -81,7 +82,7 @@ (resnet_v2.ResNet101V2, 2048, resnet_v2), (resnet_v2.ResNet152V2, 2048, resnet_v2), ] -MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet"] +MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet", "LPIPS"] # Add names for `named_parameters`, and add each data format for each model test_parameters = [ @@ -264,3 +265,43 @@ def test_application_classifier_activation(self, app, *_): ) last_layer_act = model.layers[-1].activation.__name__ self.assertEqual(last_layer_act, "softmax") + + @parameterized.named_parameters( + [ + ( + "{}_{}".format(lpips.LPIPS.__name__, image_data_format), + image_data_format, + ) + for image_data_format in ["channels_first", "channels_last"] + ] + ) + def test_application_lpips(self, image_data_format): + self.skip_if_invalid_image_data_format_for_model( + lpips.LPIPS, image_data_format + ) + backend.set_image_data_format(image_data_format) + + model = lpips.LPIPS() + output_shape = list(model.outputs[0].shape) + + # Two images as input + self.assertEqual(len(model.input_shape), 2) + + # Single output + self.assertEqual(output_shape, [None]) + + # Can run a correct inference on a test image + if image_data_format == "channels_first": + shape = model.input_shape[0][2:4] + else: + shape = model.input_shape[0][1:3] + + x = _get_elephant(shape) + + x = lpips.preprocess_input(x) + y = lpips.preprocess_input(x) + + preds = model.predict([x, y]) + + # same image so lpips should be 0 + self.assertEqual(preds, 0.0) diff --git a/keras/src/applications/lpips.py b/keras/src/applications/lpips.py new file mode 100644 index 000000000000..4519a746a78f --- /dev/null +++ b/keras/src/applications/lpips.py @@ -0,0 +1,189 @@ +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.applications import vgg16 +from keras.src.models import Functional +from keras.src.utils import file_utils + +WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "lpips/lpips_vgg16_weights.h5" +) # TODO: store weights at this location + + +def vgg_backbone(layer_names): + """VGG backbone for LPIPS. + + Args: + layer_names: list of layer names to extract features from + + Returns: + Functional model with outputs at specified layers + """ + vgg = vgg16.VGG16(include_top=False, weights=None) + outputs = [ + layer.output for layer in vgg.layers if layer.name in layer_names + ] + return Functional(vgg.input, outputs) + + +def linear_model(channels): + """Get the linear head model for LPIPS. + Combines feature differences from VGG backbone. + + Args: + channels: list of channel sizes for feature differences + + Returns: + Functional model + """ + inputs, outputs = [], [] + for ii, channel in enumerate(channels): + x = layers.Input(shape=(None, None, channel)) + y = layers.Dropout(rate=0.5)(x) + y = layers.Conv2D( + filters=1, + kernel_size=1, + use_bias=False, + name=f"linear_{ii}", + )(y) + inputs.append(x) + outputs.append(y) + + model = Functional(inputs=inputs, outputs=outputs, name="linear_model") + return model + + +@keras_export(["keras.applications.lpips.LPIPS", "keras.applications.LPIPS"]) +def LPIPS( + weights="imagenet", + input_tensor=None, + input_shape=None, + network_type="vgg", + name="lpips", +): + """Instantiates the LPIPS model. + + Reference: + - [The Unreasonable Effectiveness of Deep Features as a Perceptual Metric]( + https://arxiv.org/abs/1801.03924) + + Args: + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor for model input + input_shape: optional shape tuple, defaults to (None, None, 3) + network_type: backbone network type (currently only 'vgg' supported) + name: model name string + + Returns: + A `Model` instance. + """ + if network_type != "vgg": + raise ValueError( + "Currently only VGG backbone is supported. " + f"Got network_type={network_type}" + ) + + if backend.image_data_format() == "channels_first": + raise ValueError( + "LPIPS does not support the `channels_first` image data " + "format. Switch to `channels_last` by editing your local " + "config file at ~/.keras/keras.json" + ) + + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + # Define inputs + if input_tensor is None: + img_input1 = layers.Input( + shape=input_shape or (None, None, 3), name="input1" + ) + img_input2 = layers.Input( + shape=input_shape or (None, None, 3), name="input2" + ) + else: + if not backend.is_keras_tensor(input_tensor): + img_input1 = layers.Input(tensor=input_tensor, shape=input_shape) + img_input2 = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input1 = input_tensor + img_input2 = input_tensor + + # VGG feature extraction + vgg_layers = [ + "block1_conv2", + "block2_conv2", + "block3_conv3", + "block4_conv3", + "block5_conv3", + ] + vgg_net = vgg_backbone(vgg_layers) + + feat1 = vgg_net(img_input1) + feat2 = vgg_net(img_input2) + + def normalize(x, eps: float = 1e-8): + return x * ops.rsqrt( + eps + ops.sum(ops.square(x), axis=-1, keepdims=True) + ) + + norm1 = [normalize(f) for f in feat1] + norm2 = [normalize(f) for f in feat2] + + diffs = [ops.square(t1 - t2) for t1, t2 in zip(norm1, norm2)] + + channels = [f.shape[-1] for f in feat1] + + linear_net = linear_model(channels) + + lin_out = linear_net(diffs) + + spatial_average = [ + ops.mean(t, axis=[1, 2], keepdims=False) for t in lin_out + ] + + # need a layer to convert list to tensor + output = layers.Lambda(lambda x: ops.convert_to_tensor(x))(spatial_average) + + output = ops.squeeze(ops.sum(output, axis=0), axis=-1) + + # Create model + model = Functional([img_input1, img_input2], output, name=name) + + # Load weights + if weights == "imagenet": + weights_path = file_utils.get_file( + "lpips_vgg16_weights.h5", + WEIGHTS_PATH, + cache_subdir="models", + file_hash=None, # TODO: add hash + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export("keras.applications.lpips.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="torch" + ) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +)