35
35
from typing import Any , Callable , Dict , Sequence , Type , Union
36
36
37
37
from arraycontext .context import ArrayContext
38
+ from arraycontext import NumpyArrayContext
38
39
39
40
40
41
# {{{ array context factories
@@ -195,6 +196,26 @@ def __str__(self):
195
196
return "<PytatoJAXArrayContext>"
196
197
197
198
199
+ # {{{ _PytestArrayContextFactory
200
+
201
+ class _NumpyArrayContextForTests (NumpyArrayContext ):
202
+ def transform_loopy_program (self , t_unit ):
203
+ return t_unit
204
+
205
+
206
+ class _PytestNumpyArrayContextFactory (PytestArrayContextFactory ):
207
+ def __init__ (self , * args , ** kwargs ):
208
+ super ().__init__ ()
209
+
210
+ def __call__ (self ):
211
+ return _NumpyArrayContextForTests ()
212
+
213
+ def __str__ (self ):
214
+ return "<NumpyArrayContext>"
215
+
216
+ # }}}
217
+
218
+
198
219
_ARRAY_CONTEXT_FACTORY_REGISTRY : \
199
220
Dict [str , Type [PytestArrayContextFactory ]] = {
200
221
"pyopencl" : _PytestPyOpenCLArrayContextFactoryWithClass ,
@@ -203,6 +224,7 @@ def __str__(self):
203
224
"pytato:pyopencl" : _PytestPytatoPyOpenCLArrayContextFactory ,
204
225
"pytato:jax" : _PytestPytatoJaxArrayContextFactory ,
205
226
"eagerjax" : _PytestEagerJaxArrayContextFactory ,
227
+ "numpy" : _PytestNumpyArrayContextFactory ,
206
228
}
207
229
208
230
0 commit comments