diff --git a/chirp/preprocessing/pipeline.py b/chirp/preprocessing/pipeline.py index 7b1f3112..e6e79fb0 100644 --- a/chirp/preprocessing/pipeline.py +++ b/chirp/preprocessing/pipeline.py @@ -34,6 +34,118 @@ Features = dict[str, tf.Tensor] +def get_class_map_tf_lookup( + source_class_list: namespace.ClassList, + target_class_list: namespace.ClassList, +): + """Create a static hash map for class indices. + + Create a lookup table for use in TF Datasets, for, eg, converting between + ClassList defined for a dataset to a ClassList used as model outputs. + Classes in the source ClassList which do not appear in the target_class_list + will be mapped to -1. It is recommended to drop these labels subsequently + with: tf.gather(x, tf.where(x >= 0)[:, 0]) + + Args: + target_class_list: Class list to target. + + Returns: + A tensorflow StaticHashTable and an indicator vector for the image of + the classlist mapping. + """ + if source_class_list.namespace != target_class_list.namespace: + raise ValueError('namespaces must match when creating a class map.') + intersection = set(source_class_list.classes) & set(target_class_list.classes) + intersection = sorted(tuple(intersection)) + keys = tuple(source_class_list.classes.index(c) for c in intersection) + values = tuple(target_class_list.classes.index(c) for c in intersection) + + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64), + default_value=-1, + ) + image_mask = tf.constant( + [k in source_class_list.classes for k in target_class_list.classes], + tf.int64, + ) + return table, image_mask + + +def get_namespace_map_tf_lookup( + source_class_list: namespace.ClassList, + mapping: namespace.Mapping, + keep_unknown: bool | None = None, + target_class_list: namespace.ClassList | None = None, +) -> tf.lookup.StaticHashTable: + """Create a tf.lookup.StaticHasTable for namespace mappings. + + Args: + source_class_list: Class list to map from. + mapping: Mapping to apply. + keep_unknown: How to handle unknowns. If true, then unknown labels in the + class list are maintained as unknown in the mapped values. If false then + the unknown value is discarded. The default (`None`) will raise an error + if an unknown value is in the source classt list. + target_class_list: Optional class list for ordering of mapping output. If + not provided, a class list consisting of the alphabetized image set of the + mapping will be used. + + Returns: + A Tensorflow StaticHashTable and the image ClassList in the mapping's + target namespace. + + Raises: + ValueError: If 'unknown' label is in source classes and keep_unknown was + not specified. + ValueError: If a target class list was passed and the namespace of this + does not match the mapping target namespace. + """ + if ( + namespace.UNKNOWN_LABEL in source_class_list.classes + and keep_unknown is None + ): + raise ValueError( + "'unknown' found in source classes. Explicitly set keep_unknown to" + " True or False. Alternatively, remove 'unknown' from source classes" + ) + # If no target_class_list is passed, default to apply_namespace_mapping + if target_class_list is None: + target_class_list = source_class_list.apply_namespace_mapping( + mapping, keep_unknown=keep_unknown + ) + else: + if target_class_list.namespace != mapping.target_namespace: + raise ValueError( + f'target class list namespace ({target_class_list.namespace}) ' + 'does not match mapping target namespace ' + f'({mapping.target_namespace})' + ) + # Now check if 'unknown' label present in target_class_list.classes + keep_unknown = ( + keep_unknown and namespace.UNKNOWN_LABEL in target_class_list.classes + ) + # Dict which maps classes to an index + target_class_indices = {k: i for i, k in enumerate(target_class_list.classes)} + # Add unknown to mapped pairs + mapped_pairs = mapping.mapped_pairs | { + namespace.UNKNOWN_LABEL: namespace.UNKNOWN_LABEL + } + # If keep unknown==False, set unknown index to -1 to discard unknowns + if not keep_unknown: + target_class_indices[namespace.UNKNOWN_LABEL] = -1 + # Get keys and values to be used in the lookup table + keys = list(range(len(source_class_list.classes))) + values = [ + target_class_indices[mapped_pairs[k]] for k in source_class_list.classes + ] + # Create the static hash table. If a key doesnt exist, set as -1. + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64), + default_value=-1, + ) + return table + + class FeaturesPreprocessOp: """Preprocessing op which applies changes to specific features.""" @@ -843,10 +955,15 @@ def load_tables( self, source_classes: namespace.ClassList ) -> Tuple[tf.lookup.StaticHashTable, tf.Tensor]: """Return a TensorFlow lookup table and a mask from source classes.""" + if self.db is None: + raise ValueError('Database not loaded.') mapping = self.db.mappings['reef_class_to_soundtype'] target_classes = self.db.class_lists[self.target_class_list] - soundtype_table = source_classes.get_namespace_map_tf_lookup( - mapping, target_class_list=target_classes, keep_unknown=True + soundtype_table = get_namespace_map_tf_lookup( + source_classes, + mapping, + target_class_list=target_classes, + keep_unknown=True, ) # Mask is all 1's. So everything multiplied by 1. Add 0's for a real mask. mask = tf.ones([len(target_classes.classes)]) @@ -934,10 +1051,12 @@ def load_tables( """ tables = {} masks = {} + if self.db is None: + raise ValueError('Database not loaded.') target_classes = self.db.class_lists[self.target_class_list] - label_table, label_mask = source_class_list.get_class_map_tf_lookup( - target_classes + label_table, label_mask = get_class_map_tf_lookup( + source_class_list, target_classes ) tables[self.species_feature_name] = label_table masks[self.species_feature_name] = label_mask @@ -962,11 +1081,11 @@ def load_tables( target_taxa_classes = target_classes.apply_namespace_mapping( namespace_mapping, keep_unknown=True ) - namespace_table = source_class_list.get_namespace_map_tf_lookup( - namespace_mapping, keep_unknown=True + namespace_table = get_namespace_map_tf_lookup( + source_class_list, namespace_mapping, keep_unknown=True ) - class_table, label_mask = source_taxa_classes.get_class_map_tf_lookup( - target_taxa_classes + class_table, label_mask = get_class_map_tf_lookup( + source_taxa_classes, target_taxa_classes ) tables[key + '_namespace'] = namespace_table tables[key + '_class'] = class_table