diff --git a/ndsl/dsl/gt4py_utils.py b/ndsl/dsl/gt4py_utils.py index d263c5d2..227eb57d 100644 --- a/ndsl/dsl/gt4py_utils.py +++ b/ndsl/dsl/gt4py_utils.py @@ -158,7 +158,7 @@ def make_storage_data( data, shape, start, dummy, axis, read_only, backend=backend ) elif n_dims == 4: - + data = _make_storage_data_4d(data, shape, start, backend=backend) else: data = _make_storage_data_3d(data, shape, start, backend=backend) @@ -262,6 +262,7 @@ def _make_storage_data_3d( ] = asarray(data, type(buffer)) return buffer + def _make_storage_data_4d( data: Field, shape: Tuple[int, ...], @@ -280,6 +281,7 @@ def _make_storage_data_4d( ] = asarray(data, type(buffer)) return buffer + def make_storage_from_shape( shape: Tuple[int, ...], origin: Tuple[int, ...] = origin, @@ -333,6 +335,7 @@ def make_storage_dict( axis: int = 2, *, backend: str, + dtype: DTypes = Float, ) -> Dict[str, "Field"]: assert names is not None, "for 4d variable storages, specify a list of names" if shape is None: @@ -347,6 +350,7 @@ def make_storage_dict( dummy=dummy, axis=axis, backend=backend, + dtype=dtype, ) return data_dict diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 09143083..779d1339 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -113,9 +113,9 @@ def make_storage_data( elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1: use_shape[1] = 1 start = (int(istart), int(jstart), int(kstart)) - if 'float' in str(array.dtype): + if "float" in str(array.dtype): dtype = Float - elif 'int' in str(array.dtype): + elif "int" in str(array.dtype): dtype = Int else: dtype = array.dtype @@ -133,7 +133,7 @@ def make_storage_data( ) else: if len(array.shape) == 4: - start = (int(istart), int(jstart), int(kstart), 0) + start = (int(istart), int(jstart), int(kstart), 0) # type: ignore use_shape.append(array.shape[-1]) return utils.make_storage_data( array,