Skip to content

Commit e6fa313

Browse files
committed
Implemented BatchedEinsumArrayContext
1 parent bb03c86 commit e6fa313

File tree

3 files changed

+851
-0
lines changed

3 files changed

+851
-0
lines changed

arraycontext/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .impl.jax import EagerJAXArrayContext
5454
from .impl.pyopencl import PyOpenCLArrayContext
5555
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
56+
from .impl.pytato.batched_einsum import BatchedEinsumPytatoPyOpenCLArrayContext
5657
from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext
5758
from .loopy import make_loopy_program
5859
# deprecated, remove in 2022.
@@ -100,6 +101,7 @@
100101

101102
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
102103
"SplitPytatoPyOpenCLArrayContext",
104+
"BatchedEinsumPytatoPyOpenCLArrayContext",
103105

104106
"PytatoJAXArrayContext",
105107
"EagerJAXArrayContext",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
"""
2+
.. autoclass:: BatchedEinsumPytatoPyOpenCLArrayContext """
3+
4+
__copyright__ = """
5+
Copyright (C) 2023 Kaushik Kulkarni
6+
Copyright (C) 2022 Andreas Kloeckner
7+
Copyright (C) 2022 Matthias Diener
8+
Copyright (C) 2022 Matt Smith
9+
"""
10+
11+
__license__ = """
12+
Permission is hereby granted, free of charge, to any person obtaining a copy
13+
of this software and associated documentation files (the "Software"), to deal
14+
in the Software without restriction, including without limitation the rights
15+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16+
copies of the Software, and to permit persons to whom the Software is
17+
furnished to do so, subject to the following conditions:
18+
19+
The above copyright notice and this permission notice shall be included in
20+
all copies or substantial portions of the Software.
21+
22+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28+
THE SOFTWARE.
29+
"""
30+
31+
32+
import logging
33+
import sys
34+
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
35+
from warnings import warn
36+
37+
import numpy as np
38+
39+
import loopy as lp
40+
from pytools import ProcessLogger
41+
from pytools.tag import Tag
42+
43+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
44+
45+
46+
logger = logging.getLogger(__name__)
47+
48+
49+
if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False):
50+
import pyopencl as cl
51+
import pytato
52+
53+
54+
class BatchedEinsumPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
55+
r"""
56+
.. attribute:: loop_fusion_axis_tag_t
57+
58+
A subtype of :class:`pytato.tag.Tag` that are attached to the
59+
:class:`~pytato.array.Array`\ 's axes in an expression graph. Loops that
60+
iterate over axes tagged with instances of same such tag types will form the
61+
candidate loops for Kennedy's unweighted Loop Fusion algorithm.
62+
63+
.. attribute:: fallback_to_no_fusion
64+
65+
If *True*, during the compilation of an array expression graph for which
66+
loop fusion fails (see note) transformation routines from
67+
:class:`arraycontext.SplitPytatoPyOpenCLArrayContext` are invoked.
68+
69+
.. attribute:: feinsum_db
70+
71+
An instance of :class:`str` corresponding to the database of tuned batched
72+
einsums. If *None*, then a static transformation strategy is applied to the
73+
batched einsums kernels.
74+
75+
.. attribute:: log_loopy_statistics
76+
77+
If *True*, statistics of compiled :class:`loopy.TranslationUnit` will be
78+
logged. If enable, we log the FLOPS and global memory access footprint for
79+
each of the programs. If *False*, nothing is done.
80+
81+
.. note::
82+
83+
The conditions under which we fallback (or raise) are:
84+
85+
#. There exists an array that is to be materialized but at least one of its
86+
axes is not tagged with tags of :attr:`loop_fusion_axis_tag_t`.
87+
"""
88+
def __init__(
89+
self,
90+
queue: "cl.CommandQueue", allocator=None,
91+
*,
92+
loop_fusion_axis_tag_t: Type[Tag],
93+
fallback_to_no_fusion: bool = True,
94+
assume_all_indirection_maps_as_non_negative: bool = False,
95+
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
96+
feinsum_db: Optional[str] = None,
97+
log_loopy_statistics: bool = False,
98+
fused_loop_name_prefix_getter: Optional[Callable[[Tag], str]] = None,
99+
) -> None:
100+
super().__init__(queue,
101+
allocator,
102+
compile_trace_callback=compile_trace_callback)
103+
104+
self.loop_fusion_axis_tag_t = loop_fusion_axis_tag_t
105+
self.fallback_to_no_fusion = fallback_to_no_fusion
106+
self.feinsum_db = feinsum_db
107+
self.assume_all_indirection_maps_as_non_negative = (
108+
assume_all_indirection_maps_as_non_negative)
109+
self.log_loopy_statistics = log_loopy_statistics
110+
if fused_loop_name_prefix_getter:
111+
self.fused_loop_name_prefix_getter = fused_loop_name_prefix_getter
112+
else:
113+
self.fused_loop_name_prefix_getter = lambda tag_t: "ifused"
114+
115+
def transform_dag(self,
116+
dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays":
117+
import pytato as pt
118+
119+
from .utils import (
120+
_make_passthrough_arg, get_indirection_maps,
121+
get_inputs_and_outputs_of_reduction_nodes)
122+
from arraycontext.impl.pytato.split_actx.utils import (
123+
get_inputs_and_outputs_of_einsum)
124+
125+
# Step 1. Collapse equivalent nodes in DAG.
126+
# -----------------------------------------
127+
# type-ignore-reason: mypy is right pytato provides imprecise types.
128+
dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment]
129+
130+
# Step 2. Materialize einsum/reduction outputs.
131+
# ---------------------------------------------
132+
_, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
133+
_, reduction_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
134+
135+
def materialize_all_einsums_or_reduces(expr):
136+
if (expr in einsum_outputs
137+
or expr in reduction_outputs):
138+
return expr.tagged(pt.tags.ImplStored())
139+
else:
140+
return expr
141+
142+
# type-ignore-reason: mypy is right pytato provides imprecise types.
143+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
144+
materialize_all_einsums_or_reduces)
145+
146+
# Step 3. Materialize with MPMS
147+
# -----------------------------
148+
dag = pt.transform.materialize_with_mpms(dag)
149+
150+
# Step 4. Mark all indirection maps as non-negative
151+
# -------------------------------------------------
152+
if self.assume_all_indirection_maps_as_non_negative:
153+
indirection_maps = get_indirection_maps(dag)
154+
155+
def tag_indices_as_non_negative(ary):
156+
if ary in indirection_maps:
157+
return ary.tagged(pt.tags.AssumeNonNegative())
158+
else:
159+
return ary
160+
161+
# type-ignore-reason: mypy is right pytato provides imprecise types.
162+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
163+
tag_indices_as_non_negative)
164+
165+
# Step 5. Get rid of broadcasts in einsum expressions (helps feinsum)
166+
# -------------------------------------------------------------------
167+
dag = pt.rewrite_einsums_with_no_broadcasts(dag)
168+
169+
# Step 6. Infer axis tags
170+
# -----------------------
171+
# type-ignore-reason: mypy is right pytato provides imprecise types.
172+
dag = pt.unify_axes_tags(dag) # type: ignore[assignment]
173+
174+
# Step 7. Make all pt.einsum/pt.reduction inputs as substitutions
175+
# ---------------------------------------------------------------
176+
def implement_einsum_reduction_inputs_as_substs(expr):
177+
from immutables import Map
178+
179+
from pytato.target.loopy import ImplSubstitution
180+
if isinstance(expr, pt.Einsum):
181+
# make the arguments passthrough to make use of already stored
182+
# values.
183+
# pylint and 'attrs' have poor compatibility
184+
# pylint: disable=too-many-function-args,redundant-keyword-arg
185+
# pylint: disable=unexpected-keyword-arg
186+
return pt.Einsum(
187+
expr.access_descriptors,
188+
tuple(_make_passthrough_arg(arg, ImplSubstitution())
189+
for arg in expr.args),
190+
expr.redn_axis_to_redn_descr,
191+
expr.index_to_access_descr,
192+
tags=expr.tags,
193+
axes=expr.axes,
194+
)
195+
elif isinstance(expr, pt.IndexLambda) and expr.var_to_reduction_descr:
196+
# make the arguments passthrough to make use of already stored
197+
# values.
198+
# pylint: disable=too-many-function-args,redundant-keyword-arg
199+
# pylint: disable=unexpected-keyword-arg
200+
return pt.IndexLambda(
201+
expr.expr,
202+
expr.shape,
203+
expr.dtype,
204+
Map({name: _make_passthrough_arg(bnd, ImplSubstitution())
205+
for name, bnd in expr.bindings.items()}),
206+
expr.var_to_reduction_descr,
207+
tags=expr.tags,
208+
axes=expr.axes,
209+
)
210+
else:
211+
return expr
212+
213+
# type-ignore-reason: mypy is right pytato provides imprecise types.
214+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
215+
implement_einsum_reduction_inputs_as_substs)
216+
217+
return dag
218+
219+
def transform_loopy_program(self,
220+
t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
221+
knl_name = t_unit.default_entrypoint.name
222+
223+
logger.info(f"[{self.__class__}.transform_loopy_program]:"
224+
f" Transforming kernel '{knl_name}' with"
225+
f" {len(t_unit.default_entrypoint.instructions)} statements.")
226+
227+
# Step 0. Fallback if cannot t_unit cannot be transformed
228+
# -------------------------------------------------------
229+
for iname in t_unit.default_entrypoint.all_inames():
230+
if not t_unit.default_entrypoint.iname_tags_of_type(
231+
iname, self.loop_fusion_axis_tag_t):
232+
if self.fallback_to_no_fusion:
233+
warn(f"[{knl_name}]: Falling back to a slower transformation"
234+
" strategy as some loops are uninferred which mesh entity"
235+
" they belong to.",
236+
stacklevel=2)
237+
from arraycontext.impl.pytato.split_actx import (
238+
SplitPytatoPyOpenCLArrayContext)
239+
240+
# type-ignore-reason: mypy is right, we are passing incorrect
241+
# types, but knowing the implementation of
242+
# SplitPytatoPyOpenCLArrayContext this should be fine.
243+
return SplitPytatoPyOpenCLArrayContext.transform_loopy_program(
244+
self, t_unit) # type: ignore[arg-type]
245+
else:
246+
raise RuntimeError(f"Iname '{iname}' is not tagged with tags"
247+
f" of type '{self.loop_fusion_axis_tag_t}'"
248+
" => Not allowed since Kennedy's Loop fusion"
249+
" cannot be applied.")
250+
251+
# Step 0.5. Make offsets as 0. (FIXME: move this to loopy knl invocation)
252+
# -----------------------------------------------------------------------
253+
knl = t_unit.default_entrypoint
254+
knl = knl.copy(args=[arg.copy(offset=0) for arg in knl.args])
255+
t_unit = t_unit.with_kernel(knl)
256+
del knl
257+
258+
# Step 1. Fuse loops indexing over the same tag
259+
# ---------------------------------------------
260+
with ProcessLogger(logger, f"[{knl_name}] Loop Fusion"):
261+
from .utils import apply_kennedy_fusion_with_batched_einsum_extension
262+
t_unit = apply_kennedy_fusion_with_batched_einsum_extension(
263+
t_unit, self.loop_fusion_axis_tag_t,
264+
self.fused_loop_name_prefix_getter)
265+
266+
# Step 2. Combine the domains of individual loop nests into individual
267+
# BasicSets
268+
# --------------------------------------------------------------------
269+
from .utils import combine_domains_of_perfect_loop_nests
270+
t_unit = combine_domains_of_perfect_loop_nests(t_unit)
271+
272+
# Step 3. Remove dead temporaries
273+
# -------------------------------
274+
from .utils import remove_dead_temporaries
275+
t_unit = remove_dead_temporaries(t_unit)
276+
277+
# Step 4. Contract arrays
278+
# -----------------------
279+
with ProcessLogger(logger, f"[{knl_name}] Array Contraction"):
280+
from .utils import contract_arrays
281+
t_unit = contract_arrays(t_unit)
282+
283+
# Step 5. Collect statistics
284+
# --------------------------
285+
286+
# {{{ compute stats
287+
288+
if self.log_loopy_statistics:
289+
290+
with ProcessLogger(logger, f"[{knl_name}] Count kernel metrics"):
291+
from loopy.kernel.array import ArrayBase
292+
from pytools import product
293+
knl = t_unit.default_entrypoint
294+
knl = knl.copy(
295+
silenced_warnings=(knl.silenced_warnings
296+
+ ["insn_count_subgroups_upper_bound",
297+
"summing_if_branches_ops"]))
298+
299+
t_unit = t_unit.with_kernel(knl)
300+
del knl
301+
302+
op_map = lp.get_op_map(t_unit, subgroup_size=32)
303+
304+
c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64],
305+
name=op_type,
306+
kernel_name=knl_name)
307+
.eval_and_sum({}))
308+
for op_type in ["add", "mul", "div"]}
309+
c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128],
310+
name=op_type,
311+
kernel_name=knl_name)
312+
.eval_and_sum({}))
313+
for op_type in ["add", "mul", "div"]}
314+
f32_ops = ((op_map.filter_by(dtype=[np.float32],
315+
kernel_name=knl_name)
316+
.eval_and_sum({}))
317+
+ (2 * c64_ops["add"]
318+
+ 6 * c64_ops["mul"]
319+
+ (6 + 3 + 2) * c64_ops["div"]))
320+
f64_ops = ((op_map.filter_by(dtype=[np.float64],
321+
kernel_name="_pt_kernel")
322+
.eval_and_sum({}))
323+
+ (2 * c128_ops["add"]
324+
+ 6 * c128_ops["mul"]
325+
+ (6 + 3 + 2) * c128_ops["div"]))
326+
327+
# {{{ footprint gathering
328+
329+
nfootprint_bytes = 0
330+
331+
for ary in knl.args:
332+
if (isinstance(ary, ArrayBase)
333+
and ary.address_space == lp.AddressSpace.GLOBAL):
334+
nfootprint_bytes += (product(ary.shape)
335+
* ary.dtype.itemsize)
336+
337+
for ary in knl.temporary_variables.values():
338+
if ary.address_space == lp.AddressSpace.GLOBAL:
339+
# global temps would be written once and read once
340+
nfootprint_bytes += (2 * product(ary.shape)
341+
* ary.dtype.itemsize)
342+
343+
# }}}
344+
345+
if f32_ops:
346+
logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}")
347+
if f64_ops:
348+
logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}")
349+
logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}")
350+
351+
# }}}
352+
353+
# Step 6. Draw kernel boundaries between batched einsum kernels
354+
# -------------------------------------------------------------
355+
from arraycontext.impl.pytato.split_actx.utils import (
356+
add_gbarrier_between_disjoint_loop_nests)
357+
358+
t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit)
359+
360+
# Step 7. Alias global temporaries with disjoint live intervals
361+
# -------------------------------------------------------------
362+
from arraycontext.impl.pytato.split_actx.utils import (
363+
alias_global_temporaries)
364+
t_unit = alias_global_temporaries(t_unit)
365+
366+
# Step 8. Macro-kernel optimizations
367+
# ----------------------------------
368+
if self.feinsum_db:
369+
from .utils import apply_feinsum_transformations
370+
t_unit = apply_feinsum_transformations(
371+
t_unit, self.feinsum_db, self.queue.device)
372+
else:
373+
from arraycontext.impl.pytato.split_actx.utils import (
374+
parallelize_reduce_to_scalars,
375+
split_iteration_domain_across_work_items)
376+
t_unit = split_iteration_domain_across_work_items(t_unit,
377+
self.queue.device)
378+
t_unit = parallelize_reduce_to_scalars(t_unit, self.queue.device)
379+
380+
return t_unit
381+
382+
# vim: fdm=marker

0 commit comments

Comments
 (0)