Skip to content

Commit 1532ec8

Browse files
committed
add tests for cubed
1 parent 003ffaf commit 1532ec8

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import ContextManager
2+
from contextlib import nullcontext
3+
4+
import pytest
5+
import hypothesis.strategies as st
6+
from hypothesis import note
7+
import numpy as np
8+
import numpy.testing as npt
9+
10+
from xarray_array_testing.base import DuckArrayTestMixin
11+
from xarray_array_testing.creation import CreationTests
12+
from xarray_array_testing.reduction import ReductionTests
13+
14+
import cubed
15+
import cubed.random
16+
17+
18+
def cubed_random_array(shape: tuple[int], dtype: np.dtype) -> cubed.Array:
19+
"""
20+
Generates a random cubed array
21+
22+
Supports integer and float dtypes.
23+
"""
24+
# TODO hypothesis doesn't like us using random inside strategies
25+
rng = np.random.default_rng()
26+
27+
if np.issubdtype(dtype, np.integer):
28+
arr = rng.integers(low=0, high=+3, size=shape, dtype=dtype)
29+
return cubed.from_array(arr)
30+
else:
31+
# TODO generate general chunking pattern
32+
ca = cubed.random.random(size=shape, chunks=shape)
33+
return cubed.array_api.astype(ca, dtype)
34+
35+
36+
def random_cubed_arrays_fn(
37+
*, shape: tuple[int, ...], dtype: np.dtype,
38+
) -> st.SearchStrategy[cubed.Array]:
39+
return st.builds(cubed_random_array, shape=st.just(shape), dtype=st.just(dtype))
40+
41+
42+
class CubedTestMixin(DuckArrayTestMixin):
43+
@property
44+
def xp(self) -> type[cubed.array_api]:
45+
return cubed.array_api
46+
47+
@property
48+
def array_type(self) -> type[cubed.Array]:
49+
return cubed.Array
50+
51+
@staticmethod
52+
def array_strategy_fn(*, shape, dtype) -> st.SearchStrategy[cubed.Array]:
53+
return random_cubed_arrays_fn(shape=shape, dtype=dtype)
54+
55+
@staticmethod
56+
def assert_equal(a: cubed.Array, b: cubed.Array):
57+
npt.assert_equal(a.compute(), b.compute())
58+
59+
60+
61+
class TestCreationCubed(CreationTests, CubedTestMixin):
62+
pass
63+
64+
65+
class TestReductionCubed(ReductionTests, CubedTestMixin):
66+
@staticmethod
67+
def expected_errors(op, **parameters) -> ContextManager:
68+
var = parameters.get('variable')
69+
70+
note(f"op = {op}")
71+
note(f"dtype = {var.dtype}")
72+
note(f"is_integer = {cubed.array_api.isdtype(var.dtype, 'integral')}")
73+
74+
if op == 'mean' and cubed.array_api.isdtype(var.dtype, "integral") or var.dtype == np.dtype('float16'):
75+
return pytest.raises(TypeError, match='Only real floating-point dtypes are allowed in mean')
76+
else:
77+
return nullcontext()

0 commit comments

Comments
 (0)