Skip to content

Commit a003b7c

Browse files
DELF updates (#9095)
* Merged commit includes the following changes: 326369548 by Andre Araujo: Fix import issues. -- 326159826 by Andre Araujo: Changed the implementation of the cosine weights from Keras layer to tf.Variable to manually control for L2 normalization. -- 326139082 by Andre Araujo: Support local feature matching using ratio test. To allow for easily choosing which matching type to use, we rename a flag/argument and modify all related files to avoid breakages. Also include a small change when computing nearest neighbors for geometric matching, to parallelize computation, which saves a little bit of time during execution (argument "n_jobs=-1"). -- 326119848 by Andre Araujo: Option to measure DELG latency taking binarization into account. -- 324316608 by Andre Araujo: DELG global features training. -- 323693131 by Andre Araujo: PY3 conversion for delf public lib. -- 321046157 by Andre Araujo: Purely Google refactor -- PiperOrigin-RevId: 326369548 * Added export of delg_model module. Co-authored-by: Andre Araujo <[email protected]>
1 parent b4c4a53 commit a003b7c

File tree

14 files changed

+406
-49
lines changed

14 files changed

+406
-49
lines changed

research/delf/delf/python/delg/extract_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

research/delf/delf/python/delg/measure_latency.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
'Path to list of images whose features will be extracted.')
4343
flags.DEFINE_integer('repeat_per_image', 10,
4444
'Number of times to repeat extraction per image.')
45+
flags.DEFINE_boolean(
46+
'binary_local_features', False,
47+
'Whether to binarize local features after extraction, and take this extra '
48+
'latency into account. This should only be used if use_local_features is '
49+
'set in the input DelfConfig from `delf_config_path`.')
4550

4651
# Pace to report extraction log.
4752
_STATUS_CHECK_ITERATIONS = 100
@@ -103,6 +108,12 @@ def main(argv):
103108
# Extract and save features.
104109
extracted_features = extractor_fn(im)
105110

111+
# Binarize local features, if desired (and if there are local features).
112+
if (config.use_local_features and FLAGS.binary_local_features and
113+
extracted_features['local_features']['attention'].size):
114+
packed_descriptors = np.packbits(
115+
extracted_features['local_features']['descriptors'] > 0, axis=1)
116+
106117

107118
if __name__ == '__main__':
108119
app.run(main)

research/delf/delf/python/delg/perform_retrieval.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -44,15 +45,19 @@
4445
'If True, performs re-ranking using local feature-based geometric '
4546
'verification.')
4647
flags.DEFINE_float(
47-
'local_feature_distance_threshold', 1.0,
48+
'local_descriptor_matching_threshold', 1.0,
4849
'Optional, only used if `use_geometric_verification` is True. '
49-
'Distance threshold below which a pair of local descriptors is considered '
50+
'Threshold below which a pair of local descriptors is considered '
5051
'a potential match, and will be fed into RANSAC.')
5152
flags.DEFINE_float(
5253
'ransac_residual_threshold', 20.0,
5354
'Optional, only used if `use_geometric_verification` is True. '
5455
'Residual error threshold for considering matches as inliers, used in '
5556
'RANSAC algorithm.')
57+
flags.DEFINE_boolean(
58+
'use_ratio_test', False,
59+
'Optional, only used if `use_geometric_verification` is True. '
60+
'Whether to use ratio test for local feature matching.')
5661
flags.DEFINE_string(
5762
'output_dir', '/tmp/retrieval',
5863
'Directory where retrieval output will be written to. A file containing '
@@ -152,8 +157,10 @@ def main(argv):
152157
junk_ids=set(medium_ground_truth[i]['junk']),
153158
local_feature_extension=_DELG_LOCAL_EXTENSION,
154159
ransac_seed=0,
155-
feature_distance_threshold=FLAGS.local_feature_distance_threshold,
156-
ransac_residual_threshold=FLAGS.ransac_residual_threshold)
160+
descriptor_matching_threshold=FLAGS
161+
.local_descriptor_matching_threshold,
162+
ransac_residual_threshold=FLAGS.ransac_residual_threshold,
163+
use_ratio_test=FLAGS.use_ratio_test)
157164
hard_ranks_after_gv[i] = image_reranking.RerankByGeometricVerification(
158165
input_ranks=ranks_before_gv[i],
159166
initial_scores=similarities,
@@ -164,8 +171,10 @@ def main(argv):
164171
junk_ids=set(hard_ground_truth[i]['junk']),
165172
local_feature_extension=_DELG_LOCAL_EXTENSION,
166173
ransac_seed=0,
167-
feature_distance_threshold=FLAGS.local_feature_distance_threshold,
168-
ransac_residual_threshold=FLAGS.ransac_residual_threshold)
174+
descriptor_matching_threshold=FLAGS
175+
.local_descriptor_matching_threshold,
176+
ransac_residual_threshold=FLAGS.ransac_residual_threshold,
177+
use_ratio_test=FLAGS.use_ratio_test)
169178

170179
elapsed = (time.time() - start)
171180
print('done! Retrieval for query %d took %f seconds' % (i, elapsed))

research/delf/delf/python/detect_to_retrieve/cluster_delf_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

research/delf/delf/python/detect_to_retrieve/extract_query_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

research/delf/delf/python/detect_to_retrieve/image_reranking.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ def MatchFeatures(query_locations,
4747
index_image_locations,
4848
index_image_descriptors,
4949
ransac_seed=None,
50-
feature_distance_threshold=0.9,
50+
descriptor_matching_threshold=0.9,
5151
ransac_residual_threshold=10.0,
5252
query_im_array=None,
5353
index_im_array=None,
5454
query_im_scale_factors=None,
55-
index_im_scale_factors=None):
55+
index_im_scale_factors=None,
56+
use_ratio_test=False):
5657
"""Matches local features using geometric verification.
5758
5859
First, finds putative local feature matches by matching `query_descriptors`
@@ -70,8 +71,10 @@ def MatchFeatures(query_locations,
7071
index_image_descriptors: Descriptors of local features for index image.
7172
NumPy array of shape [#index_image_features, depth].
7273
ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
73-
feature_distance_threshold: Distance threshold below which a pair of
74-
features is considered a potential match, and will be fed into RANSAC.
74+
descriptor_matching_threshold: Threshold below which a pair of local
75+
descriptors is considered a potential match, and will be fed into RANSAC.
76+
If use_ratio_test==False, this is a simple distance threshold. If
77+
use_ratio_test==True, this is Lowe's ratio test threshold.
7578
ransac_residual_threshold: Residual error threshold for considering matches
7679
as inliers, used in RANSAC algorithm.
7780
query_im_array: Optional. If not None, contains a NumPy array with the query
@@ -83,6 +86,8 @@ def MatchFeatures(query_locations,
8386
(ie, feature locations are not scaled).
8487
index_im_scale_factors: Optional. Same as `query_im_scale_factors`, but for
8588
index image.
89+
use_ratio_test: If True, descriptor matching is performed via ratio test,
90+
instead of distance-based threshold.
8691
8792
Returns:
8893
score: Number of inliers of match. If no match is found, returns 0.
@@ -105,22 +110,38 @@ def MatchFeatures(query_locations,
105110
'Local feature dimensionality is not consistent for query and index '
106111
'images.')
107112

108-
# Find nearest-neighbor matches using a KD tree.
113+
# Construct KD-tree used to find nearest neighbors.
109114
index_image_tree = spatial.cKDTree(index_image_descriptors)
110-
_, indices = index_image_tree.query(
111-
query_descriptors, distance_upper_bound=feature_distance_threshold)
112-
113-
# Select feature locations for putative matches.
114-
query_locations_to_use = np.array([
115-
query_locations[i,]
116-
for i in range(num_features_query)
117-
if indices[i] != num_features_index_image
118-
])
119-
index_image_locations_to_use = np.array([
120-
index_image_locations[indices[i],]
121-
for i in range(num_features_query)
122-
if indices[i] != num_features_index_image
123-
])
115+
if use_ratio_test:
116+
distances, indices = index_image_tree.query(
117+
query_descriptors, k=2, n_jobs=-1)
118+
query_locations_to_use = np.array([
119+
query_locations[i,]
120+
for i in range(num_features_query)
121+
if distances[i][0] < descriptor_matching_threshold * distances[i][1]
122+
])
123+
index_image_locations_to_use = np.array([
124+
index_image_locations[indices[i][0],]
125+
for i in range(num_features_query)
126+
if distances[i][0] < descriptor_matching_threshold * distances[i][1]
127+
])
128+
else:
129+
_, indices = index_image_tree.query(
130+
query_descriptors,
131+
distance_upper_bound=descriptor_matching_threshold,
132+
n_jobs=-1)
133+
134+
# Select feature locations for putative matches.
135+
query_locations_to_use = np.array([
136+
query_locations[i,]
137+
for i in range(num_features_query)
138+
if indices[i] != num_features_index_image
139+
])
140+
index_image_locations_to_use = np.array([
141+
index_image_locations[indices[i],]
142+
for i in range(num_features_query)
143+
if indices[i] != num_features_index_image
144+
])
124145

125146
# If there are not enough putative matches, early return 0.
126147
if query_locations_to_use.shape[0] <= _MIN_RANSAC_SAMPLES:
@@ -175,8 +196,9 @@ def RerankByGeometricVerification(input_ranks,
175196
junk_ids,
176197
local_feature_extension=_DELF_EXTENSION,
177198
ransac_seed=None,
178-
feature_distance_threshold=0.9,
179-
ransac_residual_threshold=10.0):
199+
descriptor_matching_threshold=0.9,
200+
ransac_residual_threshold=10.0,
201+
use_ratio_test=False):
180202
"""Re-ranks retrieval results using geometric verification.
181203
182204
Args:
@@ -195,10 +217,11 @@ def RerankByGeometricVerification(input_ranks,
195217
local_feature_extension: String, extension to use for loading local feature
196218
files.
197219
ransac_seed: Seed used by RANSAC. If None (default), no seed is provided.
198-
feature_distance_threshold: Distance threshold below which a pair of local
199-
features is considered a potential match, and will be fed into RANSAC.
220+
descriptor_matching_threshold: Threshold used for local descriptor matching.
200221
ransac_residual_threshold: Residual error threshold for considering matches
201222
as inliers, used in RANSAC algorithm.
223+
use_ratio_test: If True, descriptor matching is performed via ratio test,
224+
instead of distance-based threshold.
202225
203226
Returns:
204227
output_ranks: 1D NumPy array with index image indices, sorted from the most
@@ -258,8 +281,9 @@ def RerankByGeometricVerification(input_ranks,
258281
index_image_locations,
259282
index_image_descriptors,
260283
ransac_seed=ransac_seed,
261-
feature_distance_threshold=feature_distance_threshold,
262-
ransac_residual_threshold=ransac_residual_threshold)
284+
descriptor_matching_threshold=descriptor_matching_threshold,
285+
ransac_residual_threshold=ransac_residual_threshold,
286+
use_ratio_test=use_ratio_test)
263287

264288
# Sort based on (inliers_score, initial_score).
265289
def _InliersInitialScoresSorting(k):

research/delf/delf/python/detect_to_retrieve/perform_retrieval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

research/delf/delf/python/examples/match_images.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Lint as: python3
12
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");

research/delf/delf/python/training/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pylint: disable=unused-import
2121
from delf.python.training.model import delf_model
22+
from delf.python.training.model import delg_model
2223
from delf.python.training.model import export_model_utils
2324
from delf.python.training.model import resnet50
2425
# pylint: enable=unused-import

research/delf/delf/python/training/model/delf_model.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,20 @@ class Delf(tf.keras.Model):
8989
from conv_4 are used to compute an attention map of the same resolution.
9090
"""
9191

92-
def __init__(self, block3_strides=True, name='DELF'):
92+
def __init__(self, block3_strides=True, name='DELF', pooling='avg',
93+
gem_power=3.0, embedding_layer=False, embedding_layer_dim=2048):
9394
"""Initialization of DELF model.
9495
9596
Args:
9697
block3_strides: bool, whether to add strides to the output of block3.
9798
name: str, name to identify model.
99+
pooling: str, pooling mode for global feature extraction; possible values
100+
are 'None', 'avg', 'max', 'gem.'
101+
gem_power: float, GeM power for GeM pooling. Only used if
102+
pooling == 'gem'.
103+
embedding_layer: bool, whether to create an embedding layer (FC whitening
104+
layer).
105+
embedding_layer_dim: int, size of the embedding layer.
98106
"""
99107
super(Delf, self).__init__(name=name)
100108

@@ -103,31 +111,38 @@ def __init__(self, block3_strides=True, name='DELF'):
103111
'channels_last',
104112
name='backbone',
105113
include_top=False,
106-
pooling='avg',
114+
pooling=pooling,
107115
block3_strides=block3_strides,
108-
average_pooling=False)
116+
average_pooling=False,
117+
gem_power=gem_power,
118+
embedding_layer=embedding_layer,
119+
embedding_layer_dim=embedding_layer_dim)
109120

110121
# Attention model.
111122
self.attention = AttentionModel(name='attention')
112123

113-
# Define classifiers for training backbone and attention models.
114-
def init_classifiers(self, num_classes):
124+
def init_classifiers(self, num_classes, desc_classification=None):
125+
"""Define classifiers for training backbone and attention models."""
115126
self.num_classes = num_classes
116-
self.desc_classification = layers.Dense(
117-
num_classes, activation=None, kernel_regularizer=None, name='desc_fc')
118-
127+
if desc_classification is None:
128+
self.desc_classification = layers.Dense(num_classes,
129+
activation=None,
130+
kernel_regularizer=None,
131+
name='desc_fc')
132+
else:
133+
self.desc_classification = desc_classification
119134
self.attn_classification = layers.Dense(
120135
num_classes, activation=None, kernel_regularizer=None, name='att_fc')
121136

122-
# Weights to optimize for descriptor fine tuning.
123137
@property
124138
def desc_trainable_weights(self):
139+
"""Weights to optimize for descriptor fine tuning."""
125140
return (self.backbone.trainable_weights +
126141
self.desc_classification.trainable_weights)
127142

128-
# Weights to optimize for attention model training.
129143
@property
130144
def attn_trainable_weights(self):
145+
"""Weights to optimize for attention model training."""
131146
return (self.attention.trainable_weights +
132147
self.attn_classification.trainable_weights)
133148

0 commit comments

Comments
 (0)