Skip to content

Commit fb42c99

Browse files
committed
Test BatchedEinsumPytatoPyOpenCLArrayContext
1 parent 1d8c2f2 commit fb42c99

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

test/test_batched_einsum_actx.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import pytest
2+
3+
4+
try:
5+
import feinsum # noqa: F401
6+
except ModuleNotFoundError:
7+
pytest.skip(reason="BatchedEinsumActx imposes feinsum as a hard dep.",
8+
allow_module_level=True)
9+
10+
try:
11+
from loopy import get_kennedy_unweighted_fusion_candidates # noqa: F401
12+
from loopy import rename_inames_in_batch # noqa: F401
13+
except ImportError:
14+
pytest.skip(reason="BatchedEinsumActx imposes loop-fusion support in "
15+
"loopy as a hard dep.", allow_module_level=True)
16+
17+
import numpy as np
18+
19+
from pytools.tag import UniqueTag
20+
21+
from arraycontext import (
22+
BatchedEinsumPytatoPyOpenCLArrayContext, PyOpenCLArrayContext,
23+
PytatoPyOpenCLArrayContext, tag_axes)
24+
from arraycontext.pytest import (
25+
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
26+
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory,
27+
_PytestSplitPytatoPyOpenCLArrayContextFactory,
28+
pytest_generate_tests_for_array_contexts)
29+
30+
31+
# {{{ axes tag types for image processing
32+
33+
class ImageDimensionTag(UniqueTag):
34+
"""
35+
An abstract tag type that is tagged to an array's axis indexing along an image's
36+
axis.
37+
"""
38+
39+
40+
class XDimension(ImageDimensionTag):
41+
"""
42+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
43+
x-dimension of an image.
44+
"""
45+
46+
47+
class YDimension(ImageDimensionTag):
48+
"""
49+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
50+
y-dimension of an image.
51+
"""
52+
53+
54+
class ChannelDimension(ImageDimensionTag):
55+
"""
56+
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
57+
channels of an image.
58+
"""
59+
60+
# }}}
61+
62+
63+
# {{{ array context fixture
64+
65+
class ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc(
66+
BatchedEinsumPytatoPyOpenCLArrayContext):
67+
def __init__(self, queue, allocator=None):
68+
super().__init__(queue, allocator,
69+
fallback_to_no_fusion=False,
70+
loop_fusion_axis_tag_t=ImageDimensionTag)
71+
72+
73+
class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext):
74+
"""Like :class:`PyOpenCLArrayContext`, but applies no program transformations
75+
whatsoever. Only to be used for testing internal to :mod:`arraycontext`.
76+
"""
77+
78+
def transform_loopy_program(self, t_unit):
79+
return t_unit
80+
81+
82+
class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext):
83+
"""Like :class:`PytatoPyOpenCLArrayContext`, but applies no program
84+
transformations whatsoever. Only to be used for testing internal to
85+
:mod:`arraycontext`.
86+
"""
87+
88+
def transform_loopy_program(self, t_unit):
89+
return t_unit
90+
91+
92+
class _PytatoPyOpenCLArrayContextForTestsFactory(
93+
_PytestPytatoPyOpenCLArrayContextFactory):
94+
actx_class = _PytatoPyOpenCLArrayContextForTests
95+
96+
97+
class _PyOpenCLArrayContextForTestsFactoryWithHostScalars(
98+
_PytestPyOpenCLArrayContextFactoryWithClass):
99+
force_device_scalars = True
100+
actx_class = _PyOpenCLArrayContextForTests
101+
102+
103+
class _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory(
104+
_PytestPytatoPyOpenCLArrayContextFactory):
105+
@property
106+
def actx_class(self):
107+
return ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc
108+
109+
110+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
111+
_PyOpenCLArrayContextForTestsFactoryWithHostScalars,
112+
_PytatoPyOpenCLArrayContextForTestsFactory,
113+
_PytestEagerJaxArrayContextFactory,
114+
_PytestPytatoJaxArrayContextFactory,
115+
_PytestSplitPytatoPyOpenCLArrayContextFactory,
116+
_PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory,
117+
])
118+
119+
# }}}
120+
121+
122+
def test_simple_add(actx_factory):
123+
# Lesson 01 of Halide Tutorial
124+
actx = actx_factory()
125+
126+
rng = np.random.default_rng(0)
127+
a_np = rng.random((800, 600))
128+
b_np = rng.random((800, 600))
129+
a = actx.from_numpy(a_np)
130+
b = actx.from_numpy(b_np)
131+
132+
a = tag_axes(actx, {0: XDimension(), 1: YDimension()}, a)
133+
b = tag_axes(actx, {0: XDimension(), 1: YDimension()}, b)
134+
135+
out = actx.to_numpy(a + b)
136+
ref_out = a_np + b_np
137+
138+
np.testing.assert_allclose(out, ref_out)
139+
140+
141+
def test_brighten_image(actx_factory):
142+
# Lesson 02 of Halide Tutorial
143+
actx = actx_factory()
144+
145+
rng = np.random.default_rng(0)
146+
147+
img_np = 255*rng.random((800, 600, 3), dtype=np.float32)
148+
149+
img = actx.from_numpy(img_np)
150+
img = tag_axes(actx,
151+
{0: XDimension(), 1: YDimension(), 2: ChannelDimension()},
152+
img)
153+
154+
brightened_img = 1.5*img
155+
clamped_brightened_img = actx.np.minimum(brightened_img, np.float32(255))
156+
157+
out = actx.to_numpy(clamped_brightened_img)
158+
ref_out = np.minimum(1.5*img_np, np.float32(255))
159+
160+
np.testing.assert_allclose(out, ref_out)
161+
162+
163+
def test_simple_einsum(actx_factory):
164+
actx = actx_factory()
165+
166+
rng = np.random.default_rng()
167+
168+
a_np = rng.random((10, 4))
169+
a = actx.from_numpy(a_np)
170+
a = tag_axes(actx,
171+
{0: XDimension(), 1: YDimension()}, a)
172+
173+
out1 = actx.einsum("ij,ij->i", a, a+1)
174+
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)
175+
176+
ref_out = (np.einsum("ij,ij->i", a_np, a_np + 1)
177+
+ np.einsum("ij,ij->i", 2*a_np, 3*a_np+7))
178+
out = actx.to_numpy(out1 + out2)
179+
180+
np.testing.assert_allclose(ref_out, out)
181+
182+
183+
def test_nested_einsum(actx_factory):
184+
actx = actx_factory()
185+
186+
rng = np.random.default_rng()
187+
188+
a_np = rng.random((10, 4))
189+
190+
# {{{ compute out
191+
192+
a = actx.from_numpy(a_np)
193+
a = tag_axes(actx,
194+
{0: XDimension(), 1: YDimension()}, a)
195+
b = a + 1
196+
197+
out1 = actx.einsum("ij,ij->i", a, b)
198+
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)
199+
out3 = actx.einsum("ij,i->i", 3*b, 2*out1)
200+
201+
out = actx.to_numpy(out1 + out2 + out3)
202+
203+
# }}}
204+
205+
# {{{ compute ref_out
206+
207+
b_np = a_np + 1
208+
out1_np = np.einsum("ij,ij->i", a_np, a_np+1)
209+
out2_np = np.einsum("ij,ij->i", 2*a_np, 3*a_np+7)
210+
out3_np = np.einsum("ij,i->i", 3*b_np, 2*out1_np)
211+
ref_out = out1_np + out2_np + out3_np
212+
213+
# }}}
214+
215+
np.testing.assert_allclose(ref_out, out)
216+
217+
# vim: fdm=marker

0 commit comments

Comments
 (0)