Skip to content

Commit 9ae80a4

Browse files
authored
Read from sparse TileDB arrays to ragged tensors (#171)
* Extract BaseSparseTensorSchema as base class for all sparse TileDB array TensorSchema classes * Add TensorKind.RAGGED and RaggedTensorSchema * Set default tensor_kind to RAGGED if the TileDB array is sparse and any non-key dimension is not integer * Enable TensorKind.RAGGED for TensorflowTileDBDataset and (if supported) for PyTorchTileDBDataLoader * Update tests for ragged tensors * Perform a stable sort along the key dimension
1 parent c35c4ee commit 9ae80a4

File tree

7 files changed

+207
-85
lines changed

7 files changed

+207
-85
lines changed

tests/readers/test_pytorch.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
from .utils import ingest_in_tiledb, parametrize_for_dataset, validate_tensor_generator
1010

11+
if hasattr(torch, "nested_tensor"):
12+
non_key_dim_dtype = (np.dtype(np.int32), np.dtype(np.float32))
13+
else:
14+
non_key_dim_dtype = (np.dtype(np.int32),)
15+
1116

1217
class TestPyTorchTileDBDataLoader:
13-
@parametrize_for_dataset()
18+
@parametrize_for_dataset(non_key_dim_dtype=non_key_dim_dtype)
1419
def test_dataloader(
1520
self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers
1621
):
@@ -33,6 +38,7 @@ def test_dataloader(
3338
)
3439

3540
@parametrize_for_dataset(
41+
non_key_dim_dtype=non_key_dim_dtype,
3642
# Add one extra key on X
3743
x_shape=((108, 10), (108, 10, 3)),
3844
y_shape=((107, 5), (107, 5, 2)),
@@ -52,7 +58,12 @@ def test_unequal_num_keys(
5258
)
5359
assert "All arrays must have the same key range" in str(ex.value)
5460

55-
@parametrize_for_dataset(num_fields=[0], shuffle_buffer_size=[0], num_workers=[0])
61+
@parametrize_for_dataset(
62+
non_key_dim_dtype=non_key_dim_dtype,
63+
num_fields=[0],
64+
shuffle_buffer_size=[0],
65+
num_workers=[0],
66+
)
5667
def test_dataloader_order(
5768
self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers
5869
):
@@ -72,21 +83,22 @@ def test_dataloader_order(
7283
)
7384
# since num_fields is 0, fields are all the array attributes of each array
7485
# the first item of each batch corresponds to the first attribute (="data")
75-
x_data_batches, y_data_batches = [], []
86+
x_batch_tensors, y_batch_tensors = [], []
7687
for x_tensors, y_tensors in dataloader:
77-
x_data_batch = x_tensors[0]
78-
if x_spec.sparse:
79-
x_data_batch = x_data_batch.to_dense()
80-
x_data_batches.append(x_data_batch)
88+
x_batch_tensors.append(x_tensors[0])
89+
y_batch_tensors.append(y_tensors[0])
90+
assert_tensors_almost_equal_array(x_batch_tensors, x_data)
91+
assert_tensors_almost_equal_array(y_batch_tensors, y_data)
8192

82-
y_data_batch = y_tensors[0]
83-
if y_spec.sparse:
84-
y_data_batch = y_data_batch.to_dense()
85-
y_data_batches.append(y_data_batch)
8693

87-
np.testing.assert_array_almost_equal(
88-
np.concatenate(x_data_batches), x_data
89-
)
90-
np.testing.assert_array_almost_equal(
91-
np.concatenate(y_data_batches), y_data
92-
)
94+
def assert_tensors_almost_equal_array(batch_tensors, array):
95+
if getattr(batch_tensors[0], "is_nested", False):
96+
# compare each ragged tensor row with the non-zero values of the respective array row
97+
tensors = [tensor for batch_tensor in batch_tensors for tensor in batch_tensor]
98+
assert len(tensors) == len(array)
99+
for tensor_row, row in zip(tensors, array):
100+
np.testing.assert_array_almost_equal(tensor_row, row[np.nonzero(row)])
101+
else:
102+
if batch_tensors[0].layout in (torch.sparse_coo, torch.sparse_csr):
103+
batch_tensors = [batch_tensor.to_dense() for batch_tensor in batch_tensors]
104+
np.testing.assert_array_almost_equal(np.concatenate(batch_tensors), array)

tests/readers/test_tensorflow.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,22 @@ def test_dataset_order(
7878
)
7979
# since num_fields is 0, fields are all the array attributes of each array
8080
# the first item of each batch corresponds to the first attribute (="data")
81-
x_data_batches, y_data_batches = [], []
81+
x_batch_tensors, y_batch_tensors = [], []
8282
for x_tensors, y_tensors in dataset:
83-
x_data_batch = x_tensors[0]
84-
if x_spec.sparse:
85-
x_data_batch = tf.sparse.to_dense(x_data_batch)
86-
x_data_batches.append(x_data_batch)
83+
x_batch_tensors.append(x_tensors[0])
84+
y_batch_tensors.append(y_tensors[0])
85+
assert_tensors_almost_equal_array(x_batch_tensors, x_data)
86+
assert_tensors_almost_equal_array(y_batch_tensors, y_data)
8787

88-
y_data_batch = y_tensors[0]
89-
if y_spec.sparse:
90-
y_data_batch = tf.sparse.to_dense(y_data_batch)
91-
y_data_batches.append(y_data_batch)
9288

93-
np.testing.assert_array_almost_equal(
94-
np.concatenate(x_data_batches), x_data
95-
)
96-
np.testing.assert_array_almost_equal(
97-
np.concatenate(y_data_batches), y_data
98-
)
89+
def assert_tensors_almost_equal_array(batch_tensors, array):
90+
if isinstance(batch_tensors[0], tf.RaggedTensor):
91+
# compare each ragged tensor row with the non-zero values of the respective array row
92+
tensors = [tensor for batch_tensor in batch_tensors for tensor in batch_tensor]
93+
assert len(tensors) == len(array)
94+
for tensor_row, row in zip(tensors, array):
95+
np.testing.assert_array_almost_equal(tensor_row, row[np.nonzero(row)])
96+
else:
97+
if isinstance(batch_tensors[0], tf.SparseTensor):
98+
batch_tensors = list(map(tf.sparse.to_dense, batch_tensors))
99+
np.testing.assert_array_almost_equal(np.concatenate(batch_tensors), array)

tests/readers/utils.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@ class ArraySpec:
2222
shape: Sequence[int]
2323
key_dim: int
2424
key_dim_dtype: np.dtype
25+
non_key_dim_dtype: np.dtype
2526
num_fields: int
2627

2728
def tensor_kind(self, supports_csr: bool) -> TensorKind:
2829
if not self.sparse:
2930
return TensorKind.DENSE
31+
elif not np.issubdtype(self.non_key_dim_dtype, np.integer):
32+
return TensorKind.RAGGED
3033
elif len(self.shape) == 2 and supports_csr:
3134
return TensorKind.SPARSE_CSR
3235
else:
@@ -42,6 +45,7 @@ def parametrize_for_dataset(
4245
x_key_dim=(0, 1),
4346
y_key_dim=(0, 1),
4447
key_dim_dtype=(np.dtype(np.int32), np.dtype("datetime64[D]"), np.dtype(np.bytes_)),
48+
non_key_dim_dtype=(np.dtype(np.int32), np.dtype(np.float32)),
4549
num_fields=(0, 1, 2),
4650
batch_size=(8,),
4751
shuffle_buffer_size=(16,),
@@ -57,6 +61,7 @@ def parametrize_for_dataset(
5761
x_key_dim_,
5862
y_key_dim_,
5963
key_dim_dtype_,
64+
non_key_dim_dtype_,
6065
num_fields_,
6166
batch_size_,
6267
shuffle_buffer_size_,
@@ -69,6 +74,7 @@ def parametrize_for_dataset(
6974
x_key_dim,
7075
y_key_dim,
7176
key_dim_dtype,
77+
non_key_dim_dtype,
7278
num_fields,
7379
batch_size,
7480
shuffle_buffer_size,
@@ -78,9 +84,12 @@ def parametrize_for_dataset(
7884
if not x_sparse_ or not y_sparse_:
7985
if not np.issubdtype(key_dim_dtype_, np.integer):
8086
continue
87+
if not np.issubdtype(non_key_dim_dtype_, np.integer):
88+
continue
8189

82-
x_spec = ArraySpec(x_sparse_, x_shape_, x_key_dim_, key_dim_dtype_, num_fields_)
83-
y_spec = ArraySpec(y_sparse_, y_shape_, y_key_dim_, key_dim_dtype_, num_fields_)
90+
common_args = (key_dim_dtype_, non_key_dim_dtype_, num_fields_)
91+
x_spec = ArraySpec(x_sparse_, x_shape_, x_key_dim_, *common_args)
92+
y_spec = ArraySpec(y_sparse_, y_shape_, y_key_dim_, *common_args)
8493
argvalues.append(
8594
(x_spec, y_spec, batch_size_, shuffle_buffer_size_, num_workers_)
8695
)
@@ -101,7 +110,7 @@ def ingest_in_tiledb(tmpdir, spec: ArraySpec):
101110
transforms = []
102111
for i in range(data.ndim):
103112
n = data.shape[i]
104-
dtype = spec.key_dim_dtype if i == spec.key_dim else np.dtype("int32")
113+
dtype = spec.key_dim_dtype if i == spec.key_dim else spec.non_key_dim_dtype
105114
if np.issubdtype(dtype, np.number):
106115
# set the domain to (-n/2, n/2) to test negative domain indexing
107116
min_value = -(n // 2)
@@ -216,16 +225,28 @@ def validate_tensor_generator(generator, x_spec, y_spec, batch_size, supports_cs
216225
def _validate_tensor(tensor, spec, batch_size, supports_csr):
217226
tensor_kind = _get_tensor_kind(tensor)
218227
assert tensor_kind is spec.tensor_kind(supports_csr)
219-
num_rows, *row_shape = tensor.shape
228+
229+
spec_row_shape = spec.shape[1:]
230+
if tensor_kind is not TensorKind.RAGGED:
231+
num_rows, *row_shape = tensor.shape
232+
assert tuple(row_shape) == spec_row_shape
233+
else:
234+
# every ragged array row has at most `np.prod(spec_row_shape)` elements,
235+
# the product of all non-key dimension sizes
236+
row_lengths = tuple(map(len, tensor))
237+
assert all(row_length <= np.prod(spec_row_shape) for row_length in row_lengths)
238+
num_rows = len(row_lengths)
239+
220240
# num_rows may be less than batch_size
221241
assert num_rows <= batch_size, (num_rows, batch_size)
222-
assert tuple(row_shape) == spec.shape[1:]
223242

224243

225244
def _get_tensor_kind(tensor) -> TensorKind:
226245
if isinstance(tensor, tf.Tensor):
227246
return TensorKind.DENSE
228247
if isinstance(tensor, torch.Tensor):
248+
if getattr(tensor, "is_nested", False):
249+
return TensorKind.RAGGED
229250
return _torch_tensor_layout_to_kind[tensor.layout]
230251
return _tensor_type_to_kind[type(tensor)]
231252

@@ -236,6 +257,7 @@ def _get_tensor_kind(tensor) -> TensorKind:
236257
scipy.sparse.coo_matrix: TensorKind.SPARSE_COO,
237258
scipy.sparse.csr_matrix: TensorKind.SPARSE_CSR,
238259
tf.SparseTensor: TensorKind.SPARSE_COO,
260+
tf.RaggedTensor: TensorKind.RAGGED,
239261
}
240262

241263
_torch_tensor_layout_to_kind = {

0 commit comments

Comments
 (0)