Skip to content

Commit ae8e04f

Browse files
committed
Dockerfile
1 parent e2913e3 commit ae8e04f

File tree

10 files changed

+43
-19
lines changed

10 files changed

+43
-19
lines changed

Dockerfile.gpu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
FROM python:3.12-slim
2+
3+
ADD . .
4+
5+
RUN pip install ".[cuda-12]"
6+
7+
CMD ["python"]

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "sparse"
77
dynamic = ["version"]
88
description = "Sparse n-dimensional arrays for the PyData ecosystem"
99
readme = "README.md"
10-
dependencies = ["numpy>=1.17", "numba>=0.49"]
10+
dependencies = ["numpy>=1.17", "numba>=0.49", "array_api_compat>=1.11"]
1111
maintainers = [{ name = "Hameer Abbasi", email = "[email protected]" }]
1212
requires-python = ">=3.10"
1313
license = { file = "LICENSE" }
@@ -51,6 +51,8 @@ tests = [
5151
"pre-commit",
5252
"pytest-codspeed",
5353
]
54+
cuda-12 = ["cupy-cuda12x"]
55+
cuda-11 = ["cupy-cuda11x"]
5456
tox = ["sparse[tests]", "tox"]
5557
notebooks = ["sparse[tests]", "nbmake", "matplotlib"]
5658
all = ["sparse[docs,tox,notebooks,mlir]", "matrepr"]

sparse/numba_backend/_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def matmul(a, b):
262262

263263
from ._settings import NUMPY_DEVICE
264264

265-
if a.device != NUMPY_DEVICE or b.device != NUMPY_DEVICE:
265+
if getattr(a, "device", NUMPY_DEVICE) != NUMPY_DEVICE or getattr(b, "device", NUMPY_DEVICE) != NUMPY_DEVICE:
266266
import cupyx.scipy.sparse as cps
267267

268268
if isinstance(a, COO):
@@ -2074,7 +2074,10 @@ def pad(array, pad_width, mode="constant", **kwargs):
20742074
if mode.lower() != "constant":
20752075
raise NotImplementedError(f"Mode '{mode}' is not yet supported.")
20762076

2077-
if not equivalent(kwargs.pop("constant_values", _zero_of_dtype(array.dtype)), array.fill_value):
2077+
if not equivalent(
2078+
array._component_namespace.asarray(kwargs.pop("constant_values", _zero_of_dtype(array.dtype, array.device))),
2079+
array.fill_value,
2080+
):
20782081
raise ValueError("constant_values can only be equal to fill value.")
20792082

20802083
if kwargs:

sparse/numba_backend/_compressed/compressed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184

185185
self._compressed_axes = tuple(compressed_axes) if isinstance(compressed_axes, Iterable) else None
186186
self.fill_value = self.data.dtype.type(fill_value)
187+
self.data, self.indices, self.indptr = np.asarray(self.data), np.asarray(self.indices), np.asarray(self.indptr)
187188

188189
if prune:
189190
self._prune()

sparse/numba_backend/_coo/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def __init__(
209209
fill_value=None,
210210
idx_dtype=None,
211211
):
212+
import array_api_compat
213+
212214
from .._common import _coerce_to_supported_dense
213215

214216
if isinstance(coords, COO):
@@ -232,6 +234,7 @@ def __init__(
232234

233235
self.data = _coerce_to_supported_dense(data)
234236
self.coords = _coerce_to_supported_dense(coords)
237+
xp = array_api_compat.get_namespace(self.data, self.coords)
235238

236239
if self.coords.ndim == 1:
237240
if self.coords.size == 0 and shape is not None:
@@ -240,7 +243,7 @@ def __init__(
240243
self.coords = self.coords[None, :]
241244

242245
if self.data.ndim == 0:
243-
self.data = self._component_namespace.broadcast_to(self.data, self.coords.shape[1])
246+
self.data = xp.broadcast_to(self.data, self.coords.shape[1])
244247

245248
if self.data.ndim != 1:
246249
raise ValueError("`data` must be a scalar or 1-dimensional.")
@@ -255,9 +258,7 @@ def __init__(
255258
shape = tuple(shape)
256259

257260
if shape and not self.coords.size:
258-
self.coords = self._component_namespace.zeros(
259-
(len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp
260-
)
261+
self.coords = xp.zeros((len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp)
261262
super().__init__(shape, fill_value=fill_value)
262263
if idx_dtype:
263264
if not can_store(idx_dtype, max(shape)):

sparse/numba_backend/_coo/indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def getitem(x, index):
118118
if n != 0:
119119
return x.data[mask][0]
120120

121-
return x.fill_value
121+
return x.fill_value[()]
122122

123123
shape = tuple(shape)
124124
data = x.data[mask]

sparse/numba_backend/_sparse_array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ def __init__(self, shape, fill_value=None):
5050

5151
self.fill_value = _zero_of_dtype(self.dtype, getattr(getattr(self, "data", None), "device", NUMPY_DEVICE))
5252
data = getattr(self, "data", None)
53-
if data is not None:
53+
if data is not None and not isinstance(data, dict):
5454
import array_api_compat
5555

56-
self.fill_value = array_api_compat.array_namespace(data).asarray(self.fill_value)
56+
xp = array_api_compat.array_namespace(data)
57+
else:
58+
xp = np
5759

60+
self.fill_value = xp.asarray(self.fill_value)
5861
self.device # noqa: B018
5962

6063
dtype = None
@@ -402,7 +405,7 @@ def _gpu_ufunc(self, ufunc, method, *inputs, **kwargs):
402405
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
403406
from ._settings import NUMPY_DEVICE
404407

405-
if not all(i.device == NUMPY_DEVICE for i in inputs):
408+
if not all(getattr(i, "device", NUMPY_DEVICE) == NUMPY_DEVICE for i in inputs):
406409
return self._gpu_ufunc(ufunc, method, *inputs, **kwargs)
407410
out = kwargs.pop("out", None)
408411
if out is not None and not all(isinstance(x, type(self)) for x in out):

sparse/numba_backend/_umath.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,9 @@ def _get_fill_value(self):
514514
import array_api_compat
515515

516516
from ._coo import COO
517+
from ._sparse_array import SparseArray
517518

518-
xp = array_api_compat.array_namespace(*(a.data for a in self.args))
519+
xp = array_api_compat.array_namespace(*(a.data if isinstance(a, SparseArray) else a for a in self.args))
519520

520521
def get_zero_arg(x):
521522
if isinstance(x, COO):

sparse/numba_backend/_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def assert_gcxs_slicing(s, x):
7373

7474

7575
def assert_nnz(s, x):
76-
fill_value = s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype, s.device)
76+
from ._settings import NUMPY_DEVICE
77+
78+
fill_value = (
79+
s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype, getattr(s, "device", NUMPY_DEVICE))
80+
)
7781
assert np.sum(~equivalent(x, fill_value)) == s.nnz
7882

7983

@@ -442,7 +446,7 @@ def equivalent(x, y, /, loose=False):
442446

443447
from ._common import _coerce_to_supported_dense
444448

445-
namespace = array_api_compat.array_namespace(x, y)
449+
xp = array_api_compat.array_namespace(x, y)
446450
x = _coerce_to_supported_dense(x)
447451
y = _coerce_to_supported_dense(y)
448452
# Can't contain NaNs
@@ -458,9 +462,9 @@ def equivalent(x, y, /, loose=False):
458462
return (x == y) | ((x != x) & (y != y))
459463

460464
if x.size == 0 or y.size == 0:
461-
shape = namespace.broadcast_shapes(x.shape, y.shape)
462-
return namespace.empty(shape, dtype=np.bool_)
463-
x, y = namespace.broadcast_arrays(x[..., None], y[..., None])
465+
shape = xp.broadcast_shapes(x.shape, y.shape)
466+
return xp.empty(shape, dtype=np.bool_)
467+
x, y = xp.broadcast_arrays(x[..., None], y[..., None])
464468
return (x.astype(dt).view(np.uint8) == y.astype(dt).view(np.uint8)).all(axis=-1)
465469

466470

sparse/numba_backend/tests/test_compressed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ def test_tranpose(a, b):
181181
@pytest.mark.parametrize("format", [sparse.COO, sparse._compressed.CSR])
182182
def test_to_scipy_sparse(fill_value_in, fill_value_out, format):
183183
s = sparse.random((3, 5), density=0.5, format=format, fill_value=fill_value_in)
184-
185-
if not ((fill_value_in in {0, None} and fill_value_out in {0, None}) or equivalent(fill_value_in, fill_value_out)):
184+
if not (
185+
(fill_value_in in {0, None} and fill_value_out in {0, None})
186+
or equivalent(np.asarray(fill_value_in), np.asarray(fill_value_out))
187+
):
186188
with pytest.raises(ValueError, match=r"fill_value=.* but should be in .*\."):
187189
s.to_scipy_sparse(accept_fv=fill_value_out)
188190
return

0 commit comments

Comments
 (0)