diff --git a/xarray_array_testing/creation.py b/xarray_array_testing/creation.py index 291e082..3696318 100644 --- a/xarray_array_testing/creation.py +++ b/xarray_array_testing/creation.py @@ -1,13 +1,12 @@ -import hypothesis.strategies as st +from functools import partial + import xarray.testing.strategies as xrst -from hypothesis import given from xarray_array_testing.base import DuckArrayTestMixin +from xarray_array_testing.decorator import delayed_given class CreationTests(DuckArrayTestMixin): - @given(st.data()) - def test_create_variable(self, data): - variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) - + @delayed_given(partial(xrst.variables)) + def test_create_variable(self, variable): assert isinstance(variable.data, self.array_type) diff --git a/xarray_array_testing/decorator.py b/xarray_array_testing/decorator.py new file mode 100644 index 0000000..3e6babf --- /dev/null +++ b/xarray_array_testing/decorator.py @@ -0,0 +1,51 @@ +from functools import partial + +from hypothesis import given + + +def instantiate_given(params, **kwargs): + def maybe_apply_kwargs(param, **kwargs): + if not isinstance(param, partial): + return param + else: + return param(**kwargs) + + given_args, given_kwargs = params + instantiated_args = tuple( + maybe_apply_kwargs(param, **kwargs) for param in given_args + ) + instantiated_kwargs = { + name: maybe_apply_kwargs(param, **kwargs) + for name, param in given_kwargs.items() + } + + return instantiated_args, instantiated_kwargs + + +def initialize_tests(cls): + for name in dir(cls): + if not name.startswith("test_"): + continue + + method = getattr(cls, name) + + if not hasattr(method, "__hypothesis_given__"): + continue + params = method.__hypothesis_given__ + args, kwargs = instantiate_given( + params, array_strategy_fn=cls.array_strategy_fn + ) + decorated = given(*args, **kwargs)(method) + + setattr(cls, name, decorated) + + return cls + + +def delayed_given(*_given_args, **_given_kwargs): + def wrapper(f): + f.__hypothesis_given__ = (_given_args, _given_kwargs) + + return f + + return wrapper diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index 54af662..6d5574d 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -1,11 +1,11 @@ from contextlib import nullcontext +from functools import partial -import hypothesis.strategies as st import pytest import xarray.testing.strategies as xrst -from hypothesis import given from xarray_array_testing.base import DuckArrayTestMixin +from xarray_array_testing.decorator import delayed_given class ReductionTests(DuckArrayTestMixin): @@ -14,10 +14,8 @@ def expected_errors(op, **parameters): return nullcontext() @pytest.mark.parametrize("op", ["mean", "sum", "prod", "std", "var"]) - @given(st.data()) - def test_variable_mean(self, op, data): - variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) - + @delayed_given(partial(xrst.variables)) + def test_variable(self, op, variable): with self.expected_errors(op, variable=variable): # compute using xr.Variable.() actual = getattr(variable, op)().data diff --git a/xarray_array_testing/tests/test_numpy.py b/xarray_array_testing/tests/test_numpy.py index 2a9d95b..5541674 100644 --- a/xarray_array_testing/tests/test_numpy.py +++ b/xarray_array_testing/tests/test_numpy.py @@ -5,6 +5,7 @@ from xarray_array_testing.base import DuckArrayTestMixin from xarray_array_testing.creation import CreationTests +from xarray_array_testing.decorator import initialize_tests from xarray_array_testing.reduction import ReductionTests @@ -26,9 +27,11 @@ def array_strategy_fn(*, shape, dtype): return create_numpy_array(shape=shape, dtype=dtype) +@initialize_tests class TestCreationNumpy(CreationTests, NumpyTestMixin): pass +@initialize_tests class TestReductionNumpy(ReductionTests, NumpyTestMixin): pass