Skip to content

Commit

Permalink
Cached and reused the label counts.
Browse files Browse the repository at this point in the history
We want to cache and reuse the label counts instead of needing to re-compute the
numbers by reloading the input files.

PiperOrigin-RevId: 722900183
  • Loading branch information
Vineet Joshi authored and The spade_anomaly_detection Authors committed Feb 4, 2025
1 parent 94305f1 commit 9b58c52
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 58 deletions.
94 changes: 37 additions & 57 deletions spade_anomaly_detection/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class Runner:
int_positive_data_value: The integer value of the positive data value.
int_negative_data_value: The integer value of the negative data value.
int_unlabeled_data_value: The integer value of the unlabeled data value.
train_label_counts: Dictionary of counts of labels in training data.
total_record_count: Total number of records in the training data.
"""

def __init__(self, runner_parameters: parameters.RunnerParameters):
Expand All @@ -95,13 +97,21 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
else:
self.data_format = DataFormat.CSV

self.train_label_counts: dict[int | str, int] | None = None
self.total_record_count: int | None = None
if self.data_format == DataFormat.BIGQUERY:
# BigQuery data loaders are the same for input, output and test data.
self.input_data_loader = data_loader.DataLoader(self.runner_parameters)
# Type hint to prevent linter errors.
self.input_data_loader = cast(
data_loader.DataLoader, self.input_data_loader
)
self.total_record_count = (
self.input_data_loader.get_query_record_result_length(
input_path=self.runner_parameters.input_bigquery_table_path,
where_statements=self.runner_parameters.where_statements,
)
)
if not self.runner_parameters.upload_only:
self.test_data_loader = self.input_data_loader
else:
Expand All @@ -112,6 +122,29 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
self.input_data_loader = cast(
csv_data_loader.CsvDataLoader, self.input_data_loader
)
_ = self.input_data_loader.load_tf_dataset_from_csv(
input_path=self.runner_parameters.data_input_gcs_uri,
label_col_name=self.runner_parameters.label_col_name,
batch_size=1,
)
self.train_label_counts = self.input_data_loader.label_counts
self.total_record_count = sum(self.train_label_counts.values())
if (
self.runner_parameters.labeling_and_model_training_batch_size
and self.runner_parameters.labeling_and_model_training_batch_size
> self.total_record_count
):
self.runner_parameters.labeling_and_model_training_batch_size = (
self.total_record_count
)
logging.info(
'Labeling and model training batch size is reduced to %s',
self.runner_parameters.labeling_and_model_training_batch_size,
)
logging.info(
'Initial label counts (before supervised training): %s',
self.train_label_counts,
)
if not self.runner_parameters.upload_only:
self.test_data_loader = csv_data_loader.CsvDataLoader(
self.runner_parameters
Expand Down Expand Up @@ -182,15 +215,7 @@ def _get_table_statistics(self) -> Mapping[str, float]:
self.runner_parameters.input_bigquery_table_path
)
else:
stats_data_loader = csv_data_loader.CsvDataLoader(self.runner_parameters)
# Type hint to prevent linter errors.
stats_data_loader = cast(csv_data_loader.CsvDataLoader, stats_data_loader)
_ = stats_data_loader.load_tf_dataset_from_csv(
input_path=self.runner_parameters.data_input_gcs_uri,
label_col_name=self.runner_parameters.label_col_name,
batch_size=1,
)
input_table_statistics = stats_data_loader.get_label_thresholds()
input_table_statistics = self.input_data_loader.get_label_thresholds()
logging.info('Input table statistics: %s', input_table_statistics)
return input_table_statistics

Expand Down Expand Up @@ -751,52 +776,7 @@ def run(self) -> None:
logging.info('SPADE training started.')

self._check_runner_parameters()

if self.data_format == DataFormat.BIGQUERY:
# Type hint to prevent linter errors.
self.input_data_loader = cast(
data_loader.DataLoader, self.input_data_loader
)
total_record_count = (
self.input_data_loader.get_query_record_result_length(
input_path=self.runner_parameters.input_bigquery_table_path,
where_statements=self.runner_parameters.where_statements,
)
)
else:
# Type hint to prevent linter errors.
self.input_data_loader = cast(
csv_data_loader.CsvDataLoader, self.input_data_loader
)
# Call the data loader to read all the files. This is needed to get the
# label counts.
_ = self.input_data_loader.load_tf_dataset_from_csv(
input_path=self.runner_parameters.data_input_gcs_uri,
label_col_name=self.runner_parameters.label_col_name,
batch_size=1,
)
train_label_counts = self.input_data_loader.label_counts
# This is not ideal, we should not need to read the files
# again. Find a way to get the label counts without reading the files.
# Assumes that data loader has already been used to read the input table.
total_record_count = sum(train_label_counts.values())
if (
self.runner_parameters.labeling_and_model_training_batch_size
and self.runner_parameters.labeling_and_model_training_batch_size
> total_record_count
):
self.runner_parameters.labeling_and_model_training_batch_size = (
total_record_count
)
logging.info(
'Labeling and model training batch size is reduced to %s',
self.runner_parameters.labeling_and_model_training_batch_size,
)
logging.info(
'Label counts before supervised training: %s', train_label_counts
)

logging.info('Total record count: %s', total_record_count)
logging.info('Total record count: %s', self.total_record_count)
unlabeled_record_count = self._get_record_count_based_on_labels(
self.int_unlabeled_data_value
)
Expand All @@ -805,7 +785,7 @@ def run(self) -> None:
)

self.check_data_tables(
total_record_count=total_record_count,
total_record_count=self.total_record_count,
unlabeled_record_count=unlabeled_record_count,
)

Expand All @@ -816,7 +796,7 @@ def run(self) -> None:

batch_size = (
self.runner_parameters.labeling_and_model_training_batch_size
or total_record_count
or self.total_record_count
)
if self.data_format == DataFormat.BIGQUERY:
self.input_data_loader = cast(
Expand Down
2 changes: 1 addition & 1 deletion spade_anomaly_detection/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,8 @@ def test_evaluation_dataset_batch_training(self):
# batches from the entire dataset.
self.runner_parameters.labeling_and_model_training_batch_size = 50

runner_object = runner.Runner(self.runner_parameters)
self._create_mock_datasets()
runner_object = runner.Runner(self.runner_parameters)

runner_object.run()

Expand Down

0 comments on commit 9b58c52

Please sign in to comment.