34
34
Features = dict [str , tf .Tensor ]
35
35
36
36
37
+ def get_class_map_tf_lookup (
38
+ source_class_list : namespace .ClassList ,
39
+ target_class_list : namespace .ClassList ,
40
+ ):
41
+ """Create a static hash map for class indices.
42
+
43
+ Create a lookup table for use in TF Datasets, for, eg, converting between
44
+ ClassList defined for a dataset to a ClassList used as model outputs.
45
+ Classes in the source ClassList which do not appear in the target_class_list
46
+ will be mapped to -1. It is recommended to drop these labels subsequently
47
+ with: tf.gather(x, tf.where(x >= 0)[:, 0])
48
+
49
+ Args:
50
+ target_class_list: Class list to target.
51
+
52
+ Returns:
53
+ A tensorflow StaticHashTable and an indicator vector for the image of
54
+ the classlist mapping.
55
+ """
56
+ if source_class_list .namespace != target_class_list .namespace :
57
+ raise ValueError ('namespaces must match when creating a class map.' )
58
+ intersection = set (source_class_list .classes ) & set (target_class_list .classes )
59
+ intersection = sorted (tuple (intersection ))
60
+ keys = tuple (source_class_list .classes .index (c ) for c in intersection )
61
+ values = tuple (target_class_list .classes .index (c ) for c in intersection )
62
+
63
+ table = tf .lookup .StaticHashTable (
64
+ tf .lookup .KeyValueTensorInitializer (keys , values , tf .int64 , tf .int64 ),
65
+ default_value = - 1 ,
66
+ )
67
+ image_mask = tf .constant (
68
+ [k in source_class_list .classes for k in target_class_list .classes ],
69
+ tf .int64 ,
70
+ )
71
+ return table , image_mask
72
+
73
+
74
+ def get_namespace_map_tf_lookup (
75
+ source_class_list : namespace .ClassList ,
76
+ mapping : namespace .Mapping ,
77
+ keep_unknown : bool | None = None ,
78
+ target_class_list : namespace .ClassList | None = None ,
79
+ ) -> tf .lookup .StaticHashTable :
80
+ """Create a tf.lookup.StaticHasTable for namespace mappings.
81
+
82
+ Args:
83
+ source_class_list: Class list to map from.
84
+ mapping: Mapping to apply.
85
+ keep_unknown: How to handle unknowns. If true, then unknown labels in the
86
+ class list are maintained as unknown in the mapped values. If false then
87
+ the unknown value is discarded. The default (`None`) will raise an error
88
+ if an unknown value is in the source classt list.
89
+ target_class_list: Optional class list for ordering of mapping output. If
90
+ not provided, a class list consisting of the alphabetized image set of the
91
+ mapping will be used.
92
+
93
+ Returns:
94
+ A Tensorflow StaticHashTable and the image ClassList in the mapping's
95
+ target namespace.
96
+
97
+ Raises:
98
+ ValueError: If 'unknown' label is in source classes and keep_unknown was
99
+ not specified.
100
+ ValueError: If a target class list was passed and the namespace of this
101
+ does not match the mapping target namespace.
102
+ """
103
+ if (
104
+ namespace .UNKNOWN_LABEL in source_class_list .classes
105
+ and keep_unknown is None
106
+ ):
107
+ raise ValueError (
108
+ "'unknown' found in source classes. Explicitly set keep_unknown to"
109
+ " True or False. Alternatively, remove 'unknown' from source classes"
110
+ )
111
+ # If no target_class_list is passed, default to apply_namespace_mapping
112
+ if target_class_list is None :
113
+ target_class_list = source_class_list .apply_namespace_mapping (
114
+ mapping , keep_unknown = keep_unknown
115
+ )
116
+ else :
117
+ if target_class_list .namespace != mapping .target_namespace :
118
+ raise ValueError (
119
+ f'target class list namespace ({ target_class_list .namespace } ) '
120
+ 'does not match mapping target namespace '
121
+ f'({ mapping .target_namespace } )'
122
+ )
123
+ # Now check if 'unknown' label present in target_class_list.classes
124
+ keep_unknown = (
125
+ keep_unknown and namespace .UNKNOWN_LABEL in target_class_list .classes
126
+ )
127
+ # Dict which maps classes to an index
128
+ target_class_indices = {k : i for i , k in enumerate (target_class_list .classes )}
129
+ # Add unknown to mapped pairs
130
+ mapped_pairs = mapping .mapped_pairs | {
131
+ namespace .UNKNOWN_LABEL : namespace .UNKNOWN_LABEL
132
+ }
133
+ # If keep unknown==False, set unknown index to -1 to discard unknowns
134
+ if not keep_unknown :
135
+ target_class_indices [namespace .UNKNOWN_LABEL ] = - 1
136
+ # Get keys and values to be used in the lookup table
137
+ keys = list (range (len (source_class_list .classes )))
138
+ values = [
139
+ target_class_indices [mapped_pairs [k ]] for k in source_class_list .classes
140
+ ]
141
+ # Create the static hash table. If a key doesnt exist, set as -1.
142
+ table = tf .lookup .StaticHashTable (
143
+ tf .lookup .KeyValueTensorInitializer (keys , values , tf .int64 , tf .int64 ),
144
+ default_value = - 1 ,
145
+ )
146
+ return table
147
+
148
+
37
149
class FeaturesPreprocessOp :
38
150
"""Preprocessing op which applies changes to specific features."""
39
151
@@ -843,10 +955,15 @@ def load_tables(
843
955
self , source_classes : namespace .ClassList
844
956
) -> Tuple [tf .lookup .StaticHashTable , tf .Tensor ]:
845
957
"""Return a TensorFlow lookup table and a mask from source classes."""
958
+ if self .db is None :
959
+ raise ValueError ('Database not loaded.' )
846
960
mapping = self .db .mappings ['reef_class_to_soundtype' ]
847
961
target_classes = self .db .class_lists [self .target_class_list ]
848
- soundtype_table = source_classes .get_namespace_map_tf_lookup (
849
- mapping , target_class_list = target_classes , keep_unknown = True
962
+ soundtype_table = get_namespace_map_tf_lookup (
963
+ source_classes ,
964
+ mapping ,
965
+ target_class_list = target_classes ,
966
+ keep_unknown = True ,
850
967
)
851
968
# Mask is all 1's. So everything multiplied by 1. Add 0's for a real mask.
852
969
mask = tf .ones ([len (target_classes .classes )])
@@ -934,10 +1051,12 @@ def load_tables(
934
1051
"""
935
1052
tables = {}
936
1053
masks = {}
1054
+ if self .db is None :
1055
+ raise ValueError ('Database not loaded.' )
937
1056
target_classes = self .db .class_lists [self .target_class_list ]
938
1057
939
- label_table , label_mask = source_class_list . get_class_map_tf_lookup (
940
- target_classes
1058
+ label_table , label_mask = get_class_map_tf_lookup (
1059
+ source_class_list , target_classes
941
1060
)
942
1061
tables [self .species_feature_name ] = label_table
943
1062
masks [self .species_feature_name ] = label_mask
@@ -962,11 +1081,11 @@ def load_tables(
962
1081
target_taxa_classes = target_classes .apply_namespace_mapping (
963
1082
namespace_mapping , keep_unknown = True
964
1083
)
965
- namespace_table = source_class_list . get_namespace_map_tf_lookup (
966
- namespace_mapping , keep_unknown = True
1084
+ namespace_table = get_namespace_map_tf_lookup (
1085
+ source_class_list , namespace_mapping , keep_unknown = True
967
1086
)
968
- class_table , label_mask = source_taxa_classes . get_class_map_tf_lookup (
969
- target_taxa_classes
1087
+ class_table , label_mask = get_class_map_tf_lookup (
1088
+ source_taxa_classes , target_taxa_classes
970
1089
)
971
1090
tables [key + '_namespace' ] = namespace_table
972
1091
tables [key + '_class' ] = class_table
0 commit comments