Skip to content

Commit 49431bc

Browse files
lintian06copybara-github
authored andcommitted
Create recommendation_dataloader, and use the original script as the third_party code.
PiperOrigin-RevId: 351733952
1 parent 49c8f0b commit 49431bc

22 files changed

+2364
-6
lines changed

lite/examples/recommendation/ml/data/example_generation_movielens.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,14 @@ def write_vocab_json(vocab_movies, filename):
184184
json.dump(vocab_movies, jsonfile, indent=2)
185185

186186

187-
def generate_datasets(data_dir, output_dir, min_timeline_length,
188-
max_context_length, build_movie_vocab):
187+
def generate_datasets(data_dir,
188+
output_dir,
189+
min_timeline_length,
190+
max_context_length,
191+
build_movie_vocab,
192+
train_filename=OUTPUT_TRAINING_DATA_FILENAME,
193+
test_filename=OUTPUT_TESTING_DATA_FILENAME,
194+
vocab_filename=OUTPUT_MOVIE_VOCAB_FILENAME):
189195
"""Generates train and test datasets as TFRecord, and returns stats."""
190196
if not tf.io.gfile.exists(data_dir):
191197
tf.io.gfile.makedirs(data_dir)
@@ -200,9 +206,9 @@ def generate_datasets(data_dir, output_dir, min_timeline_length,
200206

201207
if not tf.io.gfile.exists(output_dir):
202208
tf.io.gfile.makedirs(output_dir)
203-
train_file = os.path.join(output_dir, OUTPUT_TRAINING_DATA_FILENAME)
209+
train_file = os.path.join(output_dir, train_filename)
204210
train_size = write_tfrecords(tf_examples=train_examples, filename=train_file)
205-
test_file = os.path.join(output_dir, OUTPUT_TESTING_DATA_FILENAME)
211+
test_file = os.path.join(output_dir, test_filename)
206212
test_size = write_tfrecords(tf_examples=test_examples, filename=test_file)
207213
stats = {
208214
"train_size": train_size,
@@ -213,7 +219,7 @@ def generate_datasets(data_dir, output_dir, min_timeline_length,
213219
if build_movie_vocab:
214220
vocab_movies = generate_sorted_movie_vocab(
215221
movies_df=movies_df, movie_counts=movie_counts)
216-
vocab_file = os.path.join(output_dir, OUTPUT_MOVIE_VOCAB_FILENAME)
222+
vocab_file = os.path.join(output_dir, vocab_filename)
217223
write_vocab_json(vocab_movies=vocab_movies, filename=vocab_file)
218224
stats.update(vocab_size=len(vocab_movies), vocab_file=vocab_file)
219225
return stats
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Recommendation dataloader class."""
15+
16+
import json
17+
import os
18+
19+
import tensorflow as tf
20+
21+
from tensorflow_examples.lite.model_maker.core import file_util
22+
from tensorflow_examples.lite.model_maker.core.data_util import dataloader
23+
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.data import example_generation_movielens as _gen
24+
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import recommendation_model_launcher_keras as _launcher
25+
26+
27+
class RecommendationDataLoader(dataloader.DataLoader):
28+
"""Recommendation data loader."""
29+
30+
def __init__(self, dataset, size, vocab_file):
31+
"""Init data loader.
32+
33+
Dataset is tf.data.Dataset of examples, containing:
34+
for inputs:
35+
- 'context': int64[], context ids as the input of variable length.
36+
for outputs:
37+
- 'label': int64[1], label id to predict.
38+
where context is controlled by `max_context_length` in generating examples.
39+
40+
The vocab file should be json format of: a list of list[size=4], where the 4
41+
elements are ordered as:
42+
[id=int, title=str, genres=str joined with '|', count=int]
43+
44+
Args:
45+
dataset: tf.data.Dataset for recommendation.
46+
size: int, dataset size.
47+
vocab_file: str, vocab file in json format.
48+
"""
49+
super(RecommendationDataLoader, self).__init__(dataset, size)
50+
self.vocab_file = vocab_file
51+
52+
def gen_dataset(self,
53+
batch_size=1,
54+
is_training=False,
55+
shuffle=False,
56+
input_pipeline_context=None,
57+
preprocess=None,
58+
drop_remainder=True):
59+
"""Generates dataset, and overwrites default drop_remainder = True."""
60+
return super(RecommendationDataLoader, self).gen_dataset(
61+
batch_size=batch_size,
62+
is_training=is_training,
63+
shuffle=shuffle,
64+
input_pipeline_context=input_pipeline_context,
65+
preprocess=preprocess,
66+
drop_remainder=drop_remainder,
67+
)
68+
69+
def split(self, fraction):
70+
return self._split(fraction, self.vocab_file)
71+
72+
def load_vocab_and_item_size(self):
73+
"""Loads vocab from file.
74+
75+
The vocab file should be json format of: a list of list[size=4], where the 4
76+
elements are ordered as:
77+
[id=int, title=str, genres=str joined with '|', count=int]
78+
It is generated when preparing movielens dataset.
79+
80+
Returns:
81+
vocab list: a list of vocab dict representing movies
82+
{
83+
'id': int,
84+
'title': str,
85+
'genres': list of str,
86+
'count': int,
87+
}
88+
item size: int, the max id of all vocab.
89+
"""
90+
with tf.io.gfile.GFile(self.vocab_file) as f:
91+
vocab_json = json.load(f)
92+
vocab = []
93+
for v in vocab_json:
94+
vocab.append({
95+
'id': v[0],
96+
'title': v[1],
97+
'genres': v[2].split('|'),
98+
'count': v[3],
99+
})
100+
item_size = max((v['id'] for v in vocab))
101+
return vocab, item_size
102+
103+
@staticmethod
104+
def read_as_dataset(filepattern):
105+
"""Reads file pattern as dataset."""
106+
dataset = _launcher.InputFn.read_dataset(filepattern)
107+
return dataset.map(
108+
_launcher.InputFn.decode_example,
109+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
110+
111+
@classmethod
112+
def _prepare_movielens_datasets(cls,
113+
raw_data_dir,
114+
generated_dir,
115+
train_filename,
116+
test_filename,
117+
vocab_filename,
118+
meta_filename,
119+
min_timeline_length=3,
120+
max_context_length=10,
121+
build_movie_vocab=True):
122+
"""Prepare movielens datasets, and returns a dict contains meta."""
123+
train_file = os.path.join(generated_dir, train_filename)
124+
test_file = os.path.join(generated_dir, test_filename)
125+
meta_file = os.path.join(generated_dir, meta_filename)
126+
# Create dataset and meta, only if they are not existed.
127+
if not all([os.path.exists(f) for f in (train_file, test_file, meta_file)]):
128+
stats = _gen.generate_datasets(
129+
data_dir=raw_data_dir,
130+
output_dir=generated_dir,
131+
min_timeline_length=min_timeline_length,
132+
max_context_length=max_context_length,
133+
build_movie_vocab=build_movie_vocab,
134+
train_filename=train_filename,
135+
test_filename=test_filename,
136+
vocab_filename=vocab_filename,
137+
)
138+
file_util.write_json_file(meta_file, stats)
139+
meta = file_util.load_json_file(meta_file)
140+
return meta
141+
142+
@classmethod
143+
def from_movielens(cls,
144+
generated_dir,
145+
data_tag,
146+
raw_data_dir,
147+
min_timeline_length=3,
148+
max_context_length=10,
149+
build_movie_vocab=True,
150+
train_filename='train_movielens_1m.tfrecord',
151+
test_filename='test_movielens_1m.tfrecord',
152+
vocab_filename='movie_vocab.json',
153+
meta_filename='meta.json'):
154+
"""Generates data loader from movielens dataset.
155+
156+
The method downloads and prepares dataset, then generates for train/eval.
157+
158+
For `movielens` data format, see:
159+
- function `_generate_fake_data` in `recommendation_testutil.py`
160+
- Or, zip file: http://files.grouplens.org/datasets/movielens/ml-1m.zip
161+
162+
Args:
163+
generated_dir: str, path to generate preprocessed examples.
164+
data_tag: str, specify dataset in {'train', 'test'}.
165+
raw_data_dir: str, path to download raw data, and unzip.
166+
min_timeline_length: int, min timeline length to split train/eval set.
167+
max_context_length: int, max context length as the input.
168+
build_movie_vocab: boolean, whether to build movie vocab.
169+
train_filename: str, generated file name for training data.
170+
test_filename: str, generated file name for test data.
171+
vocab_filename: str, generated file name for vocab data.
172+
meta_filename: str, generated file name for meta data.
173+
174+
Returns:
175+
Data Loader.
176+
"""
177+
if data_tag not in ('train', 'test'):
178+
raise ValueError(
179+
'Expected data_tag is train or test, but got {}'.format(data_tag))
180+
meta = cls._prepare_movielens_datasets(
181+
raw_data_dir,
182+
generated_dir,
183+
train_filename=train_filename,
184+
test_filename=test_filename,
185+
vocab_filename=vocab_filename,
186+
meta_filename=meta_filename,
187+
min_timeline_length=min_timeline_length,
188+
max_context_length=max_context_length,
189+
build_movie_vocab=build_movie_vocab)
190+
if data_tag == 'train':
191+
ds = cls.read_as_dataset(meta['train_file'])
192+
return cls(ds, meta['train_size'], meta['vocab_file'])
193+
elif data_tag == 'test':
194+
ds = cls.read_as_dataset(meta['test_file'])
195+
return cls(ds, meta['test_size'], meta['vocab_file'])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for Recommendation dataloader."""
15+
16+
import os
17+
18+
import tensorflow.compat.v2 as tf
19+
from tensorflow_examples.lite.model_maker.core.data_util import recommendation_dataloader as _dl
20+
from tensorflow_examples.lite.model_maker.core.data_util import recommendation_testutil as _testutil
21+
22+
23+
class RecommendationDataLoaderTest(tf.test.TestCase):
24+
25+
def setUp(self):
26+
super().setUp()
27+
_testutil.setup_fake_testdata(self)
28+
29+
def test_prepare_movielens_datasets(self):
30+
loader = _dl.RecommendationDataLoader
31+
with _testutil.patch_download_and_extract_data(self.movielens_dir):
32+
stats = loader._prepare_movielens_datasets(
33+
self.test_tempdir, self.generated_dir, 'train.tfrecord',
34+
'test.tfrecord', 'movie_vocab.json', 'meta.json')
35+
self.assertDictContainsSubset(
36+
{
37+
'train_file': os.path.join(self.generated_dir, 'train.tfrecord'),
38+
'test_file': os.path.join(self.generated_dir, 'test.tfrecord'),
39+
'vocab_file': os.path.join(self.generated_dir, 'movie_vocab.json'),
40+
'train_size': _testutil.TRAIN_SIZE,
41+
'test_size': _testutil.TEST_SIZE,
42+
'vocab_size': _testutil.VOCAB_SIZE,
43+
}, stats)
44+
45+
self.assertTrue(os.path.exists(self.movielens_dir))
46+
self.assertGreater(len(os.listdir(self.movielens_dir)), 0)
47+
48+
meta_file = os.path.join(self.generated_dir, 'meta.json')
49+
self.assertTrue(os.path.exists(meta_file))
50+
51+
def test_from_movielens(self):
52+
with _testutil.patch_download_and_extract_data(self.movielens_dir):
53+
train_loader = _dl.RecommendationDataLoader.from_movielens(
54+
self.generated_dir, 'train', self.test_tempdir)
55+
test_loader = _dl.RecommendationDataLoader.from_movielens(
56+
self.generated_dir, 'test', self.test_tempdir)
57+
58+
self.assertEqual(len(train_loader), _testutil.TRAIN_SIZE)
59+
self.assertIsNotNone(train_loader._dataset)
60+
61+
self.assertEqual(len(test_loader), _testutil.TEST_SIZE)
62+
self.assertIsNotNone(test_loader._dataset)
63+
64+
def test_split(self):
65+
with _testutil.patch_download_and_extract_data(self.movielens_dir):
66+
test_loader = _dl.RecommendationDataLoader.from_movielens(
67+
self.generated_dir, 'test', self.test_tempdir)
68+
test0, test1 = test_loader.split(0.1)
69+
expected_size0 = int(0.1 * _testutil.TEST_SIZE)
70+
expected_size1 = _testutil.TEST_SIZE - expected_size0
71+
self.assertEqual(len(test0), expected_size0)
72+
self.assertIsNotNone(test0._dataset)
73+
74+
self.assertEqual(len(test1), expected_size1)
75+
self.assertIsNotNone(test1._dataset)
76+
77+
def test_load_vocab_and_item_size(self):
78+
with _testutil.patch_download_and_extract_data(self.movielens_dir):
79+
test_loader = _dl.RecommendationDataLoader.from_movielens(
80+
self.generated_dir, 'test', self.test_tempdir)
81+
vocab, item_size = test_loader.load_vocab_and_item_size()
82+
self.assertEqual(len(vocab), _testutil.VOCAB_SIZE)
83+
self.assertEqual(item_size, _testutil.ITEM_SIZE)
84+
85+
def test_gen_dataset(self):
86+
with _testutil.patch_download_and_extract_data(self.movielens_dir):
87+
test_loader = _dl.RecommendationDataLoader.from_movielens(
88+
self.generated_dir, 'test', self.test_tempdir)
89+
ds = test_loader.gen_dataset(10, is_training=False)
90+
self.assertIsInstance(ds, tf.data.Dataset)
91+
92+
93+
if __name__ == '__main__':
94+
tf.test.main()

0 commit comments

Comments
 (0)