Skip to content

Commit d22dddc

Browse files
kaushikcfdinducer
authored andcommitted
test NumpyArrayContext
1 parent ce8ab7c commit d22dddc

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

arraycontext/pytest.py

+22
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from typing import Any, Callable, Dict, Sequence, Type, Union
3636

3737
from arraycontext.context import ArrayContext
38+
from arraycontext import NumpyArrayContext
3839

3940

4041
# {{{ array context factories
@@ -195,6 +196,26 @@ def __str__(self):
195196
return "<PytatoJAXArrayContext>"
196197

197198

199+
# {{{ _PytestArrayContextFactory
200+
201+
class _NumpyArrayContextForTests(NumpyArrayContext):
202+
def transform_loopy_program(self, t_unit):
203+
return t_unit
204+
205+
206+
class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
207+
def __init__(self, *args, **kwargs):
208+
super().__init__()
209+
210+
def __call__(self):
211+
return _NumpyArrayContextForTests()
212+
213+
def __str__(self):
214+
return "<NumpyArrayContext>"
215+
216+
# }}}
217+
218+
198219
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
199220
Dict[str, Type[PytestArrayContextFactory]] = {
200221
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
@@ -203,6 +224,7 @@ def __str__(self):
203224
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
204225
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
205226
"eagerjax": _PytestEagerJaxArrayContextFactory,
227+
"numpy": _PytestNumpyArrayContextFactory,
206228
}
207229

208230

test/test_arraycontext.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
_PytestPytatoPyOpenCLArrayContextFactory,
4646
_PytestEagerJaxArrayContextFactory,
4747
_PytestPytatoJaxArrayContextFactory)
48+
_PytestPytatoPyOpenCLArrayContextFactory,
49+
_PytestNumpyArrayContextFactory)
4850

4951

5052
import logging
@@ -93,6 +95,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory(
9395
_PytatoPyOpenCLArrayContextForTestsFactory,
9496
_PytestEagerJaxArrayContextFactory,
9597
_PytestPytatoJaxArrayContextFactory,
98+
_PytestNumpyArrayContextFactory,
9699
])
97100

98101

0 commit comments

Comments
 (0)