Skip to content

Commit 2d4f1dd

Browse files
Fix batch_sampler maybe None error (#3025)
* Fix batch_sampler maybe None For more details, see: #3011 * Update test_data_loader.py Add unit test for dataloader with batch_size=None when using Iterabledataset * Update tests/test_data_loader.py Co-authored-by: Zach Mueller <[email protected]> * Fix inconsistent indentation Fix inconsistent indentation --------- Co-authored-by: Zach Mueller <[email protected]>
1 parent c0cf860 commit 2d4f1dd

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/accelerate/data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def set_epoch(self, epoch: int):
820820
# In case it is manually passed in, the user can set it to what they like
821821
if self.iteration != epoch:
822822
self.iteration = epoch
823-
if hasattr(self.batch_sampler.sampler, "set_epoch"):
823+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
824824
self.batch_sampler.sampler.set_epoch(epoch)
825825
elif hasattr(self.dataset, "set_epoch"):
826826
self.dataset.set_epoch(epoch)

tests/test_data_loader.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ def __iter__(self):
6464
stop = random.random() < self.p_stop
6565

6666

67+
class SimpleIterableDataset(IterableDataset):
68+
def __init__(self, num_samples=1000):
69+
self.num_samples = num_samples
70+
71+
def __iter__(self):
72+
for _ in range(self.num_samples):
73+
yield torch.rand(1)
74+
75+
def __len__(self):
76+
return self.num_samples
77+
78+
def set_epoch(self, epoch):
79+
self.epoch = epoch
80+
81+
6782
class DataLoaderTester(unittest.TestCase):
6883
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
6984
batch_sampler_shards = [
@@ -384,6 +399,14 @@ def test_iterable_dataset_shard(self):
384399
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)
385400
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)
386401

402+
def test_iterable_dataset_using_none_batch_size(self):
403+
dataset = SimpleIterableDataset(100)
404+
accelerator = Accelerator()
405+
dataloader = DataLoader(dataset, batch_size=None)
406+
dataloader = accelerator.prepare(dataloader)
407+
for d in dataloader:
408+
assert isinstance(d, torch.Tensor)
409+
387410
def test_skip_batch_sampler(self):
388411
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
389412
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)

0 commit comments

Comments
 (0)