Skip to content

Commit 399fc4c

Browse files
committed
Decouple memory budget (buffer_bytes) from tensor_generator
1 parent 5713585 commit 399fc4c

File tree

6 files changed

+291
-273
lines changed

6 files changed

+291
-273
lines changed

tests/readers/test_batch_utils.py

+1-119
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,4 @@
1-
import numpy as np
2-
import pytest
3-
4-
import tiledb
5-
from tiledb.ml.readers._batch_utils import (
6-
estimate_row_bytes,
7-
get_max_buffer_size,
8-
iter_batches,
9-
)
10-
11-
12-
@pytest.fixture
13-
def dense_uri(tmp_path):
14-
uri = str(tmp_path / "dense")
15-
schema = tiledb.ArraySchema(
16-
sparse=False,
17-
domain=tiledb.Domain(
18-
tiledb.Dim(name="d0", domain=(0, 9999), dtype=np.uint32, tile=123),
19-
tiledb.Dim(name="d1", domain=(1, 5), dtype=np.uint32, tile=2),
20-
tiledb.Dim(name="d2", domain=(1, 2), dtype=np.uint32, tile=1),
21-
),
22-
attrs=[
23-
tiledb.Attr(name="af8", dtype=np.float64),
24-
tiledb.Attr(name="af4", dtype=np.float32),
25-
tiledb.Attr(name="au1", dtype=np.uint8),
26-
],
27-
)
28-
tiledb.Array.create(uri, schema)
29-
with tiledb.open(uri, "w") as a:
30-
size = a.schema.domain.size
31-
a[:] = {
32-
"af8": np.random.rand(size),
33-
"af4": np.random.rand(size).astype(np.float32),
34-
"au1": np.random.randint(128, size=size, dtype=np.uint8),
35-
}
36-
return uri
37-
38-
39-
@pytest.fixture
40-
def sparse_uri(tmp_path):
41-
uri = str(tmp_path / "sparse")
42-
schema = tiledb.ArraySchema(
43-
sparse=True,
44-
allows_duplicates=True,
45-
domain=tiledb.Domain(
46-
tiledb.Dim(name="d0", domain=(0, 999), dtype=np.int32),
47-
tiledb.Dim(name="d1", domain=(-5000, 5000), dtype=np.int32),
48-
tiledb.Dim(name="d2", domain=(1, 10), dtype=np.int32),
49-
),
50-
attrs=[
51-
tiledb.Attr(name="af8", dtype=np.float64),
52-
tiledb.Attr(name="af4", dtype=np.float32),
53-
tiledb.Attr(name="au1", dtype=np.uint8),
54-
],
55-
)
56-
tiledb.Array.create(uri, schema)
57-
with tiledb.open(uri, "w") as a:
58-
num_rows = 1000
59-
cells_per_row = 3
60-
num_cells = num_rows * cells_per_row
61-
d0 = np.concatenate(
62-
[np.arange(num_rows, dtype=np.uint32) for _ in range(cells_per_row)]
63-
)
64-
d1 = np.random.randint(-5000, 5001, num_cells).astype(np.int32)
65-
d2 = np.random.randint(1, 11, num_cells).astype(np.uint16)
66-
a[d0, d1, d2] = {
67-
"af8": np.random.rand(num_cells),
68-
"af4": np.random.rand(num_cells).astype(np.float32),
69-
"au1": np.random.randint(128, size=num_cells, dtype=np.uint8),
70-
}
71-
return uri
72-
73-
74-
def test_estimate_row_bytes_dense(dense_uri):
75-
with tiledb.open(dense_uri) as a:
76-
# 10 cells/row, 8+4+1=13 bytes/cell
77-
assert estimate_row_bytes(a) == 130
78-
# 10 cells/row, 8+1=9 bytes/cell
79-
assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 90
80-
# 10 cells/row, 4 bytes/cell
81-
assert estimate_row_bytes(a, attrs=["af4"]) == 40
82-
83-
84-
def test_estimate_row_bytes_sparse(sparse_uri):
85-
with tiledb.open(sparse_uri) as a:
86-
# 3 cells/row, 3*4 bytes for dims + 8+4+1=13 bytes for attrs = 25 bytes/cell
87-
assert estimate_row_bytes(a) == 75
88-
# 3 cells/row, 3*4 bytes for dims + 8+1=9 bytes for attrs = 21 bytes/cell
89-
assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 63
90-
# 3 cells/row, 3*4 bytes for dims + 4 bytes for attrs = 16 bytes/cell
91-
assert estimate_row_bytes(a, attrs=["af4"]) == 48
1+
from tiledb.ml.readers._batch_utils import iter_batches
922

933

944
def test_iter_batches():
@@ -120,31 +30,3 @@ def test_iter_batches():
12030
"Batch(18, x[0:18], y[9:27], x_read[194:212])",
12131
"Batch(1, x[0:1], y[0:1], x_read[212:213], y_read[212:213])",
12232
]
123-
124-
125-
@pytest.mark.parametrize("memory_budget", [2**i for i in range(14, 20)])
126-
@pytest.mark.parametrize(
127-
"attrs",
128-
[(), ("af8",), ("af4",), ("au1",), ("af8", "af4"), ("af8", "au1"), ("af4", "au1")],
129-
)
130-
def test_get_max_buffer_size(dense_uri, memory_budget, attrs):
131-
config = {
132-
"sm.memory_budget": memory_budget,
133-
"py.max_incomplete_retries": 0,
134-
}
135-
with tiledb.scope_ctx(config), tiledb.open(dense_uri) as a:
136-
buffer_size = get_max_buffer_size(a.schema, attrs)
137-
# Check that the buffer size is a multiple of the row tile extent
138-
assert buffer_size % a.dim(0).tile == 0
139-
140-
# Check that we can slice with buffer_size without incomplete reads
141-
query = a.query(attrs=attrs or None)
142-
for offset in range(0, a.shape[0], buffer_size):
143-
query[offset : offset + buffer_size]
144-
145-
if buffer_size < a.shape[0]:
146-
# Check that buffer_size is the max size we can slice without incomplete reads
147-
buffer_size += 1
148-
with pytest.raises(tiledb.TileDBError):
149-
for offset in range(0, a.shape[0], buffer_size):
150-
query[offset : offset + buffer_size]

tests/readers/test_buffer_utils.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
import pytest
3+
4+
import tiledb
5+
from tiledb.ml.readers._buffer_utils import estimate_row_bytes, get_max_buffer_size
6+
7+
8+
@pytest.fixture
9+
def dense_uri(tmp_path):
10+
uri = str(tmp_path / "dense")
11+
schema = tiledb.ArraySchema(
12+
sparse=False,
13+
domain=tiledb.Domain(
14+
tiledb.Dim(name="d0", domain=(0, 9999), dtype=np.uint32, tile=123),
15+
tiledb.Dim(name="d1", domain=(1, 5), dtype=np.uint32, tile=2),
16+
tiledb.Dim(name="d2", domain=(1, 2), dtype=np.uint32, tile=1),
17+
),
18+
attrs=[
19+
tiledb.Attr(name="af8", dtype=np.float64),
20+
tiledb.Attr(name="af4", dtype=np.float32),
21+
tiledb.Attr(name="au1", dtype=np.uint8),
22+
],
23+
)
24+
tiledb.Array.create(uri, schema)
25+
with tiledb.open(uri, "w") as a:
26+
size = a.schema.domain.size
27+
a[:] = {
28+
"af8": np.random.rand(size),
29+
"af4": np.random.rand(size).astype(np.float32),
30+
"au1": np.random.randint(128, size=size, dtype=np.uint8),
31+
}
32+
return uri
33+
34+
35+
@pytest.fixture
36+
def sparse_uri(tmp_path):
37+
uri = str(tmp_path / "sparse")
38+
schema = tiledb.ArraySchema(
39+
sparse=True,
40+
allows_duplicates=True,
41+
domain=tiledb.Domain(
42+
tiledb.Dim(name="d0", domain=(0, 999), dtype=np.int32),
43+
tiledb.Dim(name="d1", domain=(-5000, 5000), dtype=np.int32),
44+
tiledb.Dim(name="d2", domain=(1, 10), dtype=np.int32),
45+
),
46+
attrs=[
47+
tiledb.Attr(name="af8", dtype=np.float64),
48+
tiledb.Attr(name="af4", dtype=np.float32),
49+
tiledb.Attr(name="au1", dtype=np.uint8),
50+
],
51+
)
52+
tiledb.Array.create(uri, schema)
53+
with tiledb.open(uri, "w") as a:
54+
num_rows = 1000
55+
cells_per_row = 3
56+
num_cells = num_rows * cells_per_row
57+
d0 = np.concatenate(
58+
[np.arange(num_rows, dtype=np.uint32) for _ in range(cells_per_row)]
59+
)
60+
d1 = np.random.randint(-5000, 5001, num_cells).astype(np.int32)
61+
d2 = np.random.randint(1, 11, num_cells).astype(np.uint16)
62+
a[d0, d1, d2] = {
63+
"af8": np.random.rand(num_cells),
64+
"af4": np.random.rand(num_cells).astype(np.float32),
65+
"au1": np.random.randint(128, size=num_cells, dtype=np.uint8),
66+
}
67+
return uri
68+
69+
70+
def test_estimate_row_bytes_dense(dense_uri):
71+
with tiledb.open(dense_uri) as a:
72+
# 10 cells/row, 8+4+1=13 bytes/cell
73+
assert estimate_row_bytes(a) == 130
74+
# 10 cells/row, 8+1=9 bytes/cell
75+
assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 90
76+
# 10 cells/row, 4 bytes/cell
77+
assert estimate_row_bytes(a, attrs=["af4"]) == 40
78+
79+
80+
def test_estimate_row_bytes_sparse(sparse_uri):
81+
with tiledb.open(sparse_uri) as a:
82+
# 3 cells/row, 3*4 bytes for dims + 8+4+1=13 bytes for attrs = 25 bytes/cell
83+
assert estimate_row_bytes(a) == 75
84+
# 3 cells/row, 3*4 bytes for dims + 8+1=9 bytes for attrs = 21 bytes/cell
85+
assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 63
86+
# 3 cells/row, 3*4 bytes for dims + 4 bytes for attrs = 16 bytes/cell
87+
assert estimate_row_bytes(a, attrs=["af4"]) == 48
88+
89+
90+
@pytest.mark.parametrize("memory_budget", [2**i for i in range(14, 20)])
91+
@pytest.mark.parametrize(
92+
"attrs",
93+
[(), ("af8",), ("af4",), ("au1",), ("af8", "af4"), ("af8", "au1"), ("af4", "au1")],
94+
)
95+
def test_get_max_buffer_size(dense_uri, memory_budget, attrs):
96+
config = {
97+
"sm.memory_budget": memory_budget,
98+
"py.max_incomplete_retries": 0,
99+
}
100+
with tiledb.scope_ctx(config), tiledb.open(dense_uri) as a:
101+
buffer_size = get_max_buffer_size(a.schema, attrs)
102+
# Check that the buffer size is a multiple of the row tile extent
103+
assert buffer_size % a.dim(0).tile == 0
104+
105+
# Check that we can slice with buffer_size without incomplete reads
106+
query = a.query(attrs=attrs or None)
107+
for offset in range(0, a.shape[0], buffer_size):
108+
query[offset : offset + buffer_size]
109+
110+
if buffer_size < a.shape[0]:
111+
# Check that buffer_size is the max size we can slice without incomplete reads
112+
buffer_size += 1
113+
with pytest.raises(tiledb.TileDBError):
114+
for offset in range(0, a.shape[0], buffer_size):
115+
query[offset : offset + buffer_size]

0 commit comments

Comments
 (0)