@@ -170,12 +170,22 @@ def __call__(self):
170170 return self .actx_class (queue , allocator = alloc )
171171
172172 def __str__ (self ):
173- return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" %
173+ return ("<%s for <pyopencl.Device '%s' on '%s'>>" %
174174 (
175+ self .__class__ .__name__ ,
175176 self .device .name .strip (),
176177 self .device .platform .name .strip ()))
177178
178179
180+ class _PytestSplitPytatoPyOpenCLArrayContextFactory (
181+ _PytestPytatoPyOpenCLArrayContextFactory ):
182+ @property
183+ def actx_class (self ):
184+ from arraycontext .impl .pytato .split_actx import (
185+ SplitPytatoPyOpenCLArrayContext )
186+ return SplitPytatoPyOpenCLArrayContext
187+
188+
179189class _PytestEagerJaxArrayContextFactory (PytestArrayContextFactory ):
180190 def __init__ (self , * args , ** kwargs ):
181191 pass
@@ -231,6 +241,7 @@ def __str__(self):
231241 _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars ,
232242 "pytato:pyopencl" : _PytestPytatoPyOpenCLArrayContextFactory ,
233243 "pytato:jax" : _PytestPytatoJaxArrayContextFactory ,
244+ "pytato:split" : _PytestSplitPytatoPyOpenCLArrayContextFactory ,
234245 "eagerjax" : _PytestEagerJaxArrayContextFactory ,
235246 }
236247
0 commit comments