Skip to content

Commit 66587d0

Browse files
authored
Merge pull request #24 from google-research/test_651030738
Add the ability to use CSV files on GCS as data input/output/test sources.
2 parents 26a1249 + 01c5247 commit 66587d0

18 files changed

+1894
-194
lines changed

CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
1818
* Define the new link url:
1919
`[2.0.0]: https://github.com/google-research/spade_anomaly_detection/compare/v1.0.0...v2.0.0`
2020
* Update the `[Unreleased]` url: `v1.0.0...HEAD` -> `v2.0.0...HEAD`
21-
21+
* If updating the PyPi version, also update the `__version__` variable in the
22+
`__init__.py` file at the root of the module.
2223
-->
2324

2425
## [Unreleased]
2526

27+
## [0.3.0] - 2024-07-10
28+
29+
* Add the ability to use CSV files on GCS as data input/output/test sources.
30+
* Miscellaneous bugfixes.
31+
2632
## [0.2.2] - 2024-05-19
2733

2834
* Update the OCC training to use negative and unlabeled samples for training.
@@ -39,7 +45,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
3945

4046
* Initial release
4147

42-
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...HEAD
48+
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...HEAD
49+
[0.3.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...v0.3.0
4350
[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2
4451
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1
4552
[0.2.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.1.0...v0.2.0

README.md

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ The metric reported by the pipeline is model [AUC](https://developers.google.com
6464

6565
<span style="color:yellow;background-color:lightgrey">**train_setting (string)**</span>: The 'PNU' (Positive, Negative, and Unlabeled) setting will train the supervised model using ground truth negative and positive data, as well as pseudo-labeled samples from the pseudo-labeler. The 'PU' (Positive and Unlabeled) setting will only use negative data from the pseudo-labeler along with the rest of the positive data (ground truth plus pseudo labeled) to train the supervised model. For model evaluation, we still require ground truth negative data to be in the BigQuery dataset, it just won't be used during training. Default is PNU.
6666

67-
<span style="color:red;background-color:lightgrey">input_bigquery_table_path (string)</span>: A BigQuery table path in the format 'project.dataset.table'. If this is the only BigQuery path provided, this will be used in conjunction with the test_dataset_holdout_fraction parameter to create a train/test split.
67+
<span style="color:red;background-color:lightgrey">input_bigquery_table_path (string)</span>: A BigQuery table path in the format 'project.dataset.table'. If this is the only BigQuery path provided, this will be used in conjunction with the test_dataset_holdout_fraction parameter to create a train/test split. Use exactly one of BigQuery locations or GCS locations.
6868

69-
<span style="color:red;background-color:lightgrey">output_gcs_uri (string)</span>: Cloud Storage location to store the supervised model assets. The location should be in the form gs://bucketname/foldername. A timestamp will be added to the end of the folder so that multiple runs of this won't overwrite previous runs.
69+
<span style="color:red;background-color:lightgrey">input_gcs_uri (string)</span>: Cloud Storage location to store the input data. If this is the only Cloud Storage location provided, this will be used in conjunction with test_dataset_holdout_fraction parameter to create a train/test split. Use exactly one of BigQuery locations or GCS locations.
70+
71+
<span style="color:red;background-color:lightgrey">output_gcs_uri (string)</span>: Cloud Storage location to store the supervised model assets. The location should be in the form gs://bucketname/foldername. A timestamp will be added to the end of the folder so that multiple runs of this won't overwrite previous runs. Supervised model assets are always stored to GCS.
7072

7173
<span style="color:red;background-color:lightgrey">label_col_name (string)</span>: The name of the label column in the input BigQuery table.
7274

@@ -89,6 +91,14 @@ one class classifier ensemble to label a point as negative. The higher this valu
8991

9092
<span style="color:yellow;background-color:lightgrey">test_dataset_holdout_fraction</span>: This setting is used if <span style="color:yellow;background-color:lightgrey">test_bigquery_table_path</span> is not provided. Float between 0 and 1 representing the fraction of samples to hold out as a test set. Default is 0.2, meaning 20% of the data is used for training. In the PU setting, this means that 20% of the positive labels and 100% of the negative data (Since we do not use any ground truth negative data for the supervised mode training) will be used for creating the test sets. For the PNU setting, it is just 20% of positive and negative samples, sampled uniformly at random, all other data would be used for training.
9193

94+
<span style="color:yellow;background-color:lightgrey">test_gcs_uri</span>: Cloud Storage location to store the CSV data to be used for evaluating the supervised model. Note that the positive and negative label values must also be the same in this testing set. It is okay to have your test labels in that form, or use 1 for positive and 0 for negative. Use exactly one of BigQuery locations or GCS locations.
95+
96+
<span style="color:yellow;background-color:lightgrey">upload_only</span>: Use this setting in conjunction with `output_bigquery_table_path` or `data_output_gcs_uri`. When `True`, the algorithm will just upload the pseudo labeled data to the specified table, and will skip training a supervised model. When set to `False`, the algorithm will also train a supervised model and upload it to a GCS location. Default is `False`.
97+
98+
<span style="color:yellow;background-color:lightgrey">output_bigquery_table_path</span>: A complete BigQuery path in the form of 'project.dataset.table' to be used for uploading the pseudo labeled data. This includes features and new labels. By default, we will use the column names from the input_bigquery_table_path BigQuery table. Use exactly one of BigQuery locations or GCS locations.
99+
100+
<span style="color:yellow;background-color:lightgrey">data_output_gcs_uri</span>: Cloud Storage location used for uploading the pseudo labeled data as CSV. This includes features and new labels. By default, we will use the column names from the data_input_gcs_uri table. Use exactly one of BigQuery locations or GCS locations.
101+
92102
<span style="color:yellow;background-color:lightgrey">alpha (float)</span>: Sample weights for weighting the loss function, only for pseudo-labeled data from the occ ensemble. Original data that is labeled will have a weight of 1.0. By default, we use alpha = 1.0.
93103

94104
<span style="color:yellow;background-color:lightgrey">ensemble_count</span>: Integer representing the number of one class classifiers in the ensemble used for pseudo labeling unlabeled data points. The more models in the ensemble, the less likely it is for all the models to gain consensus, and thus will reduce the amount of labeled data points. By default, we use 5 one class classifiers.
@@ -154,28 +164,37 @@ echo "Built and pushed ${IMAGE_URI_ML}"
154164
set -x
155165
156166
PROJECT_ID=${1:-"project_id"}
157-
#Args, maintain same order as runner.run and task.py.
158-
TRAIN_SETTING=${15:-"PNU"}
159-
160-
TRIAL_NAME="prod_spade_credit_${TRAIN_SETTING}_${USER}"
167+
DATETIME=$(date '+%Y%m%d_%H%M%S')
161168
162-
INPUT_BIGQUERY_TABLE_PATH=${2:-"bq_table_path"}
163-
OUTPUT_GCS_URI=${14:-"gs://[your-gcs-bucket]/models/model_experiment_$(date '+%Y%m%d_%H%M%S')"}
164-
LABEL_COL_NAME=${3:-"y"}
169+
# Give a unique name to your training job.
170+
TRIAL_NAME="spade_${USER}_${DATETIME}"
171+
172+
#Args
173+
TRAIN_SETTING=${2:-"PNU"}
174+
175+
# Use either Bigquery or GCS for input/output/test data.
176+
INPUT_BIGQUERY_TABLE_PATH=${3:-"${PROJECT_ID}.[bq-dataset].[bq-input-table]"}
177+
DATA_INPUT_GCS_URI=${4:-""}
178+
OUTPUT_BIGQUERY_TABLE_PATH=${5:-"${PROJECT_ID}.[bq-dataset].[bq-output-table]"}
179+
DATA_OUTPUT_GCS_URI=${6:-""}
180+
OUTPUT_GCS_URI=${7:-"gs://[gcs-bucket]/[model-folder]"}
181+
LABEL_COL_NAME=${8:-"y"}
165182
# The label column is of type float, these must match in order for array
166183
# filtering to work correctly.
167-
POSITIVE_DATA_VALUE=${4:-"1"}
168-
NEGATIVE_DATA_VALUE=${5:-"0"}
169-
UNLABELED_DATA_VALUE=${6:-"-1"}
170-
POSITIVE_THRESHOLD=${7:-".1"}
171-
NEGATIVE_THRESHOLD=${8:-"95"}
172-
TEST_BIGQUERY_TABLE_PATH=${16:-"table_path"}
173-
TEST_LABEL_COL_NAME=${17:-"y"}
174-
ALPHA=${10:-"1.0"}
175-
ENSEMBLE_COUNT=${12:-"5"}
176-
VERBOSE=${13:-"True"}
177-
178-
PROD_IMAGE_URI="us-docker.pkg.dev/[project_id]/spade-anomaly-detection/spade:latest"
184+
POSITIVE_DATA_VALUE=${9:-"1"}
185+
NEGATIVE_DATA_VALUE=${10:-"0"}
186+
UNLABELED_DATA_VALUE=${11:-"-1"}
187+
POSITIVE_THRESHOLD=${12:-".1"}
188+
NEGATIVE_THRESHOLD=${13:-"95"}
189+
TEST_BIGQUERY_TABLE_PATH=${14:-"${PROJECT_ID}.[bq-dataset].[bq-test-table]"}
190+
DATA_TEST_GCS_URI=${15:-""}
191+
TEST_LABEL_COL_NAME=${16:-"y"}
192+
ALPHA=${17:-"1.0"}
193+
ENSEMBLE_COUNT=${19:-"5"}
194+
VERBOSE=${22:-"True"}
195+
UPLOAD_ONLY=${23:-"False"}
196+
197+
IMAGE_URI="us-docker.pkg.dev/[project_id]/spade-anomaly-detection/spade:latest"
179198
180199
REGION="us-central1"
181200
@@ -187,21 +206,26 @@ gcloud ai custom-jobs create \
187206
--region="${REGION}" \
188207
--project="${PROJECT_ID}" \
189208
--display-name="${TRIAL_NAME}" \
190-
--worker-pool-spec="${WORKER_MACHINE}",replica-count=1,container-image-uri="${PROD_IMAGE_URI}" \
191-
--args=--train_setting="${TRAIN_SETTING}" \
192-
--args=--input_bigquery_table_path="${INPUT_BIGQUERY_TABLE_PATH}" \
193-
--args=--output_gcs_uri="${OUTPUT_GCS_URI}" \
194-
--args=--label_col_name="${LABEL_COL_NAME}" \
195-
--args=--positive_data_value="${POSITIVE_DATA_VALUE}" \
196-
--args=--negative_data_value="${NEGATIVE_DATA_VALUE}" \
197-
--args=--unlabeled_data_value="${UNLABELED_DATA_VALUE}" \
198-
--args=--positive_threshold="${POSITIVE_THRESHOLD}" \
199-
--args=--negative_threshold="${NEGATIVE_THRESHOLD}" \
200-
--args=--test_bigquery_table_path="${TEST_BIGQUERY_TABLE_PATH}" \
201-
--args=--test_label_col_name="${TEST_LABEL_COL_NAME}" \
202-
--args=--alpha="${ALPHA}" \
203-
--args=--ensemble_count="${ENSEMBLE_COUNT}" \
204-
--args=--verbose="${VERBOSE}"
209+
--worker-pool-spec="${WORKER_MACHINE}",replica-count=1,container-image-uri="${IMAGE_URI}" \
210+
--args=--train_setting="${TRAIN_SETTING}" \
211+
--args=--input_bigquery_table_path="${INPUT_BIGQUERY_TABLE_PATH}" \
212+
--args=--data_input_gcs_uri="${DATA_INPUT_GCS_URI}" \
213+
--args=--output_bigquery_table_path="${OUTPUT_BIGQUERY_TABLE_PATH}" \
214+
--args=--data_output_gcs_uri="${DATA_OUTPUT_GCS_URI}" \
215+
--args=--output_gcs_uri="${OUTPUT_GCS_URI}" \
216+
--args=--label_col_name="${LABEL_COL_NAME}" \
217+
--args=--positive_data_value="${POSITIVE_DATA_VALUE}" \
218+
--args=--negative_data_value="${NEGATIVE_DATA_VALUE}" \
219+
--args=--unlabeled_data_value="${UNLABELED_DATA_VALUE}" \
220+
--args=--positive_threshold="${POSITIVE_THRESHOLD}" \
221+
--args=--negative_threshold="${NEGATIVE_THRESHOLD}" \
222+
--args=--test_bigquery_table_path="${TEST_BIGQUERY_TABLE_PATH}" \
223+
--args=--data_test_gcs_uri="${DATA_TEST_GCS_URI}" \
224+
--args=--test_label_col_name="${TEST_LABEL_COL_NAME}" \
225+
--args=--alpha="${ALPHA}" \
226+
--args=--ensemble_count="${ENSEMBLE_COUNT}" \
227+
--args=--upload_only="${UPLOAD_ONLY}" \
228+
--args=--verbose="${VERBOSE}"
205229
~~~
206230

207231
## Example Datasets and their Licenses

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ dependencies = [
2323
"pyarrow==14.0.1",
2424
"retry==0.9.2",
2525
"scikit-learn==1.4.2",
26+
"tensorflow",
27+
"tensorflow-datasets==4.9.6",
2628
"parameterized==0.8.1",
2729
"pytest==7.1.2",
2830
"fastavro[codecs]==1.4.12",

spade_anomaly_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@
3131

3232
# A new PyPI release will be pushed every time `__version__` is increased.
3333
# When changing this, also update the CHANGELOG.md.
34-
__version__ = '0.2.2'
34+
__version__ = '0.3.0'

0 commit comments

Comments
 (0)