Skip to content

Commit 3e0fee5

Browse files
committed
validate as a csr matrix
1 parent 9a50873 commit 3e0fee5

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

pydiso/mkl_solver.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,16 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
143143

144144
self.matrix_type = matrix_type
145145

146-
indptr = np.asarray(A.indptr) # double check it's a numpy array
146+
A = self._validate_csr_matrix(A)
147+
148+
max_a_ind_itemsize = max(A.indptr.itemsize, A.indices.itemsize)
147149
mkl_int_size = get_mkl_int_size()
148150
mkl_int64_size = get_mkl_int64_size()
149151

150-
target_int_size = mkl_int_size if indptr.itemsize <= mkl_int_size else mkl_int64_size
152+
target_int_size = mkl_int_size if max_a_ind_itemsize <= mkl_int_size else mkl_int64_size
151153
self._ind_dtype = np.dtype(f"i{target_int_size}")
152154

153-
data, indptr, indices = self._validate_matrix(A)
155+
data, indptr, indices = self._validate_matrix_dtypes(A)
154156
self._data = data
155157
self._indptr = indptr
156158
self._indices = indices
@@ -185,7 +187,9 @@ def refactor(self, A):
185187
raise TypeError("A is not a sparse matrix.")
186188
if A.shape != self.shape:
187189
raise ValueError("A is not the same size as the previous matrix.")
188-
data, indptr, indices = self._validate_matrix(A)
190+
191+
A = self._validate_csr_matrix(A)
192+
data, indptr, indices = self._validate_matrix_dtypes(A)
189193
if len(data) != len(self._data):
190194
raise ValueError("new A matrix does not have the same number of non zeros.")
191195

@@ -284,21 +288,24 @@ def iparm(self):
284288
"""
285289
return np.array(self._handle.iparm)
286290

287-
def _validate_matrix(self, mat):
288-
291+
def _validate_csr_matrix(self, mat):
289292
if self.matrix_type in [-2, 2, -4, 4, 6]:
290-
# Symmetric matrices must have only the upper triangle
291-
if sp.isspmatrix_csc(mat):
292-
mat = mat.T # Transpose to get a CSR matrix since it's symmetric
293+
# only grab the upper triangle.
293294
mat = sp.triu(mat, format='csr')
294295

295-
if not (sp.isspmatrix_csr(mat)):
296-
warnings.warn("Converting %s matrix to CSR format."
297-
% mat.__class__.__name__, PardisoTypeConversionWarning)
296+
if mat.format != 'csr':
297+
warnings.warn(
298+
"Converting %s matrix to CSR format."% A.__class__.__name__,
299+
PardisoTypeConversionWarning,
300+
stacklevel=3
301+
)
298302
mat = mat.tocsr()
303+
299304
mat.sort_indices()
300305
mat.sum_duplicates()
306+
return mat
301307

308+
def _validate_matrix_dtypes(self, mat):
302309
data = np.require(mat.data, self._data_dtype, requirements="C")
303310
indptr = np.require(mat.indptr, self._ind_dtype, requirements="C")
304311
indices = np.require(mat.indices, self._ind_dtype, requirements="C")

0 commit comments

Comments
 (0)