2
2
from typing import Generic , Iterator , Sequence , Type , TypeVar , Union
3
3
4
4
import numpy as np
5
- import scipy . sparse as sp
5
+ import sparse
6
6
7
7
import tiledb
8
8
11
11
Tensor = TypeVar ("Tensor" )
12
12
13
13
14
- class TileDBTensorGenerator ( ABC , Generic [ Tensor ]) :
14
+ class TileDBNumpyGenerator :
15
15
"""Base class for generating tensors read from a TileDB array."""
16
16
17
17
def __init__ (self , array : tiledb .Array , attrs : Sequence [str ]) -> None :
@@ -21,83 +21,50 @@ def __init__(self, array: tiledb.Array, attrs: Sequence[str]) -> None:
21
21
"""
22
22
self ._query = array .query (attrs = attrs )
23
23
24
- @abstractmethod
25
24
def read_buffer (self , array_slice : slice ) -> None :
26
25
"""
27
26
Read an array slice and save it as the current buffer.
28
27
29
28
:param array_slice: Requested array slice.
30
29
"""
30
+ self ._buf_arrays = tuple (self ._query [array_slice ].values ())
31
31
32
- @abstractmethod
33
- def iter_tensors (self , buffer_slice : slice ) -> Iterator [Tensor ]:
32
+ def iter_tensors (self , buffer_slice : slice ) -> Iterator [np .ndarray ]:
34
33
"""
35
34
Return an iterator of tensors for the given slice, one tensor per attribute
36
35
37
36
Must be called after `read_buffer`.
38
37
39
38
:param buffer_slice: Slice of the current buffer to convert to tensors.
40
39
"""
40
+ return (buf_array [buffer_slice ] for buf_array in self ._buf_arrays )
41
41
42
42
43
- class TileDBNumpyGenerator (TileDBTensorGenerator [np .ndarray ]):
44
- def read_buffer (self , array_slice : slice ) -> None :
45
- self ._buf_arrays = tuple (self ._query [array_slice ].values ())
46
-
47
- def iter_tensors (self , buffer_slice : slice ) -> Iterator [np .ndarray ]:
48
- for buf_array in self ._buf_arrays :
49
- yield buf_array [buffer_slice ]
50
-
51
-
52
- class TileDBSparseTensorGenerator (TileDBTensorGenerator [Tensor ]):
43
+ class TileDBSparseTensorGenerator (TileDBNumpyGenerator , ABC , Generic [Tensor ]):
53
44
def __init__ (self , array : tiledb .Array , attrs : Sequence [str ]) -> None :
54
- schema = array .schema
55
- if schema .ndim != 2 :
56
- raise NotImplementedError ("Only 2D sparse tensors are currently supported" )
57
- self ._row_dim = schema .domain .dim (0 ).name
58
- self ._col_dim = schema .domain .dim (1 ).name
59
- self ._row_shape = schema .shape [1 :]
60
- self ._attr_dtypes = tuple (schema .attr (attr ).dtype for attr in attrs )
45
+ self ._dims = tuple (array .domain .dim (i ).name for i in range (array .ndim ))
46
+ self ._row_shape = array .shape [1 :]
61
47
super ().__init__ (array , attrs )
62
48
63
49
def read_buffer (self , array_slice : slice ) -> None :
64
50
buffer = self ._query [array_slice ]
65
- # COO to CSR transformation for batching and row slicing
66
- row = buffer .pop (self ._row_dim )
67
- col = buffer .pop (self ._col_dim )
68
- # Normalize indices: We want the coords indices to be in the [0, array_slice size]
69
- # range. If we do not normalize the sparse tensor is being created but with a
70
- # dimension [0, max(coord_index)], which is overkill
51
+ coords = [buffer .pop (dim ) for dim in self ._dims ]
52
+ # normalize the first coordinate dimension to start at start_offset
71
53
start_offset = array_slice .start
72
- stop_offset = array_slice . stop
73
- shape = ( stop_offset - start_offset , * self . _row_shape )
74
- self . _buf_csrs = tuple (
75
- sp . csr_matrix (( data , ( row - start_offset , col )), shape = shape )
76
- for data in buffer .values ()
54
+ if start_offset :
55
+ coords [ 0 ] -= start_offset
56
+ shape = ( array_slice . stop - start_offset , * self . _row_shape )
57
+ self . _buf_arrays = tuple (
58
+ sparse . COO ( coords , data , shape ) for data in buffer .values ()
77
59
)
78
60
79
61
def iter_tensors (self , buffer_slice : slice ) -> Iterator [Tensor ]:
80
- for buf_csr , dtype in zip (self ._buf_csrs , self ._attr_dtypes ):
81
- batch_csr = buf_csr [buffer_slice ]
82
- batch_coo = batch_csr .tocoo ()
83
- data = batch_coo .data
84
- coords = np .stack ((batch_coo .row , batch_coo .col ), axis = - 1 )
85
- dense_shape = (batch_csr .shape [0 ], * self ._row_shape )
86
- yield self ._tensor_from_coo (data , coords , dense_shape , dtype )
62
+ return map (self ._tensor_from_coo , super ().iter_tensors (buffer_slice ))
87
63
88
64
@staticmethod
89
65
@abstractmethod
90
- def _tensor_from_coo (
91
- data : np .ndarray ,
92
- coords : np .ndarray ,
93
- dense_shape : Sequence [int ],
94
- dtype : np .dtype ,
95
- ) -> Tensor :
96
- """Convert a scipy.sparse.coo_matrix to a Tensor"""
97
-
98
-
99
- DT = TypeVar ("DT" )
100
- ST = TypeVar ("ST" )
66
+ def _tensor_from_coo (coo : sparse .COO ) -> Tensor :
67
+ """Convert a sparse.COO to a Tensor"""
101
68
102
69
103
70
def tensor_generator (
@@ -107,11 +74,10 @@ def tensor_generator(
107
74
y_buffer_size : int ,
108
75
x_attrs : Sequence [str ],
109
76
y_attrs : Sequence [str ],
77
+ sparse_generator_cls : Type [TileDBSparseTensorGenerator [Tensor ]],
110
78
start_offset : int = 0 ,
111
79
stop_offset : int = 0 ,
112
- dense_generator_cls : Type [TileDBTensorGenerator [DT ]] = TileDBNumpyGenerator ,
113
- sparse_generator_cls : Type [TileDBTensorGenerator [ST ]] = TileDBSparseTensorGenerator ,
114
- ) -> Iterator [Sequence [Union [DT , ST ]]]:
80
+ ) -> Iterator [Sequence [Union [np .ndarray , Tensor ]]]:
115
81
"""
116
82
Generator for batches of tensors.
117
83
@@ -126,18 +92,17 @@ def tensor_generator(
126
92
:param y_attrs: Attribute names of y_array.
127
93
:param start_offset: Start row offset; defaults to 0.
128
94
:param stop_offset: Stop row offset; defaults to number of rows.
129
- :param dense_generator_cls: Dense tensor generator type.
130
95
:param sparse_generator_cls: Sparse tensor generator type.
131
96
"""
132
- x_gen : Union [TileDBTensorGenerator [ DT ], TileDBTensorGenerator [ ST ]] = (
97
+ x_gen : Union [TileDBNumpyGenerator , TileDBSparseTensorGenerator [ Tensor ]] = (
133
98
sparse_generator_cls (x_array , x_attrs )
134
99
if x_array .schema .sparse
135
- else dense_generator_cls (x_array , x_attrs )
100
+ else TileDBNumpyGenerator (x_array , x_attrs )
136
101
)
137
- y_gen : Union [TileDBTensorGenerator [ DT ], TileDBTensorGenerator [ ST ]] = (
102
+ y_gen : Union [TileDBNumpyGenerator , TileDBSparseTensorGenerator [ Tensor ]] = (
138
103
sparse_generator_cls (y_array , y_attrs )
139
104
if y_array .schema .sparse
140
- else dense_generator_cls (y_array , y_attrs )
105
+ else TileDBNumpyGenerator (y_array , y_attrs )
141
106
)
142
107
if not stop_offset :
143
108
stop_offset = x_array .shape [0 ]
0 commit comments