@@ -958,9 +958,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
958
958
epochs = fit_params .get ("epochs" , 1 )
959
959
batch_size = fit_params .pop ("batch_size" , 32 )
960
960
shuffle = fit_params .pop ("shuffle" , True )
961
+ buffer_size = fit_params .pop ("buffer_size" , None )
961
962
validation_data = fit_params .pop ("validation_data" , None )
962
963
validation_split = fit_params .pop ("validation_split" , 0. )
963
- validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
964
+ validation_batch_size = fit_params .get ("validation_batch_size" , batch_size )
964
965
965
966
# 2. Prepare datasets
966
967
@@ -998,8 +999,7 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
998
999
for dom in range (self .n_sources_ ))
999
1000
)
1000
1001
1001
- dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
1002
-
1002
+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
1003
1003
else :
1004
1004
dataset_src = X
1005
1005
@@ -1029,47 +1029,62 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
1029
1029
self ._initialize_networks ()
1030
1030
if isinstance (Xt , tf .data .Dataset ):
1031
1031
first_elem = next (iter (Xt ))
1032
- if (not isinstance (first_elem , tuple ) or
1033
- not len (first_elem )== 2 ):
1034
- raise ValueError ("When first argument is a dataset. "
1035
- "It should return (x, y) tuples." )
1032
+ if not isinstance (first_elem , tuple ):
1033
+ shape = first_elem .shape
1036
1034
else :
1037
1035
shape = first_elem [0 ].shape
1036
+ if self ._check_for_batch (Xt ):
1037
+ shape = shape [1 :]
1038
1038
else :
1039
1039
shape = Xt .shape [1 :]
1040
1040
self ._initialize_weights (shape )
1041
1041
1042
- # validation_data = self._check_validation_data(validation_data,
1043
- # validation_batch_size,
1044
- # shuffle)
1042
+
1043
+ # 3.5 Get datasets length
1044
+ self .length_src_ = self ._get_length_dataset (dataset_src , domain = "src" )
1045
+ self .length_tgt_ = self ._get_length_dataset (dataset_tgt , domain = "tgt" )
1046
+
1045
1047
1046
1048
# 4. Prepare validation dataset
1047
1049
if validation_data is None and validation_split > 0. :
1048
1050
if shuffle :
1049
- dataset_src = dataset_src .shuffle (buffer_size = 1024 )
1050
- frac = int (len (dataset_src )* validation_split )
1051
+ dataset_src = dataset_src .shuffle (buffer_size = self .length_src_ ,
1052
+ reshuffle_each_iteration = False )
1053
+ frac = int (self .length_src_ * validation_split )
1051
1054
validation_data = dataset_src .take (frac )
1052
1055
dataset_src = dataset_src .skip (frac )
1053
- validation_data = validation_data .batch (batch_size )
1056
+ if not self ._check_for_batch (validation_data ):
1057
+ validation_data = validation_data .batch (validation_batch_size )
1058
+
1059
+ if validation_data is not None :
1060
+ if isinstance (validation_data , tf .data .Dataset ):
1061
+ if not self ._check_for_batch (validation_data ):
1062
+ validation_data = validation_data .batch (validation_batch_size )
1054
1063
1064
+
1055
1065
# 5. Set datasets
1056
1066
# Same length for src and tgt + complete last batch + shuffle
1057
- try :
1058
- max_size = max (len (dataset_src ), len (dataset_tgt ))
1059
- max_size = np .ceil (max_size / batch_size ) * batch_size
1060
- repeat_src = np .ceil (max_size / len (dataset_src ))
1061
- repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1062
-
1063
- dataset_src = dataset_src .repeat (repeat_src )
1064
- dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1065
-
1066
- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1067
- except :
1068
- pass
1069
-
1070
1067
if shuffle :
1071
- dataset_src = dataset_src .shuffle (buffer_size = 1024 )
1072
- dataset_tgt = dataset_tgt .shuffle (buffer_size = 1024 )
1068
+ if buffer_size is None :
1069
+ dataset_src = dataset_src .shuffle (buffer_size = self .length_src_ ,
1070
+ reshuffle_each_iteration = True )
1071
+ dataset_tgt = dataset_tgt .shuffle (buffer_size = self .length_tgt_ ,
1072
+ reshuffle_each_iteration = True )
1073
+ else :
1074
+ dataset_src = dataset_src .shuffle (buffer_size = buffer_size ,
1075
+ reshuffle_each_iteration = True )
1076
+ dataset_tgt = dataset_tgt .shuffle (buffer_size = buffer_size ,
1077
+ reshuffle_each_iteration = True )
1078
+
1079
+ max_size = max (self .length_src_ , self .length_tgt_ )
1080
+ max_size = np .ceil (max_size / batch_size ) * batch_size
1081
+ repeat_src = np .ceil (max_size / self .length_src_ )
1082
+ repeat_tgt = np .ceil (max_size / self .length_tgt_ )
1083
+
1084
+ dataset_src = dataset_src .repeat (repeat_src ).take (max_size )
1085
+ dataset_tgt = dataset_tgt .repeat (repeat_tgt ).take (max_size )
1086
+
1087
+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1073
1088
1074
1089
# 5. Pretraining
1075
1090
if not hasattr (self , "pretrain_" ):
@@ -1097,14 +1112,14 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
1097
1112
pre_verbose = prefit_params .pop ("verbose" , verbose )
1098
1113
pre_epochs = prefit_params .pop ("epochs" , epochs )
1099
1114
pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
1100
- pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1101
1115
prefit_params .pop ("validation_data" , None )
1102
- prefit_params .pop ("validation_split" , None )
1103
- prefit_params .pop ("validation_batch_size" , None )
1104
1116
1105
1117
# !!! shuffle is already done
1106
- dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
1107
-
1118
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt ))
1119
+
1120
+ if not self ._check_for_batch (dataset ):
1121
+ dataset = dataset .batch (pre_batch_size )
1122
+
1108
1123
hist = super ().fit (dataset , validation_data = validation_data ,
1109
1124
epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1110
1125
@@ -1121,7 +1136,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
1121
1136
self .history_ = {}
1122
1137
1123
1138
# .7 Training
1124
- dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1139
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt ))
1140
+
1141
+ if not self ._check_for_batch (dataset ):
1142
+ dataset = dataset .batch (batch_size )
1125
1143
1126
1144
self .pretrain_ = False
1127
1145
@@ -1257,7 +1275,8 @@ def compile(self,
1257
1275
if "_" in name :
1258
1276
new_name = ""
1259
1277
for split in name .split ("_" ):
1260
- new_name += split [0 ]
1278
+ if len (split ) > 0 :
1279
+ new_name += split [0 ]
1261
1280
name = new_name
1262
1281
else :
1263
1282
name = name [:3 ]
@@ -1571,6 +1590,37 @@ def _initialize_weights(self, shape_X):
1571
1590
X_enc = self .encoder_ (np .zeros ((1 ,) + shape_X ))
1572
1591
if hasattr (self , "discriminator_" ):
1573
1592
self .discriminator_ (X_enc )
1593
+
1594
+
1595
+ def _get_length_dataset (self , dataset , domain = "src" ):
1596
+ try :
1597
+ length = len (dataset )
1598
+ except :
1599
+ if self .verbose :
1600
+ print ("Computing %s dataset size..." % domain )
1601
+ if not hasattr (self , "length_%s_" % domain ):
1602
+ length = 0
1603
+ for _ in dataset :
1604
+ length += 1
1605
+ else :
1606
+ length = getattr (self , "length_%s_" % domain )
1607
+ if self .verbose :
1608
+ print ("Done!" )
1609
+ return length
1610
+
1611
+
1612
+ def _check_for_batch (self , dataset ):
1613
+ if dataset .__class__ .__name__ == "BatchDataset" :
1614
+ return True
1615
+ if hasattr (dataset , "_input_dataset" ):
1616
+ return self ._check_for_batch (dataset ._input_dataset )
1617
+ elif hasattr (dataset , "_datasets" ):
1618
+ checks = []
1619
+ for data in dataset ._datasets :
1620
+ checks .append (self ._check_for_batch (data ))
1621
+ return np .all (checks )
1622
+ else :
1623
+ return False
1574
1624
1575
1625
1576
1626
def _unpack_data (self , data ):
0 commit comments