Skip to content

Commit

Permalink
Merge branch 'develop' into feature/translate-test-multi-rank-failure
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck authored Feb 11, 2025
2 parents ba0e923 + 0aa1775 commit c1e4b25
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 29 deletions.
25 changes: 0 additions & 25 deletions ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, List, Optional, TypeVar, cast

from ndsl.comm.comm_abc import Comm, ReductionOperator, Request
from ndsl.logging import ndsl_log


T = TypeVar("T")
Expand Down Expand Up @@ -43,70 +42,46 @@ def Get_size(self) -> int:
return self._comm.Get_size()

def bcast(self, value: Optional[T], root=0) -> T:
ndsl_log.debug("bcast from root %s on rank %s", root, self._comm.Get_rank())
return self._comm.bcast(value, root=root)

def barrier(self):
ndsl_log.debug("barrier on rank %s", self._comm.Get_rank())
self._comm.barrier()

def Barrier(self):
pass

def Scatter(self, sendbuf, recvbuf, root=0, **kwargs):
ndsl_log.debug("Scatter on rank %s with root %s", self._comm.Get_rank(), root)
self._comm.Scatter(sendbuf, recvbuf, root=root, **kwargs)

def Gather(self, sendbuf, recvbuf, root=0, **kwargs):
ndsl_log.debug("Gather on rank %s with root %s", self._comm.Get_rank(), root)
self._comm.Gather(sendbuf, recvbuf, root=root, **kwargs)

def allgather(self, sendobj: T) -> List[T]:
ndsl_log.debug("allgather on rank %s", self._comm.Get_rank())
return self._comm.allgather(sendobj)

def Send(self, sendbuf, dest, tag: int = 0, **kwargs):
ndsl_log.debug("Send on rank %s with dest %s", self._comm.Get_rank(), dest)
self._comm.Send(sendbuf, dest, tag=tag, **kwargs)

def sendrecv(self, sendbuf, dest, **kwargs):
ndsl_log.debug("sendrecv on rank %s with dest %s", self._comm.Get_rank(), dest)
return self._comm.sendrecv(sendbuf, dest, **kwargs)

def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request:
ndsl_log.debug("Isend on rank %s with dest %s", self._comm.Get_rank(), dest)
return self._comm.Isend(sendbuf, dest, tag=tag, **kwargs)

def Recv(self, recvbuf, source, tag: int = 0, **kwargs):
ndsl_log.debug("Recv on rank %s with source %s", self._comm.Get_rank(), source)
self._comm.Recv(recvbuf, source, tag=tag, **kwargs)

def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
ndsl_log.debug("Irecv on rank %s with source %s", self._comm.Get_rank(), source)
return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs)

def Split(self, color, key) -> "Comm":
ndsl_log.debug(
"Split on rank %s with color %s, key %s", self._comm.Get_rank(), color, key
)
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.allreduce(sendobj, self._op_mapping[op])

def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op])

def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce (in place) on rank %s with operator %s",
self._comm.Get_rank(),
op,
)
return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op])
5 changes: 4 additions & 1 deletion ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable:
def orchestrate(
*,
obj: object,
config: DaceConfig,
config: Optional[DaceConfig],
method_to_orchestrate: str = "__call__",
dace_compiletime_args: Optional[Sequence[str]] = None,
):
Expand All @@ -455,6 +455,9 @@ def orchestrate(
dace_compiletime_args: list of names of arguments to be flagged has
dace.compiletime for orchestration to behave
"""
if config is None:
raise ValueError("DaCe config cannot be None")

if dace_compiletime_args is None:
dace_compiletime_args = []

Expand Down
3 changes: 3 additions & 0 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
from ndsl.dsl.typing import Float, Index3D, cast_to_index3d
from ndsl.initialization.sizer import GridSizer, SubtileGridSizer
from ndsl.logging import ndsl_log
from ndsl.quantity import Quantity
from ndsl.testing.comparison import LegacyMetric

Expand Down Expand Up @@ -374,6 +375,8 @@ def nothing_function(*args, **kwargs):
setattr(self, "__call__", nothing_function)

def __call__(self, *args, **kwargs) -> None:
if self.stencil_config.verbose:
ndsl_log.debug(f"Running {self._func_name}")
args_list = list(args)
_convert_quantities_to_storage(args_list, kwargs)
args = tuple(args_list)
Expand Down
1 change: 1 addition & 0 deletions ndsl/dsl/stencil_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class StencilConfig(Hashable):
compare_to_numpy: bool = False
compilation_config: CompilationConfig = CompilationConfig()
dace_config: Optional[DaceConfig] = None
verbose: bool = False

def __post_init__(self):
self.backend_opts = {
Expand Down
2 changes: 1 addition & 1 deletion ndsl/grid/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __init__(
self._dy_center = None
self._area = None
self._area_c = None
if eta_file is not None:
if eta_file is not None or ak is not None or bk is not None:
(
self._ks,
self._ptop,
Expand Down
36 changes: 34 additions & 2 deletions ndsl/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import logging
import os
import sys
from typing import Annotated

from mpi4py import MPI

Expand All @@ -18,7 +21,7 @@
}


def _ndsl_logger():
def _ndsl_logger() -> logging.Logger:
name_log = logging.getLogger(__name__)
name_log.setLevel(LOGLEVEL)

Expand All @@ -36,4 +39,33 @@ def _ndsl_logger():
return name_log


ndsl_log = _ndsl_logger()
def _ndsl_logger_on_rank_0() -> logging.Logger:
name_log = logging.getLogger(f"{__name__}_on_rank_0")
name_log.setLevel(LOGLEVEL)

rank = MPI.COMM_WORLD.Get_rank()

if rank == 0:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(LOGLEVEL)
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
"%(name)s:%(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
name_log.addHandler(handler)
else:
name_log.disabled = True
return name_log


ndsl_log: Annotated[
logging.Logger, "NDSL Python logger, logs on all rank"
] = _ndsl_logger()

ndsl_log_on_rank_0: Annotated[
logging.Logger, "NDSL Python logger, logs on rank 0 only"
] = _ndsl_logger_on_rank_0()
5 changes: 5 additions & 0 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import matplotlib.pyplot as plt
import numpy as np
from mpi4py import MPI

import ndsl.constants as constants
from ndsl.dsl.typing import Float, is_float
Expand Down Expand Up @@ -152,6 +153,10 @@ def from_data_array(
gt4py_backend=gt4py_backend,
)

def to_netcdf(self, name: str, rank: int = -1) -> None:
if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank:
self.data_array.to_netcdf(f"{name}__r{rank}.nc4")

def halo_spec(self, n_halo: int) -> QuantityHaloSpec:
return QuantityHaloSpec(
n_halo,
Expand Down

0 comments on commit c1e4b25

Please sign in to comment.