Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into gtir-dace-exclusive_if
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Feb 5, 2025
2 parents 5246b79 + 14d18b3 commit c6ae8b1
Show file tree
Hide file tree
Showing 16 changed files with 448 additions and 239 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:

- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.10
rev: 0.5.25
hooks:
- id: uv-lock

Expand Down
6 changes: 5 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
"internal": {"extras": [], "markers": ["not requires_dace"]},
"dace": {"extras": ["dace"], "markers": ["requires_dace"]},
}
# Use dace-next for GT4Py-next, to install a different dace version than in cartesian
CodeGenNextTestSettings = CodeGenTestSettings | {
"dace": {"extras": ["dace-next"], "markers": ["requires_dace"]},
}


# -- nox sessions --
Expand Down Expand Up @@ -158,7 +162,7 @@ def test_next(
) -> None:
"""Run selected 'gt4py.next' tests."""

codegen_settings = CodeGenTestSettings[codegen]
codegen_settings = CodeGenNextTestSettings[codegen]
device_settings = DeviceTestSettings[device]
groups: list[str] = ["test"]
mesh_markers: list[str] = []
Expand Down
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ all = ['gt4py[dace,formatting,jax,performance,testing]']
cuda11 = ['cupy-cuda11x>=12.0']
cuda12 = ['cupy-cuda12x>=12.0']
# features
dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4
dace = ['dace>=1.0.1,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4
dace-next = ['dace'] # pull dace latest version from the git repository
formatting = ['clang-format>=9.0']
jax = ['jax>=0.4.26']
jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]']
Expand Down Expand Up @@ -438,6 +439,14 @@ conflicts = [
{extra = 'jax-cuda12'},
{extra = 'rocm4_3'},
{extra = 'rocm5_0'}
],
[
{extra = 'dace'},
{extra = 'dace-next'}
],
[
{extra = 'all'},
{extra = 'dace-next'}
]
]

Expand All @@ -448,3 +457,4 @@ url = 'https://test.pypi.org/simple/'

[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = {git = "https://github.com/spcl/dace", branch = "main", extra = "dace-next"}
37 changes: 13 additions & 24 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class IteratorExpr:

field: dace.nodes.AccessNode
gt_dtype: ts.ListType | ts.ScalarType
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]]
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]]
indices: dict[gtx_common.Dimension, DataExpr]

def get_field_type(self) -> ts.FieldType:
Expand Down Expand Up @@ -798,9 +798,6 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp

assert len(node.args) == 3

# TODO(edopao): enable once supported in next DaCe release
use_conditional_block: Final[bool] = False

# evaluate the if-condition that will write to a boolean scalar node
condition_value = self.visit(node.args[0])
assert (
Expand All @@ -816,26 +813,18 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo)

# create states inside the nested SDFG for the if-branches
if use_conditional_block:
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body)

else:
entry_state = nsdfg.add_state("entry", is_start_block=True)
tstate = nsdfg.add_state("true_branch")
nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond"))
fstate = nsdfg.add_state("false_branch")
nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)"))
if_region = dace.sdfg.state.ConditionalBlock("if")
nsdfg.add_node(if_region)
entry_state = nsdfg.add_state("entry", is_start_block=True)
nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge())

then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg)
tstate = then_body.add_state("true_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg)
fstate = else_body.add_state("false_branch", is_start_block=True)
if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body)

input_memlets: dict[str, MemletExpr | ValueExpr] = {}
nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def _lower_lambda_to_nested_sdfg(
# the lambda expression, i.e. body of the scan, will be created inside a nested SDFG.
nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan"))
nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo)
# We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`.
# This property is used by pattern matching in SDFG transformation framework
# to skip those transformations that do not yet support control flow blocks.
nsdfg.using_explicit_control_flow = True
lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols)

# use the vertical dimension in the domain as scan dimension
Expand Down
4 changes: 0 additions & 4 deletions src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union

import dace
from dace.sdfg import utils as dace_sdfg_utils

from gt4py import eve
from gt4py.eve import concepts
Expand Down Expand Up @@ -999,9 +998,6 @@ def build_sdfg_from_gtir(
sdfg = sdfg_genenerator.visit(ir)
assert isinstance(sdfg, dace.SDFG)

# TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct
dace_sdfg_utils.inline_loop_blocks(sdfg)

if disable_field_origin_on_program_arguments:
_remove_field_origin_symbols(ir, sdfg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def gt_auto_optimize(
# For compatibility with DaCe (and until we found out why) the GT4Py
# auto optimizer will emulate this behaviour.
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
for edge in state.edges():
edge.data.wcr_nonatomic = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,62 +160,9 @@ def gt_gpu_transform_non_standard_memlet(
correct loop order.
- This function should be called after `gt_set_iteration_order()` has run.
"""
new_maps: set[dace_nodes.MapEntry] = set()

# This code is is copied from DaCe's code generator.
for e, state in list(sdfg.all_edges_recursive()):
nsdfg = state.parent
if (
isinstance(e.src, dace_nodes.AccessNode)
and isinstance(e.dst, dace_nodes.AccessNode)
and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global
and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global
):
a: dace_nodes.AccessNode = e.src
b: dace_nodes.AccessNode = e.dst

copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides(
None, nsdfg, state, e, a, b
)
dims = len(copy_shape)
if dims == 1:
continue
elif dims == 2:
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue
else:
continue
elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
continue

# For identifying the new map, we first store all neighbors of `a`.
old_neighbors_of_a: list[dace_nodes.AccessNode] = [
edge.dst for edge in state.out_edges(a)
]

# Turn unsupported copy to a map
try:
dace_transformation.dataflow.CopyToMap.apply_to(
nsdfg, save=False, annotate=False, a=a, b=b
)
except ValueError: # If transformation doesn't match, continue normally
continue

# We find the new map by comparing the new neighborhood of `a` with the old one.
new_nodes: set[dace_nodes.MapEntry] = {
edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a
}
assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes)
assert len(new_nodes) == 1
new_maps.update(new_nodes)
# Expand all non standard memlets and get the new MapEntries.
new_maps: set[dace_nodes.MapEntry] = _gt_expand_non_standard_memlets(sdfg)

# If there are no Memlets that are translated to copy-Maps, then we have nothing to do.
if len(new_maps) == 0:
Expand Down Expand Up @@ -283,6 +230,88 @@ def restrict_fusion_to_newly_created_maps(
return sdfg


def _gt_expand_non_standard_memlets(
sdfg: dace.SDFG,
) -> set[dace_nodes.MapEntry]:
"""Finds all non standard Memlet in the SDFG and expand them.
The function is used by `gt_gpu_transform_non_standard_memlet()` and performs
the actual expansion of the Memlet, i.e. turning all Memlets that can not be
expressed as a `memcpy()` into a Map, copy kernel.
The function will return the MapEntries of all expanded.
The function will process the SDFG recursively.
"""
new_maps: set[dace_nodes.MapEntry] = set()
for nsdfg in sdfg.all_sdfgs_recursive():
new_maps.update(_gt_expand_non_standard_memlets_sdfg(nsdfg))
return new_maps


def _gt_expand_non_standard_memlets_sdfg(
sdfg: dace.SDFG,
) -> set[dace_nodes.MapEntry]:
"""Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG."""
new_maps: set[dace_nodes.MapEntry] = set()
# The implementation is based on DaCe's code generator.
for state in sdfg.states():
for e in state.edges():
# We are only interested in edges that connects two access nodes of GPU memory.
if not (
isinstance(e.src, dace_nodes.AccessNode)
and isinstance(e.dst, dace_nodes.AccessNode)
and e.src.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global
and e.dst.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global
):
continue

a: dace_nodes.AccessNode = e.src
b: dace_nodes.AccessNode = e.dst
copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides(
None, sdfg, state, e, a, b
)
dims = len(copy_shape)
if dims == 1:
continue
elif dims == 2:
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue
else:
continue
elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
continue

# For identifying the new map, we first store all neighbors of `a`.
old_neighbors_of_a: list[dace_nodes.AccessNode] = [
edge.dst for edge in state.out_edges(a)
]

# Turn unsupported copy to a map
try:
dace_transformation.dataflow.CopyToMap.apply_to(
sdfg, save=False, annotate=False, a=a, b=b
)
except ValueError: # If transformation doesn't match, continue normally
continue

# We find the new map by comparing the new neighborhood of `a` with the old one.
new_nodes: set[dace_nodes.MapEntry] = {
edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a
}
assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes)
assert len(new_nodes) == 1
new_maps.update(new_nodes)
return new_maps


def gt_set_gpu_blocksize(
sdfg: dace.SDFG,
block_size: Optional[Sequence[int | str] | str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def gt_create_local_double_buffering(
it is not needed that the whole data is stored, but only the working set
of a single thread.
"""

processed_maps = 0
for nsdfg in sdfg.all_sdfgs_recursive():
processed_maps += _create_local_double_buffering_non_recursive(nsdfg)
Expand All @@ -60,6 +59,7 @@ def _create_local_double_buffering_non_recursive(

processed_maps = 0
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
scope_dict = state.scope_dict()
for node in state.nodes():
if not isinstance(node, dace_nodes.MapEntry):
Expand Down
Loading

0 comments on commit c6ae8b1

Please sign in to comment.