Skip to content

Commit 7415c29

Browse files
raj-sinhaThe spade_anomaly_detection Authors
authored and
The spade_anomaly_detection Authors
committed
Update to allow labels class values to be unique arbitrary strings. Also allows the unlabeled value to be the empty string.
For example: positive label value = "positive" negative label value = "negative" unlabeled label value = "" PiperOrigin-RevId: 714146904
1 parent 462e53e commit 7415c29

File tree

6 files changed

+299
-78
lines changed

6 files changed

+299
-78
lines changed

spade_anomaly_detection/csv_data_loader.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,22 @@
4444
import tensorflow as tf
4545

4646

47-
# Types are from //cloud/ml/research/data_utils/feature_metadata.py
4847
_FEATURES_TYPE: Final[str] = 'FLOAT64'
4948
_SOURCE_LABEL_TYPE: Final[str] = 'STRING'
5049
_SOURCE_LABEL_DEFAULT_VALUE: Final[str] = '-1'
5150
_LABEL_TYPE: Final[str] = 'INT64'
51+
_STRING_TO_INTEGER_LABEL_MAP: dict[str | int, int] = {
52+
1: 1,
53+
0: 0,
54+
-1: -1,
55+
'': -1,
56+
'-1': -1,
57+
'0': 0,
58+
'1': 1,
59+
'positive': 1,
60+
'negative': 0,
61+
'unlabeled': -1,
62+
}
5263

5364
# Setting the shuffle buffer size to 1M seems to be necessary to get the CSV
5465
# reader to provide a diversity of data to the model.
@@ -167,12 +178,12 @@ def from_inputs_file(
167178
raise ValueError(
168179
f'Label column {label_column_name} not found in the header: {header}'
169180
)
170-
num_features = len(all_columns) - 1
171181
features_types = [_FEATURES_TYPE] * len(all_columns)
172182
column_names_dict = collections.OrderedDict(
173183
zip(all_columns, features_types)
174184
)
175185
column_names_dict[label_column_name] = _SOURCE_LABEL_DEFAULT_VALUE
186+
num_features = len(all_columns) - 1
176187
return ColumnNamesInfo(
177188
column_names_dict=column_names_dict,
178189
header=header,
@@ -216,6 +227,13 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
216227
self.runner_parameters.negative_data_value,
217228
self.runner_parameters.unlabeled_data_value,
218229
]
230+
# Add any labels that are not already in the map.
231+
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.positive_data_value] = 1
232+
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.negative_data_value] = 0
233+
_STRING_TO_INTEGER_LABEL_MAP[
234+
self.runner_parameters.unlabeled_data_value
235+
] = -1
236+
219237
# Construct a label remap from string labels to integers. The table is not
220238
# necessary for the case when the labels are all integers. But instead of
221239
# checking if the labels are all integers, we construct the table and use
@@ -286,7 +304,8 @@ def get_inputs_metadata(
286304
)
287305
# Get information about the columns.
288306
column_names_info = ColumnNamesInfo.from_inputs_file(
289-
csv_filenames[0], label_column_name
307+
csv_filenames[0],
308+
label_column_name,
290309
)
291310
logging.info(
292311
'Obtained metadata for data with CSV prefix %s (number of features=%d)',
@@ -360,22 +379,19 @@ def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disab
360379
@classmethod
361380
def convert_str_to_int(cls, value: str) -> int:
362381
"""Converts a string integer label to an integer label."""
363-
if isinstance(value, str) and value.lstrip('-').isdigit():
364-
return int(value)
365-
elif isinstance(value, int):
366-
return value
382+
if value in _STRING_TO_INTEGER_LABEL_MAP:
383+
return _STRING_TO_INTEGER_LABEL_MAP[value]
367384
else:
368385
raise ValueError(
369-
f'Label {value} of type {type(value)} is not a string integer.'
386+
f'Label {value} of type {type(value)} is not a string integer or '
387+
'mappable to an integer.'
370388
)
371389

372390
@classmethod
373391
def _get_label_remap_table(
374392
cls, labels_mapping: dict[str, int]
375393
) -> tf.lookup.StaticHashTable:
376394
"""Returns a label remap table that converts string labels to integers."""
377-
# The possible keys are '', '-1, '0', '1'. None is not included because the
378-
# Data Loader will default to '' if the label is None.
379395
keys_tensor = tf.constant(
380396
list(labels_mapping.keys()),
381397
dtype=tf.dtypes.as_dtype(_SOURCE_LABEL_TYPE.lower()),
@@ -390,6 +406,14 @@ def _get_label_remap_table(
390406
)
391407
return label_remap_table
392408

409+
def remap_label(self, label: str | tf.Tensor) -> int | tf.Tensor:
410+
"""Remaps the label to an integer."""
411+
if isinstance(label, str) or (
412+
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
413+
):
414+
return self._label_remap_table.lookup(label)
415+
return label
416+
393417
def load_tf_dataset_from_csv(
394418
self,
395419
input_path: str,
@@ -441,6 +465,7 @@ def load_tf_dataset_from_csv(
441465
self._last_read_metadata.column_names_info.column_names_dict.values()
442466
)
443467
]
468+
logging.info('column_defaults: %s', column_defaults)
444469

445470
# Construct a single dataset out of multiple CSV files.
446471
# TODO(sinharaj): Remove the determinism after testing.
@@ -456,7 +481,7 @@ def load_tf_dataset_from_csv(
456481
na_value='',
457482
header=True,
458483
num_epochs=1,
459-
shuffle=True,
484+
shuffle=False,
460485
shuffle_buffer_size=_SHUFFLE_BUFFER_SIZE,
461486
shuffle_seed=self.runner_parameters.random_seed,
462487
prefetch_buffer_size=tf.data.AUTOTUNE,
@@ -473,17 +498,9 @@ def load_tf_dataset_from_csv(
473498
'created.'
474499
)
475500

476-
def remap_label(label: str | tf.Tensor) -> int | tf.Tensor:
477-
"""Remaps the label to an integer."""
478-
if isinstance(label, str) or (
479-
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
480-
):
481-
return self._label_remap_table.lookup(label)
482-
return label
483-
484501
# The Dataset can have labels of type int or str. Cast them to int.
485502
dataset = dataset.map(
486-
lambda features, label: (features, remap_label(label)),
503+
lambda features, label: (features, self.remap_label(label)),
487504
num_parallel_calls=tf.data.AUTOTUNE,
488505
deterministic=True,
489506
)
@@ -535,7 +552,6 @@ def combine_features_dict_into_tensor(
535552
self._label_counts = {
536553
k: v.numpy() for k, v in self.counts_by_label(dataset).items()
537554
}
538-
logging.info('Label counts: %s', self._label_counts)
539555

540556
return dataset
541557

@@ -554,11 +570,11 @@ def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:
554570

555571
@tf.function
556572
def count_class(
557-
counts: Dict[int, int], # Keys are always strings.
573+
counts: Dict[int, int],
558574
batch: Tuple[tf.Tensor, tf.Tensor],
559575
) -> Dict[int, int]:
560576
_, labels = batch
561-
# Keys are always strings.
577+
labels = self.remap_label(labels)
562578
new_counts: Dict[int, int] = counts.copy()
563579
for i in self.all_labels:
564580
# This function is called after the Dataset is constructed and the
@@ -582,6 +598,59 @@ def count_class(
582598
)
583599
return counts
584600

601+
def counts_by_original_label(
602+
self, dataset: tf.data.Dataset
603+
) -> tuple[dict[str, tf.Tensor], dict[int, tf.Tensor]]:
604+
"""Counts the number of samples in each label class in the dataset."""
605+
606+
all_int_labels = [l for l in self.all_labels if isinstance(l, int)]
607+
logging.info('all_int_labels: %s', all_int_labels)
608+
all_str_labels = [l for l in self.all_labels if isinstance(l, str)]
609+
logging.info('all_str_labels: %s', all_str_labels)
610+
611+
@tf.function
612+
def count_original_class(
613+
counts: Dict[int | str, int],
614+
batch: Tuple[tf.Tensor, tf.Tensor],
615+
) -> Dict[int | str, int]:
616+
keys_are_int = all(isinstance(k, int) for k in counts.keys())
617+
if keys_are_int:
618+
all_labels = all_int_labels
619+
else:
620+
all_labels = all_str_labels
621+
_, labels = batch
622+
new_counts: Dict[int | str, int] = counts.copy()
623+
for label in all_labels:
624+
cc: tf.Tensor = tf.cast(labels == label, tf.int32)
625+
if label in list(new_counts.keys()):
626+
new_counts[label] += tf.reduce_sum(cc)
627+
else:
628+
new_counts[label] = tf.reduce_sum(cc)
629+
return new_counts
630+
631+
int_keys_map = {
632+
k: v
633+
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
634+
if isinstance(k, int)
635+
}
636+
initial_int_state = dict((int(label), 0) for label in int_keys_map.keys())
637+
if initial_int_state:
638+
int_counts = dataset.reduce(
639+
initial_state=initial_int_state, reduce_func=count_original_class
640+
)
641+
else:
642+
int_counts = {}
643+
str_keys_map = {
644+
k: v
645+
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
646+
if isinstance(k, str)
647+
}
648+
initial_str_state = dict((str(label), 0) for label in str_keys_map.keys())
649+
str_counts = dataset.reduce(
650+
initial_state=initial_str_state, reduce_func=count_original_class
651+
)
652+
return int_counts, str_counts
653+
585654
def get_label_thresholds(self) -> Mapping[str, float]:
586655
"""Computes positive and negative thresholds based on label ratios.
587656

0 commit comments

Comments
 (0)