Skip to content

Commit 9bba8e4

Browse files
committed
Add SplitPytatoArrayContext to the test suite
1 parent 33cdfc4 commit 9bba8e4

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

arraycontext/pytest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,22 @@ def __call__(self):
170170
return self.actx_class(queue, allocator=alloc)
171171

172172
def __str__(self):
173-
return ("<PytatoPyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>>" %
173+
return ("<%s for <pyopencl.Device '%s' on '%s'>>" %
174174
(
175+
self.__class__.__name__,
175176
self.device.name.strip(),
176177
self.device.platform.name.strip()))
177178

178179

180+
class _PytestSplitPytatoPyOpenCLArrayContextFactory(
181+
_PytestPytatoPyOpenCLArrayContextFactory):
182+
@property
183+
def actx_class(self):
184+
from arraycontext.impl.pytato.split_actx import (
185+
SplitPytatoPyOpenCLArrayContext)
186+
return SplitPytatoPyOpenCLArrayContext
187+
188+
179189
class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
180190
def __init__(self, *args, **kwargs):
181191
pass
@@ -231,6 +241,7 @@ def __str__(self):
231241
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
232242
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
233243
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
244+
"pytato:split": _PytestSplitPytatoPyOpenCLArrayContextFactory,
234245
"eagerjax": _PytestEagerJaxArrayContextFactory,
235246
}
236247

test/test_arraycontext.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
serialize_container, tag_axes, with_array_context, with_container_arithmetic)
3737
from arraycontext.pytest import (
3838
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
39-
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory)
39+
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory,
40+
_PytestSplitPytatoPyOpenCLArrayContextFactory)
4041

4142

4243
logger = logging.getLogger(__name__)
@@ -84,6 +85,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory(
8485
_PytatoPyOpenCLArrayContextForTestsFactory,
8586
_PytestEagerJaxArrayContextFactory,
8687
_PytestPytatoJaxArrayContextFactory,
88+
_PytestSplitPytatoPyOpenCLArrayContextFactory,
8789
])
8890

8991

0 commit comments

Comments
 (0)