@@ -104,6 +104,31 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
104104 else self .runner_parameters .negative_threshold
105105 )
106106
107+ def _get_record_count_based_on_labels (self , label_value : int ) -> int :
108+ """Gets the number of records in the table.
109+
110+ Args:
111+ label_value: The value of the label to use as the filter for records.
112+
113+ Returns:
114+ The count of records.
115+ """
116+ label_record_count_filter = (
117+ f'{ self .runner_parameters .label_col_name } = { label_value } '
118+ )
119+ if self .runner_parameters .where_statements :
120+ label_record_count_where_statements = [
121+ self .runner_parameters .where_statements
122+ ] + [label_record_count_filter ]
123+ else :
124+ label_record_count_where_statements = [label_record_count_filter ]
125+
126+ label_record_count = self .data_loader .get_query_record_result_length (
127+ input_path = self .runner_parameters .input_bigquery_table_path ,
128+ where_statements = label_record_count_where_statements ,
129+ )
130+ return label_record_count
131+
107132 def check_data_tables (
108133 self ,
109134 total_record_count : int ,
@@ -166,12 +191,13 @@ def check_data_tables(
166191 )
167192
168193 def instantiate_and_fit_ensemble (
169- self , unlabeled_record_count : int
194+ self , unlabeled_record_count : int , negative_record_count : int
170195 ) -> occ_ensemble .GmmEnsemble :
171196 """Creates and fits an OCC ensemble on the specified input data.
172197
173198 Args:
174199 unlabeled_record_count: Number of unlabeled records in the table.
200+ negative_record_count: Number of negative records in the table.
175201
176202 Returns:
177203 A trained one class classifier ensemble.
@@ -183,7 +209,8 @@ def instantiate_and_fit_ensemble(
183209 negative_threshold = self .runner_parameters .negative_threshold ,
184210 )
185211
186- records_per_occ = unlabeled_record_count // ensemble_object .ensemble_count
212+ training_record_count = unlabeled_record_count + negative_record_count
213+ records_per_occ = training_record_count // ensemble_object .ensemble_count
187214 batch_size = records_per_occ // self .runner_parameters .batches_per_model
188215 batch_size = np .min ([batch_size , self .runner_parameters .max_occ_batch_size ])
189216
@@ -195,7 +222,11 @@ def instantiate_and_fit_ensemble(
195222 where_statements = self .runner_parameters .where_statements ,
196223 ignore_columns = self .runner_parameters .ignore_columns ,
197224 batch_size = batch_size ,
198- label_column_filter_value = self .runner_parameters .unlabeled_data_value ,
225+ # Train using negative labeled data and unlabeled data.
226+ label_column_filter_value = [
227+ self .runner_parameters .unlabeled_data_value ,
228+ self .runner_parameters .negative_data_value ,
229+ ],
199230 )
200231
201232 logging .info ('Fitting ensemble.' )
@@ -527,20 +558,11 @@ def run(self) -> None:
527558 where_statements = self .runner_parameters .where_statements ,
528559 )
529560
530- unlabeled_record_count_filter = (
531- f'{ self .runner_parameters .label_col_name } = '
532- f'{ self .runner_parameters .unlabeled_data_value } '
561+ unlabeled_record_count = self ._get_record_count_based_on_labels (
562+ self .runner_parameters .unlabeled_data_value
533563 )
534- if self .runner_parameters .where_statements :
535- unlabeled_record_count_where_statements = [
536- self .runner_parameters .where_statements
537- ] + [unlabeled_record_count_filter ]
538- else :
539- unlabeled_record_count_where_statements = [unlabeled_record_count_filter ]
540-
541- unlabeled_record_count = self .data_loader .get_query_record_result_length (
542- input_path = self .runner_parameters .input_bigquery_table_path ,
543- where_statements = unlabeled_record_count_where_statements ,
564+ negative_record_count = self ._get_record_count_based_on_labels (
565+ self .runner_parameters .negative_data_value
544566 )
545567
546568 self .check_data_tables (
@@ -549,7 +571,8 @@ def run(self) -> None:
549571 )
550572
551573 ensemble_object = self .instantiate_and_fit_ensemble (
552- unlabeled_record_count = unlabeled_record_count
574+ unlabeled_record_count = unlabeled_record_count ,
575+ negative_record_count = negative_record_count ,
553576 )
554577
555578 batch_size = (
@@ -615,10 +638,11 @@ def run(self) -> None:
615638 )
616639
617640 if not self .runner_parameters .upload_only :
641+ if self .supervised_model_object is None :
642+ raise ValueError ('Supervised model was not created and trained.' )
618643 self .evaluate_model ()
619- if self .supervised_model_object is not None :
620- self .supervised_model_object .save (
621- save_location = self .runner_parameters .output_gcs_uri
622- )
644+ self .supervised_model_object .save (
645+ save_location = self .runner_parameters .output_gcs_uri
646+ )
623647
624648 logging .info ('SPADE training completed.' )
0 commit comments