30
30
get_data_structure ,
31
31
initialize_tensors ,
32
32
is_torch_version ,
33
+ is_torchdata_stateful_dataloader_available ,
33
34
send_to_device ,
34
35
slice_tensors ,
35
36
synchronize_rng_states ,
@@ -388,9 +389,75 @@ def end(self):
388
389
self .gradient_state ._remove_dataloader (self )
389
390
390
391
391
- class DataLoaderShard ( DataLoader , DataLoaderStateMixin ) :
392
+ class DataLoaderAdapter :
392
393
"""
393
- Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
394
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
395
+ compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
396
+ """
397
+
398
+ def __init__ (self , dataset , use_stateful_dataloader = False , batch_sampler = None , ** kwargs ):
399
+ self .use_stateful_dataloader = use_stateful_dataloader
400
+ if is_torchdata_stateful_dataloader_available ():
401
+ from torchdata .stateful_dataloader import StatefulDataLoader
402
+
403
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available ():
404
+ raise ImportError (
405
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
406
+ )
407
+ if use_stateful_dataloader :
408
+ self .base_dataloader = StatefulDataLoader (dataset , batch_sampler = batch_sampler , ** kwargs )
409
+ else :
410
+ self .base_dataloader = DataLoader (dataset , batch_sampler = batch_sampler , ** kwargs )
411
+
412
+ # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
413
+ # In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
414
+ # StatefulDataLoader
415
+ #
416
+ # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
417
+ # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
418
+ # dispatching scattered throughout various functions and files.
419
+ #
420
+ # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
421
+ # transparently.
422
+ #
423
+ # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
424
+ # but this would not be backwards compatible with existing code which assumes
425
+ # DataLoaderShard/DataLoaderDispatcher are DataLoaders.
426
+ base_cls = self .__class__
427
+ base_cls_name = self .__class__ .__name__
428
+ parent_cls_name = self .base_dataloader .__class__
429
+ self .__class__ = type (base_cls_name , (base_cls , parent_cls_name ), {})
430
+
431
+ if hasattr (self .base_dataloader , "state_dict" ):
432
+ self .dl_state_dict = self .base_dataloader .state_dict ()
433
+
434
+ def __getattr__ (self , name ):
435
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
436
+ if name == "base_dataloader" :
437
+ raise AttributeError ()
438
+ # Delegate attribute access to the internal dataloader
439
+ return getattr (self .base_dataloader , name )
440
+
441
+ def state_dict (self ):
442
+ return self .dl_state_dict
443
+
444
+ def load_state_dict (self , state_dict ):
445
+ self .base_dataloader .load_state_dict (state_dict )
446
+ self .dl_state_dict = self .state_dict
447
+
448
+ def _update_state_dict (self ):
449
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
450
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
451
+ # what it wants to yield.
452
+ #
453
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
454
+ if hasattr (self .base_dataloader , "state_dict" ):
455
+ self .dl_state_dict = self .base_dataloader .state_dict ()
456
+
457
+
458
+ class DataLoaderShard (DataLoaderAdapter , DataLoaderStateMixin ):
459
+ """
460
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
394
461
395
462
Args:
396
463
dataset (`torch.utils.data.dataset.Dataset`):
@@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
409
476
A random number generator to keep synchronized across processes.
410
477
skip_batches (`int`, *optional*, defaults to 0):
411
478
The number of batches to skip at the beginning.
479
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
480
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
412
481
**kwargs (additional keyword arguments, *optional*):
413
482
All other keyword arguments to pass to the regular `DataLoader` initialization.
414
483
@@ -428,11 +497,12 @@ def __init__(
428
497
rng_types = None ,
429
498
synchronized_generator = None ,
430
499
skip_batches = 0 ,
500
+ use_stateful_dataloader = False ,
431
501
_drop_last : bool = False ,
432
502
_non_blocking : bool = False ,
433
503
** kwargs ,
434
504
):
435
- super ().__init__ (dataset , ** kwargs )
505
+ super ().__init__ (dataset , use_stateful_dataloader = use_stateful_dataloader , ** kwargs )
436
506
self .device = device
437
507
self .rng_types = rng_types
438
508
self .synchronized_generator = synchronized_generator
@@ -448,7 +518,7 @@ def __iter__(self):
448
518
self .begin ()
449
519
450
520
self .set_epoch (self .iteration )
451
- dataloader_iter = super () .__iter__ ()
521
+ dataloader_iter = self . base_dataloader .__iter__ ()
452
522
# We iterate one batch ahead to check when we are at the end
453
523
try :
454
524
current_batch = next (dataloader_iter )
@@ -461,6 +531,7 @@ def __iter__(self):
461
531
# But we still move it to the device so it is done before `StopIteration` is reached
462
532
if self .device is not None :
463
533
current_batch = send_to_device (current_batch , self .device , non_blocking = self ._non_blocking )
534
+ self ._update_state_dict ()
464
535
next_batch = next (dataloader_iter )
465
536
if batch_index >= self .skip_batches :
466
537
yield current_batch
@@ -564,10 +635,10 @@ def dataloader(self):
564
635
return self ._loader
565
636
566
637
567
- class DataLoaderDispatcher (DataLoader , DataLoaderStateMixin ):
638
+ class DataLoaderDispatcher (DataLoaderAdapter , DataLoaderStateMixin ):
568
639
"""
569
- Subclass of a PyTorch `DataLoader ` that will iterate and preprocess on process 0 only, then dispatch on each
570
- process their part of the batch.
640
+ Subclass of `DataLoaderAdapter ` that will iterate and preprocess on process 0 only, then dispatch on each process
641
+ their part of the batch.
571
642
572
643
Args:
573
644
split_batches (`bool`, *optional*, defaults to `False`):
@@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
579
650
size of the `dataloader` is a round multiple of `batch_size`.
580
651
skip_batches (`int`, *optional*, defaults to 0):
581
652
The number of batches to skip at the beginning of an iteration.
653
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
654
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
582
655
583
656
**Available attributes:**
584
657
@@ -594,6 +667,7 @@ def __init__(
594
667
dataset ,
595
668
split_batches : bool = False ,
596
669
skip_batches = 0 ,
670
+ use_stateful_dataloader = False ,
597
671
_drop_last : bool = False ,
598
672
_non_blocking : bool = False ,
599
673
slice_fn = None ,
@@ -606,7 +680,7 @@ def __init__(
606
680
# We need to save the shuffling state of the DataPipe
607
681
if isinstance (dataset , ShufflerIterDataPipe ):
608
682
shuffle = dataset ._shuffle_enabled
609
- super ().__init__ (dataset , ** kwargs )
683
+ super ().__init__ (dataset , use_stateful_dataloader = use_stateful_dataloader , ** kwargs )
610
684
self .split_batches = split_batches
611
685
if shuffle :
612
686
torch .utils .data .graph_settings .apply_shuffle_settings (dataset , shuffle = shuffle )
@@ -627,12 +701,14 @@ def _fetch_batches(self, iterator):
627
701
try :
628
702
if self .split_batches :
629
703
# One batch of the main iterator is dispatched and split.
704
+ self ._update_state_dict ()
630
705
batch = next (iterator )
631
706
else :
632
707
# num_processes batches of the main iterator are concatenated then dispatched and split.
633
708
# We add the batches one by one so we have the remainder available when drop_last=False.
634
709
batches = []
635
710
for _ in range (self .state .num_processes ):
711
+ self ._update_state_dict ()
636
712
batches .append (next (iterator ))
637
713
try :
638
714
batch = concatenate (batches , dim = 0 )
@@ -673,9 +749,9 @@ def __iter__(self):
673
749
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
674
750
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
675
751
# But, we only iterate through the DataLoader on process 0.
676
- main_iterator = super () .__iter__ ()
752
+ main_iterator = self . base_dataloader .__iter__ ()
677
753
elif self .state .process_index == 0 :
678
- main_iterator = super () .__iter__ ()
754
+ main_iterator = self . base_dataloader .__iter__ ()
679
755
stop_iteration = False
680
756
self ._stop_iteration = False
681
757
first_batch = None
@@ -812,6 +888,7 @@ def prepare_data_loader(
812
888
slice_fn_for_dispatch : Optional [Callable ] = None ,
813
889
use_seedable_sampler : bool = False ,
814
890
non_blocking : bool = False ,
891
+ use_stateful_dataloader : bool = False ,
815
892
) -> DataLoader :
816
893
"""
817
894
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@@ -873,6 +950,10 @@ def prepare_data_loader(
873
950
non_blocking (`bool`, *optional*, defaults to `False`):
874
951
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
875
952
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
953
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
954
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
955
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
956
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
876
957
877
958
878
959
Returns:
@@ -1006,6 +1087,7 @@ def prepare_data_loader(
1006
1087
_drop_last = dataloader .drop_last ,
1007
1088
_non_blocking = non_blocking ,
1008
1089
slice_fn = slice_fn_for_dispatch ,
1090
+ use_stateful_dataloader = use_stateful_dataloader ,
1009
1091
** kwargs ,
1010
1092
)
1011
1093
elif sampler_is_batch_sampler :
@@ -1018,6 +1100,7 @@ def prepare_data_loader(
1018
1100
_drop_last = dataloader .drop_last ,
1019
1101
_non_blocking = non_blocking ,
1020
1102
synchronized_generator = synchronized_generator ,
1103
+ use_stateful_dataloader = use_stateful_dataloader ,
1021
1104
** kwargs ,
1022
1105
)
1023
1106
else :
@@ -1029,6 +1112,7 @@ def prepare_data_loader(
1029
1112
synchronized_generator = synchronized_generator ,
1030
1113
_drop_last = dataloader .drop_last ,
1031
1114
_non_blocking = non_blocking ,
1115
+ use_stateful_dataloader = use_stateful_dataloader ,
1032
1116
** kwargs ,
1033
1117
)
1034
1118
@@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler):
1046
1130
1047
1131
def __init__ (self , batch_sampler , skip_batches = 0 ):
1048
1132
self .batch_sampler = batch_sampler
1133
+ self .sampler = batch_sampler .sampler
1049
1134
self .skip_batches = skip_batches
1050
1135
1051
1136
def __iter__ (self ):
@@ -1061,7 +1146,7 @@ def __len__(self):
1061
1146
return len (self .batch_sampler ) - self .skip_batches
1062
1147
1063
1148
1064
- class SkipDataLoader (DataLoader ):
1149
+ class SkipDataLoader (DataLoaderAdapter ):
1065
1150
"""
1066
1151
Subclass of a PyTorch `DataLoader` that will skip the first batches.
1067
1152
@@ -1070,24 +1155,30 @@ class SkipDataLoader(DataLoader):
1070
1155
The dataset to use to build this datalaoder.
1071
1156
skip_batches (`int`, *optional*, defaults to 0):
1072
1157
The number of batches to skip at the beginning.
1158
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1159
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
1073
1160
kwargs:
1074
1161
All other keyword arguments to pass to the regular `DataLoader` initialization.
1075
1162
"""
1076
1163
1077
- def __init__ (self , dataset , skip_batches = 0 , ** kwargs ):
1078
- super ().__init__ (dataset , ** kwargs )
1164
+ def __init__ (self , dataset , skip_batches = 0 , use_stateful_dataloader = False , ** kwargs ):
1165
+ super ().__init__ (dataset , use_stateful_dataloader = use_stateful_dataloader , ** kwargs )
1079
1166
self .skip_batches = skip_batches
1080
1167
1081
1168
def __iter__ (self ):
1082
- for index , batch in enumerate (super () .__iter__ ()):
1169
+ for index , batch in enumerate (self . base_dataloader .__iter__ ()):
1083
1170
if index >= self .skip_batches :
1171
+ self ._update_state_dict ()
1084
1172
yield batch
1085
1173
1086
1174
1087
1175
def skip_first_batches (dataloader , num_batches = 0 ):
1088
1176
"""
1089
1177
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
1090
1178
"""
1179
+ if is_torchdata_stateful_dataloader_available ():
1180
+ from torchdata .stateful_dataloader import StatefulDataLoader
1181
+
1091
1182
state = PartialState ()
1092
1183
if state .distributed_type == DistributedType .XLA :
1093
1184
device = dataloader .device
@@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0):
1131
1222
split_batches = dataloader .split_batches ,
1132
1223
batch_sampler = new_batch_sampler ,
1133
1224
_drop_last = dataloader ._drop_last ,
1225
+ use_stateful_dataloader = dataloader .use_stateful_dataloader ,
1134
1226
** kwargs ,
1135
1227
)
1136
1228
elif isinstance (dataloader , DataLoaderShard ):
@@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0):
1147
1239
device = dataloader .device ,
1148
1240
rng_types = dataloader .rng_types ,
1149
1241
synchronized_generator = dataloader .synchronized_generator ,
1242
+ use_stateful_dataloader = dataloader .use_stateful_dataloader ,
1150
1243
** kwargs ,
1151
1244
)
1152
1245
else :
1153
1246
if new_batch_sampler is None :
1154
1247
# Need to manually skip batches in the dataloader
1155
- dataloader = SkipDataLoader (dataset , skip_batches = num_batches , ** kwargs )
1248
+ dataloader = SkipDataLoader (
1249
+ dataset , skip_batches = num_batches , use_stateful_dataloader = dataloader .use_stateful_dataloader , ** kwargs
1250
+ )
1251
+ elif is_torchdata_stateful_dataloader_available () and isinstance (dataloader , StatefulDataLoader ):
1252
+ dataloader = StatefulDataLoader (dataset , batch_sampler = new_batch_sampler , ** kwargs )
1156
1253
else :
1157
1254
dataloader = DataLoader (dataset , batch_sampler = new_batch_sampler , ** kwargs )
1158
1255
0 commit comments