Skip to content

Commit ac12c02

Browse files
Vineet JoshiThe spade_anomaly_detection Authors
Vineet Joshi
authored and
The spade_anomaly_detection Authors
committed
Cache and reused the label counts.
We want to cache and reuse the label counts instead of needing to re-compute the numbers by reloading the input files. PiperOrigin-RevId: 722900183
1 parent 94305f1 commit ac12c02

File tree

2 files changed

+38
-58
lines changed

2 files changed

+38
-58
lines changed

spade_anomaly_detection/runner.py

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class Runner:
8282
int_positive_data_value: The integer value of the positive data value.
8383
int_negative_data_value: The integer value of the negative data value.
8484
int_unlabeled_data_value: The integer value of the unlabeled data value.
85+
train_label_counts: Dictionary of counts of labels in training data.
86+
total_record_count: Total number of records in the training data.
8587
"""
8688

8789
def __init__(self, runner_parameters: parameters.RunnerParameters):
@@ -95,13 +97,21 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
9597
else:
9698
self.data_format = DataFormat.CSV
9799

100+
self.train_label_counts: dict[int | str, int] | None = None
101+
self.total_record_count: int | None = None
98102
if self.data_format == DataFormat.BIGQUERY:
99103
# BigQuery data loaders are the same for input, output and test data.
100104
self.input_data_loader = data_loader.DataLoader(self.runner_parameters)
101105
# Type hint to prevent linter errors.
102106
self.input_data_loader = cast(
103107
data_loader.DataLoader, self.input_data_loader
104108
)
109+
self.total_record_count = (
110+
self.input_data_loader.get_query_record_result_length(
111+
input_path=self.runner_parameters.input_bigquery_table_path,
112+
where_statements=self.runner_parameters.where_statements,
113+
)
114+
)
105115
if not self.runner_parameters.upload_only:
106116
self.test_data_loader = self.input_data_loader
107117
else:
@@ -112,6 +122,29 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
112122
self.input_data_loader = cast(
113123
csv_data_loader.CsvDataLoader, self.input_data_loader
114124
)
125+
_ = self.input_data_loader.load_tf_dataset_from_csv(
126+
input_path=self.runner_parameters.data_input_gcs_uri,
127+
label_col_name=self.runner_parameters.label_col_name,
128+
batch_size=1,
129+
)
130+
self.train_label_counts = self.input_data_loader.label_counts
131+
self.total_record_count = sum(self.train_label_counts.values())
132+
if (
133+
self.runner_parameters.labeling_and_model_training_batch_size
134+
and self.runner_parameters.labeling_and_model_training_batch_size
135+
> self.total_record_count
136+
):
137+
self.runner_parameters.labeling_and_model_training_batch_size = (
138+
self.total_record_count
139+
)
140+
logging.info(
141+
'Labeling and model training batch size is reduced to %s',
142+
self.runner_parameters.labeling_and_model_training_batch_size,
143+
)
144+
logging.info(
145+
'Initial label counts (before supervised training): %s',
146+
self.train_label_counts,
147+
)
115148
if not self.runner_parameters.upload_only:
116149
self.test_data_loader = csv_data_loader.CsvDataLoader(
117150
self.runner_parameters
@@ -182,15 +215,7 @@ def _get_table_statistics(self) -> Mapping[str, float]:
182215
self.runner_parameters.input_bigquery_table_path
183216
)
184217
else:
185-
stats_data_loader = csv_data_loader.CsvDataLoader(self.runner_parameters)
186-
# Type hint to prevent linter errors.
187-
stats_data_loader = cast(csv_data_loader.CsvDataLoader, stats_data_loader)
188-
_ = stats_data_loader.load_tf_dataset_from_csv(
189-
input_path=self.runner_parameters.data_input_gcs_uri,
190-
label_col_name=self.runner_parameters.label_col_name,
191-
batch_size=1,
192-
)
193-
input_table_statistics = stats_data_loader.get_label_thresholds()
218+
input_table_statistics = self.input_data_loader.get_label_thresholds()
194219
logging.info('Input table statistics: %s', input_table_statistics)
195220
return input_table_statistics
196221

@@ -751,52 +776,7 @@ def run(self) -> None:
751776
logging.info('SPADE training started.')
752777

753778
self._check_runner_parameters()
754-
755-
if self.data_format == DataFormat.BIGQUERY:
756-
# Type hint to prevent linter errors.
757-
self.input_data_loader = cast(
758-
data_loader.DataLoader, self.input_data_loader
759-
)
760-
total_record_count = (
761-
self.input_data_loader.get_query_record_result_length(
762-
input_path=self.runner_parameters.input_bigquery_table_path,
763-
where_statements=self.runner_parameters.where_statements,
764-
)
765-
)
766-
else:
767-
# Type hint to prevent linter errors.
768-
self.input_data_loader = cast(
769-
csv_data_loader.CsvDataLoader, self.input_data_loader
770-
)
771-
# Call the data loader to read all the files. This is needed to get the
772-
# label counts.
773-
_ = self.input_data_loader.load_tf_dataset_from_csv(
774-
input_path=self.runner_parameters.data_input_gcs_uri,
775-
label_col_name=self.runner_parameters.label_col_name,
776-
batch_size=1,
777-
)
778-
train_label_counts = self.input_data_loader.label_counts
779-
# This is not ideal, we should not need to read the files
780-
# again. Find a way to get the label counts without reading the files.
781-
# Assumes that data loader has already been used to read the input table.
782-
total_record_count = sum(train_label_counts.values())
783-
if (
784-
self.runner_parameters.labeling_and_model_training_batch_size
785-
and self.runner_parameters.labeling_and_model_training_batch_size
786-
> total_record_count
787-
):
788-
self.runner_parameters.labeling_and_model_training_batch_size = (
789-
total_record_count
790-
)
791-
logging.info(
792-
'Labeling and model training batch size is reduced to %s',
793-
self.runner_parameters.labeling_and_model_training_batch_size,
794-
)
795-
logging.info(
796-
'Label counts before supervised training: %s', train_label_counts
797-
)
798-
799-
logging.info('Total record count: %s', total_record_count)
779+
logging.info('Total record count: %s', self.total_record_count)
800780
unlabeled_record_count = self._get_record_count_based_on_labels(
801781
self.int_unlabeled_data_value
802782
)
@@ -805,7 +785,7 @@ def run(self) -> None:
805785
)
806786

807787
self.check_data_tables(
808-
total_record_count=total_record_count,
788+
total_record_count=self.total_record_count,
809789
unlabeled_record_count=unlabeled_record_count,
810790
)
811791

@@ -816,7 +796,7 @@ def run(self) -> None:
816796

817797
batch_size = (
818798
self.runner_parameters.labeling_and_model_training_batch_size
819-
or total_record_count
799+
or self.total_record_count
820800
)
821801
if self.data_format == DataFormat.BIGQUERY:
822802
self.input_data_loader = cast(

spade_anomaly_detection/runner_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,8 @@ def test_evaluation_dataset_batch_training(self):
813813
# batches from the entire dataset.
814814
self.runner_parameters.labeling_and_model_training_batch_size = 50
815815

816-
runner_object = runner.Runner(self.runner_parameters)
817816
self._create_mock_datasets()
817+
runner_object = runner.Runner(self.runner_parameters)
818818

819819
runner_object.run()
820820

0 commit comments

Comments
 (0)