Skip to content

Commit 00698d5

Browse files
authored
Fix typing issue for numpy 2.3 and python 3.11 (#70)
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent d2bc1f1 commit 00698d5

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

src/power_grid_model_ds/_core/model/arrays/base/_build.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ def build_array(*args: tuple[Any], dtype: np.dtype, defaults: dict[str, np.gener
2525
return array
2626

2727
if isinstance(parsed_input, np.ndarray) and parsed_input.dtype.names:
28-
_check_missing_columns(array.dtype.names, defaults, set(parsed_input.dtype.names))
28+
_check_missing_columns(array.dtype.names or (), defaults, set(parsed_input.dtype.names))
2929
return _parse_structured_array(parsed_input, array)
3030
if isinstance(parsed_input, np.ndarray):
3131
# Note: defaults are not supported when working with unstructured arrays
3232
return _parse_array(parsed_input, array.dtype)
3333

34-
_check_missing_columns(array.dtype.names, defaults, set(parsed_input.keys()))
34+
_check_missing_columns(array.dtype.names or (), defaults, set(parsed_input.keys()))
3535
_fill_with_kwargs(array, parsed_input)
3636
return array
3737

@@ -54,7 +54,7 @@ def _parse_input(*args: Any, dtype: np.dtype, **kwargs):
5454
return {}, 0
5555

5656

57-
def _check_missing_columns(array_columns: tuple, defaults: dict[str, np.generic], provided_columns: set[str]):
57+
def _check_missing_columns(array_columns: tuple[str, ...], defaults: dict[str, np.generic], provided_columns: set[str]):
5858
required_columns = set(array_columns) - set(defaults.keys())
5959
if missing_columns := required_columns - provided_columns:
6060
raise ValueError(f"Missing required columns: {missing_columns}")
@@ -64,7 +64,8 @@ def _fill_defaults(array: np.ndarray, defaults: dict[str, np.generic]):
6464
"""Fills the defaults into the array."""
6565
for column, default in defaults.items():
6666
if default is empty:
67-
array[column] = empty(array.dtype[column]) # type: ignore[call-overload]
67+
column_type: type = array.dtype[column]
68+
array[column] = empty(column_type) # type: ignore[call-overload]
6869
else:
6970
array[column] = default # type: ignore[call-overload]
7071

@@ -87,8 +88,8 @@ def _parse_structured_array(from_array: np.ndarray, to_array: np.ndarray) -> np.
8788

8889
def _determine_column_overlap(from_array: np.ndarray, to_array: np.ndarray) -> tuple[list[str], list[str]]:
8990
"""Returns two lists: columns present in both arrays and the columns that are only present in from_array"""
90-
from_columns = set(from_array.dtype.names)
91-
to_columns = set(to_array.dtype.names)
91+
from_columns = set(from_array.dtype.names or ())
92+
to_columns = set(to_array.dtype.names or ())
9293

9394
return list(from_columns & to_columns), list(from_columns - to_columns)
9495

src/power_grid_model_ds/_core/model/arrays/base/_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_filter_mask(
2020
"""Returns a mask that matches the input parameters."""
2121
parsed_kwargs = _parse(args, kwargs)
2222

23-
if invalid_kwargs := set(parsed_kwargs.keys()) - set(array.dtype.names):
23+
if invalid_kwargs := set(parsed_kwargs.keys()) - set(array.dtype.names or ()):
2424
raise ValueError(f"Invalid kwargs: {invalid_kwargs}")
2525

2626
filter_mask = _initialize_filter_mask(mode_, array.size)

src/power_grid_model_ds/_core/model/arrays/base/_modify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def re_order(array: np.ndarray, new_order: ArrayLike, column: str = "id") -> np.
1212
"""Re-order an id-array by the id column so that it follows a new_order.
1313
Expects the new_order input to contain the same values as self.id
1414
"""
15-
if column not in array.dtype.names:
15+
if column not in (array.dtype.names or ()):
1616
raise ValueError(f"Cannot re-order array: column {column} does not exist.")
1717
if not np.array_equal(np.sort(array[column]), np.sort(new_order)):
1818
raise ValueError(f"Cannot re-order array: mismatch between new_order and values in '{column}'-column.")
@@ -50,7 +50,7 @@ def update_by_id(array: np.ndarray, ids: ArrayLike, allow_missing: bool, **kwarg
5050

5151
def check_ids(array: np.ndarray, return_duplicates: bool = False) -> NDArray | None:
5252
"""Check for duplicate ids within the array"""
53-
if "id" not in array.dtype.names:
53+
if "id" not in (array.dtype.names or ()):
5454
raise AttributeError("Array has no 'id' column.")
5555

5656
unique, counts = np.unique(array["id"], return_counts=True)

0 commit comments

Comments
 (0)