3
3
import itertools as it
4
4
import math
5
5
import random
6
- from typing import Iterable , Iterator , Optional , Sequence , TypeVar
6
+ from typing import Any , Callable , Iterable , Iterator , Optional , Sequence , TypeVar
7
7
8
8
import numpy as np
9
9
import torch
10
10
11
11
import tiledb
12
12
13
- from ._batch_utils import SparseTileDBTensorGenerator , tensor_generator
13
+ from ._batch_utils import SparseTileDBTensorGenerator , get_attr_names , tensor_generator
14
14
15
15
16
16
def PyTorchTileDBDataLoader (
@@ -35,11 +35,25 @@ def PyTorchTileDBDataLoader(
35
35
:param y_attrs: Attribute names of y_array.
36
36
:param num_workers: how many subprocesses to use for data loading
37
37
"""
38
- dataset = PyTorchTileDBDataset (
39
- x_array , y_array , buffer_bytes , shuffle_buffer_size , x_attrs , y_attrs
40
- )
38
+ x_schema = x_array .schema
39
+ y_schema = y_array .schema
41
40
return torch .utils .data .DataLoader (
42
- dataset , batch_size = batch_size , num_workers = num_workers
41
+ dataset = PyTorchTileDBDataset (
42
+ x_array , y_array , buffer_bytes , shuffle_buffer_size , x_attrs , y_attrs
43
+ ),
44
+ batch_size = batch_size ,
45
+ num_workers = num_workers ,
46
+ collate_fn = CompositeCollator (
47
+ (
48
+ np_arrays_collate
49
+ if not is_sparse
50
+ else torch .utils .data .dataloader .default_collate
51
+ )
52
+ for is_sparse in it .chain (
53
+ it .repeat (x_schema .sparse , len (x_attrs or get_attr_names (x_schema ))),
54
+ it .repeat (y_schema .sparse , len (y_attrs or get_attr_names (y_schema ))),
55
+ )
56
+ ),
43
57
)
44
58
45
59
@@ -93,6 +107,30 @@ def __iter__(self) -> Iterator[Sequence[torch.Tensor]]:
93
107
return rows
94
108
95
109
110
+ class CompositeCollator :
111
+ """
112
+ A callable for collating "rows" of data into Tensors.
113
+
114
+ Each data "column" is collated to a torch.Tensor by a different collator function.
115
+ Finally, the collated columns are returned as a sequence of torch.Tensors.
116
+ """
117
+
118
+ def __init__ (self , collators : Iterable [Callable [[Sequence [Any ]], torch .Tensor ]]):
119
+ self ._collators = tuple (collators )
120
+
121
+ def __call__ (self , rows : Sequence [Sequence [Any ]]) -> Sequence [torch .Tensor ]:
122
+ columns = list (zip (* rows ))
123
+ assert len (columns ) == len (self ._collators )
124
+ return [collator (column ) for collator , column in zip (self ._collators , columns )]
125
+
126
+
127
+ def np_arrays_collate (arrays : Sequence [np .ndarray ]) -> torch .Tensor :
128
+ # Specialized version of default_collate for collating Numpy arrays
129
+ # Faster than `torch.as_tensor(arrays)` (https://github.com/pytorch/pytorch/pull/51731)
130
+ # and `torch.stack([torch.as_tensor(array) for array in arrays]])`
131
+ return torch .as_tensor (np .stack (arrays ))
132
+
133
+
96
134
T = TypeVar ("T" )
97
135
98
136
0 commit comments