1
1
from __future__ import annotations
2
2
3
3
from abc import ABC , abstractmethod
4
- from concurrent . futures import ThreadPoolExecutor
5
- from typing import Generic , Iterator , Mapping , Optional , Sequence , Type , TypeVar , Union
4
+ from concurrent import futures
5
+ from typing import Generic , Iterator , Optional , Sequence , Type , TypeVar , Union
6
6
7
7
import numpy as np
8
8
import scipy .sparse as sp
@@ -25,18 +25,18 @@ def __init__(
25
25
self ._batch_size = batch_size
26
26
27
27
@abstractmethod
28
- def set_buffer_offset (self , buffer : Mapping [ str , np . ndarray ], offset : int ) -> None :
29
- """Set the current buffer from which subsequent batches are to be read .
28
+ def read_buffer (self , array : tiledb . Array , buffer_slice : slice ) -> None :
29
+ """Read a slice from a TileDB array into a buffer .
30
30
31
- :param buffer: Mapping of attribute names to numpy arrays .
32
- :param offset: Start offset of the buffer in the TileDB array .
31
+ :param array: TileDB array to read from .
32
+ :param buffer_slice: Slice of the array to read .
33
33
"""
34
34
35
35
@abstractmethod
36
36
def set_batch_slice (self , batch_slice : slice ) -> None :
37
- """Set the current batch as a slice of the set buffer.
37
+ """Set the current batch as a slice of the read buffer.
38
38
39
- Must be called after `set_buffer_offset `.
39
+ Must be called after `read_buffer `.
40
40
41
41
:param batch_slice: Slice of the buffer to be used as the current batch.
42
42
"""
@@ -62,12 +62,14 @@ def __len__(self) -> int:
62
62
63
63
64
64
class BaseDenseBatch (BaseBatch [Tensor ]):
65
- def set_buffer_offset (self , buffer : Mapping [ str , np . ndarray ], offset : int ) -> None :
66
- self ._buffer = buffer
65
+ def read_buffer (self , array : tiledb . Array , buffer_slice : slice ) -> None :
66
+ self ._buffer = array . query ( dims = (), attrs = self . _attrs )[ buffer_slice ]
67
67
68
68
def set_batch_slice (self , batch_slice : slice ) -> None :
69
- assert hasattr (self , "_buffer" ), "set_buffer_offset() not called"
70
- self ._attr_batches = [self ._buffer [attr ][batch_slice ] for attr in self ._attrs ]
69
+ assert hasattr (self , "_buffer" ), "read_buffer() not called"
70
+ self ._attr_batches = tuple (
71
+ self ._buffer [attr ][batch_slice ] for attr in self ._attrs
72
+ )
71
73
72
74
def iter_tensors (self , perm_idxs : Optional [np .ndarray ] = None ) -> Iterator [Tensor ]:
73
75
assert hasattr (self , "_attr_batches" ), "set_batch_slice() not called"
@@ -103,20 +105,24 @@ def __init__(
103
105
self ._dense_shape = (batch_size , schema .shape [1 ])
104
106
self ._attr_dtypes = tuple (schema .attr (attr ).dtype for attr in self ._attrs )
105
107
106
- def set_buffer_offset (self , buffer : Mapping [str , np .ndarray ], offset : int ) -> None :
108
+ def read_buffer (self , array : tiledb .Array , buffer_slice : slice ) -> None :
109
+ buffer = array .query (attrs = self ._attrs )[buffer_slice ]
107
110
# COO to CSR transformation for batching and row slicing
108
111
row = buffer [self ._row_dim ]
109
112
col = buffer [self ._col_dim ]
110
113
# Normalize indices: We want the coords indices to be in the [0, batch_size]
111
114
# range. If we do not normalize the sparse tensor is being created but with a
112
115
# dimension [0, max(coord_index)], which is overkill
113
- self ._buffer_csrs = [
116
+ offset = buffer_slice .start
117
+ self ._buffer_csrs = tuple (
114
118
sp .csr_matrix ((buffer [attr ], (row - offset , col ))) for attr in self ._attrs
115
- ]
119
+ )
116
120
117
121
def set_batch_slice (self , batch_slice : slice ) -> None :
118
- assert hasattr (self , "_buffer_csrs" ), "set_buffer_offset() not called"
119
- self ._batch_csrs = [buffer_csr [batch_slice ] for buffer_csr in self ._buffer_csrs ]
122
+ assert hasattr (self , "_buffer_csrs" ), "read_buffer() not called"
123
+ self ._batch_csrs = tuple (
124
+ buffer_csr [batch_slice ] for buffer_csr in self ._buffer_csrs
125
+ )
120
126
121
127
def iter_tensors (self , perm_idxs : Optional [np .ndarray ] = None ) -> Iterator [Tensor ]:
122
128
assert hasattr (self , "_batch_csrs" ), "set_batch_slice() not called"
@@ -208,15 +214,15 @@ def batch_factory(
208
214
209
215
x_batch = batch_factory (x_array .schema , x_attrs )
210
216
y_batch = batch_factory (y_array .schema , y_attrs )
211
- with ThreadPoolExecutor (max_workers = 2 ) as executor :
217
+ with futures . ThreadPoolExecutor (max_workers = 2 ) as executor :
212
218
for offset in range (start_offset , stop_offset , buffer_size ):
213
- x_buffer , y_buffer = executor .map (
214
- lambda array : array [offset : offset + buffer_size ], # type: ignore
215
- (x_array , y_array ),
219
+ buffer_slice = slice (offset , offset + buffer_size )
220
+ futures .wait (
221
+ (
222
+ executor .submit (x_batch .read_buffer , x_array , buffer_slice ),
223
+ executor .submit (y_batch .read_buffer , y_array , buffer_slice ),
224
+ )
216
225
)
217
- x_batch .set_buffer_offset (x_buffer , offset )
218
- y_batch .set_buffer_offset (y_buffer , offset )
219
-
220
226
# Split the buffer_size into batch_size chunks
221
227
batch_offsets = np .arange (
222
228
0 , min (buffer_size , stop_offset - offset ), batch_size
0 commit comments