44
44
import tensorflow as tf
45
45
46
46
47
- # Types are from //cloud/ml/research/data_utils/feature_metadata.py
48
47
_FEATURES_TYPE : Final [str ] = 'FLOAT64'
49
48
_SOURCE_LABEL_TYPE : Final [str ] = 'STRING'
50
49
_SOURCE_LABEL_DEFAULT_VALUE : Final [str ] = '-1'
51
50
_LABEL_TYPE : Final [str ] = 'INT64'
51
+ _STRING_TO_INTEGER_LABEL_MAP : dict [str | int , int ] = {
52
+ 1 : 1 ,
53
+ 0 : 0 ,
54
+ - 1 : - 1 ,
55
+ '' : - 1 ,
56
+ '-1' : - 1 ,
57
+ '0' : 0 ,
58
+ '1' : 1 ,
59
+ 'positive' : 1 ,
60
+ 'negative' : 0 ,
61
+ 'unlabeled' : - 1 ,
62
+ }
52
63
53
64
# Setting the shuffle buffer size to 1M seems to be necessary to get the CSV
54
65
# reader to provide a diversity of data to the model.
@@ -167,12 +178,12 @@ def from_inputs_file(
167
178
raise ValueError (
168
179
f'Label column { label_column_name } not found in the header: { header } '
169
180
)
170
- num_features = len (all_columns ) - 1
171
181
features_types = [_FEATURES_TYPE ] * len (all_columns )
172
182
column_names_dict = collections .OrderedDict (
173
183
zip (all_columns , features_types )
174
184
)
175
185
column_names_dict [label_column_name ] = _SOURCE_LABEL_DEFAULT_VALUE
186
+ num_features = len (all_columns ) - 1
176
187
return ColumnNamesInfo (
177
188
column_names_dict = column_names_dict ,
178
189
header = header ,
@@ -216,6 +227,13 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
216
227
self .runner_parameters .negative_data_value ,
217
228
self .runner_parameters .unlabeled_data_value ,
218
229
]
230
+ # Add any labels that are not already in the map.
231
+ _STRING_TO_INTEGER_LABEL_MAP [self .runner_parameters .positive_data_value ] = 1
232
+ _STRING_TO_INTEGER_LABEL_MAP [self .runner_parameters .negative_data_value ] = 0
233
+ _STRING_TO_INTEGER_LABEL_MAP [
234
+ self .runner_parameters .unlabeled_data_value
235
+ ] = - 1
236
+
219
237
# Construct a label remap from string labels to integers. The table is not
220
238
# necessary for the case when the labels are all integers. But instead of
221
239
# checking if the labels are all integers, we construct the table and use
@@ -286,7 +304,8 @@ def get_inputs_metadata(
286
304
)
287
305
# Get information about the columns.
288
306
column_names_info = ColumnNamesInfo .from_inputs_file (
289
- csv_filenames [0 ], label_column_name
307
+ csv_filenames [0 ],
308
+ label_column_name ,
290
309
)
291
310
logging .info (
292
311
'Obtained metadata for data with CSV prefix %s (number of features=%d)' ,
@@ -360,22 +379,19 @@ def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disab
360
379
@classmethod
361
380
def convert_str_to_int (cls , value : str ) -> int :
362
381
"""Converts a string integer label to an integer label."""
363
- if isinstance (value , str ) and value .lstrip ('-' ).isdigit ():
364
- return int (value )
365
- elif isinstance (value , int ):
366
- return value
382
+ if value in _STRING_TO_INTEGER_LABEL_MAP :
383
+ return _STRING_TO_INTEGER_LABEL_MAP [value ]
367
384
else :
368
385
raise ValueError (
369
- f'Label { value } of type { type (value )} is not a string integer.'
386
+ f'Label { value } of type { type (value )} is not a string integer or '
387
+ 'mappable to an integer.'
370
388
)
371
389
372
390
@classmethod
373
391
def _get_label_remap_table (
374
392
cls , labels_mapping : dict [str , int ]
375
393
) -> tf .lookup .StaticHashTable :
376
394
"""Returns a label remap table that converts string labels to integers."""
377
- # The possible keys are '', '-1, '0', '1'. None is not included because the
378
- # Data Loader will default to '' if the label is None.
379
395
keys_tensor = tf .constant (
380
396
list (labels_mapping .keys ()),
381
397
dtype = tf .dtypes .as_dtype (_SOURCE_LABEL_TYPE .lower ()),
@@ -390,6 +406,14 @@ def _get_label_remap_table(
390
406
)
391
407
return label_remap_table
392
408
409
+ def remap_label (self , label : str | tf .Tensor ) -> int | tf .Tensor :
410
+ """Remaps the label to an integer."""
411
+ if isinstance (label , str ) or (
412
+ isinstance (label , tf .Tensor ) and label .dtype == tf .dtypes .string
413
+ ):
414
+ return self ._label_remap_table .lookup (label )
415
+ return label
416
+
393
417
def load_tf_dataset_from_csv (
394
418
self ,
395
419
input_path : str ,
@@ -441,6 +465,7 @@ def load_tf_dataset_from_csv(
441
465
self ._last_read_metadata .column_names_info .column_names_dict .values ()
442
466
)
443
467
]
468
+ logging .info ('column_defaults: %s' , column_defaults )
444
469
445
470
# Construct a single dataset out of multiple CSV files.
446
471
# TODO(sinharaj): Remove the determinism after testing.
@@ -456,7 +481,7 @@ def load_tf_dataset_from_csv(
456
481
na_value = '' ,
457
482
header = True ,
458
483
num_epochs = 1 ,
459
- shuffle = True ,
484
+ shuffle = False ,
460
485
shuffle_buffer_size = _SHUFFLE_BUFFER_SIZE ,
461
486
shuffle_seed = self .runner_parameters .random_seed ,
462
487
prefetch_buffer_size = tf .data .AUTOTUNE ,
@@ -473,17 +498,9 @@ def load_tf_dataset_from_csv(
473
498
'created.'
474
499
)
475
500
476
- def remap_label (label : str | tf .Tensor ) -> int | tf .Tensor :
477
- """Remaps the label to an integer."""
478
- if isinstance (label , str ) or (
479
- isinstance (label , tf .Tensor ) and label .dtype == tf .dtypes .string
480
- ):
481
- return self ._label_remap_table .lookup (label )
482
- return label
483
-
484
501
# The Dataset can have labels of type int or str. Cast them to int.
485
502
dataset = dataset .map (
486
- lambda features , label : (features , remap_label (label )),
503
+ lambda features , label : (features , self . remap_label (label )),
487
504
num_parallel_calls = tf .data .AUTOTUNE ,
488
505
deterministic = True ,
489
506
)
@@ -535,7 +552,6 @@ def combine_features_dict_into_tensor(
535
552
self ._label_counts = {
536
553
k : v .numpy () for k , v in self .counts_by_label (dataset ).items ()
537
554
}
538
- logging .info ('Label counts: %s' , self ._label_counts )
539
555
540
556
return dataset
541
557
@@ -554,11 +570,11 @@ def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:
554
570
555
571
@tf .function
556
572
def count_class (
557
- counts : Dict [int , int ], # Keys are always strings.
573
+ counts : Dict [int , int ],
558
574
batch : Tuple [tf .Tensor , tf .Tensor ],
559
575
) -> Dict [int , int ]:
560
576
_ , labels = batch
561
- # Keys are always strings.
577
+ labels = self . remap_label ( labels )
562
578
new_counts : Dict [int , int ] = counts .copy ()
563
579
for i in self .all_labels :
564
580
# This function is called after the Dataset is constructed and the
@@ -582,6 +598,59 @@ def count_class(
582
598
)
583
599
return counts
584
600
601
+ def counts_by_original_label (
602
+ self , dataset : tf .data .Dataset
603
+ ) -> tuple [dict [str , tf .Tensor ], dict [int , tf .Tensor ]]:
604
+ """Counts the number of samples in each label class in the dataset."""
605
+
606
+ all_int_labels = [l for l in self .all_labels if isinstance (l , int )]
607
+ logging .info ('all_int_labels: %s' , all_int_labels )
608
+ all_str_labels = [l for l in self .all_labels if isinstance (l , str )]
609
+ logging .info ('all_str_labels: %s' , all_str_labels )
610
+
611
+ @tf .function
612
+ def count_original_class (
613
+ counts : Dict [int | str , int ],
614
+ batch : Tuple [tf .Tensor , tf .Tensor ],
615
+ ) -> Dict [int | str , int ]:
616
+ keys_are_int = all (isinstance (k , int ) for k in counts .keys ())
617
+ if keys_are_int :
618
+ all_labels = all_int_labels
619
+ else :
620
+ all_labels = all_str_labels
621
+ _ , labels = batch
622
+ new_counts : Dict [int | str , int ] = counts .copy ()
623
+ for label in all_labels :
624
+ cc : tf .Tensor = tf .cast (labels == label , tf .int32 )
625
+ if label in list (new_counts .keys ()):
626
+ new_counts [label ] += tf .reduce_sum (cc )
627
+ else :
628
+ new_counts [label ] = tf .reduce_sum (cc )
629
+ return new_counts
630
+
631
+ int_keys_map = {
632
+ k : v
633
+ for k , v in _STRING_TO_INTEGER_LABEL_MAP .items ()
634
+ if isinstance (k , int )
635
+ }
636
+ initial_int_state = dict ((int (label ), 0 ) for label in int_keys_map .keys ())
637
+ if initial_int_state :
638
+ int_counts = dataset .reduce (
639
+ initial_state = initial_int_state , reduce_func = count_original_class
640
+ )
641
+ else :
642
+ int_counts = {}
643
+ str_keys_map = {
644
+ k : v
645
+ for k , v in _STRING_TO_INTEGER_LABEL_MAP .items ()
646
+ if isinstance (k , str )
647
+ }
648
+ initial_str_state = dict ((str (label ), 0 ) for label in str_keys_map .keys ())
649
+ str_counts = dataset .reduce (
650
+ initial_state = initial_str_state , reduce_func = count_original_class
651
+ )
652
+ return int_counts , str_counts
653
+
585
654
def get_label_thresholds (self ) -> Mapping [str , float ]:
586
655
"""Computes positive and negative thresholds based on label ratios.
587
656
0 commit comments