Skip to content

Commit ad3f574

Browse files
byi8220faaanymuellerzr
authored
Add early support for torchdata.stateful_dataloader.StatefulDataLoader within the Accelerator (#2895)
* temporary commit * checkout? * dataloader wrapper * tmp * weird failing test * trying multiple inheritance * DataLoaderAdapter * make style * Some dark magic dynamic reflection (for backwards compat) * typo * some tests * more mixin stuff * maybe found broken test? * this is a very invasive feature * i think the feature is done? * add xpu support (#2864) * better tests * discovered a bug * maybe fixed bug? * make style * hopefully this is PR ready * properly skip tests * parameterize * temporary commit * checkout? * dataloader wrapper * tmp * weird failing test * trying multiple inheritance * DataLoaderAdapter * make style * Some dark magic dynamic reflection (for backwards compat) * typo * some tests * more mixin stuff * maybe found broken test? * this is a very invasive feature * i think the feature is done? * better tests * discovered a bug * maybe fixed bug? * make style * hopefully this is PR ready * properly skip tests * parameterize * Update src/accelerate/utils/dataclasses.py Co-authored-by: Zach Mueller <[email protected]> * Update src/accelerate/data_loader.py Co-authored-by: Zach Mueller <[email protected]> * merge conflicts * move imports * make style * merges are breaking tests * fix test name * Require safetensors>=0.4.3 * undo last commit * minor style * address pr comments * Torchdata version 0.8.0 is stable now * added docs and require torchdata>=0.8.0 for testing * test base_dataloader attr doesn't cause infinite recursion * address pr * replace super().__iter__ with self.base_dataloader.__iter__ --------- Co-authored-by: Fanli Lin <[email protected]> Co-authored-by: Zach Mueller <[email protected]>
1 parent 1a6af0b commit ad3f574

File tree

12 files changed

+605
-21
lines changed

12 files changed

+605
-21
lines changed

docs/source/basic_tutorials/migration.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer,
219219
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.
220220

221221
Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.
222+
223+
<Note>
224+
225+
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.
226+
227+
</Note>

docs/source/concept_guides/internal_mechanism.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes.
6969

7070
</Tip>
7171

72+
<Note>
73+
74+
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.
75+
76+
</Note>
77+
7278
For more details about the internals, see the [Internals page](package_reference/torch_wrappers).

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"datasets",
2828
"diffusers",
2929
"evaluate",
30+
"torchdata>=0.8.0",
3031
"torchpippy>=0.2.0",
3132
"transformers",
3233
"scipy",

src/accelerate/accelerator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,12 @@ def use_seedable_sampler(self):
583583
def non_blocking(self):
584584
return self.dataloader_config.non_blocking
585585

586+
@property
587+
def use_stateful_dataloader(self):
588+
if hasattr(self.dataloader_config, "use_stateful_dataloader"):
589+
return self.dataloader_config.use_stateful_dataloader
590+
return False
591+
586592
@property
587593
def project_dir(self):
588594
return self.project_configuration.project_dir
@@ -2068,6 +2074,7 @@ def prepare_data_loader(
20682074
slice_fn_for_dispatch=slice_fn_for_dispatch,
20692075
use_seedable_sampler=self.use_seedable_sampler,
20702076
non_blocking=self.non_blocking,
2077+
use_stateful_dataloader=self.use_stateful_dataloader,
20712078
)
20722079
self._dataloaders.append(prepared_data_loader)
20732080
return prepared_data_loader

src/accelerate/data_loader.py

Lines changed: 112 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_data_structure,
3131
initialize_tensors,
3232
is_torch_version,
33+
is_torchdata_stateful_dataloader_available,
3334
send_to_device,
3435
slice_tensors,
3536
synchronize_rng_states,
@@ -388,9 +389,75 @@ def end(self):
388389
self.gradient_state._remove_dataloader(self)
389390

390391

391-
class DataLoaderShard(DataLoader, DataLoaderStateMixin):
392+
class DataLoaderAdapter:
392393
"""
393-
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
394+
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
395+
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
396+
"""
397+
398+
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
399+
self.use_stateful_dataloader = use_stateful_dataloader
400+
if is_torchdata_stateful_dataloader_available():
401+
from torchdata.stateful_dataloader import StatefulDataLoader
402+
403+
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
404+
raise ImportError(
405+
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
406+
)
407+
if use_stateful_dataloader:
408+
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
409+
else:
410+
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
411+
412+
# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
413+
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
414+
# StatefulDataLoader
415+
#
416+
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
417+
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
418+
# dispatching scattered throughout various functions and files.
419+
#
420+
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
421+
# transparently.
422+
#
423+
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
424+
# but this would not be backwards compatible with existing code which assumes
425+
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
426+
base_cls = self.__class__
427+
base_cls_name = self.__class__.__name__
428+
parent_cls_name = self.base_dataloader.__class__
429+
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
430+
431+
if hasattr(self.base_dataloader, "state_dict"):
432+
self.dl_state_dict = self.base_dataloader.state_dict()
433+
434+
def __getattr__(self, name):
435+
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
436+
if name == "base_dataloader":
437+
raise AttributeError()
438+
# Delegate attribute access to the internal dataloader
439+
return getattr(self.base_dataloader, name)
440+
441+
def state_dict(self):
442+
return self.dl_state_dict
443+
444+
def load_state_dict(self, state_dict):
445+
self.base_dataloader.load_state_dict(state_dict)
446+
self.dl_state_dict = self.state_dict
447+
448+
def _update_state_dict(self):
449+
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
450+
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
451+
# what it wants to yield.
452+
#
453+
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
454+
if hasattr(self.base_dataloader, "state_dict"):
455+
self.dl_state_dict = self.base_dataloader.state_dict()
456+
457+
458+
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
459+
"""
460+
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
394461
395462
Args:
396463
dataset (`torch.utils.data.dataset.Dataset`):
@@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
409476
A random number generator to keep synchronized across processes.
410477
skip_batches (`int`, *optional*, defaults to 0):
411478
The number of batches to skip at the beginning.
479+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
480+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
412481
**kwargs (additional keyword arguments, *optional*):
413482
All other keyword arguments to pass to the regular `DataLoader` initialization.
414483
@@ -428,11 +497,12 @@ def __init__(
428497
rng_types=None,
429498
synchronized_generator=None,
430499
skip_batches=0,
500+
use_stateful_dataloader=False,
431501
_drop_last: bool = False,
432502
_non_blocking: bool = False,
433503
**kwargs,
434504
):
435-
super().__init__(dataset, **kwargs)
505+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
436506
self.device = device
437507
self.rng_types = rng_types
438508
self.synchronized_generator = synchronized_generator
@@ -448,7 +518,7 @@ def __iter__(self):
448518
self.begin()
449519

450520
self.set_epoch(self.iteration)
451-
dataloader_iter = super().__iter__()
521+
dataloader_iter = self.base_dataloader.__iter__()
452522
# We iterate one batch ahead to check when we are at the end
453523
try:
454524
current_batch = next(dataloader_iter)
@@ -461,6 +531,7 @@ def __iter__(self):
461531
# But we still move it to the device so it is done before `StopIteration` is reached
462532
if self.device is not None:
463533
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
534+
self._update_state_dict()
464535
next_batch = next(dataloader_iter)
465536
if batch_index >= self.skip_batches:
466537
yield current_batch
@@ -564,10 +635,10 @@ def dataloader(self):
564635
return self._loader
565636

566637

567-
class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
638+
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
568639
"""
569-
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
570-
process their part of the batch.
640+
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
641+
their part of the batch.
571642
572643
Args:
573644
split_batches (`bool`, *optional*, defaults to `False`):
@@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
579650
size of the `dataloader` is a round multiple of `batch_size`.
580651
skip_batches (`int`, *optional*, defaults to 0):
581652
The number of batches to skip at the beginning of an iteration.
653+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
654+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
582655
583656
**Available attributes:**
584657
@@ -594,6 +667,7 @@ def __init__(
594667
dataset,
595668
split_batches: bool = False,
596669
skip_batches=0,
670+
use_stateful_dataloader=False,
597671
_drop_last: bool = False,
598672
_non_blocking: bool = False,
599673
slice_fn=None,
@@ -606,7 +680,7 @@ def __init__(
606680
# We need to save the shuffling state of the DataPipe
607681
if isinstance(dataset, ShufflerIterDataPipe):
608682
shuffle = dataset._shuffle_enabled
609-
super().__init__(dataset, **kwargs)
683+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
610684
self.split_batches = split_batches
611685
if shuffle:
612686
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
@@ -627,12 +701,14 @@ def _fetch_batches(self, iterator):
627701
try:
628702
if self.split_batches:
629703
# One batch of the main iterator is dispatched and split.
704+
self._update_state_dict()
630705
batch = next(iterator)
631706
else:
632707
# num_processes batches of the main iterator are concatenated then dispatched and split.
633708
# We add the batches one by one so we have the remainder available when drop_last=False.
634709
batches = []
635710
for _ in range(self.state.num_processes):
711+
self._update_state_dict()
636712
batches.append(next(iterator))
637713
try:
638714
batch = concatenate(batches, dim=0)
@@ -673,9 +749,9 @@ def __iter__(self):
673749
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
674750
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
675751
# But, we only iterate through the DataLoader on process 0.
676-
main_iterator = super().__iter__()
752+
main_iterator = self.base_dataloader.__iter__()
677753
elif self.state.process_index == 0:
678-
main_iterator = super().__iter__()
754+
main_iterator = self.base_dataloader.__iter__()
679755
stop_iteration = False
680756
self._stop_iteration = False
681757
first_batch = None
@@ -812,6 +888,7 @@ def prepare_data_loader(
812888
slice_fn_for_dispatch: Optional[Callable] = None,
813889
use_seedable_sampler: bool = False,
814890
non_blocking: bool = False,
891+
use_stateful_dataloader: bool = False,
815892
) -> DataLoader:
816893
"""
817894
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@@ -873,6 +950,10 @@ def prepare_data_loader(
873950
non_blocking (`bool`, *optional*, defaults to `False`):
874951
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
875952
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
953+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
954+
"If set to true, the dataloader prepared by the Accelerator will be backed by "
955+
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
956+
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
876957
877958
878959
Returns:
@@ -1006,6 +1087,7 @@ def prepare_data_loader(
10061087
_drop_last=dataloader.drop_last,
10071088
_non_blocking=non_blocking,
10081089
slice_fn=slice_fn_for_dispatch,
1090+
use_stateful_dataloader=use_stateful_dataloader,
10091091
**kwargs,
10101092
)
10111093
elif sampler_is_batch_sampler:
@@ -1018,6 +1100,7 @@ def prepare_data_loader(
10181100
_drop_last=dataloader.drop_last,
10191101
_non_blocking=non_blocking,
10201102
synchronized_generator=synchronized_generator,
1103+
use_stateful_dataloader=use_stateful_dataloader,
10211104
**kwargs,
10221105
)
10231106
else:
@@ -1029,6 +1112,7 @@ def prepare_data_loader(
10291112
synchronized_generator=synchronized_generator,
10301113
_drop_last=dataloader.drop_last,
10311114
_non_blocking=non_blocking,
1115+
use_stateful_dataloader=use_stateful_dataloader,
10321116
**kwargs,
10331117
)
10341118

@@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler):
10461130

10471131
def __init__(self, batch_sampler, skip_batches=0):
10481132
self.batch_sampler = batch_sampler
1133+
self.sampler = batch_sampler.sampler
10491134
self.skip_batches = skip_batches
10501135

10511136
def __iter__(self):
@@ -1061,7 +1146,7 @@ def __len__(self):
10611146
return len(self.batch_sampler) - self.skip_batches
10621147

10631148

1064-
class SkipDataLoader(DataLoader):
1149+
class SkipDataLoader(DataLoaderAdapter):
10651150
"""
10661151
Subclass of a PyTorch `DataLoader` that will skip the first batches.
10671152
@@ -1070,24 +1155,30 @@ class SkipDataLoader(DataLoader):
10701155
The dataset to use to build this datalaoder.
10711156
skip_batches (`int`, *optional*, defaults to 0):
10721157
The number of batches to skip at the beginning.
1158+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1159+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
10731160
kwargs:
10741161
All other keyword arguments to pass to the regular `DataLoader` initialization.
10751162
"""
10761163

1077-
def __init__(self, dataset, skip_batches=0, **kwargs):
1078-
super().__init__(dataset, **kwargs)
1164+
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
1165+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
10791166
self.skip_batches = skip_batches
10801167

10811168
def __iter__(self):
1082-
for index, batch in enumerate(super().__iter__()):
1169+
for index, batch in enumerate(self.base_dataloader.__iter__()):
10831170
if index >= self.skip_batches:
1171+
self._update_state_dict()
10841172
yield batch
10851173

10861174

10871175
def skip_first_batches(dataloader, num_batches=0):
10881176
"""
10891177
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
10901178
"""
1179+
if is_torchdata_stateful_dataloader_available():
1180+
from torchdata.stateful_dataloader import StatefulDataLoader
1181+
10911182
state = PartialState()
10921183
if state.distributed_type == DistributedType.XLA:
10931184
device = dataloader.device
@@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0):
11311222
split_batches=dataloader.split_batches,
11321223
batch_sampler=new_batch_sampler,
11331224
_drop_last=dataloader._drop_last,
1225+
use_stateful_dataloader=dataloader.use_stateful_dataloader,
11341226
**kwargs,
11351227
)
11361228
elif isinstance(dataloader, DataLoaderShard):
@@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0):
11471239
device=dataloader.device,
11481240
rng_types=dataloader.rng_types,
11491241
synchronized_generator=dataloader.synchronized_generator,
1242+
use_stateful_dataloader=dataloader.use_stateful_dataloader,
11501243
**kwargs,
11511244
)
11521245
else:
11531246
if new_batch_sampler is None:
11541247
# Need to manually skip batches in the dataloader
1155-
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
1248+
dataloader = SkipDataLoader(
1249+
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
1250+
)
1251+
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
1252+
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
11561253
else:
11571254
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
11581255

src/accelerate/test_utils/scripts/test_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(
305305

306306
def test_dataloader_break():
307307
accelerator = Accelerator()
308-
309308
first_dset = RegressionDataset(length=80)
310309
first_dataloader = DataLoader(first_dset, batch_size=16)
311310
second_dset = RegressionDataset(length=96)
312311
second_dataloader = DataLoader(second_dset, batch_size=16)
313312
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
313+
314314
assert accelerator.gradient_state.active_dataloader is None
315315
for iteration, _ in enumerate(first_dataloader):
316316
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)

0 commit comments

Comments
 (0)