Skip to content

Commit 22cf7d7

Browse files
committed
LPIPS loss
1 parent fc1b26d commit 22cf7d7

File tree

3 files changed

+260
-0
lines changed

3 files changed

+260
-0
lines changed

keras/src/applications/lpips.py

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from keras.src import backend
2+
from keras.src import layers
3+
from keras.src import ops
4+
from keras.src.api_export import keras_export
5+
from keras.src.applications import imagenet_utils
6+
from keras.src.applications import vgg16
7+
from keras.src.models import Functional
8+
from keras.src.utils import file_utils
9+
10+
WEIGHTS_PATH = (
11+
"https://storage.googleapis.com/tensorflow/keras-applications/"
12+
"lpips/lpips_vgg16_weights.h5"
13+
)
14+
15+
16+
def vgg_backbone(layer_names):
17+
"""VGG backbone for LPIPS.
18+
19+
Args:
20+
layer_names: list of layer names to extract features from
21+
22+
Returns:
23+
Functional model with outputs at specified layers
24+
"""
25+
vgg = vgg16.VGG16(include_top=False, weights=None)
26+
outputs = [
27+
layer.output for layer in vgg.layers if layer.name in layer_names
28+
]
29+
return Functional(vgg.input, outputs)
30+
31+
32+
def linear_model(channels):
33+
"""Get the linear head model for LPIPS.
34+
Combines feature differences from VGG backbone.
35+
36+
Args:
37+
channels: list of channel sizes for feature differences
38+
39+
Returns:
40+
Functional model
41+
"""
42+
inputs, outputs = [], []
43+
for ii, channel in enumerate(channels):
44+
x = layers.Input(shape=(None, None, channel))
45+
y = layers.Dropout(rate=0.5)(x)
46+
y = layers.Conv2D(
47+
filters=1,
48+
kernel_size=1,
49+
use_bias=False,
50+
name=f"linear_{ii}",
51+
)(y)
52+
inputs.append(x)
53+
outputs.append(y)
54+
55+
model = Functional(inputs=inputs, outputs=outputs, name="linear_model")
56+
return model
57+
58+
59+
@keras_export(["keras.applications.lpips.LPIPS", "keras.applications.LPIPS"])
60+
def LPIPS(
61+
weights="imagenet",
62+
input_tensor=None,
63+
input_shape=None,
64+
network_type="vgg",
65+
name="lpips",
66+
):
67+
"""Instantiates the LPIPS model.
68+
69+
Reference:
70+
- [The Unreasonable Effectiveness of Deep Features as a Perceptual Metric](
71+
https://arxiv.org/abs/1801.03924)
72+
73+
Args:
74+
weights: one of `None` (random initialization),
75+
`"imagenet"` (pre-training on ImageNet),
76+
or the path to the weights file to be loaded.
77+
input_tensor: optional Keras tensor for model input
78+
input_shape: optional shape tuple, defaults to (None, None, 3)
79+
network_type: backbone network type (currently only 'vgg' supported)
80+
name: model name string
81+
82+
Returns:
83+
A `Model` instance.
84+
"""
85+
if network_type != "vgg":
86+
raise ValueError(
87+
"Currently only VGG backbone is supported. "
88+
f"Got network_type={network_type}"
89+
)
90+
91+
if not (weights in {"imagenet", None} or file_utils.exists(weights)):
92+
raise ValueError(
93+
"The `weights` argument should be either "
94+
"`None` (random initialization), 'imagenet' "
95+
"(pre-training on ImageNet), "
96+
"or the path to the weights file to be loaded."
97+
)
98+
99+
# Define inputs
100+
if input_tensor is None:
101+
img_input1 = layers.Input(
102+
shape=input_shape or (None, None, 3), name="input1"
103+
)
104+
img_input2 = layers.Input(
105+
shape=input_shape or (None, None, 3), name="input2"
106+
)
107+
else:
108+
if not backend.is_keras_tensor(input_tensor):
109+
img_input1 = layers.Input(tensor=input_tensor, shape=input_shape)
110+
img_input2 = layers.Input(tensor=input_tensor, shape=input_shape)
111+
else:
112+
img_input1 = input_tensor
113+
img_input2 = input_tensor
114+
115+
# VGG feature extraction
116+
vgg_layers = [
117+
"block1_conv2",
118+
"block2_conv2",
119+
"block3_conv3",
120+
"block4_conv3",
121+
"block5_conv3",
122+
]
123+
vgg_net = vgg_backbone(vgg_layers)
124+
125+
# Process inputs
126+
feat1 = vgg_net(img_input1)
127+
feat2 = vgg_net(img_input2)
128+
129+
# Normalize features
130+
def normalize(x):
131+
return x * ops.rsqrt(ops.sum(ops.square(x), axis=-1, keepdims=True))
132+
133+
norm1 = [layers.Lambda(normalize)(f) for f in feat1]
134+
norm2 = [layers.Lambda(normalize)(f) for f in feat2]
135+
136+
# Feature differences
137+
diffs = [
138+
layers.Lambda(lambda x: ops.square(x[0] - x[1]))([n1, n2])
139+
for n1, n2 in zip(norm1, norm2)
140+
]
141+
142+
# Get shapes for linear model
143+
channels = [f.shape[-1] for f in feat1]
144+
145+
linear_net = linear_model(channels)
146+
147+
lin_out = linear_net(diffs)
148+
149+
spatial_average = [
150+
layers.Lambda(lambda x: ops.mean(x, axis=[1, 2]))(t) for t in lin_out
151+
]
152+
153+
output = layers.Lambda(
154+
lambda x: ops.squeeze(
155+
ops.sum(backend.convert_to_tensor(x), axis=0), axis=-1
156+
)
157+
)(spatial_average)
158+
159+
# Create model
160+
model = Functional([img_input1, img_input2], output, name=name)
161+
162+
# Load weights
163+
if weights == "imagenet":
164+
weights_path = file_utils.get_file(
165+
"lpips_vgg16_weights.h5",
166+
WEIGHTS_PATH,
167+
cache_subdir="models",
168+
file_hash=None, # TODO: add hash
169+
)
170+
model.load_weights(weights_path)
171+
elif weights is not None:
172+
model.load_weights(weights)
173+
174+
return model
175+
176+
177+
@keras_export("keras.applications.lpips.preprocess_input")
178+
def preprocess_input(x, data_format=None):
179+
return imagenet_utils.preprocess_input(
180+
x, data_format=data_format, mode="torch"
181+
)
182+
183+
184+
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
185+
mode="",
186+
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
187+
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,
188+
)

keras/src/losses/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from keras.src.losses.losses import MeanAbsolutePercentageError
2121
from keras.src.losses.losses import MeanSquaredError
2222
from keras.src.losses.losses import MeanSquaredLogarithmicError
23+
from keras.src.losses.losses import PerceptualSimilarity
2324
from keras.src.losses.losses import Poisson
2425
from keras.src.losses.losses import SparseCategoricalCrossentropy
2526
from keras.src.losses.losses import SquaredHinge
@@ -76,6 +77,8 @@
7677
Tversky,
7778
# Similarity
7879
Circle,
80+
# Feature Extraction
81+
PerceptualSimilarity,
7982
# Sequence
8083
CTC,
8184
# Probabilistic

keras/src/losses/losses.py

+69
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,75 @@ def get_config(self):
15041504
return config
15051505

15061506

1507+
class PerceptualSimilarity(LossFunctionWrapper):
1508+
"""Computes the Learned Perceptual Image Patch Similarity (LPIPS) loss.
1509+
1510+
Reference:
1511+
- [The Unreasonable Effectiveness of Deep Features as a Perceptual Metric](
1512+
https://arxiv.org/abs/1801.03924)
1513+
1514+
LPIPS measures perceptual similarity between images by comparing deep
1515+
features, which is a more perceptually-aligned metric compared to
1516+
pixel-wise losses.
1517+
1518+
Args:
1519+
weights: one of `None` (random initialization),
1520+
`"imagenet"` (pre-training on ImageNet),
1521+
or the path to the weights file to be loaded.
1522+
network_type: backbone network type (currently only 'vgg' supported)
1523+
preprocess_inputs: Whether to preprocess inputs using the same
1524+
preprocessing function as the original LPIPS implementation.
1525+
Defaults to `True`. If set to `False`, the inputs are expected
1526+
to be normalized to the range [-1, 1] and in RGB format.
1527+
If set to `True`, the inputs are expected to to be in the
1528+
range [0, 255], after which they will be standardized to
1529+
imageNet mean and std.
1530+
reduction: Type of reduction to apply to the loss. In almost all cases
1531+
this should be `"sum_over_batch_size"`. Supported options are
1532+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
1533+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1534+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1535+
sample size, and `"mean_with_sample_weight"` sums the loss and
1536+
divides by the sum of the sample weights. `"none"` and `None`
1537+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
1538+
name: Optional name for the loss instance.
1539+
dtype: The dtype of the loss's computations. Defaults to `None`, which
1540+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1541+
`"float32"` unless set to different value
1542+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1543+
provided, then the `compute_dtype` will be utilized.
1544+
"""
1545+
1546+
def __init__(
1547+
self,
1548+
weights="imagenet",
1549+
network_type="vgg",
1550+
preprocess_inputs=True,
1551+
reduction="sum_over_batch_size",
1552+
name="lpips",
1553+
dtype=None,
1554+
):
1555+
from keras.src.applications import lpips # lazy import
1556+
1557+
lpips_model = lpips.LPIPS(weights=weights, network_type=network_type)
1558+
1559+
def _lpips_wrapper(y_true, y_pred):
1560+
if preprocess_inputs:
1561+
y_true = lpips.preprocess_input(y_true)
1562+
y_pred = lpips.preprocess_input(y_pred)
1563+
return lpips_model([y_true, y_pred])
1564+
1565+
super().__init__(
1566+
_lpips_wrapper,
1567+
name=name,
1568+
reduction=reduction,
1569+
dtype=dtype,
1570+
)
1571+
1572+
def get_config(self):
1573+
return Loss.get_config(self)
1574+
1575+
15071576
def convert_binary_labels_to_hinge(y_true):
15081577
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
15091578
are_zeros = ops.equal(y_true, 0)

0 commit comments

Comments
 (0)