Skip to content

Commit 9e391d7

Browse files
authored
Support nD-dim sparse arrays (#130)
* Replace scipy.sparse with pydata/sparse * Enable tests >2D sparse arrays * Merge base abstract TileDBTensorGenerator with TileDBNumpyGenerator
1 parent e38a75b commit 9e391d7

File tree

5 files changed

+41
-101
lines changed

5 files changed

+41
-101
lines changed

Diff for: setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ zip_safe = False
3434
packages = find_namespace:
3535
python_requires = >=3.7
3636
test_suite = tests
37-
install_requires = tiledb == 0.12.3
37+
install_requires = sparse; tiledb == 0.12.3
3838
setup_requires = setuptools_scm <= 6.0.0
3939

4040
[options.extras_require]

Diff for: tests/readers/utils.py

+10-20
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ def parametrize_for_dataset(
2222
buffer_bytes=(1024, None),
2323
shuffle_buffer_size=(0, 16),
2424
):
25-
def is_valid_combination(t):
26-
x_sparse_, y_sparse_, x_shape_, y_shape_, *_ = t
27-
# sparse not supported with multi-dimensional arrays
28-
if x_sparse_ and len(x_shape_) > 1 or y_sparse_ and len(y_shape_) > 1:
29-
return False
30-
return True
31-
3225
argnames = [
3326
"x_sparse",
3427
"y_sparse",
@@ -40,19 +33,16 @@ def is_valid_combination(t):
4033
"batch_size",
4134
"shuffle_buffer_size",
4235
]
43-
argvalues = filter(
44-
is_valid_combination,
45-
it.product(
46-
x_sparse,
47-
y_sparse,
48-
x_shape,
49-
y_shape,
50-
num_attrs,
51-
pass_attrs,
52-
buffer_bytes,
53-
batch_size,
54-
shuffle_buffer_size,
55-
),
36+
argvalues = it.product(
37+
x_sparse,
38+
y_sparse,
39+
x_shape,
40+
y_shape,
41+
num_attrs,
42+
pass_attrs,
43+
buffer_bytes,
44+
batch_size,
45+
shuffle_buffer_size,
5646
)
5747
return pytest.mark.parametrize(argnames, argvalues)
5848

Diff for: tiledb/ml/readers/_tensor_gen.py

+24-59
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Generic, Iterator, Sequence, Type, TypeVar, Union
33

44
import numpy as np
5-
import scipy.sparse as sp
5+
import sparse
66

77
import tiledb
88

@@ -11,7 +11,7 @@
1111
Tensor = TypeVar("Tensor")
1212

1313

14-
class TileDBTensorGenerator(ABC, Generic[Tensor]):
14+
class TileDBNumpyGenerator:
1515
"""Base class for generating tensors read from a TileDB array."""
1616

1717
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
@@ -21,83 +21,50 @@ def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
2121
"""
2222
self._query = array.query(attrs=attrs)
2323

24-
@abstractmethod
2524
def read_buffer(self, array_slice: slice) -> None:
2625
"""
2726
Read an array slice and save it as the current buffer.
2827
2928
:param array_slice: Requested array slice.
3029
"""
30+
self._buf_arrays = tuple(self._query[array_slice].values())
3131

32-
@abstractmethod
33-
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
32+
def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
3433
"""
3534
Return an iterator of tensors for the given slice, one tensor per attribute
3635
3736
Must be called after `read_buffer`.
3837
3938
:param buffer_slice: Slice of the current buffer to convert to tensors.
4039
"""
40+
return (buf_array[buffer_slice] for buf_array in self._buf_arrays)
4141

4242

43-
class TileDBNumpyGenerator(TileDBTensorGenerator[np.ndarray]):
44-
def read_buffer(self, array_slice: slice) -> None:
45-
self._buf_arrays = tuple(self._query[array_slice].values())
46-
47-
def iter_tensors(self, buffer_slice: slice) -> Iterator[np.ndarray]:
48-
for buf_array in self._buf_arrays:
49-
yield buf_array[buffer_slice]
50-
51-
52-
class TileDBSparseTensorGenerator(TileDBTensorGenerator[Tensor]):
43+
class TileDBSparseTensorGenerator(TileDBNumpyGenerator, ABC, Generic[Tensor]):
5344
def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
54-
schema = array.schema
55-
if schema.ndim != 2:
56-
raise NotImplementedError("Only 2D sparse tensors are currently supported")
57-
self._row_dim = schema.domain.dim(0).name
58-
self._col_dim = schema.domain.dim(1).name
59-
self._row_shape = schema.shape[1:]
60-
self._attr_dtypes = tuple(schema.attr(attr).dtype for attr in attrs)
45+
self._dims = tuple(array.domain.dim(i).name for i in range(array.ndim))
46+
self._row_shape = array.shape[1:]
6147
super().__init__(array, attrs)
6248

6349
def read_buffer(self, array_slice: slice) -> None:
6450
buffer = self._query[array_slice]
65-
# COO to CSR transformation for batching and row slicing
66-
row = buffer.pop(self._row_dim)
67-
col = buffer.pop(self._col_dim)
68-
# Normalize indices: We want the coords indices to be in the [0, array_slice size]
69-
# range. If we do not normalize the sparse tensor is being created but with a
70-
# dimension [0, max(coord_index)], which is overkill
51+
coords = [buffer.pop(dim) for dim in self._dims]
52+
# normalize the first coordinate dimension to start at start_offset
7153
start_offset = array_slice.start
72-
stop_offset = array_slice.stop
73-
shape = (stop_offset - start_offset, *self._row_shape)
74-
self._buf_csrs = tuple(
75-
sp.csr_matrix((data, (row - start_offset, col)), shape=shape)
76-
for data in buffer.values()
54+
if start_offset:
55+
coords[0] -= start_offset
56+
shape = (array_slice.stop - start_offset, *self._row_shape)
57+
self._buf_arrays = tuple(
58+
sparse.COO(coords, data, shape) for data in buffer.values()
7759
)
7860

7961
def iter_tensors(self, buffer_slice: slice) -> Iterator[Tensor]:
80-
for buf_csr, dtype in zip(self._buf_csrs, self._attr_dtypes):
81-
batch_csr = buf_csr[buffer_slice]
82-
batch_coo = batch_csr.tocoo()
83-
data = batch_coo.data
84-
coords = np.stack((batch_coo.row, batch_coo.col), axis=-1)
85-
dense_shape = (batch_csr.shape[0], *self._row_shape)
86-
yield self._tensor_from_coo(data, coords, dense_shape, dtype)
62+
return map(self._tensor_from_coo, super().iter_tensors(buffer_slice))
8763

8864
@staticmethod
8965
@abstractmethod
90-
def _tensor_from_coo(
91-
data: np.ndarray,
92-
coords: np.ndarray,
93-
dense_shape: Sequence[int],
94-
dtype: np.dtype,
95-
) -> Tensor:
96-
"""Convert a scipy.sparse.coo_matrix to a Tensor"""
97-
98-
99-
DT = TypeVar("DT")
100-
ST = TypeVar("ST")
66+
def _tensor_from_coo(coo: sparse.COO) -> Tensor:
67+
"""Convert a sparse.COO to a Tensor"""
10168

10269

10370
def tensor_generator(
@@ -107,11 +74,10 @@ def tensor_generator(
10774
y_buffer_size: int,
10875
x_attrs: Sequence[str],
10976
y_attrs: Sequence[str],
77+
sparse_generator_cls: Type[TileDBSparseTensorGenerator[Tensor]],
11078
start_offset: int = 0,
11179
stop_offset: int = 0,
112-
dense_generator_cls: Type[TileDBTensorGenerator[DT]] = TileDBNumpyGenerator,
113-
sparse_generator_cls: Type[TileDBTensorGenerator[ST]] = TileDBSparseTensorGenerator,
114-
) -> Iterator[Sequence[Union[DT, ST]]]:
80+
) -> Iterator[Sequence[Union[np.ndarray, Tensor]]]:
11581
"""
11682
Generator for batches of tensors.
11783
@@ -126,18 +92,17 @@ def tensor_generator(
12692
:param y_attrs: Attribute names of y_array.
12793
:param start_offset: Start row offset; defaults to 0.
12894
:param stop_offset: Stop row offset; defaults to number of rows.
129-
:param dense_generator_cls: Dense tensor generator type.
13095
:param sparse_generator_cls: Sparse tensor generator type.
13196
"""
132-
x_gen: Union[TileDBTensorGenerator[DT], TileDBTensorGenerator[ST]] = (
97+
x_gen: Union[TileDBNumpyGenerator, TileDBSparseTensorGenerator[Tensor]] = (
13398
sparse_generator_cls(x_array, x_attrs)
13499
if x_array.schema.sparse
135-
else dense_generator_cls(x_array, x_attrs)
100+
else TileDBNumpyGenerator(x_array, x_attrs)
136101
)
137-
y_gen: Union[TileDBTensorGenerator[DT], TileDBTensorGenerator[ST]] = (
102+
y_gen: Union[TileDBNumpyGenerator, TileDBSparseTensorGenerator[Tensor]] = (
138103
sparse_generator_cls(y_array, y_attrs)
139104
if y_array.schema.sparse
140-
else dense_generator_cls(y_array, y_attrs)
105+
else TileDBNumpyGenerator(y_array, y_attrs)
141106
)
142107
if not stop_offset:
143108
stop_offset = x_array.shape[0]

Diff for: tiledb/ml/readers/pytorch.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, TypeVar
77

88
import numpy as np
9+
import sparse
910
import torch
1011

1112
import tiledb
@@ -168,12 +169,5 @@ def iter_shuffled(iterable: Iterable[T], buffer_size: int) -> Iterator[T]:
168169

169170
class PyTorchSparseTensorGenerator(TileDBSparseTensorGenerator[torch.Tensor]):
170171
@staticmethod
171-
def _tensor_from_coo(
172-
data: np.ndarray,
173-
coords: np.ndarray,
174-
dense_shape: Sequence[int],
175-
dtype: np.dtype,
176-
) -> torch.Tensor:
177-
return torch.sparse_coo_tensor(
178-
torch.tensor(coords).t(), data, dense_shape, requires_grad=False
179-
)
172+
def _tensor_from_coo(coo: sparse.COO) -> torch.Tensor:
173+
return torch.sparse_coo_tensor(coo.coords, coo.data, coo.shape)

Diff for: tiledb/ml/readers/tensorflow.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44
from typing import Iterator, Optional, Sequence, Union
55

6-
import numpy as np
6+
import sparse
77
import tensorflow as tf
88

99
import tiledb
@@ -79,14 +79,5 @@ def _iter_tensor_specs(
7979

8080
class TensorflowSparseTensorGenerator(TileDBSparseTensorGenerator[tf.SparseTensor]):
8181
@staticmethod
82-
def _tensor_from_coo(
83-
data: np.ndarray,
84-
coords: np.ndarray,
85-
dense_shape: Sequence[int],
86-
dtype: np.dtype,
87-
) -> tf.SparseTensor:
88-
return tf.SparseTensor(
89-
indices=tf.constant(coords, dtype=tf.int64),
90-
values=tf.constant(data, dtype=dtype),
91-
dense_shape=dense_shape,
92-
)
82+
def _tensor_from_coo(coo: sparse.COO) -> tf.SparseTensor:
83+
return tf.SparseTensor(coo.coords.T, coo.data, coo.shape)

0 commit comments

Comments
 (0)