Skip to content

Commit 3eb4051

Browse files
authored
Develop a reader to read records from recordio files (#2502)
* Fix the example of TensorFlow 1.x * Develop a reader to read records from recordio files * The master does not add the worker host if it is in the next group * Add a docstring to find which block the shard is in
1 parent 48d37bd commit 3eb4051

File tree

4 files changed

+101
-22
lines changed

4 files changed

+101
-22
lines changed

elasticai_api/io/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2021 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.

elasticai_api/io/recordio_reader.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2020 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import collections
15+
from contextlib import closing
16+
17+
import recordio
18+
19+
RecordBlock = collections.namedtuple("RecordBlock", ("name", "start", "end"))
20+
21+
22+
class RecordIOReader(object):
23+
def __init__(self, recordio_files):
24+
self._init_file_record_count(recordio_files)
25+
26+
def _init_file_record_count(self, recordio_files):
27+
self._data_blocks = []
28+
start = 0
29+
for file_path in recordio_files:
30+
with closing(recordio.Index(file_path)) as rio:
31+
num_records = rio.num_records()
32+
end = start + num_records
33+
self._data_blocks.append(RecordBlock(file_path, start, end))
34+
start = end
35+
36+
def read_records(self, start, end):
37+
target_files = self._get_record_file(start, end)
38+
for file_path, start, count in target_files:
39+
with closing(recordio.Scanner(file_path, start, count)) as reader:
40+
while True:
41+
record = reader.record()
42+
if record:
43+
yield record
44+
else:
45+
break
46+
47+
def _get_record_file(self, start, end):
48+
"""The block ranges in data_blocks are sorted in
49+
increasing order. For example,
50+
blocks are [[0,100),[100, 200),[200,300)]. So we
51+
can find which block the shard is in by sequential search.
52+
"""
53+
target_files = []
54+
for block in self._data_blocks:
55+
if start < block.end:
56+
if end < block.end:
57+
target_files.append(
58+
(block.name, start - block.start, end - start)
59+
)
60+
break
61+
else:
62+
target_files.append(
63+
(block.name, start - block.start, block.end - start)
64+
)
65+
start = block.end
66+
return target_files

elasticai_api/tensorflow/controller.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@
3737

3838

3939
def create_elastic_controller(
40-
batch_size,
41-
num_epochs=None,
42-
dataset_size=None,
43-
shuffle=False,
44-
training_data=None,
40+
batch_size, num_epochs=None, dataset_size=None, shuffle=False,
4541
):
4642
"""Create an elastic AllReduce controller with data shard service.
4743
Users can use the `controller.data_shard_service` to get data
@@ -64,7 +60,6 @@ def create_elastic_controller(
6460
num_epochs=num_epochs,
6561
dataset_size=dataset_size,
6662
shuffle=shuffle,
67-
training_data=training_data,
6863
)
6964
if _IS_TF2:
7065
controller = TensorFlowV2AllReduceController(

model_zoo/mnist/mnist_train_tfv1.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# limitations under the License.
1313

1414
import argparse
15-
from contextlib import closing
15+
import os
1616

17-
import recordio
1817
import tensorflow as tf
1918

19+
from elasticai_api.io.recordio_reader import RecordIOReader
2020
from elasticai_api.tensorflow.controller import create_elastic_controller
2121
from elasticai_api.tensorflow.optimizer import (
2222
AdjustBackwardPassesPerStepHook,
@@ -27,31 +27,35 @@
2727
layers = tf.layers
2828

2929

30-
def get_dataset_gen(data_shard_service):
30+
def get_dataset_gen(data_shard_service, data_reader):
3131
def gen():
3232
while True:
3333
shard = data_shard_service.fetch_shard()
3434
if not shard:
3535
raise StopIteration("No data")
36-
with closing(
37-
recordio.Scanner(
38-
shard.name, shard.start, shard.end - shard.start,
39-
)
40-
) as reader:
41-
for i in range(shard.start, shard.end):
42-
record = reader.record()
43-
if record:
44-
yield record
36+
count = 0
37+
for record in data_reader.read_records(shard.start, shard.end):
38+
count += 1
39+
yield record
4540

4641
return gen
4742

4843

49-
def create_dataset(data_shard_service):
50-
gen = get_dataset_gen(data_shard_service)
44+
def create_dataset(data_shard_service, training_data_dir):
45+
data_files = _get_data_files(training_data_dir)
46+
data_reader = RecordIOReader(data_files)
47+
gen = get_dataset_gen(data_shard_service, data_reader)
5148
dataset = tf.data.Dataset.from_generator(gen, tf.string)
5249
return dataset
5350

5451

52+
def _get_data_files(data_dir):
53+
data_files = []
54+
for filename in os.listdir(data_dir):
55+
data_files.append(os.path.join(data_dir, filename))
56+
return data_files
57+
58+
5559
def conv_model(feature, target, mode):
5660
"""2-layer convolution model."""
5761
# Convert the target to a one-hot tensor of shape (batch_size, 10) and
@@ -109,9 +113,11 @@ def train(args):
109113
allreduce_controller = create_elastic_controller(
110114
batch_size=args.batch_size,
111115
num_epochs=args.num_epochs,
112-
training_data=args.training_data,
116+
dataset_size=50000,
117+
)
118+
dataset = create_dataset(
119+
allreduce_controller.data_shard_service, args.training_data
113120
)
114-
dataset = create_dataset(allreduce_controller.data_shard_service)
115121
dataset = feed(dataset)
116122
dataset = dataset.batch(args.batch_size).prefetch(1)
117123
dataset_it = dataset.make_one_shot_iterator()

0 commit comments

Comments
 (0)