@@ -76,9 +76,28 @@ def on_epoch_end(self, epoch, logs=None):
7676 self .checkpoint_manager .save (checkpoint_number = step_counter )
7777
7878
79- def get_input_fn (data_filepattern , batch_size ):
80- """Get input_fn for recommendation model estimator."""
79+ class InputFn :
80+ """InputFn for recommendation model estimator."""
81+
82+ def __init__ (self ,
83+ data_filepattern ,
84+ batch_size = 10 ,
85+ shuffle = True ,
86+ repeat = True ):
87+ """Init input fn.
88+
89+ Args:
90+ data_filepattern: str, file pattern. (allow wild match like ? and *)
91+ batch_size: int, batch examples with a given size.
92+ shuffle: boolean, whether to shuffle examples.
93+ repeat: boolean, whether to repeat examples.
94+ """
95+ self .data_filepattern = data_filepattern
96+ self .batch_size = batch_size
97+ self .shuffle = shuffle
98+ self .repeat = repeat
8199
100+ @staticmethod
82101 def decode_example (serialized_proto ):
83102 """Decode single serialized example."""
84103 name_to_features = dict (
@@ -98,26 +117,28 @@ def decode_example(serialized_proto):
98117 features ['label' ] = record_features ['label' ]
99118 return features , record_features ['label' ]
100119
101- def input_fn ():
120+ @staticmethod
121+ def read_dataset (data_filepattern ):
122+ input_files = utils .GetShardFilenames (data_filepattern )
123+ return tf .data .TFRecordDataset (input_files )
124+
125+ def __call__ (self ):
102126 """An input_fn satisfying the TF estimator spec.
103127
104128 Returns:
105129 a Dataset where each element is a batch of `features` dicts, passed to the
106130 Estimator model_fn.
107-
108131 """
109- input_files = utils . GetShardFilenames ( data_filepattern )
110- d = tf . data . TFRecordDataset ( input_files )
111- d . shuffle ( len ( input_files ) )
112- d = d . repeat ()
113- d = d . shuffle ( buffer_size = 100 )
114- d = d .map ( decode_example )
115- d = d .batch ( batch_size , drop_remainder = True )
116- d = d .prefetch ( 1 )
132+ d = self . read_dataset ( self . data_filepattern )
133+ if self . repeat :
134+ d = d . repeat ( )
135+ if self . shuffle :
136+ buffer_size = max ( 3 * self . batch_size , 100 )
137+ d = d .shuffle ( buffer_size = buffer_size )
138+ d = d .map ( self . decode_example )
139+ d = d .batch ( self . batch_size , drop_remainder = True )
117140 return d
118141
119- return input_fn
120-
121142
122143def _get_optimizer (learning_rate , gradient_clip_norm = None ):
123144 """Gets model optimizer."""
@@ -139,15 +160,19 @@ def _get_metrics(eval_top_k):
139160 return metrics_list
140161
141162
142- def build_keras_model (params ):
143- """Construct and compile recommendation keras model."""
144- model = recommendation_model .RecommendationModel (params )
163+ def compile_model (model , params , learning_rate , gradient_clip_norm ):
164+ """Compile keras model."""
145165 model .compile (
146166 optimizer = _get_optimizer (
147- learning_rate = FLAGS .learning_rate ,
148- gradient_clip_norm = FLAGS .gradient_clip_norm ),
167+ learning_rate = learning_rate , gradient_clip_norm = gradient_clip_norm ),
149168 loss = losses .GlobalSoftmax (),
150169 metrics = _get_metrics (params ['eval_top_k' ]))
170+
171+
172+ def build_keras_model (params , learning_rate , gradient_clip_norm ):
173+ """Construct and compile recommendation keras model."""
174+ model = recommendation_model .RecommendationModel (params )
175+ compile_model (model , params , learning_rate , gradient_clip_norm )
151176 return model
152177
153178
@@ -185,7 +210,7 @@ def train_and_eval(model, model_dir, train_input_fn, eval_input_fn,
185210 return model
186211
187212
188- def export (checkpoint_path , export_dir , params ):
213+ def export (checkpoint_path , export_dir , params , max_history_length ):
189214 """Export savedmodel."""
190215 model = recommendation_model .RecommendationModel (params )
191216 checkpoint = tf .train .Checkpoint (model = model )
@@ -194,7 +219,7 @@ def export(checkpoint_path, export_dir, params):
194219 tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY :
195220 model .serve .get_concrete_function (
196221 input_context = tf .TensorSpec (
197- shape = [FLAGS . max_history_length ],
222+ shape = [max_history_length ],
198223 dtype = tf .dtypes .int32 ,
199224 name = 'context' ))
200225 }
@@ -226,12 +251,12 @@ def main(_):
226251 params ['num_predictions' ] = FLAGS .num_predictions
227252
228253 logger .info ('Setting up train and eval input_fns.' )
229- train_input_fn = get_input_fn (FLAGS .training_data_filepattern ,
230- FLAGS .batch_size )
231- eval_input_fn = get_input_fn (FLAGS .testing_data_filepattern , FLAGS .batch_size )
254+ train_input_fn = InputFn (FLAGS .training_data_filepattern , FLAGS .batch_size )
255+ eval_input_fn = InputFn (FLAGS .testing_data_filepattern , FLAGS .batch_size )
232256
233257 logger .info ('Build keras model for mode: {}.' .format (FLAGS .run_mode ))
234- model = build_keras_model (params = params )
258+ model = build_keras_model (params , FLAGS .learning_rate ,
259+ FLAGS .gradient_clip_norm )
235260
236261 if FLAGS .run_mode == 'train_and_eval' :
237262 train_and_eval (
@@ -248,7 +273,8 @@ def main(_):
248273 export (
249274 checkpoint_path = FLAGS .checkpoint_path ,
250275 export_dir = export_dir ,
251- params = params )
276+ params = params ,
277+ max_history_length = FLAGS .max_history_length )
252278 logger .info ('Converting model to tflite model.' )
253279 export_tflite (export_dir )
254280
0 commit comments