Skip to content

Commit 18ef3a0

Browse files
michael-digginpytorchmergebot
authored andcommitted
Add option in data loader for out of order data (pytorch#141833)
Fixes pytorch#105203 Facing a similar problem to the linked issue, where variable sized input data can mean that a handful of slow to process samples holds up smaller and faster to process samples from being used. This also leads to lower GPU utilization as well. In certain cases, e.g. evaluation epochs, inference pipelines or other cases where reproducibility isn't important, this can bring significant speed ups. This PR adds an `allow_out_of_order` bool input to the `DataLoader` class, defaulting to `false`, which when set to `true` will returning data from workers in whatever order they are ready/processed in, rather in the strict index order. Instead of storing data that was returned out of order, it is passed directly to the main thread and the entry in `_task_info` is deleted. The main changes are they to check that an entry in `_task_info` does exist, and only increasing `self._rcvd_idx` when the lowest index remaining gets returned. Two tests are added to test this for iterable type datasets and index type datasets. Pull Request resolved: pytorch#141833 Approved by: https://github.com/andrewkho
1 parent 61a7c83 commit 18ef3a0

File tree

2 files changed

+115
-8
lines changed

2 files changed

+115
-8
lines changed

test/test_dataloader.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3501,6 +3501,99 @@ def test_conv_after_fork(self):
35013501
self.assertEqual(x.shape, (1, 1, 1, 23999))
35023502

35033503

3504+
class TestSlowIndexDataset(Dataset):
3505+
def __init__(self, end: int, slow_index: int):
3506+
self.end = end
3507+
self.slow_index = slow_index
3508+
3509+
def __getitem__(self, idx):
3510+
if idx == self.slow_index:
3511+
time.sleep(0.5)
3512+
return idx
3513+
3514+
def __len__(self):
3515+
return self.end
3516+
3517+
3518+
class TestSlowIterableDataset(IterableDataset):
3519+
def __init__(self, start: int, end: int):
3520+
self.start = start
3521+
self.end = end
3522+
self.mid = math.ceil((self.end - self.start) / 2)
3523+
3524+
def give_data(self, iter_start, iter_end):
3525+
for i in range(iter_start, iter_end):
3526+
if i >= self.mid:
3527+
time.sleep(0.5)
3528+
yield i
3529+
3530+
def __iter__(self):
3531+
worker_info = torch.utils.data.get_worker_info()
3532+
per_worker = int(
3533+
math.ceil((self.end - self.start) / float(worker_info.num_workers))
3534+
)
3535+
worker_id = worker_info.id
3536+
iter_start = self.start + worker_id * per_worker
3537+
iter_end = min(iter_start + per_worker, self.end)
3538+
return self.give_data(iter_start, iter_end)
3539+
3540+
3541+
class TestOutOfOrderDataLoader(TestCase):
3542+
def test_in_order_index_ds(self):
3543+
dataset = TestSlowIndexDataset(end=10, slow_index=2)
3544+
3545+
dataloader = torch.utils.data.DataLoader(
3546+
dataset,
3547+
num_workers=2,
3548+
in_order=True,
3549+
)
3550+
3551+
expected_order = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3552+
output = [sample.item() for sample in dataloader]
3553+
self.assertEqual(expected_order, output)
3554+
3555+
def test_out_of_order_index_ds(self):
3556+
dataset = TestSlowIndexDataset(end=10, slow_index=2)
3557+
3558+
dataloader = torch.utils.data.DataLoader(
3559+
dataset,
3560+
num_workers=2,
3561+
in_order=False,
3562+
)
3563+
3564+
# normally, this should be [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3565+
expected_order = [0, 1, 3, 5, 7, 2, 4, 6, 8, 9]
3566+
output = [sample.item() for sample in dataloader]
3567+
self.assertEqual(expected_order, output)
3568+
3569+
def test_in_order_iterable_ds(self):
3570+
dataset = TestSlowIterableDataset(start=0, end=10)
3571+
3572+
dataloader = torch.utils.data.DataLoader(
3573+
dataset,
3574+
num_workers=2,
3575+
in_order=True,
3576+
)
3577+
3578+
expected_order = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
3579+
output = [sample.item() for sample in dataloader]
3580+
self.assertEqual(expected_order, output)
3581+
3582+
def test_out_of_order_iterable_ds(self):
3583+
dataset = TestSlowIterableDataset(start=0, end=10)
3584+
3585+
dataloader = torch.utils.data.DataLoader(
3586+
dataset,
3587+
num_workers=2,
3588+
in_order=False,
3589+
)
3590+
3591+
# normally, this should be [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
3592+
expected_order = [0, 1, 2, 3, 5, 4, 6, 7, 8, 9]
3593+
output = [sample.item() for sample in dataloader]
3594+
self.assertEqual(expected_order, output)
3595+
3596+
35043597
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
35053598

35063599

torch/utils/data/dataloader.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ class DataLoader(Generic[_T_co]):
185185
maintain the workers `Dataset` instances alive. (default: ``False``)
186186
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
187187
``True``.
188+
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
189+
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
188190
189191
190192
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
@@ -213,6 +215,9 @@ class DataLoader(Generic[_T_co]):
213215
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
214216
:ref:`data-loading-randomness` notes for random seed related questions.
215217
218+
.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data
219+
distribution being fed to the trainer in cases with imbalanced data.
220+
216221
.. _multiprocessing context:
217222
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
218223
"""
@@ -248,6 +253,7 @@ def __init__(
248253
prefetch_factor: Optional[int] = None,
249254
persistent_workers: bool = False,
250255
pin_memory_device: str = "",
256+
in_order: bool = True,
251257
):
252258
torch._C._log_api_usage_once("python.data_loader")
253259

@@ -281,6 +287,7 @@ def __init__(
281287
self.timeout = timeout
282288
self.worker_init_fn = worker_init_fn
283289
self.multiprocessing_context = multiprocessing_context
290+
self.in_order = in_order
284291

285292
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
286293
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
@@ -1074,6 +1081,7 @@ def __init__(self, loader):
10741081
super().__init__(loader)
10751082

10761083
self._prefetch_factor = loader.prefetch_factor
1084+
self._in_order = loader.in_order
10771085

10781086
assert self._num_workers > 0
10791087
assert self._prefetch_factor > 0
@@ -1423,13 +1431,14 @@ def _next_data(self):
14231431
# call and `_IterableDatasetStopIteration` check below can mark
14241432
# extra worker(s) as dead.
14251433
while self._rcvd_idx < self._send_idx:
1426-
info = self._task_info[self._rcvd_idx]
1427-
worker_id = info[0]
1428-
if (
1429-
len(info) == 2 or self._workers_status[worker_id]
1430-
): # has data or is still active
1431-
break
1432-
del self._task_info[self._rcvd_idx]
1434+
info = self._task_info.get(self._rcvd_idx, None)
1435+
if info:
1436+
worker_id = info[0]
1437+
if (
1438+
len(info) == 2 or self._workers_status[worker_id]
1439+
): # has data or is still active
1440+
break
1441+
del self._task_info[self._rcvd_idx]
14331442
self._rcvd_idx += 1
14341443
else:
14351444
# no valid `self._rcvd_idx` is found (i.e., didn't break)
@@ -1442,6 +1451,7 @@ def _next_data(self):
14421451
# Check if the next sample has already been generated
14431452
if len(self._task_info[self._rcvd_idx]) == 2:
14441453
data = self._task_info.pop(self._rcvd_idx)[1]
1454+
self._rcvd_idx += 1
14451455
return self._process_data(data)
14461456

14471457
assert not self._shutdown and self._tasks_outstanding > 0
@@ -1458,10 +1468,15 @@ def _next_data(self):
14581468
continue
14591469

14601470
if idx != self._rcvd_idx:
1471+
if not self._in_order:
1472+
# don't store it for later, process now
1473+
del self._task_info[idx]
1474+
return self._process_data(data)
14611475
# store out-of-order samples
14621476
self._task_info[idx] += (data,)
14631477
else:
14641478
del self._task_info[idx]
1479+
self._rcvd_idx += 1
14651480
return self._process_data(data)
14661481

14671482
def _try_put_index(self):
@@ -1485,7 +1500,6 @@ def _try_put_index(self):
14851500
self._send_idx += 1
14861501

14871502
def _process_data(self, data):
1488-
self._rcvd_idx += 1
14891503
self._try_put_index()
14901504
if isinstance(data, ExceptionWrapper):
14911505
data.reraise()

0 commit comments

Comments
 (0)