@@ -110,33 +110,41 @@ def set_buffer_offset(self, buffer: Mapping[str, np.ndarray], offset: int) -> No
110
110
# Normalize indices: We want the coords indices to be in the [0, batch_size]
111
111
# range. If we do not normalize the sparse tensor is being created but with a
112
112
# dimension [0, max(coord_index)], which is overkill
113
- self ._buffer_csr = sp .csr_matrix ((buffer [self ._attrs [0 ]], (row - offset , col )))
113
+ self ._buffer_csrs = [
114
+ sp .csr_matrix ((buffer [attr ], (row - offset , col ))) for attr in self ._attrs
115
+ ]
114
116
115
117
def set_batch_slice (self , batch_slice : slice ) -> None :
116
- assert hasattr (self , "_buffer_csr " ), "set_buffer_offset() not called"
117
- self ._batch_csr = self . _buffer_csr [ batch_slice ]
118
+ assert hasattr (self , "_buffer_csrs " ), "set_buffer_offset() not called"
119
+ self ._batch_csrs = [ buffer_csr [ batch_slice ] for buffer_csr in self . _buffer_csrs ]
118
120
119
121
def iter_tensors (self , perm_idxs : Optional [np .ndarray ] = None ) -> Iterator [Tensor ]:
120
- assert hasattr (self , "_batch_csr " ), "set_batch_slice() not called"
122
+ assert hasattr (self , "_batch_csrs " ), "set_batch_slice() not called"
121
123
if perm_idxs is not None :
122
124
raise NotImplementedError (
123
125
"within_batch_shuffle not implemented for sparse arrays"
124
126
)
125
- batch_coo = self ._batch_csr . tocoo ()
126
- data = batch_coo . data
127
- coords = np . stack (( batch_coo . row , batch_coo . col ), axis = - 1 )
128
- for dtype in self . _attr_dtypes :
127
+ for batch_csr , dtype in zip ( self ._batch_csrs , self . _attr_dtypes ):
128
+ batch_coo = batch_csr . tocoo ()
129
+ data = batch_coo . data
130
+ coords = np . stack (( batch_coo . row , batch_coo . col ), axis = - 1 )
129
131
yield self ._tensor_from_coo (data , coords , self ._dense_shape , dtype )
130
132
131
133
def __len__ (self ) -> int :
132
- assert hasattr (self , "_batch_csr " ), "set_batch_slice() not called"
134
+ assert hasattr (self , "_batch_csrs " ), "set_batch_slice() not called"
133
135
# return number of non-zero rows
134
- return int ((self ._batch_csr .getnnz (axis = 1 ) > 0 ).sum ())
136
+ lengths = {
137
+ int ((batch_csr .getnnz (axis = 1 ) > 0 ).sum ()) for batch_csr in self ._batch_csrs
138
+ }
139
+ assert len (lengths ) == 1 , f"Multiple different batch lengths: { lengths } "
140
+ return lengths .pop ()
135
141
136
142
def __bool__ (self ) -> bool :
137
- assert hasattr (self , "_batch_csr " ), "set_batch_slice() not called"
143
+ assert hasattr (self , "_batch_csrs " ), "set_batch_slice() not called"
138
144
# faster version of __len__() > 0
139
- return len (self ._batch_csr .data ) > 0
145
+ lengths = {len (batch_csr .data ) for batch_csr in self ._batch_csrs }
146
+ assert len (lengths ) == 1 , f"Multiple different batch lengths: { lengths } "
147
+ return lengths .pop () > 0
140
148
141
149
@staticmethod
142
150
@abstractmethod
0 commit comments