@@ -76,9 +76,28 @@ def on_epoch_end(self, epoch, logs=None):
76
76
self .checkpoint_manager .save (checkpoint_number = step_counter )
77
77
78
78
79
- def get_input_fn (data_filepattern , batch_size ):
80
- """Get input_fn for recommendation model estimator."""
79
+ class InputFn :
80
+ """InputFn for recommendation model estimator."""
81
+
82
+ def __init__ (self ,
83
+ data_filepattern ,
84
+ batch_size = 10 ,
85
+ shuffle = True ,
86
+ repeat = True ):
87
+ """Init input fn.
88
+
89
+ Args:
90
+ data_filepattern: str, file pattern. (allow wild match like ? and *)
91
+ batch_size: int, batch examples with a given size.
92
+ shuffle: boolean, whether to shuffle examples.
93
+ repeat: boolean, whether to repeat examples.
94
+ """
95
+ self .data_filepattern = data_filepattern
96
+ self .batch_size = batch_size
97
+ self .shuffle = shuffle
98
+ self .repeat = repeat
81
99
100
+ @staticmethod
82
101
def decode_example (serialized_proto ):
83
102
"""Decode single serialized example."""
84
103
name_to_features = dict (
@@ -98,26 +117,28 @@ def decode_example(serialized_proto):
98
117
features ['label' ] = record_features ['label' ]
99
118
return features , record_features ['label' ]
100
119
101
- def input_fn ():
120
+ @staticmethod
121
+ def read_dataset (data_filepattern ):
122
+ input_files = utils .GetShardFilenames (data_filepattern )
123
+ return tf .data .TFRecordDataset (input_files )
124
+
125
+ def __call__ (self ):
102
126
"""An input_fn satisfying the TF estimator spec.
103
127
104
128
Returns:
105
129
a Dataset where each element is a batch of `features` dicts, passed to the
106
130
Estimator model_fn.
107
-
108
131
"""
109
- input_files = utils . GetShardFilenames ( data_filepattern )
110
- d = tf . data . TFRecordDataset ( input_files )
111
- d . shuffle ( len ( input_files ) )
112
- d = d . repeat ()
113
- d = d . shuffle ( buffer_size = 100 )
114
- d = d .map ( decode_example )
115
- d = d .batch ( batch_size , drop_remainder = True )
116
- d = d .prefetch ( 1 )
132
+ d = self . read_dataset ( self . data_filepattern )
133
+ if self . repeat :
134
+ d = d . repeat ( )
135
+ if self . shuffle :
136
+ buffer_size = max ( 3 * self . batch_size , 100 )
137
+ d = d .shuffle ( buffer_size = buffer_size )
138
+ d = d .map ( self . decode_example )
139
+ d = d .batch ( self . batch_size , drop_remainder = True )
117
140
return d
118
141
119
- return input_fn
120
-
121
142
122
143
def _get_optimizer (learning_rate , gradient_clip_norm = None ):
123
144
"""Gets model optimizer."""
@@ -139,15 +160,19 @@ def _get_metrics(eval_top_k):
139
160
return metrics_list
140
161
141
162
142
- def build_keras_model (params ):
143
- """Construct and compile recommendation keras model."""
144
- model = recommendation_model .RecommendationModel (params )
163
+ def compile_model (model , params , learning_rate , gradient_clip_norm ):
164
+ """Compile keras model."""
145
165
model .compile (
146
166
optimizer = _get_optimizer (
147
- learning_rate = FLAGS .learning_rate ,
148
- gradient_clip_norm = FLAGS .gradient_clip_norm ),
167
+ learning_rate = learning_rate , gradient_clip_norm = gradient_clip_norm ),
149
168
loss = losses .GlobalSoftmax (),
150
169
metrics = _get_metrics (params ['eval_top_k' ]))
170
+
171
+
172
+ def build_keras_model (params , learning_rate , gradient_clip_norm ):
173
+ """Construct and compile recommendation keras model."""
174
+ model = recommendation_model .RecommendationModel (params )
175
+ compile_model (model , params , learning_rate , gradient_clip_norm )
151
176
return model
152
177
153
178
@@ -185,7 +210,7 @@ def train_and_eval(model, model_dir, train_input_fn, eval_input_fn,
185
210
return model
186
211
187
212
188
- def export (checkpoint_path , export_dir , params ):
213
+ def export (checkpoint_path , export_dir , params , max_history_length ):
189
214
"""Export savedmodel."""
190
215
model = recommendation_model .RecommendationModel (params )
191
216
checkpoint = tf .train .Checkpoint (model = model )
@@ -194,7 +219,7 @@ def export(checkpoint_path, export_dir, params):
194
219
tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY :
195
220
model .serve .get_concrete_function (
196
221
input_context = tf .TensorSpec (
197
- shape = [FLAGS . max_history_length ],
222
+ shape = [max_history_length ],
198
223
dtype = tf .dtypes .int32 ,
199
224
name = 'context' ))
200
225
}
@@ -226,12 +251,12 @@ def main(_):
226
251
params ['num_predictions' ] = FLAGS .num_predictions
227
252
228
253
logger .info ('Setting up train and eval input_fns.' )
229
- train_input_fn = get_input_fn (FLAGS .training_data_filepattern ,
230
- FLAGS .batch_size )
231
- eval_input_fn = get_input_fn (FLAGS .testing_data_filepattern , FLAGS .batch_size )
254
+ train_input_fn = InputFn (FLAGS .training_data_filepattern , FLAGS .batch_size )
255
+ eval_input_fn = InputFn (FLAGS .testing_data_filepattern , FLAGS .batch_size )
232
256
233
257
logger .info ('Build keras model for mode: {}.' .format (FLAGS .run_mode ))
234
- model = build_keras_model (params = params )
258
+ model = build_keras_model (params , FLAGS .learning_rate ,
259
+ FLAGS .gradient_clip_norm )
235
260
236
261
if FLAGS .run_mode == 'train_and_eval' :
237
262
train_and_eval (
@@ -248,7 +273,8 @@ def main(_):
248
273
export (
249
274
checkpoint_path = FLAGS .checkpoint_path ,
250
275
export_dir = export_dir ,
251
- params = params )
276
+ params = params ,
277
+ max_history_length = FLAGS .max_history_length )
252
278
logger .info ('Converting model to tflite model.' )
253
279
export_tflite (export_dir )
254
280
0 commit comments