11from  __future__ import  annotations 
22
33from  abc  import  ABC , abstractmethod 
4- from  concurrent . futures  import  ThreadPoolExecutor 
5- from  typing  import  Generic , Iterator , Mapping ,  Optional , Sequence , Type , TypeVar , Union 
4+ from  concurrent  import  futures 
5+ from  typing  import  Generic , Iterator , Optional , Sequence , Type , TypeVar , Union 
66
77import  numpy  as  np 
88import  scipy .sparse  as  sp 
@@ -25,18 +25,18 @@ def __init__(
2525        self ._batch_size  =  batch_size 
2626
2727    @abstractmethod  
28-     def  set_buffer_offset (self , buffer :  Mapping [ str ,  np . ndarray ],  offset :  int ) ->  None :
29-         """Set the current buffer  from which subsequent batches are to be read . 
28+     def  read_buffer (self , array :  tiledb . Array ,  buffer_slice :  slice ) ->  None :
29+         """Read a slice  from a TileDB array into a buffer . 
3030
31-         :param buffer: Mapping of attribute names  to numpy arrays . 
32-         :param offset: Start offset  of the buffer in the TileDB array . 
31+         :param array: TileDB array  to read from . 
32+         :param buffer_slice: Slice  of the array to read . 
3333        """ 
3434
3535    @abstractmethod  
3636    def  set_batch_slice (self , batch_slice : slice ) ->  None :
37-         """Set the current batch as a slice of the set  buffer. 
37+         """Set the current batch as a slice of the read  buffer. 
3838
39-         Must be called after `set_buffer_offset `. 
39+         Must be called after `read_buffer `. 
4040
4141        :param batch_slice: Slice of the buffer to be used as the current batch. 
4242        """ 
@@ -62,12 +62,14 @@ def __len__(self) -> int:
6262
6363
6464class  BaseDenseBatch (BaseBatch [Tensor ]):
65-     def  set_buffer_offset (self , buffer :  Mapping [ str ,  np . ndarray ],  offset :  int ) ->  None :
66-         self ._buffer  =  buffer 
65+     def  read_buffer (self , array :  tiledb . Array ,  buffer_slice :  slice ) ->  None :
66+         self ._buffer  =  array . query ( dims = (),  attrs = self . _attrs )[ buffer_slice ] 
6767
6868    def  set_batch_slice (self , batch_slice : slice ) ->  None :
69-         assert  hasattr (self , "_buffer" ), "set_buffer_offset() not called" 
70-         self ._attr_batches  =  [self ._buffer [attr ][batch_slice ] for  attr  in  self ._attrs ]
69+         assert  hasattr (self , "_buffer" ), "read_buffer() not called" 
70+         self ._attr_batches  =  tuple (
71+             self ._buffer [attr ][batch_slice ] for  attr  in  self ._attrs 
72+         )
7173
7274    def  iter_tensors (self , perm_idxs : Optional [np .ndarray ] =  None ) ->  Iterator [Tensor ]:
7375        assert  hasattr (self , "_attr_batches" ), "set_batch_slice() not called" 
@@ -103,20 +105,24 @@ def __init__(
103105        self ._dense_shape  =  (batch_size , schema .shape [1 ])
104106        self ._attr_dtypes  =  tuple (schema .attr (attr ).dtype  for  attr  in  self ._attrs )
105107
106-     def  set_buffer_offset (self , buffer : Mapping [str , np .ndarray ], offset : int ) ->  None :
108+     def  read_buffer (self , array : tiledb .Array , buffer_slice : slice ) ->  None :
109+         buffer  =  array .query (attrs = self ._attrs )[buffer_slice ]
107110        # COO to CSR transformation for batching and row slicing 
108111        row  =  buffer [self ._row_dim ]
109112        col  =  buffer [self ._col_dim ]
110113        # Normalize indices: We want the coords indices to be in the [0, batch_size] 
111114        # range. If we do not normalize the sparse tensor is being created but with a 
112115        # dimension [0, max(coord_index)], which is overkill 
113-         self ._buffer_csrs  =  [
116+         offset  =  buffer_slice .start 
117+         self ._buffer_csrs  =  tuple (
114118            sp .csr_matrix ((buffer [attr ], (row  -  offset , col ))) for  attr  in  self ._attrs 
115-         ] 
119+         ) 
116120
117121    def  set_batch_slice (self , batch_slice : slice ) ->  None :
118-         assert  hasattr (self , "_buffer_csrs" ), "set_buffer_offset() not called" 
119-         self ._batch_csrs  =  [buffer_csr [batch_slice ] for  buffer_csr  in  self ._buffer_csrs ]
122+         assert  hasattr (self , "_buffer_csrs" ), "read_buffer() not called" 
123+         self ._batch_csrs  =  tuple (
124+             buffer_csr [batch_slice ] for  buffer_csr  in  self ._buffer_csrs 
125+         )
120126
121127    def  iter_tensors (self , perm_idxs : Optional [np .ndarray ] =  None ) ->  Iterator [Tensor ]:
122128        assert  hasattr (self , "_batch_csrs" ), "set_batch_slice() not called" 
@@ -208,15 +214,15 @@ def batch_factory(
208214
209215    x_batch  =  batch_factory (x_array .schema , x_attrs )
210216    y_batch  =  batch_factory (y_array .schema , y_attrs )
211-     with  ThreadPoolExecutor (max_workers = 2 ) as  executor :
217+     with  futures . ThreadPoolExecutor (max_workers = 2 ) as  executor :
212218        for  offset  in  range (start_offset , stop_offset , buffer_size ):
213-             x_buffer , y_buffer  =  executor .map (
214-                 lambda  array : array [offset  : offset  +  buffer_size ],  # type: ignore 
215-                 (x_array , y_array ),
219+             buffer_slice  =  slice (offset , offset  +  buffer_size )
220+             futures .wait (
221+                 (
222+                     executor .submit (x_batch .read_buffer , x_array , buffer_slice ),
223+                     executor .submit (y_batch .read_buffer , y_array , buffer_slice ),
224+                 )
216225            )
217-             x_batch .set_buffer_offset (x_buffer , offset )
218-             y_batch .set_buffer_offset (y_buffer , offset )
219- 
220226            # Split the buffer_size into batch_size chunks 
221227            batch_offsets  =  np .arange (
222228                0 , min (buffer_size , stop_offset  -  offset ), batch_size 
0 commit comments