Skip to content

Commit

Permalink
fix dict, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
oelbert committed Sep 5, 2024
1 parent d44551e commit ed3d431
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 5 additions & 1 deletion ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def make_storage_data(
data, shape, start, dummy, axis, read_only, backend=backend
)
elif n_dims == 4:

data = _make_storage_data_4d(data, shape, start, backend=backend)
else:
data = _make_storage_data_3d(data, shape, start, backend=backend)
Expand Down Expand Up @@ -262,6 +262,7 @@ def _make_storage_data_3d(
] = asarray(data, type(buffer))
return buffer


def _make_storage_data_4d(
data: Field,
shape: Tuple[int, ...],
Expand All @@ -280,6 +281,7 @@ def _make_storage_data_4d(
] = asarray(data, type(buffer))
return buffer


def make_storage_from_shape(
shape: Tuple[int, ...],
origin: Tuple[int, ...] = origin,
Expand Down Expand Up @@ -333,6 +335,7 @@ def make_storage_dict(
axis: int = 2,
*,
backend: str,
dtype: DTypes = Float,
) -> Dict[str, "Field"]:
assert names is not None, "for 4d variable storages, specify a list of names"
if shape is None:
Expand All @@ -347,6 +350,7 @@ def make_storage_dict(
dummy=dummy,
axis=axis,
backend=backend,
dtype=dtype,
)
return data_dict

Expand Down
6 changes: 3 additions & 3 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def make_storage_data(
elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1:
use_shape[1] = 1
start = (int(istart), int(jstart), int(kstart))
if 'float' in str(array.dtype):
if "float" in str(array.dtype):
dtype = Float
elif 'int' in str(array.dtype):
elif "int" in str(array.dtype):
dtype = Int
else:
dtype = array.dtype
Expand All @@ -133,7 +133,7 @@ def make_storage_data(
)
else:
if len(array.shape) == 4:
start = (int(istart), int(jstart), int(kstart), 0)
start = (int(istart), int(jstart), int(kstart), 0) # type: ignore
use_shape.append(array.shape[-1])
return utils.make_storage_data(
array,
Expand Down

0 comments on commit ed3d431

Please sign in to comment.