@@ -82,6 +82,8 @@ class Runner:
82
82
int_positive_data_value: The integer value of the positive data value.
83
83
int_negative_data_value: The integer value of the negative data value.
84
84
int_unlabeled_data_value: The integer value of the unlabeled data value.
85
+ train_label_counts: Dictionary of counts of labels in training data.
86
+ total_record_count: Total number of records in the training data.
85
87
"""
86
88
87
89
def __init__ (self , runner_parameters : parameters .RunnerParameters ):
@@ -95,13 +97,21 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
95
97
else :
96
98
self .data_format = DataFormat .CSV
97
99
100
+ self .train_label_counts : dict [int | str , int ] | None = None
101
+ self .total_record_count : int | None = None
98
102
if self .data_format == DataFormat .BIGQUERY :
99
103
# BigQuery data loaders are the same for input, output and test data.
100
104
self .input_data_loader = data_loader .DataLoader (self .runner_parameters )
101
105
# Type hint to prevent linter errors.
102
106
self .input_data_loader = cast (
103
107
data_loader .DataLoader , self .input_data_loader
104
108
)
109
+ self .total_record_count = (
110
+ self .input_data_loader .get_query_record_result_length (
111
+ input_path = self .runner_parameters .input_bigquery_table_path ,
112
+ where_statements = self .runner_parameters .where_statements ,
113
+ )
114
+ )
105
115
if not self .runner_parameters .upload_only :
106
116
self .test_data_loader = self .input_data_loader
107
117
else :
@@ -112,6 +122,29 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
112
122
self .input_data_loader = cast (
113
123
csv_data_loader .CsvDataLoader , self .input_data_loader
114
124
)
125
+ _ = self .input_data_loader .load_tf_dataset_from_csv (
126
+ input_path = self .runner_parameters .data_input_gcs_uri ,
127
+ label_col_name = self .runner_parameters .label_col_name ,
128
+ batch_size = 1 ,
129
+ )
130
+ self .train_label_counts = self .input_data_loader .label_counts
131
+ self .total_record_count = sum (self .train_label_counts .values ())
132
+ if (
133
+ self .runner_parameters .labeling_and_model_training_batch_size
134
+ and self .runner_parameters .labeling_and_model_training_batch_size
135
+ > self .total_record_count
136
+ ):
137
+ self .runner_parameters .labeling_and_model_training_batch_size = (
138
+ self .total_record_count
139
+ )
140
+ logging .info (
141
+ 'Labeling and model training batch size is reduced to %s' ,
142
+ self .runner_parameters .labeling_and_model_training_batch_size ,
143
+ )
144
+ logging .info (
145
+ 'Initial label counts (before supervised training): %s' ,
146
+ self .train_label_counts ,
147
+ )
115
148
if not self .runner_parameters .upload_only :
116
149
self .test_data_loader = csv_data_loader .CsvDataLoader (
117
150
self .runner_parameters
@@ -182,15 +215,7 @@ def _get_table_statistics(self) -> Mapping[str, float]:
182
215
self .runner_parameters .input_bigquery_table_path
183
216
)
184
217
else :
185
- stats_data_loader = csv_data_loader .CsvDataLoader (self .runner_parameters )
186
- # Type hint to prevent linter errors.
187
- stats_data_loader = cast (csv_data_loader .CsvDataLoader , stats_data_loader )
188
- _ = stats_data_loader .load_tf_dataset_from_csv (
189
- input_path = self .runner_parameters .data_input_gcs_uri ,
190
- label_col_name = self .runner_parameters .label_col_name ,
191
- batch_size = 1 ,
192
- )
193
- input_table_statistics = stats_data_loader .get_label_thresholds ()
218
+ input_table_statistics = self .input_data_loader .get_label_thresholds ()
194
219
logging .info ('Input table statistics: %s' , input_table_statistics )
195
220
return input_table_statistics
196
221
@@ -751,52 +776,7 @@ def run(self) -> None:
751
776
logging .info ('SPADE training started.' )
752
777
753
778
self ._check_runner_parameters ()
754
-
755
- if self .data_format == DataFormat .BIGQUERY :
756
- # Type hint to prevent linter errors.
757
- self .input_data_loader = cast (
758
- data_loader .DataLoader , self .input_data_loader
759
- )
760
- total_record_count = (
761
- self .input_data_loader .get_query_record_result_length (
762
- input_path = self .runner_parameters .input_bigquery_table_path ,
763
- where_statements = self .runner_parameters .where_statements ,
764
- )
765
- )
766
- else :
767
- # Type hint to prevent linter errors.
768
- self .input_data_loader = cast (
769
- csv_data_loader .CsvDataLoader , self .input_data_loader
770
- )
771
- # Call the data loader to read all the files. This is needed to get the
772
- # label counts.
773
- _ = self .input_data_loader .load_tf_dataset_from_csv (
774
- input_path = self .runner_parameters .data_input_gcs_uri ,
775
- label_col_name = self .runner_parameters .label_col_name ,
776
- batch_size = 1 ,
777
- )
778
- train_label_counts = self .input_data_loader .label_counts
779
- # This is not ideal, we should not need to read the files
780
- # again. Find a way to get the label counts without reading the files.
781
- # Assumes that data loader has already been used to read the input table.
782
- total_record_count = sum (train_label_counts .values ())
783
- if (
784
- self .runner_parameters .labeling_and_model_training_batch_size
785
- and self .runner_parameters .labeling_and_model_training_batch_size
786
- > total_record_count
787
- ):
788
- self .runner_parameters .labeling_and_model_training_batch_size = (
789
- total_record_count
790
- )
791
- logging .info (
792
- 'Labeling and model training batch size is reduced to %s' ,
793
- self .runner_parameters .labeling_and_model_training_batch_size ,
794
- )
795
- logging .info (
796
- 'Label counts before supervised training: %s' , train_label_counts
797
- )
798
-
799
- logging .info ('Total record count: %s' , total_record_count )
779
+ logging .info ('Total record count: %s' , self .total_record_count )
800
780
unlabeled_record_count = self ._get_record_count_based_on_labels (
801
781
self .int_unlabeled_data_value
802
782
)
@@ -805,7 +785,7 @@ def run(self) -> None:
805
785
)
806
786
807
787
self .check_data_tables (
808
- total_record_count = total_record_count ,
788
+ total_record_count = self . total_record_count ,
809
789
unlabeled_record_count = unlabeled_record_count ,
810
790
)
811
791
@@ -816,7 +796,7 @@ def run(self) -> None:
816
796
817
797
batch_size = (
818
798
self .runner_parameters .labeling_and_model_training_batch_size
819
- or total_record_count
799
+ or self . total_record_count
820
800
)
821
801
if self .data_format == DataFormat .BIGQUERY :
822
802
self .input_data_loader = cast (
0 commit comments