Skip to content

Commit 0701609

Browse files
authored
Merge pull request #36 from csdms/mcflugen/add-annotations
Add type annotations
2 parents 040d072 + 0401f3f commit 0701609

File tree

7 files changed

+142
-95
lines changed

7 files changed

+142
-95
lines changed

CHANGES.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog for bmi-example-python
44
2.1.3 (unreleased)
55
------------------
66

7-
- Nothing changed yet.
7+
- Added type annotation for the heat package (#36)
88

99

1010
2.1.2 (2024-01-05)

heat/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._version import __version__
44
from .bmi_heat import BmiHeat
5-
from .heat import Heat, solve_2d
5+
from .heat import Heat
6+
from .heat import solve_2d
67

78
__all__ = ["__version__", "BmiHeat", "solve_2d", "Heat"]

heat/bmi_heat.py

+68-52
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#! /usr/bin/env python
22
"""Basic Model Interface implementation for the 2D heat model."""
33

4+
from typing import Any
5+
46
import numpy as np
57
from bmipy import Bmi
8+
from numpy.typing import NDArray
69

710
from .heat import Heat
811

@@ -14,20 +17,21 @@ class BmiHeat(Bmi):
1417
_input_var_names = ("plate_surface__temperature",)
1518
_output_var_names = ("plate_surface__temperature",)
1619

17-
def __init__(self):
20+
def __init__(self) -> None:
1821
"""Create a BmiHeat model that is ready for initialization."""
19-
self._model = None
20-
self._values = {}
21-
self._var_units = {}
22-
self._var_loc = {}
23-
self._grids = {}
24-
self._grid_type = {}
22+
# self._model: Heat | None = None
23+
self._model: Heat
24+
self._values: dict[str, NDArray[Any]] = {}
25+
self._var_units: dict[str, str] = {}
26+
self._var_loc: dict[str, str] = {}
27+
self._grids: dict[int, list[str]] = {}
28+
self._grid_type: dict[int, str] = {}
2529

2630
self._start_time = 0.0
27-
self._end_time = np.finfo("d").max
31+
self._end_time = float(np.finfo("d").max)
2832
self._time_units = "s"
2933

30-
def initialize(self, filename=None):
34+
def initialize(self, filename: str | None = None) -> None:
3135
"""Initialize the Heat model.
3236
3337
Parameters
@@ -39,7 +43,7 @@ def initialize(self, filename=None):
3943
self._model = Heat()
4044
elif isinstance(filename, str):
4145
with open(filename) as file_obj:
42-
self._model = Heat.from_file_like(file_obj.read())
46+
self._model = Heat.from_file_like(file_obj)
4347
else:
4448
self._model = Heat.from_file_like(filename)
4549

@@ -49,11 +53,11 @@ def initialize(self, filename=None):
4953
self._grids = {0: ["plate_surface__temperature"]}
5054
self._grid_type = {0: "uniform_rectilinear"}
5155

52-
def update(self):
56+
def update(self) -> None:
5357
"""Advance model by one time step."""
5458
self._model.advance_in_time()
5559

56-
def update_frac(self, time_frac):
60+
def update_frac(self, time_frac: float) -> None:
5761
"""Update model by a fraction of a time step.
5862
5963
Parameters
@@ -66,7 +70,7 @@ def update_frac(self, time_frac):
6670
self.update()
6771
self._model.time_step = time_step
6872

69-
def update_until(self, then):
73+
def update_until(self, then: float) -> None:
7074
"""Update model until a particular time.
7175
7276
Parameters
@@ -80,11 +84,12 @@ def update_until(self, then):
8084
self.update()
8185
self.update_frac(n_steps - int(n_steps))
8286

83-
def finalize(self):
87+
def finalize(self) -> None:
8488
"""Finalize model."""
85-
self._model = None
89+
del self._model
90+
# self._model = None
8691

87-
def get_var_type(self, var_name):
92+
def get_var_type(self, var_name: str) -> str:
8893
"""Data type of variable.
8994
9095
Parameters
@@ -99,7 +104,7 @@ def get_var_type(self, var_name):
99104
"""
100105
return str(self.get_value_ptr(var_name).dtype)
101106

102-
def get_var_units(self, var_name):
107+
def get_var_units(self, var_name: str) -> str:
103108
"""Get units of variable.
104109
105110
Parameters
@@ -114,7 +119,7 @@ def get_var_units(self, var_name):
114119
"""
115120
return self._var_units[var_name]
116121

117-
def get_var_nbytes(self, var_name):
122+
def get_var_nbytes(self, var_name: str) -> int:
118123
"""Get units of variable.
119124
120125
Parameters
@@ -129,13 +134,13 @@ def get_var_nbytes(self, var_name):
129134
"""
130135
return self.get_value_ptr(var_name).nbytes
131136

132-
def get_var_itemsize(self, name):
137+
def get_var_itemsize(self, name: str) -> int:
133138
return np.dtype(self.get_var_type(name)).itemsize
134139

135-
def get_var_location(self, name):
140+
def get_var_location(self, name: str) -> str:
136141
return self._var_loc[name]
137142

138-
def get_var_grid(self, var_name):
143+
def get_var_grid(self, var_name: str) -> int | None:
139144
"""Grid id for a variable.
140145
141146
Parameters
@@ -151,8 +156,9 @@ def get_var_grid(self, var_name):
151156
for grid_id, var_name_list in self._grids.items():
152157
if var_name in var_name_list:
153158
return grid_id
159+
return None
154160

155-
def get_grid_rank(self, grid_id):
161+
def get_grid_rank(self, grid_id: int) -> int:
156162
"""Rank of grid.
157163
158164
Parameters
@@ -167,7 +173,7 @@ def get_grid_rank(self, grid_id):
167173
"""
168174
return len(self._model.shape)
169175

170-
def get_grid_size(self, grid_id):
176+
def get_grid_size(self, grid_id: int) -> int:
171177
"""Size of grid.
172178
173179
Parameters
@@ -182,7 +188,7 @@ def get_grid_size(self, grid_id):
182188
"""
183189
return int(np.prod(self._model.shape))
184190

185-
def get_value_ptr(self, var_name):
191+
def get_value_ptr(self, var_name: str) -> NDArray[Any]:
186192
"""Reference to values.
187193
188194
Parameters
@@ -197,7 +203,7 @@ def get_value_ptr(self, var_name):
197203
"""
198204
return self._values[var_name]
199205

200-
def get_value(self, var_name, dest):
206+
def get_value(self, var_name: str, dest: NDArray[Any]) -> NDArray[Any]:
201207
"""Copy of values.
202208
203209
Parameters
@@ -215,7 +221,9 @@ def get_value(self, var_name, dest):
215221
dest[:] = self.get_value_ptr(var_name).flatten()
216222
return dest
217223

218-
def get_value_at_indices(self, var_name, dest, indices):
224+
def get_value_at_indices(
225+
self, var_name: str, dest: NDArray[Any], indices: NDArray[np.int_]
226+
) -> NDArray[Any]:
219227
"""Get values at particular indices.
220228
221229
Parameters
@@ -235,7 +243,7 @@ def get_value_at_indices(self, var_name, dest, indices):
235243
dest[:] = self.get_value_ptr(var_name).take(indices)
236244
return dest
237245

238-
def set_value(self, var_name, src):
246+
def set_value(self, var_name: str, src: NDArray[Any]) -> None:
239247
"""Set model values.
240248
241249
Parameters
@@ -248,7 +256,9 @@ def set_value(self, var_name, src):
248256
val = self.get_value_ptr(var_name)
249257
val[:] = src.reshape(val.shape)
250258

251-
def set_value_at_indices(self, name, inds, src):
259+
def set_value_at_indices(
260+
self, name: str, inds: NDArray[np.int_], src: NDArray[Any]
261+
) -> None:
252262
"""Set model values at particular indices.
253263
254264
Parameters
@@ -263,76 +273,80 @@ def set_value_at_indices(self, name, inds, src):
263273
val = self.get_value_ptr(name)
264274
val.flat[inds] = src
265275

266-
def get_component_name(self):
276+
def get_component_name(self) -> str:
267277
"""Name of the component."""
268278
return self._name
269279

270-
def get_input_item_count(self):
280+
def get_input_item_count(self) -> int:
271281
"""Get names of input variables."""
272282
return len(self._input_var_names)
273283

274-
def get_output_item_count(self):
284+
def get_output_item_count(self) -> int:
275285
"""Get names of output variables."""
276286
return len(self._output_var_names)
277287

278-
def get_input_var_names(self):
288+
def get_input_var_names(self) -> tuple[str, ...]:
279289
"""Get names of input variables."""
280290
return self._input_var_names
281291

282-
def get_output_var_names(self):
292+
def get_output_var_names(self) -> tuple[str, ...]:
283293
"""Get names of output variables."""
284294
return self._output_var_names
285295

286-
def get_grid_shape(self, grid_id, shape):
296+
def get_grid_shape(self, grid_id: int, shape: NDArray[np.int_]) -> NDArray[np.int_]:
287297
"""Number of rows and columns of uniform rectilinear grid."""
288298
var_name = self._grids[grid_id][0]
289299
shape[:] = self.get_value_ptr(var_name).shape
290300
return shape
291301

292-
def get_grid_spacing(self, grid_id, spacing):
302+
def get_grid_spacing(
303+
self, grid_id: int, spacing: NDArray[np.float64]
304+
) -> NDArray[np.float64]:
293305
"""Spacing of rows and columns of uniform rectilinear grid."""
294306
spacing[:] = self._model.spacing
295307
return spacing
296308

297-
def get_grid_origin(self, grid_id, origin):
309+
def get_grid_origin(
310+
self, grid_id: int, origin: NDArray[np.float64]
311+
) -> NDArray[np.float64]:
298312
"""Origin of uniform rectilinear grid."""
299313
origin[:] = self._model.origin
300314
return origin
301315

302-
def get_grid_type(self, grid_id):
316+
def get_grid_type(self, grid_id: int) -> str:
303317
"""Type of grid."""
304318
return self._grid_type[grid_id]
305319

306-
def get_start_time(self):
320+
def get_start_time(self) -> float:
307321
"""Start time of model."""
308322
return self._start_time
309323

310-
def get_end_time(self):
324+
def get_end_time(self) -> float:
311325
"""End time of model."""
312326
return self._end_time
313327

314-
def get_current_time(self):
328+
def get_current_time(self) -> float:
315329
return self._model.time
316330

317-
def get_time_step(self):
331+
def get_time_step(self) -> float:
318332
return self._model.time_step
319333

320-
def get_time_units(self):
334+
def get_time_units(self) -> str:
321335
return self._time_units
322336

323-
def get_grid_edge_count(self, grid):
337+
def get_grid_edge_count(self, grid: int) -> int:
324338
raise NotImplementedError("get_grid_edge_count")
325339

326-
def get_grid_edge_nodes(self, grid, edge_nodes):
340+
def get_grid_edge_nodes(self, grid: int, edge_nodes: NDArray[np.int_]) -> None:
327341
raise NotImplementedError("get_grid_edge_nodes")
328342

329-
def get_grid_face_count(self, grid):
343+
def get_grid_face_count(self, grid: int) -> None:
330344
raise NotImplementedError("get_grid_face_count")
331345

332-
def get_grid_face_nodes(self, grid, face_nodes):
346+
def get_grid_face_nodes(self, grid: int, face_nodes: NDArray[np.int_]) -> None:
333347
raise NotImplementedError("get_grid_face_nodes")
334348

335-
def get_grid_node_count(self, grid):
349+
def get_grid_node_count(self, grid: int) -> int:
336350
"""Number of grid nodes.
337351
338352
Parameters
@@ -347,17 +361,19 @@ def get_grid_node_count(self, grid):
347361
"""
348362
return self.get_grid_size(grid)
349363

350-
def get_grid_nodes_per_face(self, grid, nodes_per_face):
364+
def get_grid_nodes_per_face(
365+
self, grid: int, nodes_per_face: NDArray[np.int_]
366+
) -> None:
351367
raise NotImplementedError("get_grid_nodes_per_face")
352368

353-
def get_grid_face_edges(self, grid, face_edges):
369+
def get_grid_face_edges(self, grid: int, face_edges: NDArray[np.int_]) -> None:
354370
raise NotImplementedError("get_grid_face_edges")
355371

356-
def get_grid_x(self, grid, x):
372+
def get_grid_x(self, grid: int, x: NDArray[np.float64]) -> None:
357373
raise NotImplementedError("get_grid_x")
358374

359-
def get_grid_y(self, grid, y):
375+
def get_grid_y(self, grid: int, y: NDArray[np.float64]) -> None:
360376
raise NotImplementedError("get_grid_y")
361377

362-
def get_grid_z(self, grid, z):
378+
def get_grid_z(self, grid: int, z: NDArray[np.float64]) -> None:
363379
raise NotImplementedError("get_grid_z")

0 commit comments

Comments
 (0)