Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: unexpanded sdfg cleanups #1860

Merged
merged 5 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, la
sdfg.add_state(gtir_pipeline.gtir.name)
return sdfg

for array in sdfg.arrays.values():
if array.transient:
array.lifetime = dace.AllocationLifetime.Persistent

sdfg.simplify(validate=False)

_set_expansion_orders(sdfg)
Expand Down
15 changes: 15 additions & 0 deletions src/gt4py/cartesian/gtc/dace/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause


from typing import Final


# StencilComputation in/out connector prefixes
CONNECTOR_PREFIX_IN: Final = "__in_"
CONNECTOR_PREFIX_OUT: Final = "__out_"
17 changes: 11 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sympy

from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder
from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder

Expand Down Expand Up @@ -77,11 +78,11 @@ def _fix_context(
"""
# change connector names
for in_edge in parent_state.in_edges(node):
assert in_edge.dst_conn.startswith("__in_")
in_edge.dst_conn = in_edge.dst_conn[len("__in_") :]
assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN)
in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)
for out_edge in parent_state.out_edges(node):
assert out_edge.src_conn.startswith("__out_")
out_edge.src_conn = out_edge.src_conn[len("__out_") :]
assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT)
out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)

# union input and output subsets
subsets = {}
Expand Down Expand Up @@ -125,9 +126,13 @@ def _get_parent_arrays(
) -> Dict[str, dace.data.Data]:
parent_arrays: Dict[str, dace.data.Data] = {}
for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None):
parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[
edge.data.data
]
for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None):
parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[
edge.data.data
]
return parent_arrays

@staticmethod
Expand Down
10 changes: 4 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset


class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait):
Expand Down Expand Up @@ -268,13 +267,13 @@ def visit_ComputationState(
for memlet in computation.read_memlets:
if memlet.field not in read_acc_and_conn:
read_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
for memlet in computation.write_memlets:
if memlet.field not in write_acc_and_conn:
write_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
node_ctx = StencilComputationSDFGBuilder.NodeContext(
Expand All @@ -298,7 +297,7 @@ def visit_FieldDecl(
dtype=data_type_to_dace_typeclass(node.dtype),
storage=node.storage.to_dace_storage(),
transient=node.name not in non_transients,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(node),
)

def visit_SymbolDecl(
Expand Down Expand Up @@ -343,7 +342,6 @@ def visit_NestedSDFG(
inputs=node.input_connectors,
outputs=node.output_connectors,
symbol_mapping=symbol_mapping,
debuginfo=dace.DebugInfo(0),
)
self.visit(
node.read_memlets,
Expand Down
14 changes: 0 additions & 14 deletions src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

from typing import TYPE_CHECKING, List

import dace
import dace.data
import dace.library
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
Expand All @@ -25,15 +20,6 @@
from gt4py.cartesian.gtc.dace.nodes import StencilComputation


def get_dace_debuginfo(node: common.LocNode):
if node.loc is not None:
return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)
else:
return dace.dtypes.DebugInfo(0)


class HorizontalIntervalRemover(eve.NodeTranslator):
def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis):
mask_attrs = dict(i=node.i, j=node.j)
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter
from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection

from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo
from .expansion_specification import ExpansionItem, make_expansion_order


def _set_expansion_order(
node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]]
Expand Down
32 changes: 19 additions & 13 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
get_dace_debuginfo,
make_dace_subset,
)
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import (
AccessCollector,
Expand Down Expand Up @@ -93,9 +98,7 @@ def _make_dace_subset(self, local_access_info, field):
global_access_info, local_access_info, self.decls[field].data_dims
)

def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs
):
def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext):
declarations = {
acc.name: ctx.decls[acc.name]
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
Expand All @@ -117,22 +120,24 @@ def visit_VerticalLoop(
access_collection = AccessCollector.apply(node)

for field in access_collection.read_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_in_connector("__in_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = f"{CONNECTOR_PREFIX_IN}{field}"
library_node.add_in_connector(connector_name)
subset = ctx.make_input_dace_subset(node, field)
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_out_connector("__out_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = f"{CONNECTOR_PREFIX_OUT}{field}"
library_node.add_out_connector(connector_name)
subset = ctx.make_output_dace_subset(node, field)
state.add_edge(
library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset)
library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset)
)

def visit_Stencil(self, node: oir.Stencil, **kwargs):
def visit_Stencil(self, node: oir.Stencil):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
for param in node.params:
if isinstance(param, oir.FieldDecl):
Expand All @@ -148,7 +153,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(param.dtype),
transient=False,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(param),
)
else:
ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
Expand All @@ -166,7 +171,8 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(decl.dtype),
transient=True,
debuginfo=dace.DebugInfo(0),
lifetime=dace.AllocationLifetime.Persistent,
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
ctx.sdfg.validate()
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo:
if node.loc is None:
return dace.dtypes.DebugInfo(0)

return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)


def array_dimensions(array: dace.data.Array):
dims = [
any(
Expand Down