2
2
import torch .multiprocessing as multiprocessing
3
3
from .sampler import SequentialSampler , RandomSampler , BatchSampler
4
4
import collections
5
+ import re
5
6
import sys
6
7
import traceback
7
8
import threading
@@ -81,6 +82,9 @@ def _pin_memory_loop(in_queue, out_queue, done_event):
81
82
82
83
def default_collate (batch ):
83
84
"Puts each data field into a tensor with outer dimension batch size"
85
+
86
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
87
+ elem_type = type (batch [0 ])
84
88
if torch .is_tensor (batch [0 ]):
85
89
out = None
86
90
if _use_shared_memory :
@@ -90,9 +94,14 @@ def default_collate(batch):
90
94
storage = batch [0 ].storage ()._new_shared (numel )
91
95
out = batch [0 ].new (storage )
92
96
return torch .stack (batch , 0 , out = out )
93
- elif type (batch [0 ]).__module__ == 'numpy' :
97
+ elif elem_type .__module__ == 'numpy' and elem_type .__name__ != 'str_' \
98
+ and elem_type .__name__ != 'string_' :
94
99
elem = batch [0 ]
95
- if type (elem ).__name__ == 'ndarray' :
100
+ if elem_type .__name__ == 'ndarray' :
101
+ # array of string classes and object
102
+ if re .search ('[SaUO]' , elem .dtype .str ) is not None :
103
+ raise TypeError (error_msg .format (elem .dtype ))
104
+
96
105
return torch .stack ([torch .from_numpy (b ) for b in batch ], 0 )
97
106
if elem .shape == (): # scalars
98
107
py_type = float if elem .dtype .name .startswith ('float' ) else int
@@ -109,8 +118,7 @@ def default_collate(batch):
109
118
transposed = zip (* batch )
110
119
return [default_collate (samples ) for samples in transposed ]
111
120
112
- raise TypeError (("batch must contain tensors, numbers, dicts or lists; found {}"
113
- .format (type (batch [0 ]))))
121
+ raise TypeError ((error_msg .format (type (batch [0 ]))))
114
122
115
123
116
124
def pin_memory_batch (batch ):
0 commit comments