Skip to content

Commit 71f0727

Browse files
authored
Merge pull request #16 from google-research/test_635124847
Update the OCC training to use negative and unlabeled samples for training.
2 parents 903affb + edbf26c commit 71f0727

File tree

5 files changed

+75
-29
lines changed

5 files changed

+75
-29
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
## [0.2.2] - 2024-05-19
27+
28+
* Update the OCC training to use negative and unlabeled samples for training.
29+
2630
## [0.2.1] - 2024-05-18
2731

2832
* Updates to data loaders. Label column filter can now be a list of integers.
@@ -35,7 +39,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
3539

3640
* Initial release
3741

38-
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...HEAD
42+
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...HEAD
43+
[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2
3944
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1
4045
[0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0
4146
[0.1.0]: https://github.com/google-research/spade_anomaly_detection/releases/tag/v0.1.0

spade_anomaly_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131

3232
# A new PyPI release will be pushed every time `__version__` is increased.
3333
# When changing this, also update the CHANGELOG.md.
34-
__version__ = '0.2.1'
34+
__version__ = '0.2.2'

spade_anomaly_detection/performance_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ def setUp(self):
7878
filter_label_value=self.runner_parameters.unlabeled_data_value,
7979
)
8080
self.unlabeled_record_count = len(self.unlabeled_labels)
81+
_, negative_labels = data_loader.load_dataframe(
82+
dataset_name=csv_path,
83+
filter_label_value=self.runner_parameters.negative_data_value,
84+
)
85+
self.negative_record_count = len(negative_labels)
8186

8287
self.occ_fit_batch_size = (
8388
self.unlabeled_record_count // self.runner_parameters.ensemble_count
@@ -128,6 +133,7 @@ def test_spade_auc_performance_pnu_single_batch(self):
128133
self.mock_get_total_records.side_effect = [
129134
self.total_record_count,
130135
self.unlabeled_record_count,
136+
self.negative_record_count,
131137
]
132138

133139
runner_object = runner.Runner(self.runner_parameters)
@@ -151,6 +157,7 @@ def test_spade_auc_performance_pu_single_batch(self):
151157
self.mock_get_total_records.side_effect = [
152158
self.total_record_count,
153159
self.unlabeled_record_count,
160+
self.negative_record_count,
154161
]
155162

156163
runner_object = runner.Runner(self.runner_parameters)

spade_anomaly_detection/runner.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,31 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
104104
else self.runner_parameters.negative_threshold
105105
)
106106

107+
def _get_record_count_based_on_labels(self, label_value: int) -> int:
108+
"""Gets the number of records in the table.
109+
110+
Args:
111+
label_value: The value of the label to use as the filter for records.
112+
113+
Returns:
114+
The count of records.
115+
"""
116+
label_record_count_filter = (
117+
f'{self.runner_parameters.label_col_name} = {label_value}'
118+
)
119+
if self.runner_parameters.where_statements:
120+
label_record_count_where_statements = [
121+
self.runner_parameters.where_statements
122+
] + [label_record_count_filter]
123+
else:
124+
label_record_count_where_statements = [label_record_count_filter]
125+
126+
label_record_count = self.data_loader.get_query_record_result_length(
127+
input_path=self.runner_parameters.input_bigquery_table_path,
128+
where_statements=label_record_count_where_statements,
129+
)
130+
return label_record_count
131+
107132
def check_data_tables(
108133
self,
109134
total_record_count: int,
@@ -166,12 +191,13 @@ def check_data_tables(
166191
)
167192

168193
def instantiate_and_fit_ensemble(
169-
self, unlabeled_record_count: int
194+
self, unlabeled_record_count: int, negative_record_count: int
170195
) -> occ_ensemble.GmmEnsemble:
171196
"""Creates and fits an OCC ensemble on the specified input data.
172197
173198
Args:
174199
unlabeled_record_count: Number of unlabeled records in the table.
200+
negative_record_count: Number of negative records in the table.
175201
176202
Returns:
177203
A trained one class classifier ensemble.
@@ -183,7 +209,8 @@ def instantiate_and_fit_ensemble(
183209
negative_threshold=self.runner_parameters.negative_threshold,
184210
)
185211

186-
records_per_occ = unlabeled_record_count // ensemble_object.ensemble_count
212+
training_record_count = unlabeled_record_count + negative_record_count
213+
records_per_occ = training_record_count // ensemble_object.ensemble_count
187214
batch_size = records_per_occ // self.runner_parameters.batches_per_model
188215
batch_size = np.min([batch_size, self.runner_parameters.max_occ_batch_size])
189216

@@ -195,7 +222,11 @@ def instantiate_and_fit_ensemble(
195222
where_statements=self.runner_parameters.where_statements,
196223
ignore_columns=self.runner_parameters.ignore_columns,
197224
batch_size=batch_size,
198-
label_column_filter_value=self.runner_parameters.unlabeled_data_value,
225+
# Train using negative labeled data and unlabeled data.
226+
label_column_filter_value=[
227+
self.runner_parameters.unlabeled_data_value,
228+
self.runner_parameters.negative_data_value,
229+
],
199230
)
200231

201232
logging.info('Fitting ensemble.')
@@ -527,20 +558,11 @@ def run(self) -> None:
527558
where_statements=self.runner_parameters.where_statements,
528559
)
529560

530-
unlabeled_record_count_filter = (
531-
f'{self.runner_parameters.label_col_name} = '
532-
f'{self.runner_parameters.unlabeled_data_value}'
561+
unlabeled_record_count = self._get_record_count_based_on_labels(
562+
self.runner_parameters.unlabeled_data_value
533563
)
534-
if self.runner_parameters.where_statements:
535-
unlabeled_record_count_where_statements = [
536-
self.runner_parameters.where_statements
537-
] + [unlabeled_record_count_filter]
538-
else:
539-
unlabeled_record_count_where_statements = [unlabeled_record_count_filter]
540-
541-
unlabeled_record_count = self.data_loader.get_query_record_result_length(
542-
input_path=self.runner_parameters.input_bigquery_table_path,
543-
where_statements=unlabeled_record_count_where_statements,
564+
negative_record_count = self._get_record_count_based_on_labels(
565+
self.runner_parameters.negative_data_value
544566
)
545567

546568
self.check_data_tables(
@@ -549,7 +571,8 @@ def run(self) -> None:
549571
)
550572

551573
ensemble_object = self.instantiate_and_fit_ensemble(
552-
unlabeled_record_count=unlabeled_record_count
574+
unlabeled_record_count=unlabeled_record_count,
575+
negative_record_count=negative_record_count,
553576
)
554577

555578
batch_size = (
@@ -615,10 +638,11 @@ def run(self) -> None:
615638
)
616639

617640
if not self.runner_parameters.upload_only:
641+
if self.supervised_model_object is None:
642+
raise ValueError('Supervised model was not created and trained.')
618643
self.evaluate_model()
619-
if self.supervised_model_object is not None:
620-
self.supervised_model_object.save(
621-
save_location=self.runner_parameters.output_gcs_uri
622-
)
644+
self.supervised_model_object.save(
645+
save_location=self.runner_parameters.output_gcs_uri
646+
)
623647

624648
logging.info('SPADE training completed.')

spade_anomaly_detection/runner_test.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def _create_mock_datasets(self) -> None:
129129
self.per_class_labeled_example_count * 2
130130
) + self.unlabeled_examples
131131
self.total_test_records = self.per_class_labeled_example_count * 2
132+
self.negative_examples = self.per_class_labeled_example_count * 1
132133

133134
unlabeled_features = np.random.rand(
134135
self.unlabeled_examples, num_features
@@ -203,6 +204,7 @@ def _create_mock_datasets(self) -> None:
203204
self.mock_get_query_record_result_length.side_effect = [
204205
self.all_examples,
205206
self.unlabeled_examples,
207+
self.negative_examples,
206208
self.total_test_records,
207209
]
208210
else:
@@ -213,6 +215,7 @@ def _create_mock_datasets(self) -> None:
213215
self.mock_get_query_record_result_length.side_effect = [
214216
self.all_examples,
215217
self.unlabeled_examples,
218+
self.negative_examples,
216219
]
217220

218221
def test_runner_data_loader_no_error(self):
@@ -228,8 +231,14 @@ def test_runner_data_loader_no_error(self):
228231
label_col_name=self.runner_parameters.label_col_name,
229232
where_statements=self.runner_parameters.where_statements,
230233
ignore_columns=self.runner_parameters.ignore_columns,
231-
label_column_filter_value=self.runner_parameters.unlabeled_data_value,
232-
batch_size=self.unlabeled_examples
234+
# Verify that both negative and unlabeled samples are used.
235+
label_column_filter_value=[
236+
self.runner_parameters.unlabeled_data_value,
237+
self.runner_parameters.negative_data_value,
238+
],
239+
# Verify that batch size is computed with both negative and unlabeled
240+
# sample counts.
241+
batch_size=(self.unlabeled_examples + self.negative_examples)
233242
// self.runner_parameters.ensemble_count,
234243
)
235244
# Assert that the data loader is also called to fetch all records.
@@ -311,7 +320,7 @@ def test_runner_get_record_count_without_where_statement_no_error(self):
311320

312321
def test_runner_record_count_raise_error(self):
313322
self.runner_parameters.ensemble_count = 10
314-
self.mock_get_query_record_result_length.side_effect = [5, 0]
323+
self.mock_get_query_record_result_length.side_effect = [5, 0, 1]
315324
runner_object = runner.Runner(self.runner_parameters)
316325

317326
with self.assertRaisesRegex(
@@ -320,7 +329,7 @@ def test_runner_record_count_raise_error(self):
320329
runner_object.run()
321330

322331
def test_runner_no_records_raise_error(self):
323-
self.mock_get_query_record_result_length.side_effect = [0, 0]
332+
self.mock_get_query_record_result_length.side_effect = [0, 0, 0]
324333
runner_object = runner.Runner(self.runner_parameters)
325334

326335
with self.assertRaisesRegex(
@@ -340,7 +349,7 @@ def _assert_regex_in(
340349

341350
def test_record_count_warning_raise(self):
342351
# Will raise a warning when there are < 1k samples in the entire dataset.
343-
self.mock_get_query_record_result_length.side_effect = [500, 100]
352+
self.mock_get_query_record_result_length.side_effect = [500, 100, 10]
344353
runner_object = runner.Runner(self.runner_parameters)
345354

346355
with self.assertLogs() as training_logs:
@@ -452,7 +461,7 @@ def test_batch_sizing_no_error(self, mock_split, mock_pseudo_label):
452461

453462
def test_batch_size_too_large_throw_error(self):
454463
self.runner_parameters.labeling_and_model_training_batch_size = 1000
455-
self.mock_get_query_record_result_length.side_effect = [100, 5]
464+
self.mock_get_query_record_result_length.side_effect = [100, 5, 10]
456465
runner_object = runner.Runner(self.runner_parameters)
457466

458467
with self.assertRaisesRegex(
@@ -695,6 +704,7 @@ def test_dataset_label_values_positive_and_negative_throws_error(self):
695704
self.mock_get_query_record_result_length.side_effect = [
696705
self.all_examples,
697706
self.unlabeled_examples,
707+
self.negative_examples,
698708
total_test_records,
699709
]
700710

0 commit comments

Comments
 (0)