3434
3535from typing import Any , Callable , Dict , Sequence , Type , Union
3636
37+ from arraycontext import NumpyArrayContext
3738from arraycontext .context import ArrayContext
3839
3940
@@ -222,6 +223,26 @@ def __str__(self):
222223 return "<PytatoJAXArrayContext>"
223224
224225
226+ # {{{ _PytestArrayContextFactory
227+
228+ class _NumpyArrayContextForTests (NumpyArrayContext ):
229+ def transform_loopy_program (self , t_unit ):
230+ return t_unit
231+
232+
233+ class _PytestNumpyArrayContextFactory (PytestArrayContextFactory ):
234+ def __init__ (self , * args , ** kwargs ):
235+ super ().__init__ ()
236+
237+ def __call__ (self ):
238+ return _NumpyArrayContextForTests ()
239+
240+ def __str__ (self ):
241+ return "<NumpyArrayContext>"
242+
243+ # }}}
244+
245+
225246_ARRAY_CONTEXT_FACTORY_REGISTRY : \
226247 Dict [str , Type [PytestArrayContextFactory ]] = {
227248 "pyopencl" : _PytestPyOpenCLArrayContextFactoryWithClass ,
@@ -230,6 +251,7 @@ def __str__(self):
230251 "pytato:pyopencl" : _PytestPytatoPyOpenCLArrayContextFactory ,
231252 "pytato:jax" : _PytestPytatoJaxArrayContextFactory ,
232253 "eagerjax" : _PytestEagerJaxArrayContextFactory ,
254+ "numpy" : _PytestNumpyArrayContextFactory ,
233255 }
234256
235257
0 commit comments