Skip to content

Commit 09f9001

Browse files
kaushikcfdinducer
authored andcommitted
test NumpyArrayContext
1 parent d8dfe02 commit 09f9001

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

Diff for: arraycontext/pytest.py

+22
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from typing import Any, Callable, Dict, Sequence, Type, Union
3636

37+
from arraycontext import NumpyArrayContext
3738
from arraycontext.context import ArrayContext
3839

3940

@@ -222,6 +223,26 @@ def __str__(self):
222223
return "<PytatoJAXArrayContext>"
223224

224225

226+
# {{{ _PytestArrayContextFactory
227+
228+
class _NumpyArrayContextForTests(NumpyArrayContext):
229+
def transform_loopy_program(self, t_unit):
230+
return t_unit
231+
232+
233+
class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
234+
def __init__(self, *args, **kwargs):
235+
super().__init__()
236+
237+
def __call__(self):
238+
return _NumpyArrayContextForTests()
239+
240+
def __str__(self):
241+
return "<NumpyArrayContext>"
242+
243+
# }}}
244+
245+
225246
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
226247
Dict[str, Type[PytestArrayContextFactory]] = {
227248
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
@@ -230,6 +251,7 @@ def __str__(self):
230251
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
231252
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
232253
"eagerjax": _PytestEagerJaxArrayContextFactory,
254+
"numpy": _PytestNumpyArrayContextFactory,
233255
}
234256

235257

Diff for: test/test_arraycontext.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from arraycontext.pytest import (
4848
_PytestEagerJaxArrayContextFactory,
49+
_PytestNumpyArrayContextFactory,
4950
_PytestPyOpenCLArrayContextFactoryWithClass,
5051
_PytestPytatoJaxArrayContextFactory,
5152
_PytestPytatoPyOpenCLArrayContextFactory,
@@ -97,6 +98,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory(
9798
_PytatoPyOpenCLArrayContextForTestsFactory,
9899
_PytestEagerJaxArrayContextFactory,
99100
_PytestPytatoJaxArrayContextFactory,
101+
_PytestNumpyArrayContextFactory,
100102
])
101103

102104

0 commit comments

Comments
 (0)