Skip to content

Commit 26b8d79

Browse files
sdenton4copybara-github
authored andcommitted
Pull TF table creation into perch preprocessing to allow removal of TF dependency in taxonomy library.
PiperOrigin-RevId: 724501733
1 parent 1a49cfb commit 26b8d79

File tree

1 file changed

+127
-8
lines changed

1 file changed

+127
-8
lines changed

chirp/preprocessing/pipeline.py

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,118 @@
3434
Features = dict[str, tf.Tensor]
3535

3636

37+
def get_class_map_tf_lookup(
38+
source_class_list: namespace.ClassList,
39+
target_class_list: namespace.ClassList,
40+
):
41+
"""Create a static hash map for class indices.
42+
43+
Create a lookup table for use in TF Datasets, for, eg, converting between
44+
ClassList defined for a dataset to a ClassList used as model outputs.
45+
Classes in the source ClassList which do not appear in the target_class_list
46+
will be mapped to -1. It is recommended to drop these labels subsequently
47+
with: tf.gather(x, tf.where(x >= 0)[:, 0])
48+
49+
Args:
50+
target_class_list: Class list to target.
51+
52+
Returns:
53+
A tensorflow StaticHashTable and an indicator vector for the image of
54+
the classlist mapping.
55+
"""
56+
if source_class_list.namespace != target_class_list.namespace:
57+
raise ValueError('namespaces must match when creating a class map.')
58+
intersection = set(source_class_list.classes) & set(target_class_list.classes)
59+
intersection = sorted(tuple(intersection))
60+
keys = tuple(source_class_list.classes.index(c) for c in intersection)
61+
values = tuple(target_class_list.classes.index(c) for c in intersection)
62+
63+
table = tf.lookup.StaticHashTable(
64+
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
65+
default_value=-1,
66+
)
67+
image_mask = tf.constant(
68+
[k in source_class_list.classes for k in target_class_list.classes],
69+
tf.int64,
70+
)
71+
return table, image_mask
72+
73+
74+
def get_namespace_map_tf_lookup(
75+
source_class_list: namespace.ClassList,
76+
mapping: namespace.Mapping,
77+
keep_unknown: bool | None = None,
78+
target_class_list: namespace.ClassList | None = None,
79+
) -> tf.lookup.StaticHashTable:
80+
"""Create a tf.lookup.StaticHasTable for namespace mappings.
81+
82+
Args:
83+
source_class_list: Class list to map from.
84+
mapping: Mapping to apply.
85+
keep_unknown: How to handle unknowns. If true, then unknown labels in the
86+
class list are maintained as unknown in the mapped values. If false then
87+
the unknown value is discarded. The default (`None`) will raise an error
88+
if an unknown value is in the source classt list.
89+
target_class_list: Optional class list for ordering of mapping output. If
90+
not provided, a class list consisting of the alphabetized image set of the
91+
mapping will be used.
92+
93+
Returns:
94+
A Tensorflow StaticHashTable and the image ClassList in the mapping's
95+
target namespace.
96+
97+
Raises:
98+
ValueError: If 'unknown' label is in source classes and keep_unknown was
99+
not specified.
100+
ValueError: If a target class list was passed and the namespace of this
101+
does not match the mapping target namespace.
102+
"""
103+
if (
104+
namespace.UNKNOWN_LABEL in source_class_list.classes
105+
and keep_unknown is None
106+
):
107+
raise ValueError(
108+
"'unknown' found in source classes. Explicitly set keep_unknown to"
109+
" True or False. Alternatively, remove 'unknown' from source classes"
110+
)
111+
# If no target_class_list is passed, default to apply_namespace_mapping
112+
if target_class_list is None:
113+
target_class_list = source_class_list.apply_namespace_mapping(
114+
mapping, keep_unknown=keep_unknown
115+
)
116+
else:
117+
if target_class_list.namespace != mapping.target_namespace:
118+
raise ValueError(
119+
f'target class list namespace ({target_class_list.namespace}) '
120+
'does not match mapping target namespace '
121+
f'({mapping.target_namespace})'
122+
)
123+
# Now check if 'unknown' label present in target_class_list.classes
124+
keep_unknown = (
125+
keep_unknown and namespace.UNKNOWN_LABEL in target_class_list.classes
126+
)
127+
# Dict which maps classes to an index
128+
target_class_indices = {k: i for i, k in enumerate(target_class_list.classes)}
129+
# Add unknown to mapped pairs
130+
mapped_pairs = mapping.mapped_pairs | {
131+
namespace.UNKNOWN_LABEL: namespace.UNKNOWN_LABEL
132+
}
133+
# If keep unknown==False, set unknown index to -1 to discard unknowns
134+
if not keep_unknown:
135+
target_class_indices[namespace.UNKNOWN_LABEL] = -1
136+
# Get keys and values to be used in the lookup table
137+
keys = list(range(len(source_class_list.classes)))
138+
values = [
139+
target_class_indices[mapped_pairs[k]] for k in source_class_list.classes
140+
]
141+
# Create the static hash table. If a key doesnt exist, set as -1.
142+
table = tf.lookup.StaticHashTable(
143+
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
144+
default_value=-1,
145+
)
146+
return table
147+
148+
37149
class FeaturesPreprocessOp:
38150
"""Preprocessing op which applies changes to specific features."""
39151

@@ -843,10 +955,15 @@ def load_tables(
843955
self, source_classes: namespace.ClassList
844956
) -> Tuple[tf.lookup.StaticHashTable, tf.Tensor]:
845957
"""Return a TensorFlow lookup table and a mask from source classes."""
958+
if self.db is None:
959+
raise ValueError('Database not loaded.')
846960
mapping = self.db.mappings['reef_class_to_soundtype']
847961
target_classes = self.db.class_lists[self.target_class_list]
848-
soundtype_table = source_classes.get_namespace_map_tf_lookup(
849-
mapping, target_class_list=target_classes, keep_unknown=True
962+
soundtype_table = get_namespace_map_tf_lookup(
963+
source_classes,
964+
mapping,
965+
target_class_list=target_classes,
966+
keep_unknown=True,
850967
)
851968
# Mask is all 1's. So everything multiplied by 1. Add 0's for a real mask.
852969
mask = tf.ones([len(target_classes.classes)])
@@ -934,10 +1051,12 @@ def load_tables(
9341051
"""
9351052
tables = {}
9361053
masks = {}
1054+
if self.db is None:
1055+
raise ValueError('Database not loaded.')
9371056
target_classes = self.db.class_lists[self.target_class_list]
9381057

939-
label_table, label_mask = source_class_list.get_class_map_tf_lookup(
940-
target_classes
1058+
label_table, label_mask = get_class_map_tf_lookup(
1059+
source_class_list, target_classes
9411060
)
9421061
tables[self.species_feature_name] = label_table
9431062
masks[self.species_feature_name] = label_mask
@@ -962,11 +1081,11 @@ def load_tables(
9621081
target_taxa_classes = target_classes.apply_namespace_mapping(
9631082
namespace_mapping, keep_unknown=True
9641083
)
965-
namespace_table = source_class_list.get_namespace_map_tf_lookup(
966-
namespace_mapping, keep_unknown=True
1084+
namespace_table = get_namespace_map_tf_lookup(
1085+
source_class_list, namespace_mapping, keep_unknown=True
9671086
)
968-
class_table, label_mask = source_taxa_classes.get_class_map_tf_lookup(
969-
target_taxa_classes
1087+
class_table, label_mask = get_class_map_tf_lookup(
1088+
source_taxa_classes, target_taxa_classes
9701089
)
9711090
tables[key + '_namespace'] = namespace_table
9721091
tables[key + '_class'] = class_table

0 commit comments

Comments
 (0)