@@ -104,6 +104,31 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
104
104
else self .runner_parameters .negative_threshold
105
105
)
106
106
107
+ def _get_record_count_based_on_labels (self , label_value : int ) -> int :
108
+ """Gets the number of records in the table.
109
+
110
+ Args:
111
+ label_value: The value of the label to use as the filter for records.
112
+
113
+ Returns:
114
+ The count of records.
115
+ """
116
+ label_record_count_filter = (
117
+ f'{ self .runner_parameters .label_col_name } = { label_value } '
118
+ )
119
+ if self .runner_parameters .where_statements :
120
+ label_record_count_where_statements = [
121
+ self .runner_parameters .where_statements
122
+ ] + [label_record_count_filter ]
123
+ else :
124
+ label_record_count_where_statements = [label_record_count_filter ]
125
+
126
+ label_record_count = self .data_loader .get_query_record_result_length (
127
+ input_path = self .runner_parameters .input_bigquery_table_path ,
128
+ where_statements = label_record_count_where_statements ,
129
+ )
130
+ return label_record_count
131
+
107
132
def check_data_tables (
108
133
self ,
109
134
total_record_count : int ,
@@ -166,12 +191,13 @@ def check_data_tables(
166
191
)
167
192
168
193
def instantiate_and_fit_ensemble (
169
- self , unlabeled_record_count : int
194
+ self , unlabeled_record_count : int , negative_record_count : int
170
195
) -> occ_ensemble .GmmEnsemble :
171
196
"""Creates and fits an OCC ensemble on the specified input data.
172
197
173
198
Args:
174
199
unlabeled_record_count: Number of unlabeled records in the table.
200
+ negative_record_count: Number of negative records in the table.
175
201
176
202
Returns:
177
203
A trained one class classifier ensemble.
@@ -183,7 +209,8 @@ def instantiate_and_fit_ensemble(
183
209
negative_threshold = self .runner_parameters .negative_threshold ,
184
210
)
185
211
186
- records_per_occ = unlabeled_record_count // ensemble_object .ensemble_count
212
+ training_record_count = unlabeled_record_count + negative_record_count
213
+ records_per_occ = training_record_count // ensemble_object .ensemble_count
187
214
batch_size = records_per_occ // self .runner_parameters .batches_per_model
188
215
batch_size = np .min ([batch_size , self .runner_parameters .max_occ_batch_size ])
189
216
@@ -195,7 +222,11 @@ def instantiate_and_fit_ensemble(
195
222
where_statements = self .runner_parameters .where_statements ,
196
223
ignore_columns = self .runner_parameters .ignore_columns ,
197
224
batch_size = batch_size ,
198
- label_column_filter_value = self .runner_parameters .unlabeled_data_value ,
225
+ # Train using negative labeled data and unlabeled data.
226
+ label_column_filter_value = [
227
+ self .runner_parameters .unlabeled_data_value ,
228
+ self .runner_parameters .negative_data_value ,
229
+ ],
199
230
)
200
231
201
232
logging .info ('Fitting ensemble.' )
@@ -527,20 +558,11 @@ def run(self) -> None:
527
558
where_statements = self .runner_parameters .where_statements ,
528
559
)
529
560
530
- unlabeled_record_count_filter = (
531
- f'{ self .runner_parameters .label_col_name } = '
532
- f'{ self .runner_parameters .unlabeled_data_value } '
561
+ unlabeled_record_count = self ._get_record_count_based_on_labels (
562
+ self .runner_parameters .unlabeled_data_value
533
563
)
534
- if self .runner_parameters .where_statements :
535
- unlabeled_record_count_where_statements = [
536
- self .runner_parameters .where_statements
537
- ] + [unlabeled_record_count_filter ]
538
- else :
539
- unlabeled_record_count_where_statements = [unlabeled_record_count_filter ]
540
-
541
- unlabeled_record_count = self .data_loader .get_query_record_result_length (
542
- input_path = self .runner_parameters .input_bigquery_table_path ,
543
- where_statements = unlabeled_record_count_where_statements ,
564
+ negative_record_count = self ._get_record_count_based_on_labels (
565
+ self .runner_parameters .negative_data_value
544
566
)
545
567
546
568
self .check_data_tables (
@@ -549,7 +571,8 @@ def run(self) -> None:
549
571
)
550
572
551
573
ensemble_object = self .instantiate_and_fit_ensemble (
552
- unlabeled_record_count = unlabeled_record_count
574
+ unlabeled_record_count = unlabeled_record_count ,
575
+ negative_record_count = negative_record_count ,
553
576
)
554
577
555
578
batch_size = (
@@ -615,10 +638,11 @@ def run(self) -> None:
615
638
)
616
639
617
640
if not self .runner_parameters .upload_only :
641
+ if self .supervised_model_object is None :
642
+ raise ValueError ('Supervised model was not created and trained.' )
618
643
self .evaluate_model ()
619
- if self .supervised_model_object is not None :
620
- self .supervised_model_object .save (
621
- save_location = self .runner_parameters .output_gcs_uri
622
- )
644
+ self .supervised_model_object .save (
645
+ save_location = self .runner_parameters .output_gcs_uri
646
+ )
623
647
624
648
logging .info ('SPADE training completed.' )
0 commit comments