Skip to content

Commit c35c4ee

Browse files
authored
Avoid intermediate sparse.COO instances if possible (#172)
* SparseTensorSchema: yield lightweight SparseData instances instead of sparse.COO Convert to sparse.COO only for >2D sparse arrays, otherwise convert to scipy.sparse.{csr,coo}_matrix
1 parent f6e09b6 commit c35c4ee

File tree

3 files changed

+69
-32
lines changed

3 files changed

+69
-32
lines changed

tiledb/ml/readers/_tensor_schema.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import Counter
66
from dataclasses import dataclass
77
from math import ceil
8-
from operator import itemgetter, methodcaller
8+
from operator import itemgetter
99
from typing import (
1010
Any,
1111
Callable,
@@ -14,6 +14,7 @@
1414
Iterable,
1515
Sequence,
1616
Tuple,
17+
Type,
1718
TypeVar,
1819
Union,
1920
cast,
@@ -234,9 +235,22 @@ def max_partition_weight(self) -> int:
234235
return max(1, int(rows_per_slice * num_slices))
235236

236237

237-
class SparseTensorSchema(TensorSchema[sparse.COO]):
238+
@dataclass(frozen=True)
239+
class SparseData:
240+
coords: np.ndarray
241+
data: np.ndarray
242+
shape: Sequence[int]
243+
244+
def to_sparse_array(self) -> Union[scipy.sparse.csr_matrix, sparse.COO]:
245+
if len(self.shape) == 2:
246+
return scipy.sparse.csr_matrix((self.data, self.coords), self.shape)
247+
else:
248+
return sparse.COO(self.coords, self.data, self.shape)
249+
250+
251+
class SparseTensorSchema(TensorSchema[SparseData]):
238252
"""
239-
TensorSchema for reading sparse TileDB arrays as sparse.COO instances.
253+
TensorSchema for reading sparse TileDB arrays as SparseData instances.
240254
"""
241255

242256
def __init__(self, **kwargs: Any):
@@ -263,7 +277,7 @@ def key_range(self) -> InclusiveRange[Any, int]:
263277

264278
def iter_tensors(
265279
self, key_ranges: Iterable[InclusiveRange[Any, int]]
266-
) -> Union[Iterable[sparse.COO], Iterable[Sequence[sparse.COO]]]:
280+
) -> Union[Iterable[SparseData], Iterable[Sequence[SparseData]]]:
267281
shape = list(self.shape)
268282
query = self.query
269283
get_data = itemgetter(*self._fields)
@@ -285,12 +299,13 @@ def iter_tensors(
285299
field_arrays.pop(dim) - dim_start
286300
for dim, dim_start in zip(non_key_dims, non_key_dim_starts)
287301
)
302+
coords = np.array(coords)
288303

289-
# yield either a single tensor or a sequence of tensors, one for each field
304+
# yield either a single SparseData or one SparseData per field
290305
if single_field:
291-
yield sparse.COO(coords, data, shape)
306+
yield SparseData(coords, data, shape)
292307
else:
293-
yield tuple(sparse.COO(coords, d, shape) for d in data)
308+
yield tuple(SparseData(coords, d, shape) for d in data)
294309

295310
@property
296311
def max_partition_weight(self) -> int:
@@ -322,18 +337,10 @@ def max_partition_weight(self) -> int:
322337
return max(1, memory_budget // ceil(max(bytes_per_cell)))
323338

324339

325-
def SparseCSRTensorSchema(**kwargs: Any) -> TensorSchema[scipy.sparse.csr_matrix]:
326-
"""
327-
Return a TensorSchema for reading sparse 2D TileDB arrays as scipy.sparse.csr_matrix
328-
instances.
329-
"""
330-
return MappedTensorSchema(SparseTensorSchema(**kwargs), methodcaller("tocsr"))
331-
332-
333-
TensorSchemaFactories = {
340+
TensorSchemaFactories: Dict[TensorKind, Type[TensorSchema[Any]]] = {
334341
TensorKind.DENSE: DenseTensorSchema,
335342
TensorKind.SPARSE_COO: SparseTensorSchema,
336-
TensorKind.SPARSE_CSR: SparseCSRTensorSchema,
343+
TensorKind.SPARSE_CSR: SparseTensorSchema,
337344
}
338345

339346

tiledb/ml/readers/pytorch.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import itertools
44
import random
55
from dataclasses import dataclass
6+
from operator import methodcaller
67
from typing import (
78
Any,
89
Callable,
910
Dict,
1011
Iterable,
1112
Iterator,
13+
Mapping,
1214
Sequence,
1315
Tuple,
1416
TypeVar,
@@ -61,7 +63,7 @@ def PyTorchTileDBDataLoader(
6163
the following arguments: 'shuffle', 'sampler', 'batch_sampler', 'worker_init_fn' and 'collate_fn'.
6264
"""
6365
schemas = tuple(
64-
array_params.to_tensor_schema() for array_params in all_array_params
66+
array_params.to_tensor_schema(_transforms) for array_params in all_array_params
6567
)
6668
key_range = schemas[0].key_range
6769
if not all(key_range.equal_values(schema.key_range) for schema in schemas[1:]):
@@ -140,13 +142,20 @@ def _ndarray_collate(arrays: Sequence[np.ndarray]) -> torch.Tensor:
140142
return torch.from_numpy(np.stack(arrays))
141143

142144

143-
def _sparse_coo_collate(arrays: Sequence[sparse.COO]) -> torch.Tensor:
145+
def _coo_collate(arrays: Sequence[sparse.COO]) -> torch.Tensor:
144146
"""Collate multiple sparse.COO arrays to a torch.Tensor with sparse_coo layout."""
145147
stacked = sparse.stack(arrays)
146148
return torch.sparse_coo_tensor(stacked.coords, stacked.data, stacked.shape)
147149

148150

149-
def _sparse_csr_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tensor:
151+
def _csr_to_coo_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tensor:
152+
"""Collate multiple Scipy CSR matrices to a torch.Tensor with sparse_coo layout."""
153+
stacked = scipy.sparse.vstack(arrays).tocoo()
154+
coords = np.stack((stacked.row, stacked.col))
155+
return torch.sparse_coo_tensor(coords, stacked.data, stacked.shape)
156+
157+
158+
def _csr_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tensor:
150159
"""Collate multiple Scipy CSR matrices to a torch.Tensor with sparse_csr layout."""
151160
stacked = scipy.sparse.vstack(arrays)
152161
return torch.sparse_csr_tensor(
@@ -157,24 +166,37 @@ def _sparse_csr_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tens
157166
)
158167

159168

160-
_collators = {
161-
TensorKind.DENSE: _ndarray_collate,
162-
TensorKind.SPARSE_COO: _sparse_coo_collate,
163-
TensorKind.SPARSE_CSR: _sparse_csr_collate,
164-
}
165-
166-
167169
def _get_tensor_collator(
168170
schema: TensorSchema[Tensor],
169171
) -> Union[_SingleCollator, _CompositeCollator]:
170-
collator = _collators[schema.kind]
172+
if schema.kind is TensorKind.DENSE:
173+
collator = _ndarray_collate
174+
elif schema.kind is TensorKind.SPARSE_COO:
175+
if len(schema.shape) != 2:
176+
collator = _coo_collate
177+
else:
178+
collator = _csr_to_coo_collate
179+
elif schema.kind is TensorKind.SPARSE_CSR:
180+
if len(schema.shape) != 2:
181+
raise ValueError("SPARSE_CSR is supported only for 2D tensors")
182+
collator = _csr_collate
183+
else:
184+
assert False, schema.kind
185+
171186
num_fields = schema.num_fields
172187
if num_fields == 1:
173188
return collator
174189
else:
175190
return _CompositeCollator(*itertools.repeat(collator, num_fields))
176191

177192

193+
_transforms: Mapping[TensorKind, Union[Callable[[Any], Any], bool]] = {
194+
TensorKind.DENSE: True,
195+
TensorKind.SPARSE_COO: methodcaller("to_sparse_array"),
196+
TensorKind.SPARSE_CSR: methodcaller("to_sparse_array"),
197+
}
198+
199+
178200
_T = TypeVar("_T")
179201

180202

tiledb/ml/readers/tensorflow.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import tensorflow as tf
77

8-
from ._tensor_schema import TensorKind, TensorSchema
8+
from ._tensor_schema import SparseData, TensorKind, TensorSchema
99
from .types import ArrayParams
1010

1111
Tensor = Union[np.ndarray, tf.SparseTensor]
@@ -70,10 +70,18 @@ def _get_tensor_specs(
7070
return specs if len(specs) > 1 else specs[0]
7171

7272

73+
def _to_sparse_tensor(sd: SparseData) -> tf.SparseTensor:
74+
sa = sd.to_sparse_array()
75+
coords = getattr(sa, "coords", None)
76+
if coords is None:
77+
# sa is a scipy.sparse.csr_matrix
78+
coo = sa.tocoo()
79+
coords = np.array((coo.row, coo.col))
80+
return tf.SparseTensor(coords.T, sa.data, sa.shape)
81+
82+
7383
_transforms: Mapping[TensorKind, Union[Callable[[Any], Any], bool]] = {
7484
TensorKind.DENSE: True,
75-
TensorKind.SPARSE_COO: (
76-
lambda coo: tf.SparseTensor(coo.coords.T, coo.data, coo.shape)
77-
),
85+
TensorKind.SPARSE_COO: _to_sparse_tensor,
7886
TensorKind.SPARSE_CSR: False,
7987
}

0 commit comments

Comments
 (0)