@@ -64,6 +64,21 @@ def __iter__(self):
64
64
stop = random .random () < self .p_stop
65
65
66
66
67
+ class SimpleIterableDataset (IterableDataset ):
68
+ def __init__ (self , num_samples = 1000 ):
69
+ self .num_samples = num_samples
70
+
71
+ def __iter__ (self ):
72
+ for _ in range (self .num_samples ):
73
+ yield torch .rand (1 )
74
+
75
+ def __len__ (self ):
76
+ return self .num_samples
77
+
78
+ def set_epoch (self , epoch ):
79
+ self .epoch = epoch
80
+
81
+
67
82
class DataLoaderTester (unittest .TestCase ):
68
83
def check_batch_sampler_shards (self , batch_sampler , expected , split_batches = False , even_batches = True ):
69
84
batch_sampler_shards = [
@@ -384,6 +399,14 @@ def test_iterable_dataset_shard(self):
384
399
self .check_iterable_dataset_shards (dataset , seed , batch_size = 4 , drop_last = False , split_batches = True )
385
400
self .check_iterable_dataset_shards (dataset , seed , batch_size = 4 , drop_last = True , split_batches = True )
386
401
402
+ def test_iterable_dataset_using_none_batch_size (self ):
403
+ dataset = SimpleIterableDataset (100 )
404
+ accelerator = Accelerator ()
405
+ dataloader = DataLoader (dataset , batch_size = None )
406
+ dataloader = accelerator .prepare (dataloader )
407
+ for d in dataloader :
408
+ assert isinstance (d , torch .Tensor )
409
+
387
410
def test_skip_batch_sampler (self ):
388
411
batch_sampler = BatchSampler (range (16 ), batch_size = 4 , drop_last = False )
389
412
new_batch_sampler = SkipBatchSampler (batch_sampler , 2 )
0 commit comments