3
3
import itertools
4
4
import random
5
5
from dataclasses import dataclass
6
+ from operator import methodcaller
6
7
from typing import (
7
8
Any ,
8
9
Callable ,
9
10
Dict ,
10
11
Iterable ,
11
12
Iterator ,
13
+ Mapping ,
12
14
Sequence ,
13
15
Tuple ,
14
16
TypeVar ,
@@ -61,7 +63,7 @@ def PyTorchTileDBDataLoader(
61
63
the following arguments: 'shuffle', 'sampler', 'batch_sampler', 'worker_init_fn' and 'collate_fn'.
62
64
"""
63
65
schemas = tuple (
64
- array_params .to_tensor_schema () for array_params in all_array_params
66
+ array_params .to_tensor_schema (_transforms ) for array_params in all_array_params
65
67
)
66
68
key_range = schemas [0 ].key_range
67
69
if not all (key_range .equal_values (schema .key_range ) for schema in schemas [1 :]):
@@ -140,13 +142,20 @@ def _ndarray_collate(arrays: Sequence[np.ndarray]) -> torch.Tensor:
140
142
return torch .from_numpy (np .stack (arrays ))
141
143
142
144
143
- def _sparse_coo_collate (arrays : Sequence [sparse .COO ]) -> torch .Tensor :
145
+ def _coo_collate (arrays : Sequence [sparse .COO ]) -> torch .Tensor :
144
146
"""Collate multiple sparse.COO arrays to a torch.Tensor with sparse_coo layout."""
145
147
stacked = sparse .stack (arrays )
146
148
return torch .sparse_coo_tensor (stacked .coords , stacked .data , stacked .shape )
147
149
148
150
149
- def _sparse_csr_collate (arrays : Sequence [scipy .sparse .csr_matrix ]) -> torch .Tensor :
151
+ def _csr_to_coo_collate (arrays : Sequence [scipy .sparse .csr_matrix ]) -> torch .Tensor :
152
+ """Collate multiple Scipy CSR matrices to a torch.Tensor with sparse_coo layout."""
153
+ stacked = scipy .sparse .vstack (arrays ).tocoo ()
154
+ coords = np .stack ((stacked .row , stacked .col ))
155
+ return torch .sparse_coo_tensor (coords , stacked .data , stacked .shape )
156
+
157
+
158
+ def _csr_collate (arrays : Sequence [scipy .sparse .csr_matrix ]) -> torch .Tensor :
150
159
"""Collate multiple Scipy CSR matrices to a torch.Tensor with sparse_csr layout."""
151
160
stacked = scipy .sparse .vstack (arrays )
152
161
return torch .sparse_csr_tensor (
@@ -157,24 +166,37 @@ def _sparse_csr_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tens
157
166
)
158
167
159
168
160
- _collators = {
161
- TensorKind .DENSE : _ndarray_collate ,
162
- TensorKind .SPARSE_COO : _sparse_coo_collate ,
163
- TensorKind .SPARSE_CSR : _sparse_csr_collate ,
164
- }
165
-
166
-
167
169
def _get_tensor_collator (
168
170
schema : TensorSchema [Tensor ],
169
171
) -> Union [_SingleCollator , _CompositeCollator ]:
170
- collator = _collators [schema .kind ]
172
+ if schema .kind is TensorKind .DENSE :
173
+ collator = _ndarray_collate
174
+ elif schema .kind is TensorKind .SPARSE_COO :
175
+ if len (schema .shape ) != 2 :
176
+ collator = _coo_collate
177
+ else :
178
+ collator = _csr_to_coo_collate
179
+ elif schema .kind is TensorKind .SPARSE_CSR :
180
+ if len (schema .shape ) != 2 :
181
+ raise ValueError ("SPARSE_CSR is supported only for 2D tensors" )
182
+ collator = _csr_collate
183
+ else :
184
+ assert False , schema .kind
185
+
171
186
num_fields = schema .num_fields
172
187
if num_fields == 1 :
173
188
return collator
174
189
else :
175
190
return _CompositeCollator (* itertools .repeat (collator , num_fields ))
176
191
177
192
193
+ _transforms : Mapping [TensorKind , Union [Callable [[Any ], Any ], bool ]] = {
194
+ TensorKind .DENSE : True ,
195
+ TensorKind .SPARSE_COO : methodcaller ("to_sparse_array" ),
196
+ TensorKind .SPARSE_CSR : methodcaller ("to_sparse_array" ),
197
+ }
198
+
199
+
178
200
_T = TypeVar ("_T" )
179
201
180
202
0 commit comments