@@ -170,12 +170,22 @@ def __call__(self):
170
170
return self .actx_class (queue , allocator = alloc )
171
171
172
172
def __str__ (self ):
173
- return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" %
173
+ return ("<%s for <pyopencl.Device '%s' on '%s'>>" %
174
174
(
175
+ self .__class__ .__name__ ,
175
176
self .device .name .strip (),
176
177
self .device .platform .name .strip ()))
177
178
178
179
180
+ class _PytestSplitPytatoPyOpenCLArrayContextFactory (
181
+ _PytestPytatoPyOpenCLArrayContextFactory ):
182
+ @property
183
+ def actx_class (self ):
184
+ from arraycontext .impl .pytato .split_actx import (
185
+ SplitPytatoPyOpenCLArrayContext )
186
+ return SplitPytatoPyOpenCLArrayContext
187
+
188
+
179
189
class _PytestEagerJaxArrayContextFactory (PytestArrayContextFactory ):
180
190
def __init__ (self , * args , ** kwargs ):
181
191
pass
@@ -231,6 +241,7 @@ def __str__(self):
231
241
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars ,
232
242
"pytato:pyopencl" : _PytestPytatoPyOpenCLArrayContextFactory ,
233
243
"pytato:jax" : _PytestPytatoJaxArrayContextFactory ,
244
+ "pytato:split" : _PytestSplitPytatoPyOpenCLArrayContextFactory ,
234
245
"eagerjax" : _PytestEagerJaxArrayContextFactory ,
235
246
}
236
247
0 commit comments