Skip to content

Commit

Permalink
Pull TF table creation into perch preprocessing to allow removal of T…
Browse files Browse the repository at this point in the history
…F dependency in taxonomy library.

PiperOrigin-RevId: 724501733
  • Loading branch information
sdenton4 authored and copybara-github committed Feb 8, 2025
1 parent 1a49cfb commit 26b8d79
Showing 1 changed file with 127 additions and 8 deletions.
135 changes: 127 additions & 8 deletions chirp/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,118 @@
Features = dict[str, tf.Tensor]


def get_class_map_tf_lookup(
source_class_list: namespace.ClassList,
target_class_list: namespace.ClassList,
):
"""Create a static hash map for class indices.
Create a lookup table for use in TF Datasets, for, eg, converting between
ClassList defined for a dataset to a ClassList used as model outputs.
Classes in the source ClassList which do not appear in the target_class_list
will be mapped to -1. It is recommended to drop these labels subsequently
with: tf.gather(x, tf.where(x >= 0)[:, 0])
Args:
target_class_list: Class list to target.
Returns:
A tensorflow StaticHashTable and an indicator vector for the image of
the classlist mapping.
"""
if source_class_list.namespace != target_class_list.namespace:
raise ValueError('namespaces must match when creating a class map.')
intersection = set(source_class_list.classes) & set(target_class_list.classes)
intersection = sorted(tuple(intersection))
keys = tuple(source_class_list.classes.index(c) for c in intersection)
values = tuple(target_class_list.classes.index(c) for c in intersection)

table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
)
image_mask = tf.constant(
[k in source_class_list.classes for k in target_class_list.classes],
tf.int64,
)
return table, image_mask


def get_namespace_map_tf_lookup(
source_class_list: namespace.ClassList,
mapping: namespace.Mapping,
keep_unknown: bool | None = None,
target_class_list: namespace.ClassList | None = None,
) -> tf.lookup.StaticHashTable:
"""Create a tf.lookup.StaticHasTable for namespace mappings.
Args:
source_class_list: Class list to map from.
mapping: Mapping to apply.
keep_unknown: How to handle unknowns. If true, then unknown labels in the
class list are maintained as unknown in the mapped values. If false then
the unknown value is discarded. The default (`None`) will raise an error
if an unknown value is in the source classt list.
target_class_list: Optional class list for ordering of mapping output. If
not provided, a class list consisting of the alphabetized image set of the
mapping will be used.
Returns:
A Tensorflow StaticHashTable and the image ClassList in the mapping's
target namespace.
Raises:
ValueError: If 'unknown' label is in source classes and keep_unknown was
not specified.
ValueError: If a target class list was passed and the namespace of this
does not match the mapping target namespace.
"""
if (
namespace.UNKNOWN_LABEL in source_class_list.classes
and keep_unknown is None
):
raise ValueError(
"'unknown' found in source classes. Explicitly set keep_unknown to"
" True or False. Alternatively, remove 'unknown' from source classes"
)
# If no target_class_list is passed, default to apply_namespace_mapping
if target_class_list is None:
target_class_list = source_class_list.apply_namespace_mapping(
mapping, keep_unknown=keep_unknown
)
else:
if target_class_list.namespace != mapping.target_namespace:
raise ValueError(
f'target class list namespace ({target_class_list.namespace}) '
'does not match mapping target namespace '
f'({mapping.target_namespace})'
)
# Now check if 'unknown' label present in target_class_list.classes
keep_unknown = (
keep_unknown and namespace.UNKNOWN_LABEL in target_class_list.classes
)
# Dict which maps classes to an index
target_class_indices = {k: i for i, k in enumerate(target_class_list.classes)}
# Add unknown to mapped pairs
mapped_pairs = mapping.mapped_pairs | {
namespace.UNKNOWN_LABEL: namespace.UNKNOWN_LABEL
}
# If keep unknown==False, set unknown index to -1 to discard unknowns
if not keep_unknown:
target_class_indices[namespace.UNKNOWN_LABEL] = -1
# Get keys and values to be used in the lookup table
keys = list(range(len(source_class_list.classes)))
values = [
target_class_indices[mapped_pairs[k]] for k in source_class_list.classes
]
# Create the static hash table. If a key doesnt exist, set as -1.
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
)
return table


class FeaturesPreprocessOp:
"""Preprocessing op which applies changes to specific features."""

Expand Down Expand Up @@ -843,10 +955,15 @@ def load_tables(
self, source_classes: namespace.ClassList
) -> Tuple[tf.lookup.StaticHashTable, tf.Tensor]:
"""Return a TensorFlow lookup table and a mask from source classes."""
if self.db is None:
raise ValueError('Database not loaded.')
mapping = self.db.mappings['reef_class_to_soundtype']
target_classes = self.db.class_lists[self.target_class_list]
soundtype_table = source_classes.get_namespace_map_tf_lookup(
mapping, target_class_list=target_classes, keep_unknown=True
soundtype_table = get_namespace_map_tf_lookup(
source_classes,
mapping,
target_class_list=target_classes,
keep_unknown=True,
)
# Mask is all 1's. So everything multiplied by 1. Add 0's for a real mask.
mask = tf.ones([len(target_classes.classes)])
Expand Down Expand Up @@ -934,10 +1051,12 @@ def load_tables(
"""
tables = {}
masks = {}
if self.db is None:
raise ValueError('Database not loaded.')
target_classes = self.db.class_lists[self.target_class_list]

label_table, label_mask = source_class_list.get_class_map_tf_lookup(
target_classes
label_table, label_mask = get_class_map_tf_lookup(
source_class_list, target_classes
)
tables[self.species_feature_name] = label_table
masks[self.species_feature_name] = label_mask
Expand All @@ -962,11 +1081,11 @@ def load_tables(
target_taxa_classes = target_classes.apply_namespace_mapping(
namespace_mapping, keep_unknown=True
)
namespace_table = source_class_list.get_namespace_map_tf_lookup(
namespace_mapping, keep_unknown=True
namespace_table = get_namespace_map_tf_lookup(
source_class_list, namespace_mapping, keep_unknown=True
)
class_table, label_mask = source_taxa_classes.get_class_map_tf_lookup(
target_taxa_classes
class_table, label_mask = get_class_map_tf_lookup(
source_taxa_classes, target_taxa_classes
)
tables[key + '_namespace'] = namespace_table
tables[key + '_class'] = class_table
Expand Down

0 comments on commit 26b8d79

Please sign in to comment.