Skip to content

Commit 21b515e

Browse files
lintian06copybara-github
authored andcommitted
Create model_spec for recommendation models.
PiperOrigin-RevId: 351740256
1 parent 49431bc commit 21b515e

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

tensorflow_examples/lite/model_maker/core/task/model_spec/recommendation_spec.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,15 @@
1616
import functools
1717

1818
import tensorflow as tf # pylint: disable=unused-import
19-
20-
HAS_RECOMMENDATION = True
21-
try:
22-
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import recommendation_model as rm # pylint: disable=g-import-not-at-top
23-
except ImportError:
24-
HAS_RECOMMENDATION = False
19+
from tensorflow_examples.lite.model_maker.core import compat
20+
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import recommendation_model as _rm
2521

2622

2723
class RecommendationSpec(object):
2824
"""Recommendation model spec."""
2925

26+
compat_tf_versions = compat.get_compat_tf_versions(2)
27+
3028
def __init__(self,
3129
encoder_type='bow',
3230
context_embedding_dim=128,
@@ -36,7 +34,9 @@ def __init__(self,
3634
hidden_layer_dim_ratios=None,
3735
conv_num_filter_ratios=None,
3836
conv_kernel_size=None,
39-
lstm_num_units=None):
37+
lstm_num_units=None,
38+
eval_top_k=None,
39+
batch_size=16):
4040
"""Initialize spec.
4141
4242
Args:
@@ -51,15 +51,23 @@ def __init__(self,
5151
ratios based on context_embedding_dim.
5252
conv_kernel_size: int, for 'rnn', Conv1D layers' kernel size.
5353
lstm_num_units: int, for 'rnn', LSTM layer's unit number.
54+
eval_top_k: list of int, evaluation metrics for a list of top k.
55+
batch_size: int, default batch size.
5456
"""
5557
hidden_layer_dim_ratios = hidden_layer_dim_ratios or [1.0, 0.5, 0.25]
5658

59+
if encoder_type not in ('bow', 'cnn', 'rnn'):
60+
raise ValueError('Not valid encoder_type: {}'.format(encoder_type))
61+
5762
if encoder_type == 'cnn':
5863
conv_num_filter_ratios = conv_num_filter_ratios or [2, 4]
5964
conv_kernel_size = conv_kernel_size or 4
6065
elif encoder_type == 'rnn':
6166
lstm_num_units = lstm_num_units or 16
6267

68+
if eval_top_k is None:
69+
eval_top_k = [1, 5, 10]
70+
6371
self.encoder_type = encoder_type
6472
self.context_embedding_dim = context_embedding_dim
6573
self.label_embedding_dim = label_embedding_dim
@@ -69,8 +77,10 @@ def __init__(self,
6977
self.conv_num_filter_ratios = conv_num_filter_ratios
7078
self.conv_kernel_size = conv_kernel_size
7179
self.lstm_num_units = lstm_num_units
80+
self.eval_top_k = eval_top_k
81+
self.batch_size = batch_size
7282

73-
self._params = {
83+
self.params = {
7484
'encoder_type': encoder_type,
7585
'context_embedding_dim': context_embedding_dim,
7686
'label_embedding_dim': label_embedding_dim,
@@ -80,13 +90,16 @@ def __init__(self,
8090
'conv_num_filter_ratios': conv_num_filter_ratios,
8191
'conv_kernel_size': conv_kernel_size,
8292
'lstm_num_units': lstm_num_units,
93+
'eval_top_k': eval_top_k,
8394
}
8495

8596
def create_model(self):
86-
"""Creates recommendation model based on params."""
87-
if not HAS_RECOMMENDATION:
88-
return None
89-
return rm.RecommendationModel(self._params)
97+
"""Creates recommendation model based on params.
98+
99+
Returns:
100+
Keras model.
101+
"""
102+
return _rm.RecommendationModel(self.params)
90103

91104

92105
recommendation_bow_spec = functools.partial(

tensorflow_examples/lite/model_maker/core/task/model_spec/recommendation_spec_test.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tensorflow.compat.v2 as tf
1818

1919
from tensorflow_examples.lite.model_maker.core.task.model_spec import recommendation_spec
20+
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import recommendation_model as _rm
2021

2122

2223
class RecommendationSpecTest(tf.test.TestCase, parameterized.TestCase):
@@ -29,10 +30,7 @@ class RecommendationSpecTest(tf.test.TestCase, parameterized.TestCase):
2930
def test_create_recommendation_model(self, encoder_type):
3031
spec = recommendation_spec.RecommendationSpec(encoder_type)
3132
model = spec.create_model()
32-
if recommendation_spec.HAS_RECOMMENDATION:
33-
self.assertIsInstance(model, recommendation_spec.rm.RecommendationModel)
34-
else:
35-
self.assertIsNone(model)
33+
self.assertIsInstance(model, _rm.RecommendationModel)
3634

3735

3836
if __name__ == '__main__':

0 commit comments

Comments
 (0)