Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perceptual Similarity loss #20844

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Conversation

tristan-deep
Copy link

@tristan-deep tristan-deep commented Feb 2, 2025

As described in #20839.

Conversion of weights partly based on an implementation found here.

Added

TODO

  • add to Metrics?
  • include tests
  • upload weights

I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to storage.googleapis.com.

Testing code snippet

Tests both the model and loss object and compares to the torch metrics implementation.

# !pip install torch torchmetrics
# !pip install scikit-image

import torch
from skimage import data
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from keras import ops
from keras.src.applications import lpips
from keras.src.losses import PerceptualSimilarity

## dummy images
image1 = data.camera()
image2 = data.gravel()

expected_lpips_value = 0.7237163782119751
print(f"Expected LPIPS value: {expected_lpips_value}")

image1 = image1[None, ..., None].astype("float32")
image2 = image2[None, ..., None].astype("float32")

image1 = ops.repeat(image1, 3, axis=-1)
image2 = ops.repeat(image2, 3, axis=-1)

# placeholder for weights
# https://huggingface.co/tristan-deep/lpips/blob/main/lpips_vgg16.weights.h5
custom_weight_path = <add path to weights here> 

## Keras model
model = lpips.LPIPS(weights=custom_weight_path)

model.summary()
print("Model loaded")

image1_tensor = lpips.preprocess_input(image1)
image2_tensor = lpips.preprocess_input(image2)

score = model([image1_tensor, image2_tensor])
print(f"Keras LPIPS score: {score}")

## Loss
loss = PerceptualSimilarity(weights=custom_weight_path)
value = loss(image1, image2)
print(f"Keras loss value: {value}")


## torchmetrics check
def preprocess_input_for_torch(x):
    # Convert from [0, 255] to [-1, 1]
    x = x / 127.5 - 1
    # Convert to numpy array
    x = ops.convert_to_numpy(x)
    # Rearrange from NHWC to NCHW format
    x = x.transpose(0, 3, 1, 2)
    # Convert to torch tensor
    x = torch.tensor(x)
    return x


image1_torch = preprocess_input_for_torch(image1)
image2_torch = preprocess_input_for_torch(image2)

# needs to be in range [-1, 1]
lpips = LearnedPerceptualImagePatchSimilarity(
    net_type="vgg",
    normalize=False,
)
score = lpips(image1_torch, image2_torch)

print(f"Torch LPIPS score: {score}")

Copy link

google-cla bot commented Feb 2, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Feb 2, 2025

Codecov Report

Attention: Patch coverage is 28.57143% with 60 lines in your changes missing coverage. Please review.

Project coverage is 84.76%. Comparing base (fc1b26d) to head (07771f1).
Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/applications/lpips.py 25.39% 47 Missing ⚠️
keras/src/losses/losses.py 25.00% 9 Missing ⚠️
keras/api/_tf_keras/keras/applications/__init__.py 0.00% 2 Missing ⚠️
...api/_tf_keras/keras/applications/lpips/__init__.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20844      +/-   ##
==========================================
+ Coverage   82.04%   84.76%   +2.72%     
==========================================
  Files         559      563       +4     
  Lines       52367    52590     +223     
  Branches     8096     8133      +37     
==========================================
+ Hits        42964    44580    +1616     
+ Misses       7427     5870    -1557     
- Partials     1976     2140     +164     
Flag Coverage Δ
keras 84.58% <28.57%> (+2.72%) ⬆️
keras-jax 66.75% <28.57%> (+2.49%) ⬆️
keras-numpy 59.06% <28.57%> (+0.07%) ⬆️
keras-openvino 32.68% <28.57%> (+2.86%) ⬆️
keras-tensorflow 67.39% <28.57%> (+2.57%) ⬆️
keras-torch 66.80% <28.57%> (+2.65%) ⬆️
keras.applications 82.01% <25.39%> (?)
keras.applications-jax 82.01% <25.39%> (?)
keras.applications-numpy 22.88% <25.39%> (?)
keras.applications-openvino 22.88% <25.39%> (?)
keras.applications-tensorflow 82.01% <25.39%> (?)
keras.applications-torch 81.75% <25.39%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

from keras.src.models import Functional
from keras.src.utils import file_utils

WEIGHTS_PATH = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct weights path should still be added here. Currently I've put a placeholder. I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to storage.googleapis.com.

name="lpips",
dtype=None,
):
from keras.src.applications import lpips # lazy import
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a better way to do this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants