|
1 |
| -import numpy as np |
2 |
| -import pytest |
3 |
| - |
4 |
| -import tiledb |
5 |
| -from tiledb.ml.readers._batch_utils import ( |
6 |
| - estimate_row_bytes, |
7 |
| - get_max_buffer_size, |
8 |
| - iter_batches, |
9 |
| -) |
10 |
| - |
11 |
| - |
12 |
| -@pytest.fixture |
13 |
| -def dense_uri(tmp_path): |
14 |
| - uri = str(tmp_path / "dense") |
15 |
| - schema = tiledb.ArraySchema( |
16 |
| - sparse=False, |
17 |
| - domain=tiledb.Domain( |
18 |
| - tiledb.Dim(name="d0", domain=(0, 9999), dtype=np.uint32, tile=123), |
19 |
| - tiledb.Dim(name="d1", domain=(1, 5), dtype=np.uint32, tile=2), |
20 |
| - tiledb.Dim(name="d2", domain=(1, 2), dtype=np.uint32, tile=1), |
21 |
| - ), |
22 |
| - attrs=[ |
23 |
| - tiledb.Attr(name="af8", dtype=np.float64), |
24 |
| - tiledb.Attr(name="af4", dtype=np.float32), |
25 |
| - tiledb.Attr(name="au1", dtype=np.uint8), |
26 |
| - ], |
27 |
| - ) |
28 |
| - tiledb.Array.create(uri, schema) |
29 |
| - with tiledb.open(uri, "w") as a: |
30 |
| - size = a.schema.domain.size |
31 |
| - a[:] = { |
32 |
| - "af8": np.random.rand(size), |
33 |
| - "af4": np.random.rand(size).astype(np.float32), |
34 |
| - "au1": np.random.randint(128, size=size, dtype=np.uint8), |
35 |
| - } |
36 |
| - return uri |
37 |
| - |
38 |
| - |
39 |
| -@pytest.fixture |
40 |
| -def sparse_uri(tmp_path): |
41 |
| - uri = str(tmp_path / "sparse") |
42 |
| - schema = tiledb.ArraySchema( |
43 |
| - sparse=True, |
44 |
| - allows_duplicates=True, |
45 |
| - domain=tiledb.Domain( |
46 |
| - tiledb.Dim(name="d0", domain=(0, 999), dtype=np.int32), |
47 |
| - tiledb.Dim(name="d1", domain=(-5000, 5000), dtype=np.int32), |
48 |
| - tiledb.Dim(name="d2", domain=(1, 10), dtype=np.int32), |
49 |
| - ), |
50 |
| - attrs=[ |
51 |
| - tiledb.Attr(name="af8", dtype=np.float64), |
52 |
| - tiledb.Attr(name="af4", dtype=np.float32), |
53 |
| - tiledb.Attr(name="au1", dtype=np.uint8), |
54 |
| - ], |
55 |
| - ) |
56 |
| - tiledb.Array.create(uri, schema) |
57 |
| - with tiledb.open(uri, "w") as a: |
58 |
| - num_rows = 1000 |
59 |
| - cells_per_row = 3 |
60 |
| - num_cells = num_rows * cells_per_row |
61 |
| - d0 = np.concatenate( |
62 |
| - [np.arange(num_rows, dtype=np.uint32) for _ in range(cells_per_row)] |
63 |
| - ) |
64 |
| - d1 = np.random.randint(-5000, 5001, num_cells).astype(np.int32) |
65 |
| - d2 = np.random.randint(1, 11, num_cells).astype(np.uint16) |
66 |
| - a[d0, d1, d2] = { |
67 |
| - "af8": np.random.rand(num_cells), |
68 |
| - "af4": np.random.rand(num_cells).astype(np.float32), |
69 |
| - "au1": np.random.randint(128, size=num_cells, dtype=np.uint8), |
70 |
| - } |
71 |
| - return uri |
72 |
| - |
73 |
| - |
74 |
| -def test_estimate_row_bytes_dense(dense_uri): |
75 |
| - with tiledb.open(dense_uri) as a: |
76 |
| - # 10 cells/row, 8+4+1=13 bytes/cell |
77 |
| - assert estimate_row_bytes(a) == 130 |
78 |
| - # 10 cells/row, 8+1=9 bytes/cell |
79 |
| - assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 90 |
80 |
| - # 10 cells/row, 4 bytes/cell |
81 |
| - assert estimate_row_bytes(a, attrs=["af4"]) == 40 |
82 |
| - |
83 |
| - |
84 |
| -def test_estimate_row_bytes_sparse(sparse_uri): |
85 |
| - with tiledb.open(sparse_uri) as a: |
86 |
| - # 3 cells/row, 3*4 bytes for dims + 8+4+1=13 bytes for attrs = 25 bytes/cell |
87 |
| - assert estimate_row_bytes(a) == 75 |
88 |
| - # 3 cells/row, 3*4 bytes for dims + 8+1=9 bytes for attrs = 21 bytes/cell |
89 |
| - assert estimate_row_bytes(a, attrs=["af8", "au1"]) == 63 |
90 |
| - # 3 cells/row, 3*4 bytes for dims + 4 bytes for attrs = 16 bytes/cell |
91 |
| - assert estimate_row_bytes(a, attrs=["af4"]) == 48 |
| 1 | +from tiledb.ml.readers._batch_utils import iter_batches |
92 | 2 |
|
93 | 3 |
|
94 | 4 | def test_iter_batches():
|
@@ -120,31 +30,3 @@ def test_iter_batches():
|
120 | 30 | "Batch(18, x[0:18], y[9:27], x_read[194:212])",
|
121 | 31 | "Batch(1, x[0:1], y[0:1], x_read[212:213], y_read[212:213])",
|
122 | 32 | ]
|
123 |
| - |
124 |
| - |
125 |
| -@pytest.mark.parametrize("memory_budget", [2**i for i in range(14, 20)]) |
126 |
| -@pytest.mark.parametrize( |
127 |
| - "attrs", |
128 |
| - [(), ("af8",), ("af4",), ("au1",), ("af8", "af4"), ("af8", "au1"), ("af4", "au1")], |
129 |
| -) |
130 |
| -def test_get_max_buffer_size(dense_uri, memory_budget, attrs): |
131 |
| - config = { |
132 |
| - "sm.memory_budget": memory_budget, |
133 |
| - "py.max_incomplete_retries": 0, |
134 |
| - } |
135 |
| - with tiledb.scope_ctx(config), tiledb.open(dense_uri) as a: |
136 |
| - buffer_size = get_max_buffer_size(a.schema, attrs) |
137 |
| - # Check that the buffer size is a multiple of the row tile extent |
138 |
| - assert buffer_size % a.dim(0).tile == 0 |
139 |
| - |
140 |
| - # Check that we can slice with buffer_size without incomplete reads |
141 |
| - query = a.query(attrs=attrs or None) |
142 |
| - for offset in range(0, a.shape[0], buffer_size): |
143 |
| - query[offset : offset + buffer_size] |
144 |
| - |
145 |
| - if buffer_size < a.shape[0]: |
146 |
| - # Check that buffer_size is the max size we can slice without incomplete reads |
147 |
| - buffer_size += 1 |
148 |
| - with pytest.raises(tiledb.TileDBError): |
149 |
| - for offset in range(0, a.shape[0], buffer_size): |
150 |
| - query[offset : offset + buffer_size] |
0 commit comments