Skip to content

Commit e579ae7

Browse files
zou3519soumith
authored andcommitted
Fix error when default_collate is passed a collection of numpy.str_ (pytorch#3404)
* Fix error when default_collate is passed a collection of numpy.str_ * Error if default_collate input is nested nparray containing non-numbers
1 parent be071d7 commit e579ae7

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

test/test_dataloader.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import traceback
55
import unittest
66
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
7+
from torch.utils.data.dataloader import default_collate
78
from common import TestCase, run_tests, TEST_NUMPY
89
from common_nn import TEST_CUDA
910

@@ -276,6 +277,23 @@ def __len__(self):
276277
batch = next(iter(loader))
277278
self.assertIsInstance(batch, tt)
278279

280+
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
281+
def test_default_colate_bad_numpy_types(self):
282+
import numpy as np
283+
284+
# Should be a no-op
285+
arr = np.array(['a', 'b', 'c'])
286+
default_collate(arr)
287+
288+
arr = np.array([[['a', 'b', 'c']]])
289+
self.assertRaises(TypeError, lambda: default_collate(arr))
290+
291+
arr = np.array([object(), object(), object()])
292+
self.assertRaises(TypeError, lambda: default_collate(arr))
293+
294+
arr = np.array([[[object(), object(), object()]]])
295+
self.assertRaises(TypeError, lambda: default_collate(arr))
296+
279297

280298
class StringDataset(Dataset):
281299
def __init__(self):

torch/utils/data/dataloader.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.multiprocessing as multiprocessing
33
from .sampler import SequentialSampler, RandomSampler, BatchSampler
44
import collections
5+
import re
56
import sys
67
import traceback
78
import threading
@@ -81,6 +82,9 @@ def _pin_memory_loop(in_queue, out_queue, done_event):
8182

8283
def default_collate(batch):
8384
"Puts each data field into a tensor with outer dimension batch size"
85+
86+
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
87+
elem_type = type(batch[0])
8488
if torch.is_tensor(batch[0]):
8589
out = None
8690
if _use_shared_memory:
@@ -90,9 +94,14 @@ def default_collate(batch):
9094
storage = batch[0].storage()._new_shared(numel)
9195
out = batch[0].new(storage)
9296
return torch.stack(batch, 0, out=out)
93-
elif type(batch[0]).__module__ == 'numpy':
97+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
98+
and elem_type.__name__ != 'string_':
9499
elem = batch[0]
95-
if type(elem).__name__ == 'ndarray':
100+
if elem_type.__name__ == 'ndarray':
101+
# array of string classes and object
102+
if re.search('[SaUO]', elem.dtype.str) is not None:
103+
raise TypeError(error_msg.format(elem.dtype))
104+
96105
return torch.stack([torch.from_numpy(b) for b in batch], 0)
97106
if elem.shape == (): # scalars
98107
py_type = float if elem.dtype.name.startswith('float') else int
@@ -109,8 +118,7 @@ def default_collate(batch):
109118
transposed = zip(*batch)
110119
return [default_collate(samples) for samples in transposed]
111120

112-
raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
113-
.format(type(batch[0]))))
121+
raise TypeError((error_msg.format(type(batch[0]))))
114122

115123

116124
def pin_memory_batch(batch):

0 commit comments

Comments
 (0)