Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantity factory to use dsl Float type as default #44

Merged
merged 12 commits into from
Jun 3, 2024
Merged
9 changes: 5 additions & 4 deletions ndsl/grid/geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ndsl.comm.partitioner import TilePartitioner
from ndsl.dsl.typing import Float
from ndsl.grid.gnomonic import (
get_lonlat_vect,
get_unit_vector_direction,
Expand Down Expand Up @@ -591,7 +592,7 @@ def edge_factors(
nhalo: int,
tile_partitioner: TilePartitioner,
rank: int,
radius: float,
radius: Float,
np,
):
"""
Expand Down Expand Up @@ -704,7 +705,7 @@ def efactor_a2c_v(
nhalo: int,
tile_partitioner: TilePartitioner,
rank: int,
radius: float,
radius: Float,
np,
):
"""
Expand Down Expand Up @@ -888,7 +889,7 @@ def unit_vector_lonlat(grid, np):
return unit_lon, unit_lat


def _fill_halo_corners(field, value: float, nhalo: int, tile_partitioner, rank):
def _fill_halo_corners(field, value: Float, nhalo: int, tile_partitioner, rank):
"""
Fills a tile halo corners (ghost cells) of a field
with a set value along the first 2 axes
Expand All @@ -905,7 +906,7 @@ def _fill_halo_corners(field, value: float, nhalo: int, tile_partitioner, rank):
field[-nhalo:, -nhalo:] = value # NE corner


def _fill_single_halo_corner(field, value: float, nhalo: int, corner: str):
def _fill_single_halo_corner(field, value: Float, nhalo: int, corner: str):
"""
Fills a tile halo corner (ghost cells) of a field
with a set value along the first 2 axes
Expand Down
5 changes: 3 additions & 2 deletions ndsl/grid/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
split_cartesian_into_storages = None
import ndsl.constants as constants
from ndsl.constants import Z_DIM, Z_INTERFACE_DIM
from ndsl.dsl.typing import Float
from ndsl.filesystem import get_fs
from ndsl.grid.generation import MetricTerms
from ndsl.initialization.allocator import QuantityFactory
Expand Down Expand Up @@ -226,13 +227,13 @@ def dp(self) -> Quantity:
return self._dp_ref

@property
def ptop(self) -> float:
def ptop(self) -> Float:
"""
top of atmosphere pressure (Pa)
"""
if self.bk.view[0] != 0:
raise ValueError("ptop is not well-defined when top-of-atmosphere bk != 0")
return float(self.ak.view[0])
return Float(self.ak.view[0])


@dataclasses.dataclass(frozen=True)
Expand Down
11 changes: 6 additions & 5 deletions ndsl/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ndsl.constants import SPATIAL_DIMS
from ndsl.dsl.typing import Float
from ndsl.initialization.sizer import GridSizer
from ndsl.optional_imports import gt4py
from ndsl.quantity import Quantity, QuantityHaloSpec
Expand Down Expand Up @@ -60,7 +61,7 @@ def empty(
self,
dims: Sequence[str],
units: str,
dtype: type = np.float64,
dtype: type = Float,
allow_mismatch_float_precision: bool = False,
):
return self._allocate(
Expand All @@ -71,7 +72,7 @@ def zeros(
self,
dims: Sequence[str],
units: str,
dtype: type = np.float64,
dtype: type = Float,
allow_mismatch_float_precision: bool = False,
):
return self._allocate(
Expand All @@ -82,7 +83,7 @@ def ones(
self,
dims: Sequence[str],
units: str,
dtype: type = np.float64,
dtype: type = Float,
allow_mismatch_float_precision: bool = False,
):
return self._allocate(
Expand Down Expand Up @@ -116,7 +117,7 @@ def _allocate(
allocator: Callable,
dims: Sequence[str],
units: str,
dtype: type = np.float64,
dtype: type = Float,
allow_mismatch_float_precision: bool = False,
):
origin = self.sizer.get_origin(dims)
Expand Down Expand Up @@ -150,7 +151,7 @@ def get_quantity_halo_spec(
self,
dims: Sequence[str],
n_halo: Optional[int] = None,
dtype: type = np.float64,
dtype: type = Float,
) -> QuantityHaloSpec:
"""Build memory specifications for the halo update.

Expand Down
7 changes: 3 additions & 4 deletions ndsl/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ndsl.constants as constants
from ndsl.comm._boundary_utils import bound_default_slice, shift_boundary_slice_tuple
from ndsl.dsl.typing import Float
from ndsl.optional_imports import cupy, dace, gt4py
from ndsl.optional_imports import xarray as xr
from ndsl.types import NumpyModule
Expand Down Expand Up @@ -260,7 +261,8 @@ def _validate_quantity_property_lengths(shape, dims, origin, extent):
def _is_float(dtype):
"""Expected floating point type for Pace"""
return (
dtype == float
dtype == Float
or dtype == float
or dtype == np.float32
or dtype == np.float64
or dtype == np.float16
Expand Down Expand Up @@ -296,9 +298,6 @@ def __init__(
storage attribute is disabled and will raise an exception. Will raise
a TypeError if this is given with a gt4py storage type as data
"""
# ToDo: [Florian 01/23] Kill the abomination.
# See https://github.com/NOAA-GFDL/pace/issues/3
from ndsl.dsl.typing import Float

if (
not allow_mismatch_float_precision
Expand Down
Loading