diff --git a/loopy/__init__.py b/loopy/__init__.py index 7a1942f3d..d23b37d8f 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -116,6 +116,7 @@ from loopy.target.execution import ExecutorBase from loopy.target.ispc import ISPCTarget from loopy.target.opencl import OpenCLTarget +from loopy.target.pycuda import PyCudaTarget, PyCudaWithPackedArgsTarget from loopy.target.pyopencl import PyOpenCLTarget from loopy.tools import Optional, clear_in_mem_caches, memoize_on_disk, t_unit_to_python from loopy.transform.add_barrier import add_barrier @@ -147,6 +148,7 @@ tag_array_axes, tag_data_axes, ) +from loopy.transform.domain import decouple_domain from loopy.transform.fusion import fuse_kernels from loopy.transform.iname import ( add_inames_for_unused_hw_axes, @@ -183,6 +185,10 @@ simplify_indices, tag_instructions, ) +from loopy.transform.loop_fusion import ( + get_kennedy_unweighted_fusion_candidates, + rename_inames_in_batch +) from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call from loopy.transform.padding import ( add_padding, @@ -198,6 +204,10 @@ unprivatize_temporaries_with_inames, ) from loopy.transform.realize_reduction import realize_reduction +from loopy.transform.reduction import ( + hoist_invariant_multiplicative_terms_in_sum_reduction, + extract_multiplicative_terms_in_sum_reduction_as_subst) +from loopy.transform.reindex import reindex_temporary_using_seghir_loechner_scheme from loopy.transform.save import save_and_reload_temporaries from loopy.transform.subst import ( assignment_to_subst, @@ -265,6 +275,8 @@ "Options", "OrderedAtomic", "PreambleInfo", + "PyCudaTarget", + "PyCudaWithPackedArgsTarget", "PyOpenCLTarget", "Reduction", "ScalarCallable", @@ -305,8 +317,10 @@ "clear_in_mem_caches", "collect_common_factors_on_increment", "concatenate_arrays", + "decouple_domain", "duplicate_inames", "expand_subst", + "extract_multiplicative_terms_in_sum_reduction_as_subst", "extract_subst", "find_instructions", "find_most_recent_global_barrier", @@ -328,6 +342,7 @@ "get_dot_dependency_graph", "get_global_barrier_order", "get_iname_duplication_options", + "get_kennedy_unweighted_fusion_candidates", "get_mem_access_map", "get_one_linearized_kernel", "get_one_scheduled_kernel", @@ -336,6 +351,7 @@ "get_subkernels", "get_synchronization_map", "has_schedulable_iname_nesting", + "hoist_invariant_multiplicative_terms_in_sum_reduction", "infer_arg_descr", "infer_unknown_types", "inline_callable_kernel", @@ -365,6 +381,7 @@ "register_preamble_generators", "register_reduction_parser", "register_symbol_manglers", + "reindex_temporary_using_seghir_loechner_scheme", "remove_inames_from_insn", "remove_instructions", "remove_predicates_from_insn", @@ -374,6 +391,7 @@ "rename_callable", "rename_iname", "rename_inames", + "rename_inames_in_batch", "replace_instruction_ids", "save_and_reload_temporaries", "set_argument_order", @@ -402,6 +420,14 @@ "untag_inames", ] +try: + import loopy.relations as relations +except ImportError: + # catching ImportErrors to avoid making minikanren a hard-dep + pass +else: + __all__ += ["relations"] + # }}} diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index 44bfa07cc..1dbdede8f 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -27,10 +27,12 @@ import islpy as isl from islpy import dim_type from pymbolic.mapper.stringifier import PREC_NONE +from typing import FrozenSet from loopy.codegen.control import build_loop_nest from loopy.codegen.result import merge_codegen_results from loopy.diagnostic import LoopyError, warn +from loopy.kernel import LoopKernel from loopy.symbolic import flatten @@ -351,6 +353,16 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func, # {{{ sequential loop +def _get_intersecting_inames(kernel: LoopKernel, iname: str) -> FrozenSet[str]: + from functools import reduce + return reduce(frozenset.union, + ((kernel.id_to_insn[insn].within_inames + | kernel.id_to_insn[insn].reduction_inames() + | kernel.id_to_insn[insn].sub_array_ref_inames()) + for insn in kernel.iname_to_insns()[iname]), + frozenset()) + + def generate_sequential_loop_dim_code(codegen_state, sched_index, hints): kernel = codegen_state.kernel @@ -362,8 +374,18 @@ def generate_sequential_loop_dim_code(codegen_state, sched_index, hints): from loopy.codegen.bounds import get_usable_inames_for_conditional # Note: this does not include loop_iname itself! + + # usable_inames = get_usable_inames_for_conditional( + # kernel, sched_index, codegen_state.codegen_cachemanager) + + # # get rid of disjoint loop nests, see + # # + # usable_inames = usable_inames & _get_intersecting_inames(kernel, + # loop_iname) + # ======= usable_inames = get_usable_inames_for_conditional(kernel, sched_index, codegen_state.codegen_cache_manager) + # >>>>>>> main domain = kernel.get_inames_domain(loop_iname) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 856ba19c5..3899ca6a9 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -2140,6 +2140,65 @@ def get_outer_params(domains): # }}} +# {{{ get access map from an instruction + + +class _IndexCollector(CombineMapper): + def __init__(self, var): + self.var = var + super().__init__() + + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) + + def map_subscript(self, expr): + if expr.aggregate.name == self.var: + return (super().map_subscript(expr) | frozenset([expr.index_tuple])) + else: + return super().map_subscript(expr) + + def map_algebraic_leaf(self, expr): + return frozenset() + + map_constant = map_algebraic_leaf + + +def _project_out_inames_from_maps(amaps, inames_to_project_out): + new_amaps = [] + for amap in amaps: + for iname in inames_to_project_out: + dt, pos = amap.get_var_dict()[iname] + amap = amap.project_out(dt, pos, 1) + + new_amaps.append(amap) + + return new_amaps + + +def _union_amaps(amaps): + import islpy as isl + return reduce(isl.Map.union, amaps[1:], amaps[0]) + + +def get_insn_access_map(kernel, insn_id, var): + from loopy.transform.subst import expand_subst + from loopy.symbolic import get_access_map + + insn = kernel.id_to_insn[insn_id] + + kernel = expand_subst(kernel) + indices = list(_IndexCollector(var)((insn.expression, + insn.assignees, + tuple(insn.predicates)))) + + amaps = [get_access_map(kernel.get_inames_domain(insn.within_inames), + idx, kernel.assumptions) for idx in indices] + + return _union_amaps(amaps) + +# }}} + def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff: """ diff --git a/loopy/relations.py b/loopy/relations.py new file mode 100644 index 000000000..5d47bfa1d --- /dev/null +++ b/loopy/relations.py @@ -0,0 +1,122 @@ +from kanren import Relation, facts + + +def get_inameo(kernel): + inameo = Relation() + for iname in kernel.all_inames(): + facts(inameo, (iname,)) + return inameo + + +def get_argo(kernel): + argo = Relation() + for arg in kernel.args: + facts(argo, (arg.name,)) + + return argo + + +def get_tempo(kernel): + tempo = Relation() + for tv in kernel.temporary_variables: + facts(tempo, (tv,)) + + return tempo + + +def get_insno(kernel): + insno = Relation() + for insn in kernel.instructions: + facts(insno, (insn.id,)) + + return insno + + +def get_taggedo(kernel): + taggedo = Relation() + + for arg_name, arg in kernel.arg_dict.items(): + for tag in arg.tags: + facts(taggedo, (arg_name, tag)) + + for iname_name, iname in kernel.inames.items(): + for tag in iname.tags: + facts(taggedo, (iname_name, tag)) + + for insn in kernel.instructions: + for tag in insn.tags: + facts(taggedo, (insn.id, tag)) + + return taggedo + + +def get_taggedo_of_type(kernel, tag_type): + taggedo = Relation() + + for arg_name, arg in kernel.arg_dict.items(): + for tag in arg.tags_of_type(tag_type): + facts(taggedo, (arg_name, tag)) + + for iname_name, iname in kernel.inames.items(): + for tag in iname.tags_of_type(tag_type): + facts(taggedo, (iname_name, tag)) + + for insn in kernel.instructions: + for tag in insn.tags_of_type(tag_type): + facts(taggedo, (insn.id, tag)) + + return taggedo + + +def get_producero(kernel): + producero = Relation() + + for insn in kernel.instructions: + for var in insn.assignee_var_names(): + facts(producero, (insn.id, var)) + + return producero + + +def get_consumero(kernel): + consumero = Relation() + + for insn in kernel.instructions: + for var in insn.read_dependency_names(): + facts(consumero, (insn.id, var)) + + return consumero + + +def get_withino(kernel): + withino = Relation() + + for insn in kernel.instructions: + facts(withino, (insn.id, insn.within_inames)) + + return withino + + +def get_reduce_insno(kernel): + reduce_insno = Relation() + + for insn in kernel.instructions: + if insn.reduction_inames(): + facts(reduce_insno, (insn.id,)) + + return reduce_insno + + +def get_reduce_inameo(kernel): + from functools import reduce + reduce_inameo = Relation() + + for iname in reduce(frozenset.union, + (insn.reduction_inames() + for insn in kernel.instructions), + frozenset()): + facts(reduce_inameo, (iname,)) + + return reduce_inameo + +# vim: fdm=marker diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 8de619fef..4c8fe8429 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -1040,8 +1040,165 @@ def key(x: ScheduleItem) -> tuple[str, ...]: # {{{ legacy scheduling algorithm -def _generate_loop_schedules_internal( - sched_state, debug=None): +def _get_outermost_diverging_inames(tree, within1, within2): + """ + For loop nestings *within1* and *within2*, returns the first inames at which + the loops nests diverge in the loop nesting tree *tree*. + + :arg tree: A :class:`loopy.tools.Tree` of inames, denoting a loop nesting. + :arg within1: A :class:`frozenset` of inames. + :arg within2: A :class:`frozenset` of inames. + """ + common_ancestors = (within1 & within2) | {""} + + innermost_parent = max(common_ancestors, + key=lambda k: tree.depth(k)) + iname1, = frozenset(tree.children(innermost_parent)) & within1 + iname2, = frozenset(tree.children(innermost_parent)) & within2 + + return iname1, iname2 + + +class V2SchedulerNotImplementedException(RuntimeError): + pass + + +def generate_loop_schedules_v2(kernel): + # from loopy.schedule.tools import get_loop_nest_tree + from loopy.schedule.tools import get_loop_tree + from functools import reduce + from pytools.graph import compute_topological_order + from loopy.kernel.data import ConcurrentTag, IlpBaseTag, VectorizeTag + + concurrent_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, ConcurrentTag)} + ilp_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, IlpBaseTag)} + vec_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, VectorizeTag)} + parallel_inames = (concurrent_inames - ilp_inames - vec_inames) + + # {{{ can v2 scheduler handle?? + + if any(len(insn.conflicts_with_groups) != 0 for insn in kernel.instructions): + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " kernels with instruction having conflicts with groups.") + + if any(insn.priority != 0 for insn in kernel.instructions): + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " kernels with instruction priorities set.") + + if kernel.linearization is not None: + # cannnot handle preschedule yet + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " prescheduled kernels.") + + if ilp_inames or vec_inames: + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " loops tagged with 'ilp'/'vec' as they are not guaranteed to" + " be single entry loops.") + + # }}} + + # loop_nest_tree = get_loop_nest_tree(kernel) + loop_nest_tree = get_loop_tree(kernel) + + # loop_inames: inames that are realized as loops. Concurrent inames aren't + # realized as a loop in the generated code for a loopy.TargetBase. + loop_inames = (reduce(frozenset.union, (insn.within_inames + for insn in kernel.instructions), + frozenset()) + - parallel_inames) + + # The idea here is to build a DAG, where nodes are schedule items and if + # there exists an edge from schedule item A to schedule item B in the DAG => + # B *must* come after A in the linearized result. + + dag = {} + + # LeaveLoop(i) *must* follow EnterLoop(i) + dag.update({EnterLoop(iname=iname): frozenset({LeaveLoop(iname=iname)}) + for iname in loop_inames}) + dag.update({LeaveLoop(iname=iname): frozenset() + for iname in loop_inames}) + dag.update({RunInstruction(insn_id=insn.id): frozenset() + for insn in kernel.instructions}) + + # {{{ add constraints imposed by the loop nesting + + for outer_loop in loop_nest_tree.nodes(): + if outer_loop == "": + continue + + for child in loop_nest_tree.children(outer_loop): + inner_loop = child + dag[EnterLoop(iname=outer_loop)] |= {EnterLoop(iname=inner_loop)} + dag[LeaveLoop(iname=inner_loop)] |= {LeaveLoop(iname=outer_loop)} + + # }}} + + # {{{ add deps. b/w schedule items coming from insn. depepdencies + + for insn in kernel.instructions: + insn_loop_inames = insn.within_inames & loop_inames + for dep_id in insn.depends_on: + dep = kernel.id_to_insn[dep_id] + dep_loop_inames = dep.within_inames & loop_inames + # Enforce instruction dep: + dag[RunInstruction(insn_id=dep_id)] |= {RunInstruction(insn_id=insn.id)} + + # {{{ register deps on loop entry/leave because of insn. deps + + if dep_loop_inames < insn_loop_inames: + for iname in insn_loop_inames - dep_loop_inames: + dag[RunInstruction(insn_id=dep.id)] |= {EnterLoop(iname=iname)} + elif insn_loop_inames < dep_loop_inames: + for iname in dep_loop_inames - insn_loop_inames: + dag[LeaveLoop(iname=iname)] |= {RunInstruction(insn_id=insn.id)} + elif dep_loop_inames != insn_loop_inames: + insn_iname, dep_iname = _get_outermost_diverging_inames( + loop_nest_tree, insn_loop_inames, dep_loop_inames) + dag[LeaveLoop(iname=dep_iname)] |= {EnterLoop(iname=insn_iname)} + else: + pass + + # }}} + + for iname in insn_loop_inames: + # For an insn within a loop nest 'i' + # for i + # insn + # end i + # 'insn' *must* come b/w 'for i' and 'end i' + dag[EnterLoop(iname=iname)] |= {RunInstruction(insn_id=insn.id)} + dag[RunInstruction(insn_id=insn.id)] |= {LeaveLoop(iname=iname)} + + # }}} + + def iname_key(iname): + all_ancestors = sorted(loop_nest_tree.ancestors(iname), + key=lambda x: loop_nest_tree.depth(x)) + return ",".join(all_ancestors+[iname]) + + def key(x): + if isinstance(x, RunInstruction): + iname = max((kernel.id_to_insn[x.insn_id].within_inames & loop_inames), + key=lambda k: loop_nest_tree.depth(k), + default="") + result = (iname_key(iname), x.insn_id) + elif isinstance(x, (EnterLoop, LeaveLoop)): + result = (iname_key(x.iname),) + else: + raise NotImplementedError + + return result + + return compute_topological_order(dag, key=key) + + +def generate_loop_schedules_internal(sched_state, debug=None): + # def _generate_loop_schedules_internal( + # sched_state, debug=None): # allow_insn is set to False initially and after entering each loop # to give loops containing high-priority instructions a chance. kernel = sched_state.kernel @@ -1095,7 +1252,7 @@ def _generate_loop_schedules_internal( if isinstance(next_preschedule_item, CallKernel): assert sched_state.within_subkernel is False - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:], @@ -1108,7 +1265,7 @@ def _generate_loop_schedules_internal( assert sched_state.within_subkernel is True # Make sure all subkernel inames have finished. if sched_state.active_inames == sched_state.enclosing_subkernel_inames: - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:], @@ -1127,7 +1284,7 @@ def _generate_loop_schedules_internal( if ( isinstance(next_preschedule_item, Barrier) and next_preschedule_item.originating_insn_id is None): - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:]), @@ -1304,7 +1461,7 @@ def insn_sort_key(insn_id): # Don't be eager about entering/leaving loops--if progress has been # made, revert to top of scheduler and see if more progress can be # made. - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( new_sched_state, debug=debug): yield sub_sched @@ -1400,7 +1557,7 @@ def insn_sort_key(insn_id): if can_leave and not debug_mode: - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( (*sched_state.schedule, @@ -1610,7 +1767,7 @@ def insn_sort_key(insn_id): key_iname), reverse=True): - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( (*sched_state.schedule, EnterLoop(iname=iname))), @@ -2230,14 +2387,16 @@ def _generate_loop_schedules_inner( if debug_args is None: debug_args = {} + debug = ScheduleDebugger(**debug_args) + from loopy.kernel import KernelState if kernel.state not in (KernelState.PREPROCESSED, KernelState.LINEARIZED): raise LoopyError("cannot schedule a kernel that has not been " - "preprocessed") + "preprocessed") from loopy.schedule.tools import V2SchedulerNotImplementedError try: - gen_sched = _generate_loop_schedules_v2(kernel) + gen_sched = generate_loop_schedules_v2(kernel) yield _postprocess_schedule(kernel, callables_table, gen_sched) return @@ -2250,8 +2409,6 @@ def _generate_loop_schedules_inner( schedule_count = 0 - debug = ScheduleDebugger(**debug_args) - preschedule = (kernel.linearization if kernel.state == KernelState.LINEARIZED else ()) @@ -2347,7 +2504,7 @@ def print_longest_dead_end(): debug.debug_length = len(debug.longest_rejected_schedule) while True: try: - for _ in _generate_loop_schedules_internal( + for _ in generate_loop_schedules_internal( sched_state, debug=debug, **schedule_gen_kwargs): pass @@ -2358,7 +2515,7 @@ def print_longest_dead_end(): break try: - for gen_sched in _generate_loop_schedules_internal( + for gen_sched in generate_loop_schedules_internal( sched_state, debug=debug, **schedule_gen_kwargs): debug.stop() diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py index 709e3705e..4c63a1d37 100644 --- a/loopy/schedule/tools.py +++ b/loopy/schedule/tools.py @@ -29,7 +29,7 @@ .. autoclass:: LoopTree .. autofunction:: separate_loop_nest -.. autofunction:: get_partial_loop_nest_tree +.. autofunction:: _get_partial_loop_nest_tree .. autofunction:: get_loop_tree References @@ -69,6 +69,7 @@ from typing_extensions import TypeAlias import islpy as isl + from pytools import memoize_method, memoize_on_first_arg from loopy.diagnostic import LoopyError @@ -954,7 +955,7 @@ def _get_parallel_inames(kernel: LoopKernel) -> AbstractSet[str]: return (concurrent_inames - ilp_inames - vec_inames) -def get_partial_loop_nest_tree(kernel: LoopKernel) -> LoopNestTree: +def _get_partial_loop_nest_tree(kernel: LoopKernel) -> LoopNestTree: """ Returns a tree representing the *kernel*'s loop nests. @@ -1077,7 +1078,7 @@ def get_loop_tree(kernel: LoopKernel) -> LoopTree: """ from islpy import dim_type - tree = get_partial_loop_nest_tree(kernel) + tree = _get_partial_loop_nest_tree(kernel) iname_to_tree_node_id = ( _get_iname_to_tree_node_id_from_partial_loop_nest_tree(tree)) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 20ff55fea..136d0bf07 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -37,7 +37,10 @@ Generic, Mapping, Sequence, + Tuple, TypeAlias, + TypeVar, + Union, cast, ) from warnings import warn @@ -217,6 +220,26 @@ def map_type_annotation( return type(expr)(expr.type, new_child) + def map_type_cast(self, expr, *args, **kwargs): + return self.rec(expr.child, *args, **kwargs) + + def map_sub_array_ref(self, expr, *args, **kwargs): + return self.combine(( + self.rec(expr.subscript, *args, **kwargs), + self.combine(tuple( + self.rec(idx, *args, **kwargs) + for idx in expr.swept_inames)))) + + # def map_sub_array_ref(self, expr, *args, **kwargs): + # new_inames = self.rec(expr.swept_inames, *args, **kwargs) + # new_subscript = self.rec(expr.subscript, *args, **kwargs) + # + # if (all(new_iname is old_iname + # for new_iname, old_iname in zip(new_inames, expr.swept_inames)) + # and new_subscript is expr.subscript): + # return expr + # + # return SubArrayRef(new_inames, new_subscript) def map_sub_array_ref(self, expr, *args: P.args, **kwargs: P.kwargs): new_inames = self.rec(expr.swept_inames, *args, **kwargs) new_subscript = self.rec(expr.subscript, *args, **kwargs) @@ -249,6 +272,16 @@ def is_expr_integer_valued(self, expr: Expression) -> bool: return True +#ArithmeticOrExpressionT = TypeVar( +# "ArithmeticOrExpressionT", +# ArithmeticExpressionT, +# ExpressionT) +ArithmeticOrExpressionT = TypeVar( + "ArithmeticOrExpressionT", + ArithmeticExpression, + Expression) + + def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT: return cast("ArithmeticOrExpressionT", FlattenMapper()(expr)) @@ -2852,4 +2885,29 @@ def is_tuple_of_expressions_equal( # }}} + +def _is_isl_set_universe(isl_set: Union[isl.BasicSet, isl.Set]): + if isinstance(isl_set, isl.BasicSet): + return isl_set.is_universe() + else: + assert isinstance(isl_set, isl.Set) + return isl_set.complement().is_empty() + + +def pw_qpolynomial_to_expr(pw_qpoly: isl.PwQPolynomial + ) -> ExpressionT: + from pymbolic.primitives import If + + result = 0 + + for bset, qpoly in reversed(pw_qpoly.get_pieces()): + if _is_isl_set_universe(bset): + result = qpolynomial_to_expr(qpoly) + else: + result = If(set_to_cond_expr(bset), + qpolynomial_to_expr(qpoly), + result) + + return result + # vim: foldmethod=marker diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index c170fb323..dfa46754b 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -265,8 +265,9 @@ def _preamble_generator(preamble_info, func_qualifier="inline"): n = -n; }""") + # inline {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ yield (f"07_{func.c_name}", f""" - inline {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ + {func_qualifier} {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ if (n == 0) return 1; {re.sub(r"^", 14*" ", signed_exponent_preamble, flags=re.M)} diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 50d2ac7fe..f39042598 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -329,6 +329,12 @@ def known_callables(self): callables.update(get_cuda_callables()) return callables + def symbol_manglers(self): + from loopy.target.opencl import opencl_symbol_mangler + return ( + super().symbol_manglers() + [ + opencl_symbol_mangler + ]) # }}} # {{{ top-level codegen diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 3fe951c4e..a1f6aca7c 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -457,6 +457,7 @@ def get_opencl_callables(): # {{{ symbol mangler def opencl_symbol_mangler(kernel, name): + # Also being used in loopy.target.cuda.CudaCASTBuilder.symbol_manglers # FIXME: should be more picky about exact names if name.startswith("FLT_"): return NumpyType(np.dtype(np.float32)), name diff --git a/loopy/target/pycuda.py b/loopy/target/pycuda.py new file mode 100644 index 000000000..e21a87d44 --- /dev/null +++ b/loopy/target/pycuda.py @@ -0,0 +1,657 @@ +"""CUDA target integrated with PyCUDA.""" + +__copyright__ = """ +Copyright (C) 2015 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import pymbolic.primitives as p +import genpy + +from loopy.target.cuda import (CudaTarget, CUDACASTBuilder, + ExpressionToCudaCExpressionMapper) +from loopy.target.python import PythonASTBuilderBase +from typing import Sequence, Tuple +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult +from loopy.target.c import CMathCallable +from loopy.diagnostic import LoopyError +from loopy.types import NumpyType +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult +from cgen import Generable + +import logging +logger = logging.getLogger(__name__) + + +# {{{ preamble generator + +def pycuda_preamble_generator(preamble_info): + has_complex = False + + for dtype in preamble_info.seen_dtypes: + if dtype.involves_complex(): + has_complex = True + + if has_complex: + yield ("03_include_complex_header", """ + #include + """) + +# }}} + + +# {{{ PyCudaCallable + +class PyCudaCallable(CMathCallable): + def with_types(self, arg_id_to_dtype, callables_table): + if any(dtype.is_complex() for dtype in arg_id_to_dtype.values()): + if self.name in ["abs", "real", "imag"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + real_dtype = np.empty(0, + arg_id_to_dtype[0].numpy_dtype).real.dtype + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = NumpyType(real_dtype) + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + elif self.name in ["sqrt", "conj", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", "exp", + "log", "log10"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = arg_id_to_dtype[0] + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + else: + raise LoopyError(f"'{self.name}' does not take complex" + " arguments.") + else: + if self.name in ["real", "imag", "conj"]: + if arg_id_to_dtype.get(0): + raise NotImplementedError("'{self.name}' for real arguments" + ", not yet supported.") + return super().with_types(arg_id_to_dtype, callables_table) + + +def get_pycuda_callables(): + cmath_ids = ["abs", "acos", "asin", "atan", "cos", "cosh", "sin", + "sinh", "pow", "atan2", "tanh", "exp", "log", "log10", + "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", + "fabs", "tan", "erf", "erfc", "isnan", "real", "imag", + "conj"] + return {id_: PyCudaCallable(id_) for id_ in cmath_ids} + +# }}} + + +# {{{ expression mapper + +def _get_complex_tmplt_arg(dtype) -> str: + if dtype == np.complex128: + return "double" + elif dtype == np.complex64: + return "float" + else: + raise RuntimeError(f"unsupported complex type {dtype}.") + + +class ExpressionToPyCudaCExpressionMapper(ExpressionToCudaCExpressionMapper): + """ + .. note:: + + - PyCUDA (very conveniently) provides access to complex arithmetic + headers which is not the default in CUDA. To access such additional + features we introduce this mapper. + """ + def wrap_in_typecast_lazy(self, actual_type_func, needed_dtype, s): + if needed_dtype.is_complex(): + return self.wrap_in_typecast(actual_type_func(), needed_dtype, s) + else: + return super().wrap_in_typecast_lazy(actual_type_func, + needed_dtype, s) + + def wrap_in_typecast(self, actual_type, needed_dtype, s): + if not actual_type.is_complex() and needed_dtype.is_complex(): + tmplt_arg = _get_complex_tmplt_arg(needed_dtype.numpy_dtype) + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(s) + else: + return super().wrap_in_typecast_lazy(actual_type, + needed_dtype, s) + + def map_constant(self, expr, type_context): + if isinstance(expr, (complex, np.complexfloating)): + try: + dtype = expr.dtype + except AttributeError: + # (COMPLEX_GUESS_LOGIC) This made it through type 'guessing' in + # type inference, and it was concluded there (search for + # COMPLEX_GUESS_LOGIC in loopy.type_inference), that no + # accuracy was lost by using single precision. + dtype = np.complex64 + else: + tmplt_arg = _get_complex_tmplt_arg(dtype) + + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(self.rec(expr.real, + type_context), + self.rec(expr.imag, + type_context)) + + return super().map_constant(expr, type_context) + +# }}} + + +# {{{ target + +class PyCudaTarget(CudaTarget): + """A code generation target that takes special advantage of :mod:`pycuda` + features such as run-time knowledge of the target device (to generate + warnings) and support for complex numbers. + """ + + # FIXME make prefixes conform to naming rules + # (see Reference: Loopy’s Model of a Kernel) + + host_program_name_prefix = "_lpy_host_" + host_program_name_suffix = "" + + def __init__(self, pycuda_module_name="_lpy_cuda"): + # import pycuda.tools import to populate the TYPE_REGISTRY + import pycuda.tools # noqa: F401 + super().__init__() + self.pycuda_module_name = pycuda_module_name + + # NB: Not including 'device', as that is handled specially here. + hash_fields = CudaTarget.hash_fields + ( + "pycuda_module_name",) + comparison_fields = CudaTarget.comparison_fields + ( + "pycuda_module_name",) + + def get_host_ast_builder(self): + return PyCudaPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaCASTBuilder(self) + + def get_kernel_executor_cache_key(self, **kwargs): + return (kwargs["entrypoint"],) + + def get_dtype_registry(self): + from pycuda.compyte.dtypes import TYPE_REGISTRY + return TYPE_REGISTRY + + def preprocess_translation_unit_for_passed_args(self, t_unit, epoint, + passed_args_dict): + + # {{{ ValueArgs -> GlobalArgs if passed as array shapes + + from loopy.kernel.data import ValueArg, GlobalArg + import pycuda.gpuarray as cu_np + + knl = t_unit[epoint] + new_args = [] + + for arg in knl.args: + if isinstance(arg, ValueArg): + if (arg.name in passed_args_dict + and isinstance(passed_args_dict[arg.name], cu_np.GPUArray) + and passed_args_dict[arg.name].shape == ()): + arg = GlobalArg(name=arg.name, dtype=arg.dtype, shape=(), + is_output=False, is_input=True) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + + t_unit = t_unit.with_kernel(knl) + + # }}} + + return t_unit + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaKernelExecutor(t_unit, entrypoint=epoint) + + +class PyCudaWithPackedArgsTarget(PyCudaTarget): + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaWithPackedArgsKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaWithPackedArgsKernelExecutor(t_unit, entrypoint=epoint) + + def get_host_ast_builder(self): + return PyCudaWithPackedArgsPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaWithPackedArgsCASTBuilder(self) + +# }}} + + +# {{{ host ast builder + +class PyCudaPythonASTBuilder(PythonASTBuilderBase): + """A Python host AST builder for integration with PyCuda. + """ + + # {{{ code generation guts + + def get_function_definition( + self, codegen_state, codegen_result, + schedule_index: int, function_decl, function_body: genpy.Generable + ) -> genpy.Function: + assert schedule_index == 0 + + from loopy.schedule.tools import get_kernel_arg_info + kai = get_kernel_arg_info(codegen_state.kernel) + + args = ( + ["_lpy_cuda_functions"] + + [arg_name for arg_name in kai.passed_arg_names] + + ["wait_for=()", "allocator=None", "stream=None"]) + + from genpy import (For, Function, Suite, Return, Line, Statement as S) + return Function( + codegen_result.current_program(codegen_state).name, + args, + Suite([ + Line(), + ] + [ + Line(), + function_body, + Line(), + ] + ([ + For("_tv", "_global_temporaries", + # Free global temporaries. + # Zero-size temporaries allocate as None, tolerate that. + S("if _tv is not None: _tv.free()")) + ] if self._get_global_temporaries(codegen_state) else [] + ) + [ + Line(), + Return("_lpy_evt"), + ])) + + def get_function_declaration( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, schedule_index: int + ) -> Tuple[Sequence[Tuple[str, str]], genpy.Generable]: + # no such thing in Python + return [], None + + def _get_global_temporaries(self, codegen_state): + from loopy.kernel.data import AddressSpace + + return sorted( + (tv for tv in codegen_state.kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL), + key=lambda tv: tv.name) + + def get_temporary_decls(self, codegen_state, schedule_index): + from genpy import Assign, Comment, Line + + from pymbolic.mapper.stringifier import PREC_NONE + ecm = self.get_expression_to_code_mapper(codegen_state) + + global_temporaries = self._get_global_temporaries(codegen_state) + if not global_temporaries: + return [] + + allocated_var_names = [] + code_lines = [] + code_lines.append(Line()) + code_lines.append(Comment("{{{ allocate global temporaries")) + code_lines.append(Line()) + + for tv in global_temporaries: + if not tv.base_storage: + nbytes_str = ecm(tv.nbytes, PREC_NONE, "i") + allocated_var_names.append(tv.name) + code_lines.append(Assign(tv.name, + f"allocator({nbytes_str})")) + + code_lines.append(Assign("_global_temporaries", "[{tvs}]".format( + tvs=", ".join(tv for tv in allocated_var_names)))) + + code_lines.append(Line()) + code_lines.append(Comment("}}}")) + code_lines.append(Line()) + + return code_lines + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + + from loopy.schedule.tools import get_subkernel_arg_info + skai = get_subkernel_arg_info( + codegen_state.kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + Statement("_lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, " + f"{', '.join(arg_name for arg_name in skai.passed_names)}" + ")",), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + + # }}} + + +class PyCudaWithPackedArgsPythonASTBuilder(PyCudaPythonASTBuilder): + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + from loopy.kernel.data import ValueArg, ArrayArg + + from loopy.schedule.tools import get_subkernel_arg_info + kernel = codegen_state.kernel + skai = get_subkernel_arg_info(kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + + struct_format = [] + for arg_name in skai.passed_names: + if arg_name in codegen_state.kernel.all_inames(): + struct_format.append(kernel.index_dtype.numpy_dtype.char) + if kernel.index_dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + elif arg_name in codegen_state.kernel.temporary_variables: + struct_format.append("P") + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_format.append(knl_arg.dtype.numpy_dtype.char) + if knl_arg.dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + else: + struct_format.append("P") + + def _arg_cast(arg_name: str) -> str: + if arg_name in skai.passed_inames: + return ("_lpy_np" + f".{codegen_state.kernel.index_dtype.numpy_dtype.name}" + f"({arg_name})") + elif arg_name in skai.passed_temporaries: + return f"_lpy_np.uintp(int({arg_name}))" + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + assert knl_arg.dtype is not None + return f"_lpy_np.{knl_arg.dtype.numpy_dtype.name}({arg_name})" + else: + assert isinstance(knl_arg, ArrayArg) + return f"_lpy_np.uintp(int({arg_name}))" + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + Assign("_lpy_args_on_dev", f"allocator({len(skai.passed_names)*8})"), + Assign("_lpy_args_on_host", + f"_lpy_struct.pack('{''.join(struct_format)}'," + f"{','.join(_arg_cast(arg) for arg in skai.passed_names)})"), + Statement("_lpy_cuda.memcpy_htod(_lpy_args_on_dev, _lpy_args_on_host)"), + Line(), + Statement("_lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, _lpy_args_on_dev)"), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + +# }}} + + +# {{{ device ast builder + +class PyCudaCASTBuilder(CUDACASTBuilder): + """A C device AST builder for integration with PyCUDA. + """ + + # {{{ library + + def preamble_generators(self): + return ([pycuda_preamble_generator] + + super().preamble_generators()) + + @property + def known_callables(self): + callables = super().known_callables + callables.update(get_pycuda_callables()) + return callables + + # }}} + + def get_expression_to_c_expression_mapper(self, codegen_state): + return ExpressionToPyCudaCExpressionMapper(codegen_state) + + +class PyCudaWithPackedArgsCASTBuilder(PyCudaCASTBuilder): + def arg_struct_name(self, kernel_name: str): + return f"_lpy_{kernel_name}_packed_args" + + def get_function_declaration(self, codegen_state, codegen_result, + schedule_index): + from loopy.target.c import FunctionDeclarationWrapper + from cgen import FunctionDeclaration, Value, Pointer, Extern + from cgen.cuda import CudaGlobal, CudaDevice, CudaLaunchBounds + + kernel = codegen_state.kernel + + assert kernel.linearization is not None + name = codegen_result.current_program(codegen_state).name + arg_type = self.arg_struct_name(name) + + if self.target.fortran_abi: + name += "_" + + fdecl = FunctionDeclaration( + Value("void", name), + [Pointer(Value(arg_type, "_lpy_args"))]) + + if codegen_state.is_entrypoint: + fdecl = CudaGlobal(fdecl) + if self.target.extern_c: + fdecl = Extern("C", fdecl) + + from loopy.schedule import get_insn_ids_for_block_at + _, lsize = kernel.get_grid_sizes_for_insn_ids_as_exprs( + get_insn_ids_for_block_at(kernel.linearization, schedule_index), + codegen_state.callables_table) + + from loopy.symbolic import get_dependencies + if not get_dependencies(lsize): + # Sizes can't have parameter dependencies if they are + # to be used in static thread block size. + from pytools import product + nthreads = product(lsize) + + fdecl = CudaLaunchBounds(nthreads, fdecl) + + return [], FunctionDeclarationWrapper(fdecl) + else: + return [], CudaDevice(fdecl) + + def get_function_definition( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, + schedule_index: int, function_decl: Generable, function_body: Generable + ) -> Generable: + from typing import cast + from loopy.target.c import generate_array_literal + from loopy.schedule import CallKernel + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg, AddressSpace + from cgen import (FunctionBody, + Module as Collection, + Initializer, + Line, Value, Pointer, Struct as GenerableStruct) + kernel = codegen_state.kernel + assert kernel.linearization is not None + assert isinstance(kernel.linearization[schedule_index], CallKernel) + kernel_name = (cast(CallKernel, + kernel.linearization[schedule_index]) + .kernel_name) + + skai = get_subkernel_arg_info(kernel, kernel_name) + + result = [] + + # We only need to write declarations for global variables with + # the first device program. `is_first_dev_prog` determines + # whether this is the first device program in the schedule. + is_first_dev_prog = codegen_state.is_generating_device_code + for i in range(schedule_index): + if isinstance(kernel.linearization[i], CallKernel): + is_first_dev_prog = False + break + if is_first_dev_prog: + for tv in sorted( + kernel.temporary_variables.values(), + key=lambda key_tv: key_tv.name): + + if tv.address_space == AddressSpace.GLOBAL and ( + tv.initializer is not None): + assert tv.read_only + + decl = self.wrap_global_constant( + self.get_temporary_var_declarator(codegen_state, tv)) + + if tv.initializer is not None: + decl = Initializer(decl, generate_array_literal( + codegen_state, tv, tv.initializer)) + + result.append(decl) + + # {{{ declare+unpack the struct type + + struct_fields = [] + + for arg_name in skai.passed_names: + if arg_name in skai.passed_inames: + struct_fields.append( + Value(self.target.dtype_to_typename(kernel.index_dtype), + f"{arg_name}, __padding_{arg_name}")) + elif arg_name in skai.passed_temporaries: + tv = kernel.temporary_variables[arg_name] + struct_fields.append(Pointer( + Value(self.target.dtype_to_typename(tv.dtype), arg_name))) + else: + knl_arg = kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_fields.append( + Value(self.target.dtype_to_typename(knl_arg.dtype), + f"{arg_name}, __padding_{arg_name}")) + else: + struct_fields.append( + Pointer(Value(self.target.dtype_to_typename(knl_arg.dtype), + arg_name))) + + function_body.insert(0, Line()) + for arg_name in skai.passed_names[::-1]: + function_body.insert(0, Initializer( + self.arg_to_cgen_declarator( + kernel, arg_name, + arg_name in kernel.get_written_variables()), + f"_lpy_args->{arg_name}" + )) + + # }}} + + fbody = FunctionBody(function_decl, function_body) + + return Collection([*result, + Line(), + GenerableStruct(self.arg_struct_name(kernel_name), + struct_fields), + Line(), + fbody]) + + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/target/pycuda_execution.py b/loopy/target/pycuda_execution.py new file mode 100644 index 000000000..ec0c1834a --- /dev/null +++ b/loopy/target/pycuda_execution.py @@ -0,0 +1,394 @@ +__copyright__ = """ +Copyright (C) 2012 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from typing import (Sequence, Tuple, Union, Callable, Any, Optional, + TYPE_CHECKING) +from dataclasses import dataclass + +import numpy as np +from immutables import Map + +from pytools import memoize_method +from pytools.codegen import Indentation, CodeGenerator + +from loopy.types import LoopyType +from loopy.typing import ExpressionT +from loopy.kernel import LoopKernel +from loopy.kernel.data import ArrayArg +from loopy.translation_unit import TranslationUnit +from loopy.schedule.tools import KernelArgInfo +from loopy.target.execution import ( + KernelExecutorBase, ExecutionWrapperGeneratorBase) +import logging +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + import pycuda.driver as cuda + + +# {{{ invoker generation + +# /!\ This code runs in a namespace controlled by the user. +# Prefix all auxiliary variables with "_lpy". + + +class PyCudaExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): + """ + Specialized form of the :class:`ExecutionWrapperGeneratorBase` for + pycuda execution + """ + + def __init__(self): + system_args = [ + "_lpy_cuda_functions", "stream=None", "allocator=None", "wait_for=()", + # ignored if options.no_numpy + "out_host=None" + ] + super().__init__(system_args) + + def python_dtype_str_inner(self, dtype): + from pycuda.tools import dtype_to_ctype + # Test for types built into numpy. dtype.isbuiltin does not work: + # https://github.com/numpy/numpy/issues/4317 + # Guided by https://numpy.org/doc/stable/reference/arrays.scalars.html + if issubclass(dtype.type, (np.bool_, np.number)): + name = dtype.name + if dtype.type == np.bool_: + name = "bool8" + return f"_lpy_np.dtype(_lpy_np.{name})" + else: + return ('_lpy_cuda_tools.get_or_register_dtype("%s")' + % dtype_to_ctype(dtype)) + + # {{{ handle non-numpy args + + def handle_non_numpy_arg(self, gen, arg): + gen("if isinstance(%s, _lpy_np.ndarray):" % arg.name) + with Indentation(gen): + gen("# retain originally passed array") + gen(f"_lpy_{arg.name}_np_input = {arg.name}") + gen("# synchronous, nothing to worry about") + gen("%s = _lpy_cuda_array.to_gpu_async(" + "%s, allocator=allocator, stream=stream)" + % (arg.name, arg.name)) + gen("_lpy_encountered_numpy = True") + gen("elif %s is not None:" % arg.name) + with Indentation(gen): + gen("_lpy_encountered_dev = True") + gen("_lpy_%s_np_input = None" % arg.name) + gen("else:") + with Indentation(gen): + gen("_lpy_%s_np_input = None" % arg.name) + + gen("") + + # }}} + + # {{{ handle allocation of unspecified arguments + + def handle_alloc( + self, gen: CodeGenerator, arg: ArrayArg, + strify: Callable[[Union[ExpressionT, Tuple[ExpressionT]]], str], + skip_arg_checks: bool) -> None: + """ + Handle allocation of non-specified arguments for pycuda execution + """ + from pymbolic import var + + from loopy.kernel.array import get_strides + strides = get_strides(arg) + num_axes = len(strides) + + assert arg.dtype is not None + itemsize = arg.dtype.numpy_dtype.itemsize + for i in range(num_axes): + gen("_lpy_ustrides_%d = %s" % (i, strify(strides[i]))) + + if not skip_arg_checks: + for i in range(num_axes): + gen("assert _lpy_ustrides_%d >= 0, " + "\"'%s' has negative stride in axis %d\"" + % (i, arg.name, i)) + + assert isinstance(arg.shape, tuple) + sym_ustrides = tuple( + var("_lpy_ustrides_%d" % i) + for i in range(num_axes)) + sym_shape = tuple(arg.shape[i] for i in range(num_axes)) + + size_expr = (sum(astrd*(alen-1) + for alen, astrd in zip(sym_shape, sym_ustrides)) + + 1) + + gen("_lpy_size = %s" % strify(size_expr)) + sym_strides = tuple(itemsize*s_i for s_i in sym_ustrides) + + dtype_name = self.python_dtype_str(gen, arg.dtype.numpy_dtype) + gen(f"{arg.name} = _lpy_cuda_array.GPUArray({strify(sym_shape)}, " + f"{dtype_name}, strides={strify(sym_strides)}, " + f"gpudata=allocator({strify(itemsize * var('_lpy_size'))}), " + "allocator=allocator)") + + for i in range(num_axes): + gen("del _lpy_ustrides_%d" % i) + gen("del _lpy_size") + gen("") + + # }}} + + def target_specific_preamble(self, gen): + """ + Add default pycuda imports to preamble + """ + gen.add_to_preamble("import numpy as _lpy_np") + gen.add_to_preamble("import pycuda.driver as _lpy_cuda") + gen.add_to_preamble("import pycuda.gpuarray as _lpy_cuda_array") + gen.add_to_preamble("import pycuda.tools as _lpy_cuda_tools") + gen.add_to_preamble("import struct as _lpy_struct") + from loopy.target.c.c_execution import DEF_EVEN_DIV_FUNCTION + gen.add_to_preamble(DEF_EVEN_DIV_FUNCTION) + + def initialize_system_args(self, gen): + """ + Initializes possibly empty system arguments + """ + gen("if allocator is None:") + with Indentation(gen): + gen("allocator = _lpy_cuda.mem_alloc") + gen("") + + # {{{ generate invocation + + def generate_invocation(self, gen: CodeGenerator, kernel: LoopKernel, + kai: KernelArgInfo, host_program_name: str, args: Sequence[str]) -> None: + arg_list = (["_lpy_cuda_functions"] + + list(args) + + ["stream=stream", "wait_for=wait_for", "allocator=allocator"]) + gen(f"_lpy_evt = {host_program_name}({', '.join(arg_list)})") + + # }}} + + # {{{ generate_output_handler + + def generate_output_handler(self, gen: CodeGenerator, + kernel: LoopKernel, kai: KernelArgInfo) -> None: + options = kernel.options + + if not options.no_numpy: + gen("if out_host is None and (_lpy_encountered_numpy " + "and not _lpy_encountered_dev):") + with Indentation(gen): + gen("out_host = True") + + for arg_name in kai.passed_arg_names: + arg = kernel.arg_dict[arg_name] + if arg.is_output: + np_name = "_lpy_%s_np_input" % arg.name + gen("if out_host or %s is not None:" % np_name) + with Indentation(gen): + gen("%s = %s.get(stream=stream, ary=%s)" + % (arg.name, arg.name, np_name)) + + gen("") + + if options.return_dict: + gen("return _lpy_evt, {%s}" + % ", ".join(f'"{arg_name}": {arg_name}' + for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output)) + else: + out_names = [arg_name for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output] + if out_names: + gen("return _lpy_evt, (%s,)" + % ", ".join(out_names)) + else: + gen("return _lpy_evt, ()") + + # }}} + + def generate_host_code(self, gen, codegen_result): + gen.add_to_preamble(codegen_result.host_code()) + + def get_arg_pass(self, arg): + return "%s.gpudata" % arg.name + +# }}} + + +@dataclass(frozen=True) +class _KernelInfo: + t_unit: TranslationUnit + cuda_functions: Map[str, "cuda.Function"] + invoker: Callable[..., Any] + + +# {{{ kernel executor + +class PyCudaKernelExecutor(KernelExecutorBase): + """ + An object connecting a kernel to a :mod:`pycuda` + for execution. + + .. automethod:: __init__ + .. automethod:: __call__ + """ + + def get_invoker_uncached(self, t_unit, entrypoint, codegen_result): + generator = PyCudaExecutionWrapperGenerator() + return generator(t_unit, entrypoint, codegen_result) + + def get_wrapper_generator(self): + return PyCudaExecutionWrapperGenerator() + + def _get_arg_dtypes(self, knl, subkernel_name): + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg + + skai = get_subkernel_arg_info(knl, subkernel_name) + arg_dtypes = [] + for arg in skai.passed_names: + if arg in skai.passed_inames: + arg_dtypes.append(knl.index_dtype.numpy_dtype) + elif arg in skai.passed_temporaries: + arg_dtypes.append("P") + else: + assert arg in knl.arg_dict + if isinstance(knl.arg_dict[arg], ValueArg): + arg_dtypes.append(knl.arg_dict[arg].dtype.numpy_dtype) + else: + # Array Arg + arg_dtypes.append("P") + + return arg_dtypes + + @memoize_method + def translation_unit_info(self, + arg_to_dtype: Optional[Map[str, LoopyType]] = None + ) -> _KernelInfo: + t_unit = self.get_typed_and_scheduled_translation_unit(self.entrypoint, + arg_to_dtype) + + # FIXME: now just need to add the types to the arguments + from loopy.codegen import generate_code_v2 + from loopy.target.execution import get_highlighted_code + codegen_result = generate_code_v2(t_unit) + + dev_code = codegen_result.device_code() + epoint_knl = t_unit[self.entrypoint] + + if t_unit[self.entrypoint].options.write_code: + #FIXME: redirect to "translation unit" level option as well. + output = dev_code + if self.t_unit[self.entrypoint].options.allow_terminal_colors: + output = get_highlighted_code(output) + + if epoint_knl.options.write_code is True: + print(output) + else: + with open(epoint_knl.options.write_code, "w") as outf: + outf.write(output) + + if epoint_knl.options.edit_code: + #FIXME: redirect to "translation unit" level option as well. + from pytools import invoke_editor + dev_code = invoke_editor(dev_code, "code.cu") + + from pycuda.compiler import SourceModule + from loopy.kernel.tools import get_subkernels + + #FIXME: redirect to "translation unit" level option as well. + src_module = SourceModule(dev_code, + options=epoint_knl.options.build_options) + + cuda_functions = Map({name: (src_module + .get_function(name) + .prepare(self._get_arg_dtypes(epoint_knl, name)) + ) + for name in get_subkernels(epoint_knl)}) + return _KernelInfo( + t_unit=t_unit, + cuda_functions=cuda_functions, + invoker=self.get_invoker(t_unit, self.entrypoint, codegen_result)) + + def __call__(self, *, + stream=None, allocator=None, wait_for=(), out_host=None, + **kwargs): + """ + :arg allocator: a callable that accepts a byte count and returns + an instance of :class:`pycuda.driver.DeviceAllocation`. Typically + one of :func:`pycuda.driver.mem_alloc` or + :meth:`pycuda.tools.DeviceMemoryPool.allocate`. + :arg wait_for: A sequence of :class:`pycuda.driver.Event` instances + for which to wait before launching the CUDA kernels. + :arg out_host: :class:`bool` + Decides whether output arguments (i.e. arguments + written by the kernel) are to be returned as + :mod:`numpy` arrays. *True* for yes, *False* for no. + + For the default value of *None*, if all (input) array + arguments are :mod:`numpy` arrays, defaults to + returning :mod:`numpy` arrays as well. + + :returns: ``(evt, output)`` where *evt* is a + :class:`pycuda.driver.Event` that is recorded right after the + kernel has been launched and output is a tuple of output arguments + (arguments that are written as part of the kernel). The order is + given by the order of kernel arguments. If this order is + unspecified (such as when kernel arguments are inferred + automatically), enable :attr:`loopy.Options.return_dict` to make + *output* a :class:`dict` instead, with keys of argument names and + values of the returned arrays. + """ + + if "entrypoint" in kwargs: + assert kwargs.pop("entrypoint") == self.entrypoint + from warnings import warn + warn("Obtained a redundant argument 'entrypoint'. This will" + " be an error in 2023.", DeprecationWarning, stacklevel=2) + + if __debug__: + self.check_for_required_array_arguments(kwargs.keys()) + + if self.packing_controller is not None: + kwargs = self.packing_controller(kwargs) + + translation_unit_info = self.translation_unit_info(self.arg_to_dtype(kwargs)) + + return translation_unit_info.invoker( + translation_unit_info.cuda_functions, stream, allocator, wait_for, + out_host, **kwargs) + + +class PyCudaWithPackedArgsKernelExecutor(PyCudaKernelExecutor): + + def _get_arg_dtypes(self, knl, subkernel_name): + return ["P"] + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/tools.py b/loopy/tools.py index e9f9932b7..1deb13e36 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -27,6 +27,8 @@ import logging from functools import cached_property from sys import intern +from typing import FrozenSet, Generic, TypeVar, Iterator +from dataclasses import dataclass import numpy as np from constantdict import constantdict diff --git a/loopy/transform/domain.py b/loopy/transform/domain.py new file mode 100644 index 000000000..03bef1547 --- /dev/null +++ b/loopy/transform/domain.py @@ -0,0 +1,90 @@ +__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +__doc__ = """ +.. currentmodule:: loopy + +.. autofunction:: decouple_domain +""" + +import islpy as isl + +from loopy.translation_unit import for_each_kernel +from loopy.kernel import LoopKernel +from loopy.diagnostic import LoopyError +from collections.abc import Collection + + +@for_each_kernel +def decouple_domain(kernel: LoopKernel, + inames: Collection[str], + parent_inames: Collection[str]) -> LoopKernel: + r""" + Returns a copy of *kernel* with altered domains. The home domain of + *inames* i.e. :math:`\mathcal{D}^{\text{home}}({\text{inames}})` is + replaced with two domains :math:`\mathcal{D}_1` and :math:`\mathcal{D}_2`. + :math:`\mathcal{D}_1` is the domain with dimensions corresponding to *inames* + projected out and :math:`\mathcal{D}_2` is the domain with all the dimensions + other than the ones corresponding to *inames* projected out. + + .. note:: + + An error is raised if all the *inames* do not correspond to the same home + domain of *kernel*. + """ + + if not inames: + raise LoopyError("No inames were provided to decouple into" + " a different domain.") + + hdi = kernel.get_home_domain_index(next(iter(inames))) + for iname in inames: + if kernel.get_home_domain_index(iname) != hdi: + raise LoopyError("inames are not a part of the same home domain.") + + for parent_iname in parent_inames: + if parent_iname not in set(kernel.domains[hdi].get_var_dict()): + raise LoopyError(f"Parent iname '{parent_iname}' not a part of the" + f" corresponding home domain '{kernel.domains[hdi]}'.") + + all_dims = frozenset(kernel.domains[hdi].get_var_dict()) + dom1 = kernel.domains[hdi] + dom2 = kernel.domains[hdi] + + for iname in sorted(all_dims): + if iname in inames: + dt, pos = dom1.get_var_dict()[iname] + dom1 = dom1.project_out(dt, pos, 1) + elif iname in parent_inames: + dt, pos = dom2.get_var_dict()[iname] + if dt != isl.dim_type.param: + n_params = dom2.dim(isl.dim_type.param) + dom2 = dom2.move_dims(isl.dim_type.param, n_params, dt, pos, 1) + else: + dt, pos = dom2.get_var_dict()[iname] + dom2 = dom2.project_out(dt, pos, 1) + + new_domains = kernel.domains[:] + new_domains[hdi] = dom1 + new_domains.append(dom2) + kernel = kernel.copy(domains=new_domains) + return kernel diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 0118795f5..46cd6a039 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -27,6 +27,7 @@ from collections.abc import Collection, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any +from immutabledict import immutabledict from typing_extensions import TypeAlias import islpy as isl @@ -1601,6 +1602,8 @@ def __init__(self, rule_mapping_context, inames, within): def get_cache_key(self, expr, expn_state): return (super().get_cache_key(expr, expn_state), + # immutabledict(self.iname_to_red_count), + # immutabledict(self.iname_to_nonsimultaneous_red_count),) hash(frozenset(self.iname_to_red_count.items())), hash(frozenset(self.iname_to_nonsimultaneous_red_count.items())), ) diff --git a/loopy/transform/loop_fusion.py b/loopy/transform/loop_fusion.py new file mode 100644 index 000000000..fcd4ab50b --- /dev/null +++ b/loopy/transform/loop_fusion.py @@ -0,0 +1,822 @@ +__copyright__ = """ +Copyright (C) 2021 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from loopy.diagnostic import LoopyError +from loopy.symbolic import RuleAwareIdentityMapper +from loopy.kernel import LoopKernel +from typing import FrozenSet, Mapping, Tuple, Dict, Set, Callable, Optional +from functools import reduce +from dataclasses import dataclass +from pytools import memoize_on_first_arg + +__doc__ = """ +.. autofunction:: rename_inames_in_batch +.. autofunction:: get_kennedy_unweighted_fusion_candidates +""" + + +# {{{ Loop Dependence graph class + builder + + +@dataclass(frozen=True, eq=True) +class LoopDependenceGraph: + """ + .. attribute:: successors + + A mapping from iname (``i``) to the collection of inames that can be + scheduled only after the loop corresponding to ``i`` has been exited. + + .. attribute:: predecessors + + A mapping from iname (``i``) to the collection of inames that must have + been exited before entering ``i``. + + .. attribute:: is_infusible + + A mapping from the edges in the loop dependence graph to their + fusibility crierion. An edge in this mapping is represented by a pair + of inames``(iname_i, iname_j)`` such that the edge ``iname_i -> + iname_j`` is present in the graph. + + .. note:: + + Both :attr:`successors` and :attr:`predecessors` are maintained to + reduce the complexity of graph primitive operations (like remove node, + add edge, etc.). + """ + successors: Mapping[str, FrozenSet[str]] + predecessors: Mapping[str, FrozenSet[str]] + is_infusible: Mapping[Tuple[str, str], bool] + + @classmethod + def new(cls, successors, is_infusible): + predecessors = {node: set() + for node in successors} + for node, succs in successors.items(): + for succ in succs: + predecessors[succ].add(node) + + predecessors = {node: frozenset(preds) + for node, preds in predecessors.items()} + successors = {node: frozenset(succs) + for node, succs in successors.items()} + + return LoopDependenceGraph(successors, predecessors, is_infusible) + + def is_empty(self): + """ + Returns *True* only if the loop dependence graph contains no nodes. + """ + return (len(self.successors) == 0) + + def get_loops_with_no_predecessors(self): + return {loop + for loop, preds in self.predecessors.items() + if len(preds) == 0} + + def remove_nodes(self, nodes_to_remove): + """ + Returns a copy of *self* after removing *nodes_to_remove* in the graph. + This routine adds necessary edges after removing *nodes_to_remove* to + conserve the scheduling constraints present in the graph. + """ + # {{{ Step 1. Remove the nodes + + new_successors = {node: succs + for node, succs in self.successors.items() + if node not in nodes_to_remove} + new_predecessors = {node: preds + for node, preds in self.predecessors.items() + if node not in nodes_to_remove} + + new_is_infusible = {(from_, to): v + for (from_, to), v in self.is_infusible.items() + if (from_ not in nodes_to_remove + and to not in nodes_to_remove)} + + # }}} + + # {{{ Step 2. Propagate dependencies + + # For every Node 'R' to be removed and every pair (S, P) such that + # 1. there exists an edge 'P' -> 'R' in the original graph, and, + # 2. there exits an edge 'R' -> 'S' in the original graph. + # add the edge 'P' -> 'S' in the new graph. + + for node_to_remove in nodes_to_remove: + for succ in (self.successors[node_to_remove] + - nodes_to_remove): + new_predecessors[succ] = (new_predecessors[succ] + - frozenset([node_to_remove])) + + for pred in (self.predecessors[node_to_remove] + - nodes_to_remove): + new_successors[pred] = (new_successors[pred] + - frozenset([node_to_remove])) + + # }}} + + return LoopDependenceGraph(new_successors, + new_predecessors, + new_is_infusible) + + +@dataclass +class LoopDependenceGraphBuilder: + """ + A mutable type to act as a helper to instantiate a + :class:`LoopDependenceGraphBuilder`. + """ + _dag: Dict[str, Set[str]] + _is_infusible: Mapping[Tuple[str, str], bool] + + @classmethod + def new(cls, candidates): + return LoopDependenceGraphBuilder({iname: set() + for iname in candidates}, + {}) + + def add_edge(self, from_: str, to: str, is_infusible: bool): + self._dag[from_].add(to) + self._is_infusible[(from_, to)] = (is_infusible + or self._is_infusible.get((from_, to), + False)) + + def done(self): + """ + Returns the built :class:`LoopDependenceGraph`. + """ + return LoopDependenceGraph.new(self._dag, self._is_infusible) + +# }}} + + +# {{{ _build_ldg + +@dataclass(frozen=True, eq=True, repr=True) +class PreLDGNode: + """ + A node in the graph representing the dependencies before building + :class:`LoopDependenceGraph`. + """ + + +@dataclass(frozen=True, eq=True, repr=True) +class CandidateLoop(PreLDGNode): + iname: str + + +@dataclass(frozen=True, eq=True, repr=True) +class NonCandidateLoop(PreLDGNode): + loop_nest: FrozenSet[str] + + +@dataclass(frozen=True, eq=True, repr=True) +class OuterLoopNestStatement(PreLDGNode): + insn_id: str + + +def _remove_non_candidate_pre_ldg_nodes(kernel, + predecessors: Mapping[PreLDGNode, + PreLDGNode], + successors: Mapping[PreLDGNode, + PreLDGNode], + candidates: FrozenSet[str]): + """ + Returns ``(new_successors, new_predecessors, inufusible_edge)`` where + ``(new_successors, new_predecessors)`` is the graph describing the + dependencies between the *candidates* loops that has been obtained by + removing instances of :class:`NonCandidateLoop` and + :class:`OuterLoopNestStatement` from the graph described by *predecessors*, + *succcessors*. + + New dependency edges are added in the new graph to preserve the transitive + dependencies that exists in the original graph. + """ + # {{{ input validation + + assert set(predecessors) == set(successors) + assert all(isinstance(val, frozenset) for val in predecessors.values()) + assert all(isinstance(val, frozenset) for val in successors.values()) + + # }}} + + nodes_to_remove = {node + for node in predecessors + if isinstance(node, (NonCandidateLoop, + OuterLoopNestStatement)) + } + new_predecessors = predecessors.copy() + new_successors = successors.copy() + infusible_edges_in_statement_dag = set() + + for node_to_remove in nodes_to_remove: + for pred in new_predecessors[node_to_remove]: + new_successors[pred] = ((new_successors[pred] + - frozenset([node_to_remove])) + | new_successors[node_to_remove]) + + for succ in new_successors[node_to_remove]: + new_predecessors[succ] = ((new_predecessors[succ] + - frozenset([node_to_remove])) + | new_predecessors[node_to_remove]) + + for pred in new_predecessors[node_to_remove]: + for succ in new_successors[node_to_remove]: + # now mark the edge from pred -> succ infusible iff both 'pred' and + # 'succ' are *not* in insns_to_remove + if ((pred not in nodes_to_remove) and (succ not in nodes_to_remove)): + assert isinstance(pred, CandidateLoop) + assert isinstance(succ, CandidateLoop) + infusible_edges_in_statement_dag.add((pred.iname, succ.iname)) + + del new_predecessors[node_to_remove] + del new_successors[node_to_remove] + + return ({key.iname: frozenset({n.iname for n in value}) + for key, value in new_predecessors.items()}, + {key.iname: frozenset({n.iname for n in value}) + for key, value in new_successors.items()}, + infusible_edges_in_statement_dag) + + +def _get_ldg_nodes_from_loopy_insn(kernel, insn, candidates, non_candidates, + just_outer_loop_nest): + """ + Helper used in :func:`_build_ldg`. + + :arg just_outer_inames: A :class:`frozenset` of the loop nest that appears + just outer to the *candidates* in the partial loop nest tree. + """ + if (insn.within_inames | insn.reduction_inames()) & candidates: + # => the statement containing + return [CandidateLoop(candidate) + for candidate in ((insn.within_inames + | insn.reduction_inames()) + & candidates)] + else: + non_candidate = {loop_nest + for loop_nest in non_candidates + if (loop_nest & insn.within_inames)} + + if non_candidate: + non_candidate, = non_candidate + return [NonCandidateLoop(non_candidate)] + else: + assert ((insn.within_inames & just_outer_loop_nest) + or (insn.within_inames == just_outer_loop_nest)) + return [OuterLoopNestStatement(insn.id)] + + +@memoize_on_first_arg +def _compute_isinfusible_via_access_map(kernel, + insn_pred, candidate_pred, + insn_succ, candidate_succ, + outer_inames, + var): + """ + Returns *True* if the inames *candidate_pred* and *candidate_succ* are fused then + that might lead to a loop carried dependency for *var*. + + Helper used in :func:`_build_ldg`. + """ + import islpy as isl + from loopy.kernel.tools import get_insn_access_map + import pymbolic.primitives as prim + from loopy.symbolic import isl_set_from_expr + from loopy.diagnostic import UnableToDetermineAccessRangeError + + try: + amap_pred = get_insn_access_map(kernel, insn_pred, var) + amap_succ = get_insn_access_map(kernel, insn_succ, var) + except UnableToDetermineAccessRangeError: + # either predecessors or successors has a non-affine access i.e. + # fallback to the safer option => infusible + return True + + amap_pred = amap_pred.project_out_except(outer_inames | {candidate_pred}, + [isl.dim_type.param, + isl.dim_type.in_]) + amap_succ = amap_succ.project_out_except(outer_inames | {candidate_succ}, + [isl.dim_type.param, + isl.dim_type.in_]) + + for outer_iname in sorted(outer_inames): + amap_pred = amap_pred.move_dims(dst_type=isl.dim_type.param, + dst_pos=amap_pred.dim(isl.dim_type.param), + src_type=isl.dim_type.in_, + src_pos=amap_pred.get_var_dict()[ + outer_iname][1], + n=1) + amap_succ = amap_succ.move_dims(dst_type=isl.dim_type.param, + dst_pos=amap_succ.dim(isl.dim_type.param), + src_type=isl.dim_type.in_, + src_pos=amap_succ.get_var_dict()[ + outer_iname][1], + n=1) + + # since both ranges denote the same variable they must be subscripted with + # the same number of indices. + assert amap_pred.dim(isl.dim_type.out) == amap_succ.dim(isl.dim_type.out) + assert amap_pred.dim(isl.dim_type.in_) == 1 + assert amap_succ.dim(isl.dim_type.in_) == 1 + + if amap_pred == amap_succ: + return False + + ndim = amap_pred.dim(isl.dim_type.out) + + # {{{ set the out dim names as `amap_a_dim0`, `amap_a_dim1`, ... + + for idim in range(ndim): + amap_pred = amap_pred.set_dim_name(isl.dim_type.out, + idim, + f"_lpy_amap_a_dim{idim}") + amap_succ = amap_succ.set_dim_name(isl.dim_type.out, + idim, + f"_lpy_amap_b_dim{idim}") + + # }}} + + # {{{ amap_pred -> set_pred, amap_succ -> set_succ + + amap_pred = amap_pred.move_dims(isl.dim_type.in_, + amap_pred.dim(isl.dim_type.in_), + isl.dim_type.out, + 0, amap_pred.dim(isl.dim_type.out)) + + amap_succ = amap_succ.move_dims(isl.dim_type.in_, + amap_succ.dim(isl.dim_type.in_), + isl.dim_type.out, + 0, amap_succ.dim(isl.dim_type.out)) + set_pred, set_succ = amap_pred.domain(), amap_succ.domain() + set_pred, set_succ = isl.align_two(set_pred, set_succ) + + # }}} + + # {{{ build the bset, both accesses access the same element + + accesses_same_index_set = isl.BasicSet.universe(set_pred.space) + for idim in range(ndim): + cnstrnt = isl.Constraint.eq_from_names(set_pred.space, + {f"_lpy_amap_a_dim{idim}": 1, + f"_lpy_amap_b_dim{idim}": -1}) + accesses_same_index_set = accesses_same_index_set.add_constraint(cnstrnt) + + # }}} + + candidates_not_equal = isl_set_from_expr(set_pred.space, + prim.Comparison( + prim.Variable(candidate_pred), + ">", + prim.Variable(candidate_succ))) + result = (not (set_pred + & set_succ + & accesses_same_index_set & candidates_not_equal).is_empty()) + + return result + + +def _build_ldg(kernel: LoopKernel, + candidates: FrozenSet[str], + outer_inames: FrozenSet[str]): + """ + Returns an instance of :class:`LoopDependenceGraph` needed while fusing + *candidates*. Invoked as a helper function in + :func:`get_kennedy_unweighted_fusion_candidates`. + """ + + from pytools.graph import compute_topological_order + + loop_nest_tree = _get_partial_loop_nest_tree_for_fusion(kernel) + + non_candidate_loop_nests = { + child_loop_nest + for child_loop_nest in loop_nest_tree.children(outer_inames) + if len(child_loop_nest & candidates) == 0} + + insns = reduce(frozenset.intersection, + (frozenset(kernel.iname_to_insns()[iname]) + for iname in outer_inames), + frozenset(kernel.id_to_insn)) + predecessors = {} + successors = {} + + for insn in insns: + for successor in _get_ldg_nodes_from_loopy_insn(kernel, + kernel.id_to_insn[insn], + candidates, + non_candidate_loop_nests, + outer_inames): + predecessors.setdefault(successor, set()) + successors.setdefault(successor, set()) + for dep in kernel.id_to_insn[insn].depends_on: + if ((kernel.id_to_insn[dep].within_inames & outer_inames) + != outer_inames): + # this is not an instruction in 'outer_inames' => bogus dep. + continue + for predecessor in _get_ldg_nodes_from_loopy_insn( + kernel, + kernel.id_to_insn[dep], + candidates, + non_candidate_loop_nests, + outer_inames): + if predecessor != successor: + predecessors.setdefault(successor, set()).add(predecessor) + successors.setdefault(predecessor, set()).add(successor) + + predecessors, successors, infusible_edges = ( + _remove_non_candidate_pre_ldg_nodes( + kernel, + {key: frozenset(value) + for key, value in predecessors.items()}, + {key: frozenset(value) + for key, value in successors.items()}, + candidates)) + del predecessors + + builder = LoopDependenceGraphBuilder.new(candidates) + + # Interpret the statement DAG as LDG + for pred, succs in successors.items(): + for succ in succs: + builder.add_edge(pred, succ, + (pred, succ) in infusible_edges) + + # {{{ add infusible edges to the LDG depending on memory deps. + + all_candidate_insns = reduce(frozenset.union, + (kernel.iname_to_insns()[iname] + for iname in candidates), + frozenset()) + + dep_inducing_vars = reduce(frozenset.union, + (frozenset(kernel + .id_to_insn[insn] + .assignee_var_names()) + for insn in all_candidate_insns), + frozenset()) + wmap = kernel.writer_map() + rmap = kernel.reader_map() + + topo_order = {el: i + for i, el in enumerate(compute_topological_order(successors))} + + for var in dep_inducing_vars: + for writer_id in (wmap.get(var, frozenset()) + & all_candidate_insns): + for access_id in ((rmap.get(var, frozenset()) + | wmap.get(var, frozenset())) + & all_candidate_insns): + if writer_id == access_id: + # no need to add self dependence + continue + + writer_candidate, = (kernel.id_to_insn[writer_id].within_inames + & candidates) + access_candidate, = (kernel.id_to_insn[access_id].within_inames + & candidates) + (pred_candidate, pred), (succ_candidate, succ) = sorted( + [(writer_candidate, writer_id), + (access_candidate, access_id)], + key=lambda x: topo_order[x[0]]) + + is_infusible = _compute_isinfusible_via_access_map(kernel, + pred, + pred_candidate, + succ, + succ_candidate, + outer_inames, + var) + + builder.add_edge(pred_candidate, succ_candidate, is_infusible) + + # }}} + + return builder.done() + +# }}} + + +def _fuse_sequential_loops_with_outer_loops(kernel: LoopKernel, + candidates: FrozenSet[str], + outer_inames: FrozenSet[str], + name_gen, prefix, force_infusible): + from collections import deque + ldg = _build_ldg(kernel, candidates, outer_inames) + + fused_chunks = {} + + while not ldg.is_empty(): + + # sorting to have a deterministic order. + # prefer 'deque' over list, as popping elements off the queue would be + # O(1). + + loops_with_no_preds = sorted(ldg.get_loops_with_no_predecessors()) + + queue = deque([loops_with_no_preds[0]]) + for node in loops_with_no_preds[1:]: + if not force_infusible(node, loops_with_no_preds[0]): + queue.append(node) + + loops_to_be_fused = set() + non_fusible_loops = set() + while queue: + next_loop_in_queue = queue.popleft() + + if next_loop_in_queue in non_fusible_loops: + # had an non-fusible edge with an already schedule loop. + # Sorry 'next_loop_in_queue', until next time :'(. + continue + + if next_loop_in_queue in loops_to_be_fused: + # already fused, no need to fuse again ;) + continue + + if not (ldg.predecessors[next_loop_in_queue] <= loops_to_be_fused): + # this loop still needs some other loops to be scheduled + # before we can reach this. + # Bye bye 'next_loop_in_queue' :'( , see you when all your + # predecessors have been scheduled. + continue + + loops_to_be_fused.add(next_loop_in_queue) + + for succ in ldg.successors[next_loop_in_queue]: + if (ldg.is_infusible.get((next_loop_in_queue, succ), False) + or force_infusible(next_loop_in_queue, succ)): + non_fusible_loops.add(succ) + else: + queue.append(succ) + + ldg = ldg.remove_nodes(loops_to_be_fused) + fused_chunks[name_gen(prefix)] = loops_to_be_fused + + assert reduce(frozenset.union, fused_chunks.values(), frozenset()) == candidates + assert sum(len(val) for val in fused_chunks.values()) == len(candidates) + + return fused_chunks + + +class ReductionLoopInserter(RuleAwareIdentityMapper): + """ + Main mapper used by :func:`_add_reduction_loops_in_partial_loop_nest_tree`. + """ + def __init__(self, rule_mapping_context, tree): + super().__init__(rule_mapping_context) + self.tree = tree + from loopy.schedule.tools import ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree) + self.iname_to_tree_node_id = ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree(tree)) + + def map_reduction(self, expr, expn_state, *, outer_redn_inames=frozenset()): + redn_inames = frozenset(expr.inames) + iname_chain = (expn_state.instruction.within_inames + | outer_redn_inames + | redn_inames) + not_seen_inames = frozenset(iname for iname in iname_chain + if iname not in self.iname_to_tree_node_id) + seen_inames = iname_chain - not_seen_inames + + # {{{ verbatim copied from loopy/schedule/tools.py + + # from loopy.schedule.tools import (_pull_out_loop_nest, + # _add_inner_loops) + + from loopy.schedule.tools import (separate_loop_nest, + _add_inner_loops) + + all_nests = {self.iname_to_tree_node_id[iname] + for iname in seen_inames} + + self.tree, outer_loop, inner_loop = separate_loop_nest( + self.tree, (all_nests | {frozenset()}), seen_inames) + if not_seen_inames: + # make '_not_seen_inames' nest inside the seen ones. + # example: if there is already a loop nesting "i,j,k" + # and the current iname chain is "i,j,l". Only way this is possible + # is if "l" is nested within "i,j"-loops. + self.tree = _add_inner_loops(self.tree, outer_loop, not_seen_inames) + + # {{{ update iname to node id + + for iname in outer_loop: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set(iname, + outer_loop) + + if inner_loop is not None: + for iname in inner_loop: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set( + iname, inner_loop) + + for iname in not_seen_inames: + self.iname_to_tree_node_id = self.iname_to_tree_node_id.set( + iname, not_seen_inames) + + # }}} + + # }}} + + assert not (outer_redn_inames & redn_inames) + return super().map_reduction( + expr, + expn_state, + outer_redn_inames=(outer_redn_inames | redn_inames)) + + +def _add_reduction_loops_in_partial_loop_nest_tree(kernel, tree): + """ + Returns a partial loop nest tree with the loop nests corresponding to the + reduction inames added to *tree*. + """ + from loopy.symbolic import SubstitutionRuleMappingContext + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + reduction_loop_inserter = ReductionLoopInserter(rule_mapping_context, tree) + + def does_insn_have_reduce(kernel, insn, *args): + return bool(insn.reduction_inames()) + + reduction_loop_inserter.map_kernel(kernel, + within=does_insn_have_reduce, + map_tvs=False, map_args=False) + return reduction_loop_inserter.tree + + +def _get_partial_loop_nest_tree_for_fusion(kernel): + from loopy.schedule.tools import _get_partial_loop_nest_tree + tree = _get_partial_loop_nest_tree(kernel) + tree = _add_reduction_loops_in_partial_loop_nest_tree(kernel, tree) + return tree + + +def get_kennedy_unweighted_fusion_candidates( + kernel: LoopKernel, + candidates: FrozenSet[str], + *, + force_infusible: Optional[Callable[[str, str], bool]] = None, + prefix="ifused"): + """ + Returns the fusion candidates mapping that could be fed to + :func:`rename_inames_in_batch` similar to Ken Kennedy's Unweighted + Loop-Fusion Algorithm. + + .. attribute:: prefix + + Prefix for the fused inames. + """ + from loopy.kernel.data import ConcurrentTag + from loopy.schedule.tools import ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree) + from collections.abc import Collection + assert not isinstance(candidates, str) + assert isinstance(candidates, Collection) + assert isinstance(kernel, LoopKernel) + + candidates = frozenset(candidates) + vng = kernel.get_var_name_generator() + fused_chunks = {} + + if force_infusible is None: + force_infusible = lambda x, y: False # noqa: E731 + + # {{{ implementation scope + + # All of the candidates must be either "pure" reduction loops or + # pure-within_inames loops. + # Reason: otherwise _compute_isinfusible_via_access_map might result in + # spurious results. + # One option is to simply perform 'realize_reduction' before implementing + # this algorithm, but that seems like an unnecessary cost to pay. + if any(candidates & insn.reduction_inames() + for insn in kernel.instructions): + if any(candidates & insn.within_inames + for insn in kernel.instructions): + raise NotImplementedError("Some candidates are reduction" + " inames while some of them are not. Such" + " cases are not yet supported.") + + # }}} + + # {{{ handle concurrent inames + + # filter out concurrent loops. + all_concurrent_tags = reduce(frozenset.union, + (kernel.inames[iname].tags_of_type(ConcurrentTag) + for iname in candidates), + frozenset()) + + concurrent_tag_to_inames = {tag: set() + for tag in all_concurrent_tags} + + for iname in candidates: + if kernel.inames[iname].tags_of_type(ConcurrentTag): + # since ConcurrentTag is a UniqueTag there must be exactly one of + # it. + tag, = kernel.tags_of_type(ConcurrentTag) + concurrent_tag_to_inames[tag].add(iname) + + for inames in concurrent_tag_to_inames.values(): + fused_chunks[vng(prefix)] = inames + candidates = candidates - inames + + # }}} + + tree = _get_partial_loop_nest_tree_for_fusion(kernel) + + iname_to_tree_node_id = ( + _get_iname_to_tree_node_id_from_partial_loop_nest_tree(tree)) + + # {{{ sanitary checks + + _nest_tree_id_to_candidate = {} + _fusable_candidates = set() + + for iname in candidates: + loop_nest_tree_node_id = iname_to_tree_node_id[iname] + if loop_nest_tree_node_id not in _nest_tree_id_to_candidate: + _nest_tree_id_to_candidate[loop_nest_tree_node_id] = iname + _fusable_candidates.add(iname) + else: + conflict_iname = _nest_tree_id_to_candidate[loop_nest_tree_node_id] + from warnings import warn + warn(f"'{iname}' and '{conflict_iname}' " + "cannot be fused as they can be nested within one another.") + + for iname in _fusable_candidates: + outer_loops = reduce(frozenset.union, + tree.ancestors(iname_to_tree_node_id[iname]), + frozenset()) + if outer_loops & _fusable_candidates: + raise LoopyError(f"Cannot fuse '{iname}' with" + f" '{outer_loops & _fusable_candidates}' as they" + " maybe nesting within one another.") + + del _nest_tree_id_to_candidate + + # }}} + + # just_outer_loop_nest: mapping from loop nest to the candidates they + # contain + just_outer_loop_nest = {tree.parent(iname_to_tree_node_id[iname]): set() + for iname in candidates} + + for iname in _fusable_candidates: + just_outer_loop_nest[tree.parent(iname_to_tree_node_id[iname])].add(iname) + + for outer_inames, inames in just_outer_loop_nest.items(): + fused_chunks.update(_fuse_sequential_loops_with_outer_loops(kernel, + inames, + outer_inames, + vng, + prefix, + force_infusible + )) + + return fused_chunks + + +def rename_inames_in_batch(kernel, batches: Mapping[str, FrozenSet[str]]): + """ + Returns a copy of *kernel* with inames renamed according to *batches*. + + :arg kernel: An instance of :class:`loopy.LoopKernel`. + :arg batches: A mapping from ``new_iname`` to a :class:`frozenset` of + inames that are to be renamed to ``new_iname``. + """ + from loopy.transform.iname import rename_inames, remove_unused_inames + for new_iname, candidates in batches.items(): + # pylint:disable=unexpected-keyword-arg + kernel = rename_inames( + kernel, candidates, new_iname, + remove_newly_unused_inames=False + ) + + return remove_unused_inames(kernel, reduce(frozenset.union, + batches.values(), + frozenset())) + +# vim: foldmethod=marker diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 406664126..9882c3a45 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -369,8 +369,9 @@ def map_kernel(self, kernel): dep_insn = kernel.id_to_insn[dep] if (frozenset(dep_insn.assignee_var_names()) & self.compute_read_variables): - self.compute_insn_depends_on.update( - insn.depends_on - excluded_insn_ids) + # self.compute_insn_depends_on.update( + # insn.depends_on - excluded_insn_ids) + self.compute_insn_depends_on.add(dep) new_insns.append(insn) diff --git a/loopy/transform/reduction.py b/loopy/transform/reduction.py new file mode 100644 index 000000000..8824dd1c1 --- /dev/null +++ b/loopy/transform/reduction.py @@ -0,0 +1,292 @@ +""" +.. currentmodule:: loopy + +.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction + +.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst +""" + +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pymbolic.primitives as p + +from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any, + Optional, Sequence) +from loopy.symbolic import IdentityMapper, Reduction, CombineMapper +from loopy.kernel import LoopKernel +from loopy.kernel.data import SubstitutionRule +from loopy.diagnostic import LoopyError + + +# {{{ partition (copied from more-itertools) + +Tpart = TypeVar("Tpart") + + +def partition(pred: Callable[[Tpart], bool], + iterable: Iterable[Tpart]) -> Tuple[List[Tpart], + List[Tpart]]: + """ + Use a predicate to partition entries into false entries and true + entries + """ + # Inspired from https://docs.python.org/3/library/itertools.html + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + from itertools import tee, filterfalse + t1, t2 = tee(iterable) + return list(filterfalse(pred, t1)), list(filter(pred, t2)) + +# }}} + + +# {{{ hoist_reduction_invariant_terms + +class EinsumTermsHoister(IdentityMapper): + """ + Mapper to hoist products out of a sum-reduction. + + .. attribute:: reduction_inames + + Inames of the reduction expressions to perform the hoisting. + """ + def __init__(self, reduction_inames: FrozenSet[str]): + super().__init__() + self.reduction_inames = reduction_inames + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction # type: ignore[override] + ) -> p.Expression: + if frozenset(expr.inames) != self.reduction_inames: + return super().map_reduction(expr) + + from loopy.library.reduction import SumReductionOperation + from loopy.symbolic import get_dependencies + if isinstance(expr.operation, SumReductionOperation): + if isinstance(expr.expr, p.Product): + from pymbolic.primitives import flattened_product + multiplicative_terms = (flattened_product(self.rec(expr.expr) + .children) + .children) + else: + multiplicative_terms = (expr.expr,) + + invariants, variants = partition(lambda x: (get_dependencies(x) + & self.reduction_inames), + multiplicative_terms) + if not variants: + # -> everything is invariant + return self.rec(expr.expr) * Reduction( + expr.operation, + inames=expr.inames, + expr=1, # FIXME: invalid dtype (not sure how?) + allow_simultaneous=expr.allow_simultaneous) + if not invariants: + # -> nothing to hoist + return Reduction( + expr.operation, + inames=expr.inames, + expr=self.rec(expr.expr), + allow_simultaneous=expr.allow_simultaneous) + + return p.Product(tuple(invariants)) * Reduction( + expr.operation, + inames=expr.inames, + expr=p.Product(tuple(variants)), + allow_simultaneous=expr.allow_simultaneous) + else: + return super().map_reduction(expr) + + +def hoist_invariant_multiplicative_terms_in_sum_reduction( + kernel: LoopKernel, + reduction_inames: Union[str, FrozenSet[str]], + within: Any = None +) -> LoopKernel: + """ + Hoists loop-invariant multiplicative terms in a sum-reduction expression. + + :arg reduction_inames: The inames over which reduction is performed that defines + the reduction expression that is to be transformed. + :arg within: A match expression understood by :func:`loopy.match.parse_match` + that specifies the instructions over which the transformation is to be + performed. + """ + from loopy.transform.instruction import map_instructions + if isinstance(reduction_inames, str): + reduction_inames = frozenset([reduction_inames]) + + if not (reduction_inames <= kernel.all_inames()): + raise ValueError(f"Some inames in '{reduction_inames}' not a part of" + " the kernel.") + + term_hoister = EinsumTermsHoister(reduction_inames) + + return map_instructions(kernel, + insn_match=within, + f=lambda x: x.with_transformed_expressions(term_hoister) + ) + +# }}} + + +# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst + +class ContainsSumReduction(CombineMapper): + """ + Returns *True* only if the mapper maps over an expression containing a + SumReduction operation. + """ + def combine(self, values: Iterable[bool]) -> bool: + return any(values) + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction) -> bool: # type: ignore[override] + from loopy.library.reduction import SumReductionOperation + return (isinstance(expr.operation, SumReductionOperation) + or self.rec(expr.expr)) + + def map_variable(self, expr: p.Variable) -> bool: + return False + + def map_algebraic_leaf(self, expr: Any) -> bool: + return False + + +class MultiplicativeTermReplacer(IdentityMapper): + """ + Primary mapper of + :func:`extract_multiplicative_terms_in_sum_reduction_as_subst`. + """ + def __init__(self, + *, + terms_filter: Callable[[p.Expression], bool], + subst_name: str, + subst_arguments: Tuple[str, ...]) -> None: + self.subst_name = subst_name + self.subst_arguments = subst_arguments + self.terms_filter = terms_filter + super().__init__() + + # mutable state to record the expression collected by the terms_filter + self.collected_subst_rule: Optional[SubstitutionRule] = None + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction) -> Reduction: # type: ignore[override] + from loopy.library.reduction import SumReductionOperation + from loopy.symbolic import SubstitutionMapper + if isinstance(expr.operation, SumReductionOperation): + if self.collected_subst_rule is not None: + # => there was already a sum-reduction operation -> raise + raise ValueError("Multiple sum reduction expressions found -> not" + " allowed.") + + if isinstance(expr.expr, p.Product): + from pymbolic.primitives import flattened_product + terms = flattened_product(expr.expr.children).children + else: + terms = (expr.expr,) + + unfiltered_terms, filtered_terms = partition(self.terms_filter, terms) + submap = SubstitutionMapper({ + argument_expr: p.Variable(f"arg{i}") + for i, argument_expr in enumerate(self.subst_arguments)}.get) + self.collected_subst_rule = SubstitutionRule( + name=self.subst_name, + arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))), + expression=submap(p.Product(tuple(filtered_terms)) + if filtered_terms + else 1) + ) + return Reduction( + expr.operation, + expr.inames, + p.Product((p.Variable(self.subst_name)(*self.subst_arguments), + *unfiltered_terms)), + expr.allow_simultaneous) + else: + return super().map_reduction(expr) + + +def extract_multiplicative_terms_in_sum_reduction_as_subst( + kernel: LoopKernel, + within: Any, + subst_name: str, + arguments: Sequence[p.Expression], + terms_filter: Callable[[p.Expression], bool], +) -> LoopKernel: + """ + Returns a copy of *kernel* with a new substitution named *subst_name* and + *arguments* as arguments for the aggregated multiplicative terms in a + sum-reduction expression. + + :arg within: A match expression understood by :func:`loopy.match.parse_match` + to specify the instructions over which the transformation is to be + performed. + :arg terms_filter: A callable to filter which terms of the sum-reduction + comprise the body of substitution rule. + :arg arguments: The sub-expressions of the product of the filtered terms that + form the arguments of the extract substitution rule in the same order. + + .. note:: + + A ``LoopyError`` is raised if none or more than 1 sum-reduction expression + appear in *within*. + """ + from loopy.match import parse_match + within = parse_match(within) + + matched_insns = [ + insn + for insn in kernel.instructions + if within(kernel, insn) and ContainsSumReduction()((insn.expression, + tuple(insn.predicates))) + ] + + if len(matched_insns) == 0: + raise LoopyError(f"No instructions found matching '{within}'" + " with sum-reductions found.") + if len(matched_insns) > 1: + raise LoopyError(f"More than one instruction found matching '{within}'" + " with sum-reductions found -> not allowed.") + + insn, = matched_insns + replacer = MultiplicativeTermReplacer(subst_name=subst_name, + subst_arguments=tuple(arguments), + terms_filter=terms_filter) + new_insn = insn.with_transformed_expressions(replacer) + new_rule = replacer.collected_subst_rule + new_substitutions = dict(kernel.substitutions).copy() + if subst_name in new_substitutions: + raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution" + " rule named '{subst_name}'.") + assert new_rule is not None + new_substitutions[subst_name] = new_rule + + return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn + for insn in kernel.instructions], + substitutions=new_substitutions) + +# }}} + + +# vim: foldmethod=marker diff --git a/loopy/transform/reindex.py b/loopy/transform/reindex.py new file mode 100644 index 000000000..3d4e7c562 --- /dev/null +++ b/loopy/transform/reindex.py @@ -0,0 +1,329 @@ +""" +.. currentmodule:: loopy + +.. autofunction:: reindex_temporary_using_seghir_loechner_scheme +""" + +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import islpy as isl +from typing import Union, Iterable, Tuple +# from loopy.typing import ExpressionT +from loopy.typing import Expression +from loopy.kernel import LoopKernel +from loopy.diagnostic import LoopyError +from loopy.symbolic import CombineMapper +from loopy.kernel.instruction import (MultiAssignmentBase, + CInstruction, BarrierInstruction) +from loopy.symbolic import RuleAwareIdentityMapper + + +ISLMapT = Union[isl.BasicMap, isl.Map] +ISLSetT = Union[isl.BasicSet, isl.Set] + + +def _add_prime_to_dim_names(isl_map: ISLMapT, + dts: Iterable[isl.dim_type]) -> ISLMapT: + """ + Returns a copy of *isl_map* with dims of types *dts* having their names + suffixed with an apostrophe (``'``). + + .. testsetup:: + + >>> import islpy as isl + >>> from loopy.transform.reindex import _add_prime_to_dim_names + + .. doctest:: + + >>> amap = isl.Map("{[i]->[j=2i]}") + >>> _add_prime_to_dim_names(amap, [isl.dim_type.in_, isl.dim_type.out]) + Map("{ [i'] -> [j' = 2i'] }") + """ + for dt in dts: + for idim in range(isl_map.dim(dt)): + old_name = isl_map.get_dim_name(dt, idim) + new_name = f"{old_name}'" + isl_map = isl_map.set_dim_name(dt, idim, new_name) + + return isl_map + + +def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT + ) -> Tuple[isl.PwQPolynomial, + isl.PwQPolynomial]: + """ + Returns ``(reindex_map, new_shape)``, where, + + * ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1, + .., in)}`` representing that an array indexed via the subscripts + ``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as + ``f(i1, .., in)``. + * ``new_shape`` is a quasi-polynomial corresponding to the shape of the + re-indexed 1-dimensional array. + """ + + # {{{ create amap: an ISL map which is an identity map from access_map's range + + amap = isl.BasicMap.identity( + access_range + .space + .add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out))) + + # set amap's dim names + for idim in range(amap.dim(isl.dim_type.in_)): + amap = amap.set_dim_name(isl.dim_type.in_, idim, + f"_lpy_in_{idim}") + amap = amap.set_dim_name(isl.dim_type.out, idim, + f"_lpy_out_{idim}") + + amap = amap.intersect_domain(access_range) + + # }}} + + n_in = amap.dim(isl.dim_type.out) + n_out = amap.dim(isl.dim_type.out) + + amap_lexmin = amap.lexmin() + primed_amap_lexmin = _add_prime_to_dim_names(amap_lexmin, [isl.dim_type.in_, + isl.dim_type.out]) + + lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin) + + # make the lexmin map parametric in terms of it's previous access expressions. + lex_lt_set = (lex_lt_map + .move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in) + .domain()) + + # {{{ initialize amap_to_count + + amap_to_count = _add_prime_to_dim_names(amap, [isl.dim_type.in_]) + amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in) + + for idim in range(n_in): + amap_to_count = amap_to_count.set_dim_name( + isl.dim_type.param, idim, + amap.get_dim_name(isl.dim_type.in_, idim)) + + amap_to_count = amap_to_count.intersect_domain(lex_lt_set) + + # }}} + + result = amap_to_count.range().card() + + # {{{ simplify 'result' by gisting with 'access_range' + + aligned_access_range = access_range.move_dims(isl.dim_type.param, 0, + isl.dim_type.set, 0, n_out) + + for idim in range(result.dim(isl.dim_type.param)): + aligned_access_range = ( + aligned_access_range + .set_dim_name(isl.dim_type.param, idim, + result.space.get_dim_name(isl.dim_type.param, + idim))) + + result = result.gist_params(aligned_access_range.params()) + + # }}} + + return result, access_range.card() + + +class _IndexCollector(CombineMapper): + """ + A mapper that collects all instances of + :class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`. + """ + def __init__(self, var_name): + super().__init__() + self.var_name = var_name + + def combine(self, values): + from functools import reduce + return reduce(frozenset.union, values, frozenset()) + + def map_subscript(self, expr): + if expr.aggregate.name == self.var_name: + return frozenset([expr]) | super().map_subscript(expr) + else: + return super().map_subscript(expr) + + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant + map_tagged_variable = map_constant + map_type_cast = map_constant + map_nan = map_constant + + +class ReindexingApplier(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, + var_to_reindex, + reindexed_var_name, + new_index_expr, + index_names): + + super().__init__(rule_mapping_context) + + self.var_to_reindex = var_to_reindex + self.reindexed_var_name = reindexed_var_name + self.new_index_expr = new_index_expr + self.index_names = index_names + + def map_subscript(self, expr, expn_state): + if expr.aggregate.name != self.var_to_reindex: + return super().map_subscript(expr, expn_state) + + from loopy.symbolic import SubstitutionMapper + from pymbolic.mapper.substitutor import make_subst_func + from pymbolic.primitives import Subscript, Variable + + rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple) + + assert len(self.index_names) == len(rec_indices) + subst_func = make_subst_func({idx_name: rec_idx + for idx_name, rec_idx in zip(self.index_names, + rec_indices)}) + + return SubstitutionMapper(subst_func)( + Subscript(Variable(self.reindexed_var_name), + self.new_index_expr) + ) + + +def reindex_temporary_using_seghir_loechner_scheme(kernel: LoopKernel, + var_name: str, + ) -> LoopKernel: + """ + Returns a kernel with expressions of the form ``var_name[i1, .., in]`` + replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a + quasi-polynomial as outlined in [Seghir_2006]_. + """ + from loopy.transform.subst import expand_subst + from loopy.symbolic import (BatchedAccessMapMapper, pw_qpolynomial_to_expr, + SubstitutionRuleMappingContext) + + if var_name not in kernel.temporary_variables: + raise LoopyError(f"'{var_name}' not in temporary variable in kernel" + f" '{kernel.name}'.") + + # {{{ compute the access_range of *var_name* in *kernel* + + subst_kernel = expand_subst(kernel) + access_map_recorder = BatchedAccessMapMapper( + subst_kernel, + frozenset([var_name])) + + # access_exprs: Tuple[ExpressionT, ...] + access_exprs: Tuple[Expression, ...] + + for insn in subst_kernel.instructions: + if var_name in insn.dependency_names(): + if isinstance(insn, MultiAssignmentBase): + access_exprs = (insn.assignees, + insn.expression, + tuple(insn.predicates)) + elif isinstance(insn, (CInstruction, BarrierInstruction)): + access_exprs = tuple(insn.predicates) + else: + raise NotImplementedError(type(insn)) + + access_map_recorder(access_exprs, insn.within_inames) + + vng = kernel.get_var_name_generator() + new_var_name = vng(var_name+"_reindexed") + + access_range = access_map_recorder.get_access_range(var_name) + + del subst_kernel + del access_map_recorder + + # }}} + + subst, new_shape = _get_seghir_loechner_reindexing_from_range( + access_range) + + # {{{ simplify new_shape with the assumptions from kernel + + new_shape = new_shape.gist_params(kernel.assumptions) + + # }}} + + # {{{ update kernel.temporary_variables + + new_shape = new_shape.drop_unused_params() + + new_temps = dict(kernel.temporary_variables).copy() + new_temps[new_var_name] = new_temps.pop(var_name).copy( + name=new_var_name, + shape=pw_qpolynomial_to_expr(new_shape), + strides=None, + dim_tags=None, + dim_names=None, + ) + + kernel = kernel.copy(temporary_variables=new_temps) + + # }}} + + # {{{ perform the substitution i.e. reindex the accesses + + subst_expr = pw_qpolynomial_to_expr(subst) + subst_dim_names = tuple( + subst.space.get_dim_name(isl.dim_type.param, idim) + for idim in range(access_range.dim(isl.dim_type.out))) + assert not (set(subst_dim_names) & kernel.all_variable_names()) + + rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions, + vng) + reindexing_mapper = ReindexingApplier(rule_mapping_context, + var_name, new_var_name, + subst_expr, subst_dim_names) + + def _does_access_var_name(kernel, insn, *args): + return var_name in insn.dependency_names() + + kernel = reindexing_mapper.map_kernel(kernel, + within=_does_access_var_name, + map_args=False, + map_tvs=False) + kernel = rule_mapping_context.finish_kernel(kernel) + + # }}} + + # Note: Distributing a piece of code that depends on loopy and distributes + # code that conditionally/unconditionally calls this routine does *NOT* + # become a derivative of GPLv2. Since, as per point (0) of GPLV2 a + # derivative is defined as: "a work containing the Program or a portion of + # it, either verbatim or with modifications and/or translated into another + # language." + # + # Loopy does *NOT* contain any portion of the barvinok library in it's + # source code. + + return kernel + +# vim: fdm=marker diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index 3ca981aa0..ff2036b53 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -26,7 +26,7 @@ import logging from pymbolic import var -from pytools import ImmutableRecord +from pytools import ImmutableRecord, memoize_on_first_arg from loopy.diagnostic import LoopyError from loopy.kernel.function_interface import CallableKernel, ScalarCallable @@ -511,6 +511,7 @@ def _accesses_lhs(kernel, insn, *args): # {{{ expand_subst @for_each_kernel +@memoize_on_first_arg def expand_subst(kernel, within=None): """ Returns an instance of :class:`loopy.LoopKernel` with the substitutions diff --git a/requirements.txt b/requirements.txt index 0751539e6..dffe232bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ ply>=3.6 # Optional, for testing special math function scipy +# Optional, kanren-style relation helpers +git+https://github.com/pythological/kanren.git#egg=miniKanren diff --git a/test/test_loop_fusion.py b/test/test_loop_fusion.py new file mode 100644 index 000000000..678718295 --- /dev/null +++ b/test/test_loop_fusion.py @@ -0,0 +1,422 @@ +__copyright__ = "Copyright (C) 2021 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +import numpy as np +import loopy as lp +import pyopencl as cl +import pyopencl.clmath # noqa +import pyopencl.clrandom # noqa + +import logging +logger = logging.getLogger(__name__) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + +from pyopencl.tools import pytest_generate_tests_for_pyopencl \ + as pytest_generate_tests + +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa + +__all__ = [ + "pytest_generate_tests", + "cl" # "cl.create_some_context" + ] + + +def test_loop_fusion_vanilla(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i0, i1, j0, j1]: 0 <= i0, i1, j0, j1 < 10}", + """ + a[i0] = 1 + b[i1, j0] = 2 {id=write_b} + c[j1] = 3 {id=write_c} + """) + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["j0", "j1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + assert len(ref_knl["loopy_kernel"].all_inames()) == 4 + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert len(knl["loopy_kernel"].id_to_insn["write_b"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames) == 1 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_outer_iname_preventing_fusion(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i0, j0, j1]: 0 <= i0, j0, j1 < 10}", + """ + a[i0] = 1 + b[i0, j0] = 2 {id=write_b} + c[j1] = 3 {id=write_c} + """) + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["j0", "j1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert len(knl["loopy_kernel"].all_inames()) == 3 + assert len(knl["loopy_kernel"].id_to_insn["write_b"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames) == 0 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_independent_deps(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[j0, j1]: 0 <= j0, j1 < 10}", + """ + a[j0] = 1 + b[j1] = 2 * a[j1] + """, seq_dependencies=True) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["j0", "j1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 1 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_constrained_by_outer_loop_deps(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[j0, j1]: 0 <= j0, j1 < 10}", + """ + a[j0] = 1 {id=write_a} + b = 2 {id=write_b} + c[j1] = 2 * a[j1] {id=write_c} + """, seq_dependencies=True) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["j0", "j1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].id_to_insn["write_a"].within_inames + & knl["loopy_kernel"].id_to_insn["write_c"].within_inames) == 0 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_carried_deps1(ctx_factory): + + ctx = ctx_factory() + knl = lp.make_kernel( + "{[i0, i1]: 1<=i0, i1<10}", + """ + x[i0] = i0 {id=first_write} + x[i1-1] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["i0", + "i1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 1 + assert len(knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames) == 1 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_loop_carried_deps2(ctx_factory): + ctx = ctx_factory() + knl = lp.make_kernel( + "{[i0, i1]: 1<=i0, i1<10}", + """ + x[i0-1] = i0 {id=first_write} + x[i1] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["i0", + "i1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames) == 0 + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + + +def test_loop_fusion_with_indirection(ctx_factory): + ctx = ctx_factory() + map_ = np.random.permutation(10) + cq = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{[i0, i1]: 0<=i0, i1<10}", + """ + x[i0] = i0 {id=first_write} + x[map[i1]] = i1 ** 2 {id=second_write} + """, + seq_dependencies=True) + + ref_knl = knl + + fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(knl["loopy_kernel"], + frozenset(["i0", + "i1"])) + + knl = knl.with_kernel(lp.rename_inames_in_batch(knl["loopy_kernel"], + fused_chunks)) + + assert len(ref_knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].all_inames()) == 2 + assert len(knl["loopy_kernel"].id_to_insn["first_write"].within_inames + & knl["loopy_kernel"].id_to_insn["second_write"].within_inames) == 0 + + _, (out1,) = ref_knl(cq, map=map_) + _, (out2,) = knl(cq, map=map_) + np.testing.assert_allclose(out1, out2) + + +def test_loop_fusion_with_induced_dependencies_from_sibling_nests(ctx_factory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[i0, j, i1, i2]: 0<=i0, j, i1, i2<10}", + """ + <> tmp0[i0] = i0 + <> tmp1[j] = tmp0[j] + <> tmp2[j] = j + out1[i1] = tmp2[i1] + out2[i2] = 2 * tmp1[i2] + """) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["i0", "i1"]))) + t_unit = t_unit.with_kernel(knl) + + # 'i1', 'i2' should not be fused. If fused that would lead to an + # unshcedulable kernel. Making sure that the kernel 'runs' suffices that + # the transformation was successful. + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_reduction_inames(ctx_factory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2]: 0<=i, j0, j1, j2<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j0, sum([j2], 3*A[i, j0, j2])) + """, [lp.GlobalArg("A", + dtype=np.float64, + shape=lp.auto), ...]) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j1", "j2"]))) + assert (knl.id_to_insn["insn"].reduction_inames() + == knl.id_to_insn["insn_0"].reduction_inames()) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_reduction_inames_with_depth_mismatch(ctx_factory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2, j3]: 0<=i, j0, j1, j2, j3<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j2, sum([j3], 3*A[i, j3, j2])) + """, [lp.GlobalArg("A", + dtype=np.float64, + shape=lp.auto), + ...]) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j1", "j3"]))) + + # cannot fuse 'j1', 'j3' because they are not nested within the same outer + # inames. + assert (knl.id_to_insn["insn"].reduction_inames() + != knl.id_to_insn["insn_0"].reduction_inames()) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_on_outer_reduction_inames(ctx_factory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1, j2, j3]: 0<=i, j0, j1, j2, j3<10}", + """ + y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1])) + y1[i] = sum(j2, sum([j3], 3*A[i, j3, j2])) + """, [lp.GlobalArg("A", + dtype=np.float64, + shape=lp.auto), + ...]) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j0", "j2"]))) + + assert len(knl.id_to_insn["insn"].reduction_inames() + & knl.id_to_insn["insn_0"].reduction_inames()) == 1 + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_loop_fusion_reduction_inames_simple(ctx_factory): + ctx = ctx_factory() + + t_unit = lp.make_kernel( + "{[i, j0, j1]: 0<=i, j0, j1<10}", + """ + y0[i] = sum(j0, 2*A[i, j0]) + y1[i] = sum(j1, 3*A[i, j1]) + """, [lp.GlobalArg("A", + dtype=np.float64, + shape=lp.auto), + ...]) + ref_t_unit = t_unit + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j0", "j1"]))) + + assert (knl.id_to_insn["insn"].reduction_inames() + == knl.id_to_insn["insn_0"].reduction_inames()) + + t_unit = t_unit.with_kernel(knl) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) + + +def test_redn_loop_fusion_with_non_candidates_loops_in_nest(ctx_factory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[i, j1, j2, d]: 0<=i, j1, j2, d<10}", + """ + for i + for d + out1[i, d] = sum(j1, 2 * j1*i) + end + out2[i] = sum(j2, 2 * j2) + end + """, seq_dependencies=True) + ref_t_unit = t_unit + + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j1", "j2"]))) + + assert not (knl.id_to_insn["insn"].reduction_inames() + & knl.id_to_insn["insn_0"].reduction_inames()) + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) + + +def test_reduction_loop_fusion_with_multiple_redn_in_same_insn(ctx_factory): + ctx = ctx_factory() + t_unit = lp.make_kernel( + "{[j1, j2]: 0<=j1, j2<10}", + """ + out = sum(j1, 2*j1) + sum(j2, 2*j2) + """, seq_dependencies=True) + ref_t_unit = t_unit + + knl = t_unit.default_entrypoint + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, frozenset(["j1", "j2"]))) + + assert len(knl.id_to_insn["insn"].reduction_inames()) == 1 + + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl)) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: fdm=marker diff --git a/test/test_pycuda_invoker.py b/test/test_pycuda_invoker.py new file mode 100644 index 000000000..8cf5bcad4 --- /dev/null +++ b/test/test_pycuda_invoker.py @@ -0,0 +1,305 @@ +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +import numpy as np +import loopy as lp +import pytest +pytest.importorskip("pycuda") +import pycuda.gpuarray as cu_np +import itertools + +import logging +logger = logging.getLogger(__name__) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + +from typing import Tuple, Any +from pycuda.tools import init_cuda_context_fixture +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa + + +@pytest.fixture(autouse=True) +def init_cuda_context(): + yield from init_cuda_context_fixture() + + +def get_random_array(rng, shape: Tuple[int, ...], dtype: np.dtype[Any]): + if np.issubdtype(dtype, np.complexfloating): + subdtype = np.empty(0, dtype=dtype).real.dtype + return (get_random_array(rng, shape, subdtype) + + dtype.type(1j) * get_random_array(rng, shape, subdtype)) + else: + assert np.issubdtype(dtype, np.floating) + return rng.random(shape, dtype=dtype) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +def test_pycuda_invoker(target): + m = 5 + n = 6 + + knl = lp.make_kernel( + "{[i, j]: 0<=i tmp[i] = sin(x[i]) + z[i] = 2 * tmp[i] + """, + target=target) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, out_host=False) + np.testing.assert_allclose(2*np.sin(x), out.get(), rtol=1e-6) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_multi_entrypoints(target, dtype): + rng = np.random.default_rng(seed=314) + x = rng.random(42, dtype=dtype) + + knl1 = lp.make_kernel( + "{[i]: 0<=i tmp[i] = 21*sin(x[i]) + 864.5*cos(y[i]) + z[i] = 2 * tmp[i] + """, + [lp.GlobalArg("x,y", + offset=lp.auto, shape=lp.auto), + ...], + target=target) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, y=y) + np.testing.assert_allclose(42*np.sin(x) + 1729*np.cos(y), out) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_sum_of_array(target, dtype, rtol): + # Reported by Mit Kotak + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out = sum(i, x[i]) + """, + target=target) + x = get_random_array(rng, (42,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(np.sum(x), out, rtol=rtol) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_int_pow(target, dtype, rtol): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out[i] = x[i] ** i + """, + target=target) + x = get_random_array(rng, (10,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(x ** np.arange(10, dtype=np.int32), out, + rtol=rtol) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128, + np.float32, np.float64]) +@pytest.mark.parametrize("func", ["abs", "sqrt", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", + "exp", "log", "log10"]) +def test_math_functions(target, dtype, func): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + f""" + y[i] = {func}(x[i]) + """, + target=target) + x = get_random_array(rng, (42,), np.dtype(dtype)) + _, (out,) = knl(x=x) + np.testing.assert_allclose(getattr(np, func)(x), + out, rtol=1e-6) + + +def test_pycuda_packargs_tgt_avoids_param_space_overflow(): + from pymbolic.primitives import Sum + from loopy.symbolic import parse + + nargs = 1_000 + rng = np.random.default_rng(32) + knl = lp.make_kernel( + "{[i]: 0<=i 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_transform.py b/test/test_transform.py index 3e3aabf14..dcb8ba7b6 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1650,6 +1650,156 @@ def test_concatenate_arrays(ctx_factory): lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) +def test_reindexing_strided_access(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j,i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (100,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + +def test_reindexing_figurate(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j<=i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (55,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + +def test_reindexing_figurate_parametric_shape(ctx_factory): + import islpy as isl + from loopy.symbolic import parse + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j<=i tmp[i, j] = a[i, j] + out[i, j] = tmp[i, j]**2 + """, + assumptions="n > 0", + ) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + tunit = lp.set_temporary_address_space(tunit, "tmp", + lp.AddressSpace.GLOBAL) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (parse("(n + n**2) // 2"),) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit, parameters={"n": 20}) + + +def test_sum_redn_algebraic_transforms(ctx_factory): + from pymbolic import variables + from loopy.symbolic import Reduction + + t_unit = lp.make_kernel( + "{[e,i,j,x,r]: 0<=e