Skip to content

Commit fa831d3

Browse files
lintian06copybara-github
authored andcommitted
Refactor and extract functions from recommendation code.
PiperOrigin-RevId: 351467621
1 parent 7f14caa commit fa831d3

File tree

2 files changed

+66
-37
lines changed

2 files changed

+66
-37
lines changed

lite/examples/recommendation/ml/model/recommendation_model_launcher_keras.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,28 @@ def on_epoch_end(self, epoch, logs=None):
7676
self.checkpoint_manager.save(checkpoint_number=step_counter)
7777

7878

79-
def get_input_fn(data_filepattern, batch_size):
80-
"""Get input_fn for recommendation model estimator."""
79+
class InputFn:
80+
"""InputFn for recommendation model estimator."""
81+
82+
def __init__(self,
83+
data_filepattern,
84+
batch_size=10,
85+
shuffle=True,
86+
repeat=True):
87+
"""Init input fn.
88+
89+
Args:
90+
data_filepattern: str, file pattern. (allow wild match like ? and *)
91+
batch_size: int, batch examples with a given size.
92+
shuffle: boolean, whether to shuffle examples.
93+
repeat: boolean, whether to repeat examples.
94+
"""
95+
self.data_filepattern = data_filepattern
96+
self.batch_size = batch_size
97+
self.shuffle = shuffle
98+
self.repeat = repeat
8199

100+
@staticmethod
82101
def decode_example(serialized_proto):
83102
"""Decode single serialized example."""
84103
name_to_features = dict(
@@ -98,26 +117,28 @@ def decode_example(serialized_proto):
98117
features['label'] = record_features['label']
99118
return features, record_features['label']
100119

101-
def input_fn():
120+
@staticmethod
121+
def read_dataset(data_filepattern):
122+
input_files = utils.GetShardFilenames(data_filepattern)
123+
return tf.data.TFRecordDataset(input_files)
124+
125+
def __call__(self):
102126
"""An input_fn satisfying the TF estimator spec.
103127
104128
Returns:
105129
a Dataset where each element is a batch of `features` dicts, passed to the
106130
Estimator model_fn.
107-
108131
"""
109-
input_files = utils.GetShardFilenames(data_filepattern)
110-
d = tf.data.TFRecordDataset(input_files)
111-
d.shuffle(len(input_files))
112-
d = d.repeat()
113-
d = d.shuffle(buffer_size=100)
114-
d = d.map(decode_example)
115-
d = d.batch(batch_size, drop_remainder=True)
116-
d = d.prefetch(1)
132+
d = self.read_dataset(self.data_filepattern)
133+
if self.repeat:
134+
d = d.repeat()
135+
if self.shuffle:
136+
buffer_size = max(3 * self.batch_size, 100)
137+
d = d.shuffle(buffer_size=buffer_size)
138+
d = d.map(self.decode_example)
139+
d = d.batch(self.batch_size, drop_remainder=True)
117140
return d
118141

119-
return input_fn
120-
121142

122143
def _get_optimizer(learning_rate, gradient_clip_norm=None):
123144
"""Gets model optimizer."""
@@ -139,15 +160,19 @@ def _get_metrics(eval_top_k):
139160
return metrics_list
140161

141162

142-
def build_keras_model(params):
143-
"""Construct and compile recommendation keras model."""
144-
model = recommendation_model.RecommendationModel(params)
163+
def compile_model(model, params, learning_rate, gradient_clip_norm):
164+
"""Compile keras model."""
145165
model.compile(
146166
optimizer=_get_optimizer(
147-
learning_rate=FLAGS.learning_rate,
148-
gradient_clip_norm=FLAGS.gradient_clip_norm),
167+
learning_rate=learning_rate, gradient_clip_norm=gradient_clip_norm),
149168
loss=losses.GlobalSoftmax(),
150169
metrics=_get_metrics(params['eval_top_k']))
170+
171+
172+
def build_keras_model(params, learning_rate, gradient_clip_norm):
173+
"""Construct and compile recommendation keras model."""
174+
model = recommendation_model.RecommendationModel(params)
175+
compile_model(model, params, learning_rate, gradient_clip_norm)
151176
return model
152177

153178

@@ -185,7 +210,7 @@ def train_and_eval(model, model_dir, train_input_fn, eval_input_fn,
185210
return model
186211

187212

188-
def export(checkpoint_path, export_dir, params):
213+
def export(checkpoint_path, export_dir, params, max_history_length):
189214
"""Export savedmodel."""
190215
model = recommendation_model.RecommendationModel(params)
191216
checkpoint = tf.train.Checkpoint(model=model)
@@ -194,7 +219,7 @@ def export(checkpoint_path, export_dir, params):
194219
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
195220
model.serve.get_concrete_function(
196221
input_context=tf.TensorSpec(
197-
shape=[FLAGS.max_history_length],
222+
shape=[max_history_length],
198223
dtype=tf.dtypes.int32,
199224
name='context'))
200225
}
@@ -226,12 +251,12 @@ def main(_):
226251
params['num_predictions'] = FLAGS.num_predictions
227252

228253
logger.info('Setting up train and eval input_fns.')
229-
train_input_fn = get_input_fn(FLAGS.training_data_filepattern,
230-
FLAGS.batch_size)
231-
eval_input_fn = get_input_fn(FLAGS.testing_data_filepattern, FLAGS.batch_size)
254+
train_input_fn = InputFn(FLAGS.training_data_filepattern, FLAGS.batch_size)
255+
eval_input_fn = InputFn(FLAGS.testing_data_filepattern, FLAGS.batch_size)
232256

233257
logger.info('Build keras model for mode: {}.'.format(FLAGS.run_mode))
234-
model = build_keras_model(params=params)
258+
model = build_keras_model(params, FLAGS.learning_rate,
259+
FLAGS.gradient_clip_norm)
235260

236261
if FLAGS.run_mode == 'train_and_eval':
237262
train_and_eval(
@@ -248,7 +273,8 @@ def main(_):
248273
export(
249274
checkpoint_path=FLAGS.checkpoint_path,
250275
export_dir=export_dir,
251-
params=params)
276+
params=params,
277+
max_history_length=FLAGS.max_history_length)
252278
logger.info('Converting model to tflite model.')
253279
export_tflite(export_dir)
254280

lite/examples/recommendation/ml/model/recommendation_model_launcher_keras_test.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ def setUp(self):
7878
def testModelFnTrainModeExecute(self):
7979
"""Verifies that 'model_fn' can be executed in train and eval mode."""
8080
self.params['encoder_type'] = FLAGS.encoder_type
81-
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern,
82-
FLAGS.batch_size)
83-
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern,
84-
FLAGS.batch_size)
85-
model = launcher.build_keras_model(params=self.params)
81+
train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern,
82+
FLAGS.batch_size)
83+
eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern,
84+
FLAGS.batch_size)
85+
model = launcher.build_keras_model(self.params, FLAGS.learning_rate,
86+
FLAGS.gradient_clip_norm)
8687
launcher.train_and_eval(
8788
model=model,
8889
model_dir=FLAGS.model_dir,
@@ -99,11 +100,12 @@ def testModelFnExportModeExecute(self):
99100
"""Verifies model can be exported to savedmodel and tflite model."""
100101
self.params['encoder_type'] = FLAGS.encoder_type
101102
self.params['num_predictions'] = FLAGS.num_predictions
102-
train_input_fn = launcher.get_input_fn(FLAGS.training_data_filepattern,
103-
FLAGS.batch_size)
104-
eval_input_fn = launcher.get_input_fn(FLAGS.testing_data_filepattern,
105-
FLAGS.batch_size)
106-
model = launcher.build_keras_model(params=self.params)
103+
train_input_fn = launcher.InputFn(FLAGS.training_data_filepattern,
104+
FLAGS.batch_size)
105+
eval_input_fn = launcher.InputFn(FLAGS.testing_data_filepattern,
106+
FLAGS.batch_size)
107+
model = launcher.build_keras_model(self.params, FLAGS.learning_rate,
108+
FLAGS.gradient_clip_norm)
107109
launcher.train_and_eval(
108110
model=model,
109111
model_dir=FLAGS.model_dir,
@@ -117,7 +119,8 @@ def testModelFnExportModeExecute(self):
117119
launcher.export(
118120
checkpoint_path=latest_checkpoint,
119121
export_dir=export_dir,
120-
params=self.params)
122+
params=self.params,
123+
max_history_length=FLAGS.max_history_length)
121124
savedmodel_path = os.path.join(export_dir, 'saved_model.pb')
122125
self.assertTrue(os.path.exists(savedmodel_path))
123126
imported = tf.saved_model.load(export_dir, tags=None)

0 commit comments

Comments
 (0)