22
22
import os
23
23
import shutil
24
24
import subprocess
25
- from typing import List , Tuple , Any , Optional
25
+ from typing import List , Tuple , Any , Optional , Type
26
26
27
27
from absl import flags
28
28
from absl import logging
@@ -830,7 +830,9 @@ def _synthetic_train_and_test(
830
830
test_numerical : Optional [bool ] = False ,
831
831
test_multidimensional_numerical : Optional [bool ] = False ,
832
832
test_categorical : Optional [bool ] = False ,
833
- test_categorical_set : Optional [bool ] = False ):
833
+ test_categorical_set : Optional [bool ] = False ,
834
+ label_shape : Optional [int ] = None ,
835
+ fit_raises : Optional [Type [Exception ]] = None ):
834
836
"""Trains a model on a synthetic dataset."""
835
837
836
838
train_path = os .path .join (self .get_temp_dir (), "train.rio.gz" )
@@ -868,12 +870,13 @@ def _synthetic_train_and_test(
868
870
popen .wait ()
869
871
870
872
feature_spec = {}
873
+ label_shape = [label_shape ] if label_shape else []
871
874
if task == keras .Task .CLASSIFICATION :
872
- feature_spec ["LABEL" ] = tf .io .FixedLenFeature ([] , tf .int64 )
875
+ feature_spec ["LABEL" ] = tf .io .FixedLenFeature (label_shape , tf .int64 )
873
876
elif task == keras .Task .REGRESSION :
874
- feature_spec ["LABEL" ] = tf .io .FixedLenFeature ([] , tf .float32 )
877
+ feature_spec ["LABEL" ] = tf .io .FixedLenFeature (label_shape , tf .float32 )
875
878
elif task == keras .Task .RANKING :
876
- feature_spec ["LABEL" ] = tf .io .FixedLenFeature ([] , tf .float32 )
879
+ feature_spec ["LABEL" ] = tf .io .FixedLenFeature (label_shape , tf .float32 )
877
880
feature_spec ["GROUP" ] = tf .io .FixedLenFeature ([], tf .string )
878
881
else :
879
882
assert False
@@ -964,8 +967,16 @@ def on_epoch_end(self, epoch, logs=None):
964
967
self .evaluation = model .evaluate (test_dataset )
965
968
966
969
callback = _TestEvalCallback ()
967
- history = model .fit (train_dataset , validation_data = test_dataset ,
968
- callbacks = [callback ])
970
+ history = None
971
+ if fit_raises is not None :
972
+ with self .assertRaises (fit_raises ):
973
+ model .fit (
974
+ train_dataset , validation_data = test_dataset , callbacks = [callback ])
975
+ else :
976
+ history = model .fit (
977
+ train_dataset , validation_data = test_dataset , callbacks = [callback ])
978
+ if history is None :
979
+ return
969
980
model .summary ()
970
981
971
982
train_evaluation = model .evaluate (train_dataset )
@@ -991,6 +1002,23 @@ def test_synthetic_classification_numerical(self):
991
1002
self ._synthetic_train_and_test (
992
1003
keras .Task .CLASSIFICATION , 0.8 , 0.72 , test_numerical = True )
993
1004
1005
+ def test_synthetic_classification_squeeze_label (self ):
1006
+ self ._synthetic_train_and_test (
1007
+ keras .Task .CLASSIFICATION ,
1008
+ 0.8 ,
1009
+ 0.72 ,
1010
+ test_numerical = True ,
1011
+ label_shape = 1 )
1012
+
1013
+ def test_synthetic_classification_squeeze_label_invalid_shape (self ):
1014
+ self ._synthetic_train_and_test (
1015
+ keras .Task .CLASSIFICATION ,
1016
+ 0.8 ,
1017
+ 0.72 ,
1018
+ test_numerical = True ,
1019
+ label_shape = 2 ,
1020
+ fit_raises = ValueError )
1021
+
994
1022
def test_synthetic_classification_categorical (self ):
995
1023
self ._synthetic_train_and_test (
996
1024
keras .Task .CLASSIFICATION , 0.95 , 0.70 , test_categorical = True )
0 commit comments