Skip to content

Commit 2eaf47b

Browse files
authored
Merge pull request #108 from TileDB-Inc/gsa/sc-14423/refactor-optimize-tiledb-ml-tests
Refactor data loader tests
2 parents d589390 + 663cbe3 commit 2eaf47b

12 files changed

+599
-1485
lines changed

tests/readers/__init__.py

Whitespace-only changes.

tests/readers/test_pytorch.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Tests for TileDB integration with PyTorch Data API."""
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
7+
from tiledb.ml.readers.pytorch import PyTorchTileDBDataset
8+
9+
from .utils import (
10+
ingest_in_tiledb,
11+
parametrize_for_dataset,
12+
rand_array,
13+
validate_tensor_generator,
14+
)
15+
16+
17+
@pytest.mark.parametrize("num_rows", [107])
18+
class TestPyTorchTileDBDataset:
19+
@parametrize_for_dataset()
20+
@pytest.mark.parametrize("num_workers", [0, 2])
21+
def test_generator(
22+
self,
23+
tmpdir,
24+
num_rows,
25+
num_workers,
26+
x_sparse,
27+
y_sparse,
28+
x_shape,
29+
y_shape,
30+
num_attrs,
31+
pass_attrs,
32+
batch_size,
33+
buffer_size,
34+
batch_shuffle,
35+
within_batch_shuffle,
36+
):
37+
if num_workers and (x_sparse or y_sparse):
38+
pytest.skip("multiple workers not supported with sparse arrays")
39+
40+
with ingest_in_tiledb(
41+
tmpdir,
42+
x_data=rand_array(num_rows, *x_shape, sparse=x_sparse),
43+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
44+
x_sparse=x_sparse,
45+
y_sparse=y_sparse,
46+
batch_size=batch_size,
47+
num_attrs=num_attrs,
48+
pass_attrs=pass_attrs,
49+
buffer_size=buffer_size,
50+
batch_shuffle=batch_shuffle,
51+
within_batch_shuffle=within_batch_shuffle,
52+
) as dataset_kwargs:
53+
dataset = PyTorchTileDBDataset(**dataset_kwargs)
54+
assert isinstance(dataset, torch.utils.data.IterableDataset)
55+
validate_tensor_generator(
56+
dataset,
57+
x_sparse=x_sparse,
58+
y_sparse=y_sparse,
59+
x_shape=x_shape,
60+
y_shape=y_shape,
61+
batch_size=batch_size,
62+
num_attrs=num_attrs,
63+
)
64+
train_loader = torch.utils.data.DataLoader(
65+
dataset, batch_size=None, num_workers=num_workers
66+
)
67+
unique_x_tensors = []
68+
unique_y_tensors = []
69+
for batchindx, data in enumerate(train_loader):
70+
for attr in range(num_attrs):
71+
# Keep unique X tensors
72+
x_tensor = data[attr]
73+
if x_sparse:
74+
x_tensor = x_tensor.to_dense()
75+
if not any(torch.equal(x_tensor, t) for t in unique_x_tensors):
76+
unique_x_tensors.append(x_tensor)
77+
78+
# Keep unique Y tensors
79+
y_tensor = data[attr + num_attrs]
80+
if y_sparse:
81+
y_tensor = y_tensor.to_dense()
82+
if not any(torch.equal(y_tensor, t) for t in unique_y_tensors):
83+
unique_y_tensors.append(y_tensor)
84+
85+
assert len(unique_x_tensors) - 1 == batchindx
86+
assert len(unique_y_tensors) - 1 == batchindx
87+
88+
@parametrize_for_dataset(batch_size=[32], buffer_size=[31])
89+
def test_buffer_size_smaller_than_batch_size(
90+
self,
91+
tmpdir,
92+
num_rows,
93+
x_sparse,
94+
y_sparse,
95+
x_shape,
96+
y_shape,
97+
num_attrs,
98+
pass_attrs,
99+
batch_size,
100+
buffer_size,
101+
batch_shuffle,
102+
within_batch_shuffle,
103+
):
104+
with ingest_in_tiledb(
105+
tmpdir,
106+
x_data=rand_array(num_rows, *x_shape, sparse=x_sparse),
107+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
108+
x_sparse=x_sparse,
109+
y_sparse=y_sparse,
110+
batch_size=batch_size,
111+
num_attrs=num_attrs,
112+
pass_attrs=pass_attrs,
113+
buffer_size=buffer_size,
114+
batch_shuffle=batch_shuffle,
115+
within_batch_shuffle=within_batch_shuffle,
116+
) as dataset_kwargs:
117+
with pytest.raises(ValueError) as ex:
118+
PyTorchTileDBDataset(**dataset_kwargs)
119+
assert "buffer_size must be >= batch_size" in str(ex.value)
120+
121+
@parametrize_for_dataset()
122+
def test_unequal_num_rows(
123+
self,
124+
tmpdir,
125+
num_rows,
126+
x_sparse,
127+
y_sparse,
128+
x_shape,
129+
y_shape,
130+
num_attrs,
131+
pass_attrs,
132+
batch_size,
133+
buffer_size,
134+
batch_shuffle,
135+
within_batch_shuffle,
136+
):
137+
with ingest_in_tiledb(
138+
tmpdir,
139+
# Add one extra row on X
140+
x_data=rand_array(num_rows + 1, *x_shape, sparse=x_sparse),
141+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
142+
x_sparse=x_sparse,
143+
y_sparse=y_sparse,
144+
batch_size=batch_size,
145+
num_attrs=num_attrs,
146+
pass_attrs=pass_attrs,
147+
buffer_size=buffer_size,
148+
batch_shuffle=batch_shuffle,
149+
within_batch_shuffle=within_batch_shuffle,
150+
) as dataset_kwargs:
151+
with pytest.raises(ValueError) as ex:
152+
PyTorchTileDBDataset(**dataset_kwargs)
153+
assert "X and Y arrays must have the same number of rows" in str(ex.value)
154+
155+
@parametrize_for_dataset(x_sparse=[True])
156+
def test_x_sparse_unequal_num_rows_in_batch(
157+
self,
158+
tmpdir,
159+
num_rows,
160+
x_sparse,
161+
y_sparse,
162+
x_shape,
163+
y_shape,
164+
num_attrs,
165+
pass_attrs,
166+
batch_size,
167+
buffer_size,
168+
batch_shuffle,
169+
within_batch_shuffle,
170+
):
171+
x_data = rand_array(num_rows, *x_shape, sparse=x_sparse)
172+
x_data[np.random.randint(len(x_data))] = 0
173+
with ingest_in_tiledb(
174+
tmpdir,
175+
x_data=x_data,
176+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
177+
x_sparse=x_sparse,
178+
y_sparse=y_sparse,
179+
batch_size=batch_size,
180+
num_attrs=num_attrs,
181+
pass_attrs=pass_attrs,
182+
buffer_size=buffer_size,
183+
batch_shuffle=batch_shuffle,
184+
within_batch_shuffle=within_batch_shuffle,
185+
) as dataset_kwargs:
186+
dataset = PyTorchTileDBDataset(**dataset_kwargs)
187+
with pytest.raises(ValueError) as ex:
188+
for _ in dataset:
189+
pass
190+
assert "x and y batches should have the same length" in str(ex.value)

tests/readers/test_tensorflow.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""Tests for TileDB integration with Tensorflow Data API."""
2+
3+
import os
4+
5+
import numpy as np
6+
import pytest
7+
import tensorflow as tf
8+
9+
from tiledb.ml.readers._batch_utils import tensor_generator
10+
from tiledb.ml.readers.tensorflow import (
11+
TensorflowDenseBatch,
12+
TensorflowSparseBatch,
13+
TensorflowTileDBDataset,
14+
)
15+
16+
from .utils import (
17+
ingest_in_tiledb,
18+
parametrize_for_dataset,
19+
rand_array,
20+
validate_tensor_generator,
21+
)
22+
23+
# Suppress all Tensorflow messages
24+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
25+
26+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
27+
28+
29+
@pytest.mark.parametrize("num_rows", [107])
30+
class TestTensorflowTileDBDataset:
31+
@parametrize_for_dataset()
32+
def test_generator(
33+
self,
34+
tmpdir,
35+
num_rows,
36+
x_sparse,
37+
y_sparse,
38+
x_shape,
39+
y_shape,
40+
num_attrs,
41+
pass_attrs,
42+
batch_size,
43+
buffer_size,
44+
batch_shuffle,
45+
within_batch_shuffle,
46+
):
47+
with ingest_in_tiledb(
48+
tmpdir,
49+
x_data=rand_array(num_rows, *x_shape, sparse=x_sparse),
50+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
51+
x_sparse=x_sparse,
52+
y_sparse=y_sparse,
53+
batch_size=batch_size,
54+
num_attrs=num_attrs,
55+
pass_attrs=pass_attrs,
56+
buffer_size=buffer_size,
57+
batch_shuffle=batch_shuffle,
58+
within_batch_shuffle=within_batch_shuffle,
59+
) as dataset_kwargs:
60+
dataset = TensorflowTileDBDataset(**dataset_kwargs)
61+
assert isinstance(dataset, tf.data.Dataset)
62+
# Test the generator twice: once with the public api (TensorflowTileDBDataset)
63+
# and once with calling tensor_generator directly. Although the former calls
64+
# the latter internally, it is not reported as covered by the coverage report
65+
# due to https://github.com/tensorflow/tensorflow/issues/33759
66+
generators = [
67+
dataset,
68+
tensor_generator(
69+
dense_batch_cls=TensorflowDenseBatch,
70+
sparse_batch_cls=TensorflowSparseBatch,
71+
**dict(dataset_kwargs, buffer_size=buffer_size or batch_size),
72+
),
73+
]
74+
for generator in generators:
75+
validate_tensor_generator(
76+
generator,
77+
x_sparse=x_sparse,
78+
y_sparse=y_sparse,
79+
x_shape=x_shape,
80+
y_shape=y_shape,
81+
batch_size=batch_size,
82+
num_attrs=num_attrs,
83+
)
84+
85+
@parametrize_for_dataset(batch_size=[32], buffer_size=[31])
86+
def test_buffer_size_smaller_than_batch_size(
87+
self,
88+
tmpdir,
89+
num_rows,
90+
x_sparse,
91+
y_sparse,
92+
x_shape,
93+
y_shape,
94+
num_attrs,
95+
pass_attrs,
96+
batch_size,
97+
buffer_size,
98+
batch_shuffle,
99+
within_batch_shuffle,
100+
):
101+
with ingest_in_tiledb(
102+
tmpdir,
103+
x_data=rand_array(num_rows, *x_shape, sparse=x_sparse),
104+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
105+
x_sparse=x_sparse,
106+
y_sparse=y_sparse,
107+
batch_size=batch_size,
108+
num_attrs=num_attrs,
109+
pass_attrs=pass_attrs,
110+
buffer_size=buffer_size,
111+
batch_shuffle=batch_shuffle,
112+
within_batch_shuffle=within_batch_shuffle,
113+
) as dataset_kwargs:
114+
with pytest.raises(ValueError) as ex:
115+
TensorflowTileDBDataset(**dataset_kwargs)
116+
assert "buffer_size must be >= batch_size" in str(ex.value)
117+
118+
@parametrize_for_dataset()
119+
def test_unequal_num_rows(
120+
self,
121+
tmpdir,
122+
num_rows,
123+
x_sparse,
124+
y_sparse,
125+
x_shape,
126+
y_shape,
127+
num_attrs,
128+
pass_attrs,
129+
batch_size,
130+
buffer_size,
131+
batch_shuffle,
132+
within_batch_shuffle,
133+
):
134+
with ingest_in_tiledb(
135+
tmpdir,
136+
# Add one extra row on X
137+
x_data=rand_array(num_rows + 1, *x_shape, sparse=x_sparse),
138+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
139+
x_sparse=x_sparse,
140+
y_sparse=y_sparse,
141+
batch_size=batch_size,
142+
num_attrs=num_attrs,
143+
pass_attrs=pass_attrs,
144+
buffer_size=buffer_size,
145+
batch_shuffle=batch_shuffle,
146+
within_batch_shuffle=within_batch_shuffle,
147+
) as dataset_kwargs:
148+
with pytest.raises(ValueError) as ex:
149+
TensorflowTileDBDataset(**dataset_kwargs)
150+
assert "X and Y arrays must have the same number of rows" in str(ex.value)
151+
152+
@parametrize_for_dataset(x_sparse=[True])
153+
def test_x_sparse_unequal_num_rows_in_batch(
154+
self,
155+
tmpdir,
156+
num_rows,
157+
x_sparse,
158+
y_sparse,
159+
x_shape,
160+
y_shape,
161+
num_attrs,
162+
pass_attrs,
163+
batch_size,
164+
buffer_size,
165+
batch_shuffle,
166+
within_batch_shuffle,
167+
):
168+
x_data = rand_array(num_rows, *x_shape, sparse=x_sparse)
169+
x_data[np.random.randint(len(x_data))] = 0
170+
with ingest_in_tiledb(
171+
tmpdir,
172+
x_data=x_data,
173+
y_data=rand_array(num_rows, *y_shape, sparse=y_sparse),
174+
x_sparse=x_sparse,
175+
y_sparse=y_sparse,
176+
batch_size=batch_size,
177+
num_attrs=num_attrs,
178+
pass_attrs=pass_attrs,
179+
buffer_size=buffer_size,
180+
batch_shuffle=batch_shuffle,
181+
within_batch_shuffle=within_batch_shuffle,
182+
) as dataset_kwargs:
183+
dataset = TensorflowTileDBDataset(**dataset_kwargs)
184+
with pytest.raises(tf.errors.InvalidArgumentError) as ex:
185+
for _ in dataset:
186+
pass
187+
assert "x and y batches should have the same length" in str(ex.value)

0 commit comments

Comments
 (0)