|
36 | 36 | """
|
37 | 37 |
|
38 | 38 | import enum
|
39 |
| -# TODO(b/247116870): Change to collections when Vertex supports python 3.9 |
40 | 39 | from typing import Mapping, Optional, Tuple, cast
|
41 | 40 |
|
42 | 41 | from absl import logging
|
|
49 | 48 | from spade_anomaly_detection import supervised_model
|
50 | 49 | import tensorflow as tf
|
51 | 50 |
|
| 51 | +# TODO(b/247116870): Change to collections when Vertex supports python 3.9 |
| 52 | + |
52 | 53 |
|
53 | 54 | @enum.unique
|
54 | 55 | class DataFormat(enum.Enum):
|
@@ -135,6 +136,7 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
|
135 | 136 | else:
|
136 | 137 | self.supervised_model_object = None
|
137 | 138 |
|
| 139 | + # If the thresholds are not set, use the thresholds from the input table. |
138 | 140 | if (
|
139 | 141 | self.runner_parameters.positive_threshold is None
|
140 | 142 | or self.runner_parameters.negative_threshold is None
|
@@ -760,7 +762,7 @@ def run(self) -> None:
|
760 | 762 | batch_size=1,
|
761 | 763 | )
|
762 | 764 | train_label_counts = self.input_data_loader.label_counts
|
763 |
| - # TODO(sinharaj): This is not ideal, we should not need to read the files |
| 765 | + # This is not ideal, we should not need to read the files |
764 | 766 | # again. Find a way to get the label counts without reading the files.
|
765 | 767 | # Assumes that data loader has already been used to read the input table.
|
766 | 768 | total_record_count = sum(train_label_counts.values())
|
@@ -885,6 +887,7 @@ def run(self) -> None:
|
885 | 887 | labels=updated_labels,
|
886 | 888 | weights=weights,
|
887 | 889 | )
|
| 890 | + # End of pseudolabeling and supervised model training loop. |
888 | 891 |
|
889 | 892 | if not self.runner_parameters.upload_only:
|
890 | 893 | self.evaluate_model()
|
|
0 commit comments