Skip to content

Commit d76033d

Browse files
authored
Faster collate (batching) function for PyTorchTileDBDataLoader (#127)
1 parent 7e5be1b commit d76033d

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

tiledb/ml/readers/pytorch.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import itertools as it
44
import math
55
import random
6-
from typing import Iterable, Iterator, Optional, Sequence, TypeVar
6+
from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, TypeVar
77

88
import numpy as np
99
import torch
1010

1111
import tiledb
1212

13-
from ._batch_utils import SparseTileDBTensorGenerator, tensor_generator
13+
from ._batch_utils import SparseTileDBTensorGenerator, get_attr_names, tensor_generator
1414

1515

1616
def PyTorchTileDBDataLoader(
@@ -35,11 +35,25 @@ def PyTorchTileDBDataLoader(
3535
:param y_attrs: Attribute names of y_array.
3636
:param num_workers: how many subprocesses to use for data loading
3737
"""
38-
dataset = PyTorchTileDBDataset(
39-
x_array, y_array, buffer_bytes, shuffle_buffer_size, x_attrs, y_attrs
40-
)
38+
x_schema = x_array.schema
39+
y_schema = y_array.schema
4140
return torch.utils.data.DataLoader(
42-
dataset, batch_size=batch_size, num_workers=num_workers
41+
dataset=PyTorchTileDBDataset(
42+
x_array, y_array, buffer_bytes, shuffle_buffer_size, x_attrs, y_attrs
43+
),
44+
batch_size=batch_size,
45+
num_workers=num_workers,
46+
collate_fn=CompositeCollator(
47+
(
48+
np_arrays_collate
49+
if not is_sparse
50+
else torch.utils.data.dataloader.default_collate
51+
)
52+
for is_sparse in it.chain(
53+
it.repeat(x_schema.sparse, len(x_attrs or get_attr_names(x_schema))),
54+
it.repeat(y_schema.sparse, len(y_attrs or get_attr_names(y_schema))),
55+
)
56+
),
4357
)
4458

4559

@@ -93,6 +107,30 @@ def __iter__(self) -> Iterator[Sequence[torch.Tensor]]:
93107
return rows
94108

95109

110+
class CompositeCollator:
111+
"""
112+
A callable for collating "rows" of data into Tensors.
113+
114+
Each data "column" is collated to a torch.Tensor by a different collator function.
115+
Finally, the collated columns are returned as a sequence of torch.Tensors.
116+
"""
117+
118+
def __init__(self, collators: Iterable[Callable[[Sequence[Any]], torch.Tensor]]):
119+
self._collators = tuple(collators)
120+
121+
def __call__(self, rows: Sequence[Sequence[Any]]) -> Sequence[torch.Tensor]:
122+
columns = list(zip(*rows))
123+
assert len(columns) == len(self._collators)
124+
return [collator(column) for collator, column in zip(self._collators, columns)]
125+
126+
127+
def np_arrays_collate(arrays: Sequence[np.ndarray]) -> torch.Tensor:
128+
# Specialized version of default_collate for collating Numpy arrays
129+
# Faster than `torch.as_tensor(arrays)` (https://github.com/pytorch/pytorch/pull/51731)
130+
# and `torch.stack([torch.as_tensor(array) for array in arrays]])`
131+
return torch.as_tensor(np.stack(arrays))
132+
133+
96134
T = TypeVar("T")
97135

98136

0 commit comments

Comments
 (0)