diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 6b3ff17f..3c466950 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -6,7 +6,6 @@ from typing import Dict, List, Optional, TypeVar, cast from ndsl.comm.comm_abc import Comm, ReductionOperator, Request -from ndsl.logging import ndsl_log T = TypeVar("T") @@ -43,70 +42,46 @@ def Get_size(self) -> int: return self._comm.Get_size() def bcast(self, value: Optional[T], root=0) -> T: - ndsl_log.debug("bcast from root %s on rank %s", root, self._comm.Get_rank()) return self._comm.bcast(value, root=root) def barrier(self): - ndsl_log.debug("barrier on rank %s", self._comm.Get_rank()) self._comm.barrier() def Barrier(self): pass def Scatter(self, sendbuf, recvbuf, root=0, **kwargs): - ndsl_log.debug("Scatter on rank %s with root %s", self._comm.Get_rank(), root) self._comm.Scatter(sendbuf, recvbuf, root=root, **kwargs) def Gather(self, sendbuf, recvbuf, root=0, **kwargs): - ndsl_log.debug("Gather on rank %s with root %s", self._comm.Get_rank(), root) self._comm.Gather(sendbuf, recvbuf, root=root, **kwargs) def allgather(self, sendobj: T) -> List[T]: - ndsl_log.debug("allgather on rank %s", self._comm.Get_rank()) return self._comm.allgather(sendobj) def Send(self, sendbuf, dest, tag: int = 0, **kwargs): - ndsl_log.debug("Send on rank %s with dest %s", self._comm.Get_rank(), dest) self._comm.Send(sendbuf, dest, tag=tag, **kwargs) def sendrecv(self, sendbuf, dest, **kwargs): - ndsl_log.debug("sendrecv on rank %s with dest %s", self._comm.Get_rank(), dest) return self._comm.sendrecv(sendbuf, dest, **kwargs) def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request: - ndsl_log.debug("Isend on rank %s with dest %s", self._comm.Get_rank(), dest) return self._comm.Isend(sendbuf, dest, tag=tag, **kwargs) def Recv(self, recvbuf, source, tag: int = 0, **kwargs): - ndsl_log.debug("Recv on rank %s with source %s", self._comm.Get_rank(), source) self._comm.Recv(recvbuf, source, tag=tag, **kwargs) def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request: - ndsl_log.debug("Irecv on rank %s with source %s", self._comm.Get_rank(), source) return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs) def Split(self, color, key) -> "Comm": - ndsl_log.debug( - "Split on rank %s with color %s, key %s", self._comm.Get_rank(), color, key - ) return self._comm.Split(color, key) def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: - ndsl_log.debug( - "allreduce on rank %s with operator %s", self._comm.Get_rank(), op - ) return self._comm.allreduce(sendobj, self._op_mapping[op]) def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T: - ndsl_log.debug( - "Allreduce on rank %s with operator %s", self._comm.Get_rank(), op - ) return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op]) def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T: - ndsl_log.debug( - "Allreduce (in place) on rank %s with operator %s", - self._comm.Get_rank(), - op, - ) return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op]) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 767610c3..c09b69cc 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -438,7 +438,7 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable: def orchestrate( *, obj: object, - config: DaceConfig, + config: Optional[DaceConfig], method_to_orchestrate: str = "__call__", dace_compiletime_args: Optional[Sequence[str]] = None, ): @@ -455,6 +455,9 @@ def orchestrate( dace_compiletime_args: list of names of arguments to be flagged has dace.compiletime for orchestration to behave """ + if config is None: + raise ValueError("DaCe config cannot be None") + if dace_compiletime_args is None: dace_compiletime_args = [] diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 5e917e66..daf78091 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -31,6 +31,7 @@ from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from ndsl.dsl.typing import Float, Index3D, cast_to_index3d from ndsl.initialization.sizer import GridSizer, SubtileGridSizer +from ndsl.logging import ndsl_log from ndsl.quantity import Quantity from ndsl.testing.comparison import LegacyMetric @@ -374,6 +375,8 @@ def nothing_function(*args, **kwargs): setattr(self, "__call__", nothing_function) def __call__(self, *args, **kwargs) -> None: + if self.stencil_config.verbose: + ndsl_log.debug(f"Running {self._func_name}") args_list = list(args) _convert_quantities_to_storage(args_list, kwargs) args = tuple(args_list) diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 6b8f75eb..4d3eafab 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -169,6 +169,7 @@ class StencilConfig(Hashable): compare_to_numpy: bool = False compilation_config: CompilationConfig = CompilationConfig() dace_config: Optional[DaceConfig] = None + verbose: bool = False def __post_init__(self): self.backend_opts = { diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index f77e2cd2..172f4d53 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -298,7 +298,7 @@ def __init__( self._dy_center = None self._area = None self._area_c = None - if eta_file is not None: + if eta_file is not None or ak is not None or bk is not None: ( self._ks, self._ptop, diff --git a/ndsl/logging.py b/ndsl/logging.py index 44cdb690..73b7979c 100644 --- a/ndsl/logging.py +++ b/ndsl/logging.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import os import sys +from typing import Annotated from mpi4py import MPI @@ -18,7 +21,7 @@ } -def _ndsl_logger(): +def _ndsl_logger() -> logging.Logger: name_log = logging.getLogger(__name__) name_log.setLevel(LOGLEVEL) @@ -36,4 +39,33 @@ def _ndsl_logger(): return name_log -ndsl_log = _ndsl_logger() +def _ndsl_logger_on_rank_0() -> logging.Logger: + name_log = logging.getLogger(f"{__name__}_on_rank_0") + name_log.setLevel(LOGLEVEL) + + rank = MPI.COMM_WORLD.Get_rank() + + if rank == 0: + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(LOGLEVEL) + formatter = logging.Formatter( + fmt=( + f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|" + "%(name)s:%(message)s" + ), + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + name_log.addHandler(handler) + else: + name_log.disabled = True + return name_log + + +ndsl_log: Annotated[ + logging.Logger, "NDSL Python logger, logs on all rank" +] = _ndsl_logger() + +ndsl_log_on_rank_0: Annotated[ + logging.Logger, "NDSL Python logger, logs on rank 0 only" +] = _ndsl_logger_on_rank_0() diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index c88ba140..33d72a44 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import numpy as np +from mpi4py import MPI import ndsl.constants as constants from ndsl.dsl.typing import Float, is_float @@ -152,6 +153,10 @@ def from_data_array( gt4py_backend=gt4py_backend, ) + def to_netcdf(self, name: str, rank: int = -1) -> None: + if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank: + self.data_array.to_netcdf(f"{name}__r{rank}.nc4") + def halo_spec(self, n_halo: int) -> QuantityHaloSpec: return QuantityHaloSpec( n_halo,