From 84cbfe7e8cece5a9ba13d3bce4d009be379b3b0d Mon Sep 17 00:00:00 2001 From: Raj Sinha Date: Sun, 14 Jul 2024 03:14:12 +0000 Subject: [PATCH] Write out the pseudolabel weights and a flag that indicates whether a sample has a ground truth label (0) or a pseudolabel (1). PiperOrigin-RevId: 652155894 --- CHANGELOG.md | 7 +- pyproject.toml | 2 +- spade_anomaly_detection/__init__.py | 2 +- spade_anomaly_detection/csv_data_loader.py | 43 ++++++- .../csv_data_loader_test.py | 57 ++++++++- spade_anomaly_detection/data_loader.py | 42 ++++++- spade_anomaly_detection/data_loader_test.py | 111 ++++++++++++++++++ spade_anomaly_detection/occ_ensemble.py | 27 +++-- spade_anomaly_detection/occ_ensemble_test.py | 36 ++++-- spade_anomaly_detection/runner.py | 25 ++-- spade_anomaly_detection/runner_test.py | 4 + 11 files changed, 315 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10236fa..47a8bac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.3.1] - 2024-07-13 + +* Now writes out the pseudolabel weights and a flag that indicates whether a sample has a ground truth label (0) or a pseudolabel (1). + ## [0.3.0] - 2024-07-10 * Add the ability to use CSV files on GCS as data input/output/test sources. @@ -45,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Initial release -[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...HEAD +[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.1...HEAD +[0.3.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...v0.3.0 [0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2 [0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1 diff --git a/pyproject.toml b/pyproject.toml index 5923319..094ddf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "pyarrow==14.0.1", "retry==0.9.2", "scikit-learn==1.4.2", - "tensorflow", + "tensorflow==2.12.1", "tensorflow-datasets==4.9.6", "parameterized==0.8.1", "pytest==7.1.2", diff --git a/spade_anomaly_detection/__init__.py b/spade_anomaly_detection/__init__.py index f9d8ac9..f938372 100644 --- a/spade_anomaly_detection/__init__.py +++ b/spade_anomaly_detection/__init__.py @@ -31,4 +31,4 @@ # A new PyPI release will be pushed every time `__version__` is increased. # When changing this, also update the CHANGELOG.md. -__version__ = '0.3.0' +__version__ = '0.3.1' diff --git a/spade_anomaly_detection/csv_data_loader.py b/spade_anomaly_detection/csv_data_loader.py index bfee840..4055966 100644 --- a/spade_anomaly_detection/csv_data_loader.py +++ b/spade_anomaly_detection/csv_data_loader.py @@ -38,6 +38,7 @@ from google.cloud import storage import numpy as np import pandas as pd +from spade_anomaly_detection import data_loader from spade_anomaly_detection import parameters import tensorflow as tf @@ -489,6 +490,8 @@ def upload_dataframe_to_gcs( batch: int, features: np.ndarray, labels: np.ndarray, + weights: Optional[np.ndarray] = None, + pseudolabel_flags: Optional[np.ndarray] = None, ) -> None: """Uploads the dataframe to BigQuery, create or replace table. @@ -496,6 +499,8 @@ def upload_dataframe_to_gcs( batch: The batch number of the pseudo-labeled data. features: Numpy array of features. labels: Numpy array of labels. + weights: Optional numpy array of weights. + pseudolabel_flags: Optional numpy array of pseudolabel flags. Returns: None. @@ -515,15 +520,37 @@ def upload_dataframe_to_gcs( 'Data output GCS URI is not set in the runner parameters. Please set ' 'the `data_output_gcs_uri` field in the runner parameters.' ) - combined_data = np.concatenate( - [features, labels.reshape(len(features), 1)], axis=1 - ) + combined_data = features column_names = list( self._last_read_metadata.column_names_info.column_names_dict.keys() ) + + # If the weights are provided, add them to the column names and to the + # combined data. + if weights is not None: + column_names.append(data_loader.WEIGHT_COLUMN_NAME) + combined_data = np.concatenate( + [combined_data, weights.reshape(len(features), 1).astype(np.float64)], + axis=1, + ) + + # If the pseudolabel flags are provided, add them to the column names and + # to the combined data. + if pseudolabel_flags is not None: + column_names.append(data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME) + combined_data = np.concatenate( + [ + combined_data, + pseudolabel_flags.reshape(len(features), 1).astype(np.int64), + ], + axis=1, + ) + # Make sure the label column is the last column. - # TODO(b/347332980): Add support for the pseudolabel flag. + combined_data = np.concatenate( + [combined_data, labels.reshape(len(features), 1)], axis=1 + ) column_names.remove(self.runner_parameters.label_col_name) column_names.append(self.runner_parameters.label_col_name) @@ -536,6 +563,14 @@ def upload_dataframe_to_gcs( complete_dataframe[self.runner_parameters.label_col_name].astype('bool') ) + # Adjust pseudolabel flag column type. + if pseudolabel_flags is not None: + complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME] = ( + complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME].astype( + np.int64 + ) + ) + output_path = os.path.join( self.runner_parameters.data_output_gcs_uri, f'pseudo_labeled_batch_{batch}.csv', diff --git a/spade_anomaly_detection/csv_data_loader_test.py b/spade_anomaly_detection/csv_data_loader_test.py index e2b4a93..10f53d5 100644 --- a/spade_anomaly_detection/csv_data_loader_test.py +++ b/spade_anomaly_detection/csv_data_loader_test.py @@ -385,30 +385,75 @@ def test_upload_dataframe_to_gcs(self): all_features = self.data_df[["x1", "x2"]].to_numpy() all_labels = self.data_df["y"].to_numpy() # Create 2 batches of features and labels. - # TODO(b/347332980): Update test when pseudolabel flag is added. features1 = all_features[0:2] labels1 = all_labels[0:2] + # Add weights and flags to the first batch. These are pseudolabeled samples. + weights1 = ( + np.repeat([0.1], len(features1)) + .reshape(len(features1), 1) + .astype(np.float64) + ) + flags1 = ( + np.repeat([1], len(features1)) + .reshape(len(features1), 1) + .astype(np.int64) + ) + # Add weights and flags to the first batch. These are ground truth samples. features2 = all_features[2:] labels2 = all_labels[2:] - # Upload batch 1. + weights2 = ( + np.repeat([1.0], len(features2)) + .reshape(len(features2), 1) + .astype(np.float64) + ) + flags2 = ( + np.repeat([0], len(features2)) + .reshape(len(features2), 1) + .astype(np.int64) + ) # Upload batch 1. data_loader.upload_dataframe_to_gcs( batch=1, features=features1, labels=labels1, + weights=weights1, + pseudolabel_flags=flags1, ) # Upload batch 2. data_loader.upload_dataframe_to_gcs( batch=2, features=features2, labels=labels2, + weights=weights2, + pseudolabel_flags=flags2, ) # Sorting means batch 1 file will be first. files_list = sorted(tf.io.gfile.listdir(output_dir)) self.assertLen(files_list, 2) - expected_dfs = [ - self.data_df.iloc[0:2].reset_index(drop=True), - self.data_df.iloc[2:].reset_index(drop=True), - ] + col_names = ["x1", "x2", "alpha", "is_pseudolabel", "y"] + expected_df1 = pd.concat( + [ + self.data_df.iloc[0:2, 0:-1].reset_index(drop=True), + pd.DataFrame(weights1, columns=["alpha"]), + pd.DataFrame(flags1, columns=["is_pseudolabel"]), + self.data_df.iloc[0:2, -1].reset_index(drop=True), + ], + names=col_names, + ignore_index=True, + axis=1, + ) + expected_df1.columns = col_names + expected_df2 = pd.concat( + [ + self.data_df.iloc[2:, 0:-1].reset_index(drop=True), + pd.DataFrame(weights2, columns=["alpha"]), + pd.DataFrame(flags2, columns=["is_pseudolabel"]), + self.data_df.iloc[2:, -1].reset_index(drop=True), + ], + ignore_index=True, + axis=1, + ) + expected_df2.columns = col_names + expected_dfs = [expected_df1, expected_df2] for i, file_name in enumerate(files_list): with self.subTest(msg=f"file_{i}"): file_path = os.path.join(output_dir, file_name) diff --git a/spade_anomaly_detection/data_loader.py b/spade_anomaly_detection/data_loader.py index 68299a5..c269894 100644 --- a/spade_anomaly_detection/data_loader.py +++ b/spade_anomaly_detection/data_loader.py @@ -54,6 +54,9 @@ _DATA_ROOT: Final[str] = 'spade_anomaly_detection/example_data/' +WEIGHT_COLUMN_NAME: Final[str] = 'alpha' +PSEUDOLABEL_FLAG_COLUMN_NAME: Final[str] = 'is_pseudolabel' + def load_dataframe( dataset_name: str, @@ -691,12 +694,19 @@ def upload_dataframe_as_bigquery_table( self, features: np.ndarray, labels: np.ndarray, + weights: Optional[np.ndarray] = None, + pseudolabel_flags: Optional[np.ndarray] = None, ) -> None: """Uploads the dataframe to BigQuery, create or replace table. Args: features: Numpy array of features. labels: Numpy array of labels. + weights: Optional numpy array of weights. + pseudolabel_flags: Optional numpy array of pseudolabel flags. + + Raises: + ValueError: If the metadata has not been read yet. """ if not self.input_feature_metadata: raise ValueError( @@ -705,11 +715,31 @@ def upload_dataframe_as_bigquery_table( 'load_tf_dataset_from_bigquery() before this method ' 'is called.' ) - combined_data = np.concatenate( - [features, labels.reshape(len(features), 1)], axis=1 - ) + combined_data = features + # Get the list of feature and label column names. column_names = list(self.input_feature_metadata.names) + + # If the weights are provided, add them to the column names and to the + # combined data. + if weights is not None: + column_names.append(WEIGHT_COLUMN_NAME) + combined_data = np.concatenate( + [combined_data, weights.reshape(len(features), 1)], axis=1 + ) + + # If the pseudolabel flags are provided, add them to the column names and + # to the combined data. + if pseudolabel_flags is not None: + column_names.append(PSEUDOLABEL_FLAG_COLUMN_NAME) + combined_data = np.concatenate( + [combined_data, pseudolabel_flags.reshape(len(features), 1)], axis=1 + ) + + # Make sure the label column is the last column. + combined_data = np.concatenate( + [combined_data, labels.reshape(len(features), 1)], axis=1 + ) column_names.remove(self.runner_parameters.label_col_name) column_names.append(self.runner_parameters.label_col_name) @@ -722,6 +752,12 @@ def upload_dataframe_as_bigquery_table( complete_dataframe[self.runner_parameters.label_col_name].astype('bool') ) + # Adjust pseudolabel flag column type. + if pseudolabel_flags is not None: + complete_dataframe[PSEUDOLABEL_FLAG_COLUMN_NAME] = complete_dataframe[ + PSEUDOLABEL_FLAG_COLUMN_NAME + ].astype(np.int64) + with bigquery.Client( project=self.table_parts.project_id ) as big_query_client: diff --git a/spade_anomaly_detection/data_loader_test.py b/spade_anomaly_detection/data_loader_test.py index 3ca3b30..1b9ea11 100644 --- a/spade_anomaly_detection/data_loader_test.py +++ b/spade_anomaly_detection/data_loader_test.py @@ -536,6 +536,117 @@ def test_bigquery_table_upload_throw_error_metadata(self): features=features, labels=labels ) + @mock.patch.object(bigquery, 'LoadJobConfig', autospec=True) + def test_upload_dataframe_with_wts_flags_as_bigquery_table_no_error( + self, mock_bqclient_loadjobconfig + ): + self.runner_parameters.output_bigquery_table_path = ( + 'project.dataset.pseudo_labeled_data' + ) + data_loader_object = data_loader.DataLoader(self.runner_parameters) + feature_column_names = [ + 'x1', + 'x2', + data_loader.WEIGHT_COLUMN_NAME, + data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME, + self.runner_parameters.label_col_name, + ] + + features = np.random.rand(10, 2).astype(np.float32) + labels = np.repeat(0, 10).reshape(10, 1).astype(np.int8) + # Two possible values for weight (alpha), repeated 10/2 = 5 times each. + weights = np.repeat([0.1, 1.0], 5).reshape(10, 1).astype(np.float32) + # The corresponding peseudolabel flags are False, True, repeated 5 times. + flags = np.repeat([1, 0], 5).reshape(10, 1).astype(np.int8) + + tf_dataset_instance_mock = mock.create_autospec( + tf.data.Dataset, instance=True + ) + + feature1_metadata = feature_metadata.FeatureMetadata('x1', 0, 'FLOAT64') + feature2_metadata = feature_metadata.FeatureMetadata('x2', 0, 'FLOAT64') + label_metadata = feature_metadata.FeatureMetadata( + self.runner_parameters.label_col_name, 1, 'INT64' + ) + metadata_container = feature_metadata.FeatureMetadataContainer( + [feature1_metadata, feature2_metadata, label_metadata] + ) + + self.mock_bq_dataset.return_value = ( + tf_dataset_instance_mock, + metadata_container, + ) + + # Perform this call so that FeatureMetadata is set. + data_loader_object.load_tf_dataset_from_bigquery( + input_path=self.runner_parameters.input_bigquery_table_path, + label_col_name=self.runner_parameters.label_col_name, + batch_size=self.batch_size, + ) + + data_loader_object.upload_dataframe_as_bigquery_table( + features=features, + labels=labels, + weights=weights, + pseudolabel_flags=flags, + ) + job_config_object = mock_bqclient_loadjobconfig.return_value + + load_table_mock_kwargs = ( + self.mock_bq_client.return_value.__enter__.return_value.load_table_from_dataframe.call_args.kwargs + ) + + with self.subTest(name='LabelColumnCorrect'): + self.assertListEqual( + list( + load_table_mock_kwargs['dataframe'][ + self.runner_parameters.label_col_name + ] + ), + list(labels), + ) + + with self.subTest(name='LabelColumnDataTypeBool'): + self.assertEqual( + load_table_mock_kwargs['dataframe'][ + self.runner_parameters.label_col_name + ].dtype, + bool, + ) + + with self.subTest(name='WeightsColumnCorrect'): + self.assertListEqual( + list( + load_table_mock_kwargs['dataframe'][ + data_loader.WEIGHT_COLUMN_NAME + ] + ), + list(weights), + ) + + with self.subTest(name='PseudolabelFlagsColumnCorrect'): + self.assertListEqual( + list( + load_table_mock_kwargs['dataframe'][ + data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME + ] + ), + list(flags), + ) + + with self.subTest(name='EqualColumnNames'): + self.assertListEqual( + feature_column_names, + list(load_table_mock_kwargs['dataframe'].columns), + ) + with self.subTest(name='EqualDestinationPath'): + self.assertEqual( + self.runner_parameters.output_bigquery_table_path, + load_table_mock_kwargs['destination'], + ) + with self.subTest(name='EqualJobConfig'): + self.assertEqual(job_config_object, load_table_mock_kwargs['job_config']) + def test_get_label_thresholds_no_error(self): mock_query_return_dictionary = { self.runner_parameters.label_col_name: [ diff --git a/spade_anomaly_detection/occ_ensemble.py b/spade_anomaly_detection/occ_ensemble.py index 61f7028..ce59be7 100644 --- a/spade_anomaly_detection/occ_ensemble.py +++ b/spade_anomaly_detection/occ_ensemble.py @@ -224,10 +224,10 @@ def pseudo_label(self, scaler. labels: A numpy array of labels. Strings or integers can be used for denoting the positive, negative, or unlabeled features. - positive_data_value: The value used in the label column to denote - positive data - data points that are anomalous. - negative_data_value: The value used in the label column to denote - negative data - data points that are not anomalous. + positive_data_value: The value used in the label column to denote positive + data - data points that are anomalous. + negative_data_value: The value used in the label column to denote negative + data - data points that are not anomalous. unlabeled_data_value: The value used in the label column to denote unlabeled data. alpha: This value is used to adjust the influence of the pseudo labeled @@ -238,9 +238,11 @@ def pseudo_label(self, Returns: A sequence including updated features (features for which we now have labels for), updated labels (includes pseudo labeled positive and negative - values, as well as ground truth), and the weights (correct alpha values) - for the new pseudo labeled data points. Labels are in the format of 1 for - positive and 0 for negative. + values, as well as ground truth), the weights (correct alpha values) + for the new pseudo labeled data points, and a binary flag that indicates + whether the data point is newly pseudolabeled, or ground truth. Labels are + in the format of 1 for positive and 0 for negative. Flag is 1 for + pseudo-labeled and 0 for ground truth. """ original_positive_idx = np.where(labels == positive_data_value)[0] original_negative_idx = np.where(labels == negative_data_value)[0] @@ -282,6 +284,15 @@ def pseudo_label(self, np.ones([len(original_negative_idx)]) ], axis=0) + pseudolabel_flags = np.concatenate( + [ + np.ones(len(new_positive_indices)), + np.ones(len(new_negative_indices)), + np.zeros(len(original_positive_idx)), + np.zeros(len(original_negative_idx)), + ], + axis=0, + ) if verbose: logging.info('Number of new positive labels: %s', @@ -289,4 +300,4 @@ def pseudo_label(self, logging.info('Number of new negative labels: %s', len(new_negative_indices)) - return new_features, new_labels, weights + return new_features, new_labels, weights, pseudolabel_flags diff --git a/spade_anomaly_detection/occ_ensemble_test.py b/spade_anomaly_detection/occ_ensemble_test.py index 9450db8..0baf11e 100644 --- a/spade_anomaly_detection/occ_ensemble_test.py +++ b/spade_anomaly_detection/occ_ensemble_test.py @@ -120,13 +120,15 @@ def test_score_unlabeled_data_no_error(self): np.where((labels == 0) | (labels == 1))[0] ) - updated_features, updated_labels, weights = ensemble_obj.pseudo_label( - features=features, - labels=labels, - alpha=alpha, - positive_data_value=positive_data_value, - negative_data_value=negative_data_value, - unlabeled_data_value=unlabeled_data_value, + updated_features, updated_labels, weights, pseudolabel_flags = ( + ensemble_obj.pseudo_label( + features=features, + labels=labels, + alpha=alpha, + positive_data_value=positive_data_value, + negative_data_value=negative_data_value, + unlabeled_data_value=unlabeled_data_value, + ) ) label_count_after_labeling = len( @@ -155,10 +157,28 @@ def test_score_unlabeled_data_no_error(self): msg='Label count after labeling was not more than before.', ) + with self.subTest(name='AlphaValuesCorrespondToPseudoLabels'): + # Note that this test will fail if the alpha value is 1.0 (the ground + # truth weight). + weights_are_alpha = np.where(weights == alpha)[0] + pseudolabel_flags_are_1 = np.where(pseudolabel_flags == 1)[0] + self.assertNDArrayNear( + weights_are_alpha, + pseudolabel_flags_are_1, + err=1e-6, + msg=( + 'The data samples where the weights are equal to the alpha ' + 'values are not the same as the samples where the pseudolabel ' + 'flags are equal to 1.' + ), + ) + with self.subTest(name='LabelFeatureArraysEqual'): self.assertLen(updated_labels, len(updated_features)) - with self.subTest(name='LabelWeightArraysEqual'): + with self.subTest(name='LabelWeightArraysEqualLen'): self.assertLen(updated_labels, len(weights)) + with self.subTest(name='PseudolabelWeightArraysEqualLen'): + self.assertLen(pseudolabel_flags, len(weights)) if __name__ == '__main__': diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index a40951f..ff774da 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -772,14 +772,16 @@ def run(self) -> None: ) logging.info('Labeling started.') - updated_features, updated_labels, weights = ensemble_object.pseudo_label( - features=train_x, - labels=train_y, - positive_data_value=self.runner_parameters.positive_data_value, - negative_data_value=self.runner_parameters.negative_data_value, - unlabeled_data_value=self.runner_parameters.unlabeled_data_value, - alpha=self.runner_parameters.alpha, - verbose=self.runner_parameters.verbose, + updated_features, updated_labels, weights, pseudolabel_flags = ( + ensemble_object.pseudo_label( + features=train_x, + labels=train_y, + positive_data_value=self.runner_parameters.positive_data_value, + negative_data_value=self.runner_parameters.negative_data_value, + unlabeled_data_value=self.runner_parameters.unlabeled_data_value, + alpha=self.runner_parameters.alpha, + verbose=self.runner_parameters.verbose, + ) ) logging.info('Labeling completed.') @@ -792,7 +794,10 @@ def run(self) -> None: data_loader.DataLoader, self.input_data_loader ) self.input_data_loader.upload_dataframe_as_bigquery_table( - features=updated_features, labels=updated_labels + features=updated_features, + labels=updated_labels, + weights=weights, + pseudolabel_flags=pseudolabel_flags, ) elif ( self.runner_parameters.data_output_gcs_uri @@ -805,6 +810,8 @@ def run(self) -> None: batch=batch_number, features=updated_features, labels=updated_labels, + weights=weights, + pseudolabel_flags=pseudolabel_flags, ) else: logging.info('No output path specified, skipping upload.') diff --git a/spade_anomaly_detection/runner_test.py b/spade_anomaly_detection/runner_test.py index 85d664b..cb9a506 100644 --- a/spade_anomaly_detection/runner_test.py +++ b/spade_anomaly_detection/runner_test.py @@ -437,6 +437,7 @@ def test_batch_sizing_no_error(self, mock_split, mock_pseudo_label): np.empty((1, 1)), np.empty((1, 1)), np.empty((1, 1)), + np.empty((1, 1)), ) runner_object = runner.Runner(self.runner_parameters) @@ -480,6 +481,7 @@ def test_preprocessing_pu_no_error(self, mock_pseudo_label): np.empty((1, 1)), np.empty((1, 1)), np.empty((1, 1)), + np.empty((1, 1)), ) self.runner_parameters.verbose = True self.runner_parameters.test_dataset_holdout_fraction = 0.2 @@ -517,6 +519,7 @@ def test_preprocessing_pnu_no_error(self, mock_pseudo_label): np.empty((1, 1)), np.empty((1, 1)), np.empty((1, 1)), + np.empty((1, 1)), ) self.runner_parameters.verbose = True self.runner_parameters.test_dataset_holdout_fraction = 0.2 @@ -565,6 +568,7 @@ def test_preprocessing_array_sizes_no_error( np.empty((1, 1)), np.empty((1, 1)), np.empty((1, 1)), + np.empty((1, 1)), ) self.runner_parameters.verbose = True self.runner_parameters.train_setting = train_setting