Skip to content

Commit 94305f1

Browse files
raj-sinhaThe spade_anomaly_detection Authors
authored and
The spade_anomaly_detection Authors
committed
Major revision to extend SPADE with new capabilities. Now it is possible to
set a voting strategy for the pseudolabeler. It is possible to have a separate number of GMM components per model. The `alpha` weight parameter can now be set separately for positive and negative pseudo-labels. See the updated README for more details. PiperOrigin-RevId: 718863197
1 parent 0bcf797 commit 94305f1

14 files changed

+394
-115
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424

2525
## [Unreleased]
2626

27+
## [0.4.0] - 2025-01-22
28+
29+
* Major revision to extend SPADE with new capabilities. Now it is possible to
30+
set a voting strategy for the pseudolabeler. It is possible to have a separate
31+
number of GMM components per model. The `alpha` weight parameter can now be set
32+
separately for positive and negative pseudolabels.
33+
* Allow labels to be arbitrary strings.
34+
* Upgrade Docker base image to Tensorflow 2.17.
35+
2736
## [0.3.3] - 2024-08-05
2837

2938
* Add support for wildcards in GCS URIs in CSV data loader.
@@ -58,7 +67,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
5867

5968
* Initial release
6069

61-
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.3...HEAD
70+
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.4.0...HEAD
71+
[0.4.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.3...v0.4.0
6272
[0.3.3]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.2...v0.3.3
6373
[0.3.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.1...v0.3.2
6474
[0.3.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...v0.3.1

README.md

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ The metric reported by the pipeline is model [AUC](https://developers.google.com
7878

7979
<span style="color:red;background-color:lightgrey">label_col_name (string)</span>: The name of the label column in the input BigQuery table.
8080

81+
<span style="color:red;background-color:lightgrey">labels_are_strings</span>: Whether the labels in the input dataset are strings or integers.
82+
8183
<span style="color:red;background-color:lightgrey">positive_data_value (integer)</span>: The value used in the label column to denote positive data - data points that are anomalous. “1” can be used, for example.
8284

8385
<span style="color:red;background-color:lightgrey">negative_data_value (integer)</span>: The value used in the label column to denote negative data - data points that are not anomalous. “0” can be used, for example.
@@ -99,17 +101,21 @@ one class classifier ensemble to label a point as negative. The higher this valu
99101

100102
<span style="color:yellow;background-color:lightgrey">data_test_gcs_uri</span>: Cloud Storage location to store the CSV data to be used for evaluating the supervised model. Note that the positive and negative label values must also be the same in this testing set. It is okay to have your test labels in that form, or use 1 for positive and 0 for negative. Use exactly one of BigQuery locations or GCS locations.
101103

102-
<span style="color:yellow;background-color:lightgrey">upload_only</span>: Use this setting in conjunction with `output_bigquery_table_path` or `data_output_gcs_uri`. When `True`, the algorithm will just upload the pseudo labeled data to the specified table, and will skip training a supervised model. When set to `False`, the algorithm will also train a supervised model and upload it to a GCS location. Default is `False`.
104+
<span style="color:yellow;background-color:lightgrey">upload_only (bool)</span>: Use this setting in conjunction with `output_bigquery_table_path` or `data_output_gcs_uri`. When `True`, the algorithm will just upload the pseudo labeled data to the specified table, and will skip training a supervised model. When set to `False`, the algorithm will also train a supervised model and upload it to a GCS location. Default is `False`.
103105

104106
<span style="color:yellow;background-color:lightgrey">output_bigquery_table_path</span>: A complete BigQuery path in the form of 'project.dataset.table' to be used for uploading the pseudo labeled data. This includes features and new labels. By default, we will use the column names from the input_bigquery_table_path BigQuery table. Use exactly one of BigQuery locations or GCS locations.
105107

106108
<span style="color:yellow;background-color:lightgrey">data_output_gcs_uri</span>: Cloud Storage location used for uploading the pseudo labeled data as CSV. This includes features and new labels. By default, we will use the column names from the data_input_gcs_uri table. Use exactly one of BigQuery locations or GCS locations.
107109

108-
<span style="color:yellow;background-color:lightgrey">alpha (float)</span>: Sample weights for weighting the loss function, only for pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. By default, we use alpha = 1.0.
110+
<span style="color:yellow;background-color:lightgrey">voting_strategy (bool)</span>: The voting strategy to use when determining if a data point is anomalous. By default, we use unanimous voting, meaning all the models in the ensemble need to agree in order to label a data point as anomalous.
111+
112+
<span style="color:yellow;background-color:lightgrey">alpha (float)</span>: Sample weights for weighting the loss function, only for positively pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. If this is provided and `alpha_negative_pseudolabels` is not provided, then this value will be used for both positive and negative pseudo-labeled data. By default, we use alpha = 1.0.
113+
114+
<span style="color:yellow;background-color:lightgrey">alpha_negative_pseudolabels (float)</span>: Sample weights for weighting the loss function, only for negatively pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. If this is not provided, then the `alpha` value will be used for both positive and negative pseudo-labeled data. By default, we use alpha_negative_pseudolabels = 1.0.
109115

110116
<span style="color:yellow;background-color:lightgrey">ensemble_count</span>: Integer representing the number of one class classifiers in the ensemble used for pseudo labeling unlabeled data points. The more models in the ensemble, the less likely it is for all the models to gain consensus, and thus will reduce the amount of labeled data points. By default, we use 5 one class classifiers.
111117

112-
<span style="color:yellow;background-color:lightgrey">n_components</span>: Integer representing the number of components to use in the one class classifier ensemble. By default, we use 1 component.
118+
<span style="color:yellow;background-color:lightgrey">n_components</span>: The number of components to use in the one class classifier ensemble. By default, we use 1 component. Pass a single integer if all the ensemble models should have the same number of components. Pass a space-separated list of integers if you want to use different numbers of components for each model in the ensemble. By default, we use 1 component.
113119

114120
<span style="color:yellow;background-color:lightgrey">covariance_type</span>: String representing the covariance type to use in the one class classifier ensemble. By default, we use 'full' covariance. Note that when there are many components, a 'full' covariance matrix may not be suitable.
115121

@@ -189,20 +195,30 @@ OUTPUT_BIGQUERY_TABLE_PATH=${5:-"${PROJECT_ID}.[bq-dataset].[bq-output-table]"}
189195
DATA_OUTPUT_GCS_URI=${6:-""}
190196
OUTPUT_GCS_URI=${7:-"gs://[gcs-bucket]/[model-folder]"}
191197
LABEL_COL_NAME=${8:-"y"}
192-
# The label column is of type float, these must match in order for array
198+
# The label column is of type string, these must match in order for array
193199
# filtering to work correctly.
194200
POSITIVE_DATA_VALUE=${9:-"1"}
195201
NEGATIVE_DATA_VALUE=${10:-"0"}
196202
UNLABELED_DATA_VALUE=${11:-"-1"}
197203
POSITIVE_THRESHOLD=${12:-".1"}
198204
NEGATIVE_THRESHOLD=${13:-"95"}
199205
TEST_BIGQUERY_TABLE_PATH=${14:-"${PROJECT_ID}.[bq-dataset].[bq-test-table]"}
200-
DATA_TEST_GCS_URI=${15:-""}
201-
TEST_LABEL_COL_NAME=${16:-"y"}
202-
ALPHA=${17:-"1.0"}
203-
ENSEMBLE_COUNT=${19:-"5"}
204-
VERBOSE=${22:-"True"}
205-
UPLOAD_ONLY=${23:-"False"}
206+
TEST_DATASET_HOLDOUT_FRACTION=${15:-"0"}
207+
DATA_TEST_GCS_URI=${16:-""}
208+
TEST_LABEL_COL_NAME=${17:-"y"}
209+
VOTING_STRATEGY=${18:-"UNANIMOUS"}
210+
ALPHA=${19:-"1.0"}
211+
ALPHA_NEGATIVE_PSEUDOLABELS=${20:-"1.0"}
212+
BATCHES_PER_MODEL=${21:-"1"}
213+
ENSEMBLE_COUNT=${22:-"5"}
214+
N_COMPONENTS=${23:-"1"}
215+
# N_COMPONENTS=${23:-"1,3,5,7,9"}
216+
COVARIANCE_TYPE=${24:-"full"}
217+
MAX_OCC_BATCH_SIZE=${25:-"50000"}
218+
LABELING_AND_MODEL_TRAINING_BATCH_SIZE=${26:-"100000"}
219+
LABELS_ARE_STRINGS=${27:-"True"}
220+
VERBOSE=${28:-"True"}
221+
UPLOAD_ONLY=${29:-"False"}
206222
207223
IMAGE_URI="us-docker.pkg.dev/[project_id]/spade-anomaly-detection/spade:latest"
208224

spade_anomaly_detection/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Google Cloud's optimized Tensorflow image
2-
FROM gcr.io/deeplearning-platform-release/tf-cpu.2-10
2+
FROM us-docker.pkg.dev/deeplearning-platform-release/gcr.io/tf2-cpu.2-17.py310
33

44
# Alternative Tensorflow image with GPU.
55
# FROM gcr.io/deeplearning-platform-release/tf-gpu.2-10:latest

spade_anomaly_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131

3232
# A new PyPI release will be pushed every time `__version__` is increased.
3333
# When changing this, also update the CHANGELOG.md.
34-
__version__ = '0.3.3'
34+
__version__ = '0.4.0'

spade_anomaly_detection/data_loader_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def setUp(self):
7373
output_bigquery_table_path='',
7474
data_output_gcs_uri='',
7575
alpha=1.0,
76+
alpha_negative_pseudolabels=1.0,
7677
batches_per_model=1,
7778
labeling_and_model_training_batch_size=None,
7879
ensemble_count=5,

spade_anomaly_detection/occ_ensemble.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,34 @@
3232
# Using typing instead of collections due to Vertex training containers
3333
# not supporting them.
3434

35+
import dataclasses
3536
from typing import Final, MutableMapping, Optional, Sequence
3637

3738
from absl import logging
3839
import numpy as np
3940
from sklearn import mixture
41+
from spade_anomaly_detection import parameters
4042
import tensorflow as tf
4143

4244

45+
@dataclasses.dataclass
46+
class PseudolabelsContainer:
47+
"""Container to hold the outputs of the pseudolabeling process.
48+
49+
Attributes:
50+
new_features: np.ndarray of features for the new pseudolabeled data.
51+
new_labels: np.ndarray of labels for the new pseudolabeled data.
52+
weights: np.ndarray of weights for the new pseudolabeled data.
53+
pseudolabel_flags: np.ndarray of flags indicating whether the data point is
54+
ground truth or pseudolabeled.
55+
"""
56+
57+
new_features: np.ndarray
58+
new_labels: np.ndarray
59+
weights: np.ndarray
60+
pseudolabel_flags: np.ndarray
61+
62+
4363
_RANDOM_SEED: Final[int] = 42
4464
_SHUFFLE_BUFFER_SIZE: Final[int] = 10_000
4565
_LABEL_TYPE: Final[str] = 'INT64'
@@ -76,6 +96,9 @@ class GmmEnsemble:
7696
precision when raising this value, and an increase in recall when lowering
7797
it. Equavalent to saying the given data point needs to be X percentile or
7898
greater in order to be considered anomalous.
99+
voting_strategy: The voting strategy to use when determining if a data point
100+
is anomalous. By default, we use unanimous voting, meaning all the models
101+
in the ensemble need to agree in order to label a data point as anomalous.
79102
unlabeled_record_count: The number of unlabeled records in the dataset.
80103
negative_record_count: The number of negative records in the dataset.
81104
unlabeled_data_value: The value used in the label column to denote unlabeled
@@ -90,13 +113,16 @@ class GmmEnsemble:
90113
# TODO(b/247116870): Create dataclass when another OCC is added.
91114
def __init__(
92115
self,
93-
n_components: int = 1,
116+
n_components: tuple[int, ...] = (1,),
94117
covariance_type: str = 'full',
95118
init_params: str = 'kmeans',
96119
max_iter: int = 100,
97120
ensemble_count: int = 5,
98121
positive_threshold: float = 1.0,
99122
negative_threshold: float = 95.0,
123+
voting_strategy: parameters.VotingStrategy = (
124+
parameters.VotingStrategy.UNANIMOUS
125+
),
100126
random_seed: int = _RANDOM_SEED,
101127
unlabeled_record_count: int | None = None,
102128
negative_record_count: int | None = None,
@@ -111,6 +137,7 @@ def __init__(
111137
self.ensemble_count = ensemble_count
112138
self.positive_threshold = positive_threshold
113139
self.negative_threshold = negative_threshold
140+
self.voting_strategy = voting_strategy
114141
self._random_seed = random_seed
115142
self.unlabeled_record_count = unlabeled_record_count
116143
self.negative_record_count = negative_record_count
@@ -122,14 +149,21 @@ def __init__(
122149

123150
self._warm_start = False
124151

125-
def _get_model(self) -> mixture.GaussianMixture:
152+
def _get_model(self, idx: int) -> mixture.GaussianMixture:
126153
"""Instantiates a Gaussian mixture model.
127154
155+
Args:
156+
idx: The index of the model in the ensemble.
157+
128158
Returns:
129159
Gaussian mixture model with class attributes.
130160
"""
131161
return mixture.GaussianMixture(
132-
n_components=self.n_components,
162+
n_components=(
163+
self.n_components[idx]
164+
if len(self.n_components) == self.ensemble_count
165+
else self.n_components[0]
166+
),
133167
covariance_type=self.covariance_type,
134168
init_params=self.init_params,
135169
warm_start=self._warm_start,
@@ -249,8 +283,8 @@ def fit(
249283
)
250284
dataset_iterator = ds_batched.as_numpy_iterator()
251285

252-
for _ in range(self.ensemble_count):
253-
model = self._get_model()
286+
for idx in range(self.ensemble_count):
287+
model = self._get_model(idx=idx)
254288

255289
for _ in range(batches_per_occ):
256290
features, _ = dataset_iterator.next()
@@ -269,6 +303,26 @@ def fit(
269303

270304
return self.ensemble
271305

306+
def _vote(self, model_scores: np.ndarray) -> np.ndarray:
307+
"""Votes on whether a data point is anomalous or not.
308+
309+
Args:
310+
model_scores: The scores for each model in the ensemble for a given data
311+
point. Can be the positive score or the negative score.
312+
313+
Returns:
314+
True if the data point is anomalous, False otherwise.
315+
"""
316+
if self.voting_strategy == parameters.VotingStrategy.UNANIMOUS:
317+
return model_scores == self.ensemble_count
318+
elif self.voting_strategy == parameters.VotingStrategy.MAJORITY:
319+
return model_scores > self.ensemble_count // 2
320+
else:
321+
raise ValueError(
322+
f'Unsupported voting strategy: {self.voting_strategy}. Supported'
323+
' strategies are UNANIMOUS and MAJORITY.'
324+
)
325+
272326
def _score_unlabeled_data(
273327
self,
274328
unlabeled_features: np.ndarray,
@@ -310,8 +364,8 @@ def _score_unlabeled_data(
310364
model_scores_pos += binary_scores_pos
311365
model_scores_neg += binary_scores_neg
312366

313-
positive_indices = np.where(model_scores_pos == self.ensemble_count)[0]
314-
negative_indices = np.where(model_scores_neg == self.ensemble_count)[0]
367+
positive_indices = np.where(self._vote(model_scores_pos))[0]
368+
negative_indices = np.where(self._vote(model_scores_neg))[0]
315369

316370
return {
317371
'positive_indices': positive_indices,
@@ -326,8 +380,9 @@ def pseudo_label(
326380
negative_data_value: str | int | None,
327381
unlabeled_data_value: str | int,
328382
alpha: float = 1.0,
383+
alpha_negative_pseudolabels: float = 1.0,
329384
verbose: Optional[bool] = False,
330-
) -> Sequence[np.ndarray]:
385+
) -> PseudolabelsContainer:
331386
"""Labels unlabeled data using the trained ensemble of OCCs.
332387
333388
Args:
@@ -341,16 +396,18 @@ def pseudo_label(
341396
data - data points that are not anomalous.
342397
unlabeled_data_value: The value used in the label column to denote
343398
unlabeled data.
344-
alpha: This value is used to adjust the influence of the pseudo labeled
345-
data in training the supervised model.
399+
alpha: This value is used to adjust the influence of the positively pseudo
400+
labeled data in training the supervised model.
401+
alpha_negative_pseudolabels: This value is used to adjust the influence of
402+
the negatively pseudo labeled data in training the supervised model.
346403
verbose: Chooses the amount of logging info to display. This can be useful
347404
when debugging model performance.
348405
349406
Returns:
350-
A sequence including updated features (features for which we now have
407+
A container including updated features (features for which we now have
351408
labels for), updated labels (includes pseudo labeled positive and negative
352409
values, as well as ground truth), the weights (correct alpha values)
353-
for the new pseudo labeled data points, and a binary flag that indicates
410+
for the new pseudo labeled data points, a binary flag that indicates
354411
whether the data point is newly pseudolabeled, or ground truth. Labels are
355412
in the format of 1 for positive and 0 for negative. Flag is 1 for
356413
pseudo-labeled and 0 for ground truth.
@@ -390,13 +447,15 @@ def pseudo_label(
390447
],
391448
axis=0,
392449
)
393-
weights = np.concatenate([
394-
np.repeat(alpha, len(new_positive_indices)),
395-
np.repeat(alpha, len(new_negative_indices)),
396-
np.ones([len(original_positive_idx)]),
397-
np.ones([len(original_negative_idx)])
398-
],
399-
axis=0)
450+
weights = np.concatenate(
451+
[
452+
np.repeat(alpha, len(new_positive_indices)),
453+
np.repeat(alpha_negative_pseudolabels, len(new_negative_indices)),
454+
np.ones([len(original_positive_idx)]),
455+
np.ones([len(original_negative_idx)]),
456+
],
457+
axis=0,
458+
)
400459
pseudolabel_flags = np.concatenate(
401460
[
402461
np.ones(len(new_positive_indices)),
@@ -412,4 +471,10 @@ def pseudo_label(
412471
len(new_positive_indices))
413472
logging.info('Number of new negative labels: %s',
414473
len(new_negative_indices))
415-
return new_features, new_labels, weights, pseudolabel_flags
474+
475+
return PseudolabelsContainer(
476+
new_features=new_features,
477+
new_labels=new_labels,
478+
weights=weights,
479+
pseudolabel_flags=pseudolabel_flags,
480+
)

0 commit comments

Comments
 (0)