|
1 | | -import numpy as np |
2 | | -import pytest |
3 | | - |
4 | | -import tiledb |
5 | | -from tiledb.ml.readers._batch_utils import ( |
6 | | - estimate_row_bytes, |
7 | | - get_max_buffer_size, |
8 | | - iter_batches, |
9 | | -) |
10 | | - |
11 | | - |
12 | | -@pytest.fixture |
13 | | -def dense_uri(tmp_path): |
14 | | - uri = str(tmp_path / "dense") |
15 | | - schema = tiledb.ArraySchema( |
16 | | - sparse=False, |
17 | | - domain=tiledb.Domain( |
18 | | - tiledb.Dim(name="d0", domain=(0, 9999), dtype=np.uint32, tile=123), |
19 | | - tiledb.Dim(name="d1", domain=(1, 5), dtype=np.uint32, tile=2), |
20 | | - tiledb.Dim(name="d2", domain=(1, 2), dtype=np.uint32, tile=1), |
21 | | - ), |
22 | | - attrs=[ |
23 | | - tiledb.Attr(name="af8", dtype=np.float64), |
24 | | - tiledb.Attr(name="af4", dtype=np.float32), |
25 | | - tiledb.Attr(name="au1", dtype=np.uint8), |
26 | | - ], |
27 | | - ) |
28 | | - tiledb.Array.create(uri, schema) |
29 | | - with tiledb.open(uri, "w") as a: |
30 | | - size = a.schema.domain.size |
31 | | - a[:] = { |
32 | | - "af8": np.random.rand(size), |
33 | | - "af4": np.random.rand(size).astype(np.float32), |
34 | | - "au1": np.random.randint(128, size=size, dtype=np.uint8), |
35 | | - } |
36 | | - return uri |
37 | | - |
38 | | - |
39 | | -@pytest.fixture |
40 | | -def sparse_uri(tmp_path): |
41 | | - uri = str(tmp_path / "sparse") |
42 | | - schema = tiledb.ArraySchema( |
43 | | - sparse=True, |
44 | | - allows_duplicates=True, |
45 | | - domain=tiledb.Domain( |
46 | | - tiledb.Dim(name="d0", domain=(0, 999), dtype=np.int32), |
47 | | - tiledb.Dim(name="d1", domain=(-5000, 5000), dtype=np.int32), |
48 | | - tiledb.Dim(name="d2", domain=(1, 10), dtype=np.int32), |
49 | | - ), |
50 | | - attrs=[ |
51 | | - tiledb.Attr(name="af8", dtype=np.float64), |
52 | | - tiledb.Attr(name="af4", dtype=np.float32), |
53 | | - tiledb.Attr(name="au1", dtype=np.uint8), |
54 | | - ], |
55 | | - ) |
56 | | - tiledb.Array.create(uri, schema) |
57 | | - with tiledb.open(uri, "w") as a: |
58 | | - num_rows = 1000 |
59 | | - cells_per_row = 3 |
60 | | - num_cells = num_rows * cells_per_row |
61 | | - d0 = np.concatenate( |
62 | | - [np.arange(num_rows, dtype=np.uint32) for _ in range(cells_per_row)] |
63 | | - ) |
64 | | - d1 = np.random.randint(-5000, 5001, num_cells).astype(np.int32) |
65 | | - d2 = np.random.randint(1, 11, num_cells).astype(np.uint16) |
66 | | - a[d0, d1, d2] = { |
67 | | - "af8": np.random.rand(num_cells), |
68 | | - "af4": np.random.rand(num_cells).astype(np.float32), |
69 | | - "au1": np.random.randint(128, size=num_cells, dtype=np.uint8), |
70 | | - } |
71 | | - return uri |
72 | | - |
73 | | - |
74 | | -def test_estimate_row_bytes_dense(dense_uri): |
75 | | - with tiledb.open(dense_uri) as a: |
76 | | - # 10 cells/row, 8+4+1=13 bytes/cell |
77 | | - assert estimate_row_bytes(a) == 130 |
78 | | - # 10 cells/row, 8+1=9 bytes/cell |
79 | | - assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 90 |
80 | | - # 10 cells/row, 4 bytes/cell |
81 | | - assert estimate_row_bytes(a, attrs=["af4"]) == 40 |
82 | | - |
83 | | - |
84 | | -def test_estimate_row_bytes_sparse(sparse_uri): |
85 | | - with tiledb.open(sparse_uri) as a: |
86 | | - # 3 cells/row, 3*4 bytes for dims + 8+4+1=13 bytes for attrs = 25 bytes/cell |
87 | | - assert estimate_row_bytes(a) == 75 |
88 | | - # 3 cells/row, 3*4 bytes for dims + 8+1=9 bytes for attrs = 21 bytes/cell |
89 | | - assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 63 |
90 | | - # 3 cells/row, 3*4 bytes for dims + 4 bytes for attrs = 16 bytes/cell |
91 | | - assert estimate_row_bytes(a, attrs=["af4"]) == 48 |
| 1 | +from tiledb.ml.readers._batch_utils import iter_batches |
92 | 2 |
|
93 | 3 |
|
94 | 4 | def test_iter_batches(): |
@@ -120,31 +30,3 @@ def test_iter_batches(): |
120 | 30 | "Batch(18, x[0:18], y[9:27], x_read[194:212])", |
121 | 31 | "Batch(1, x[0:1], y[0:1], x_read[212:213], y_read[212:213])", |
122 | 32 | ] |
123 | | - |
124 | | - |
125 | | -@pytest.mark.parametrize("memory_budget", [2**i for i in range(14, 20)]) |
126 | | -@pytest.mark.parametrize( |
127 | | - "attrs", |
128 | | - [(), ("af8",), ("af4",), ("au1",), ("af8", "af4"), ("af8", "au1"), ("af4", "au1")], |
129 | | -) |
130 | | -def test_get_max_buffer_size(dense_uri, memory_budget, attrs): |
131 | | - config = { |
132 | | - "sm.memory_budget": memory_budget, |
133 | | - "py.max_incomplete_retries": 0, |
134 | | - } |
135 | | - with tiledb.scope_ctx(config), tiledb.open(dense_uri) as a: |
136 | | - buffer_size = get_max_buffer_size(a.schema, attrs) |
137 | | - # Check that the buffer size is a multiple of the row tile extent |
138 | | - assert buffer_size % a.dim(0).tile == 0 |
139 | | - |
140 | | - # Check that we can slice with buffer_size without incomplete reads |
141 | | - query = a.query(attrs=attrs or None) |
142 | | - for offset in range(0, a.shape[0], buffer_size): |
143 | | - query[offset : offset + buffer_size] |
144 | | - |
145 | | - if buffer_size < a.shape[0]: |
146 | | - # Check that buffer_size is the max size we can slice without incomplete reads |
147 | | - buffer_size += 1 |
148 | | - with pytest.raises(tiledb.TileDBError): |
149 | | - for offset in range(0, a.shape[0], buffer_size): |
150 | | - query[offset : offset + buffer_size] |
0 commit comments