Skip to content

Commit d2bc1f1

Browse files
authored
Fix: prevent type change when using tmp_remove_nodes (#60)
* Fix: prevent type change in graph.external_ids when using tmp_remove_nodes_array Signed-off-by: Thijs Baaijen <[email protected]> * provide correct type to tmp_remove_nodes Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]>
1 parent df324f4 commit d2bc1f1

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
208208
yield
209209

210210
for node in nodes:
211-
self.add_node(node)
211+
self.add_node(int(node)) # convert to int to avoid type issues when input is e.g. a numpy array
212212
for source, target in edge_list:
213213
self.add_branch(source, target)
214214

src/power_grid_model_ds/_core/model/grids/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def set_feeder_ids(grid: "Grid"):
4646
"""
4747
_reset_feeder_ids(grid)
4848
feeder_node_ids = grid.node.filter(node_type=NodeType.SUBSTATION_NODE)["id"]
49-
with grid.graphs.active_graph.tmp_remove_nodes(feeder_node_ids):
49+
with grid.graphs.active_graph.tmp_remove_nodes(feeder_node_ids.tolist()):
5050
components = grid.graphs.active_graph.get_components()
5151
for component_node_ids in components:
5252
component_branches = _get_active_component_branches(grid, component_node_ids)

tests/unit/model/graphs/test_graph_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ def test_tmp_remove_nodes(graph_with_2_routes: BaseGraphModel) -> None:
171171
assert counter_before == counter_after
172172

173173

174+
def test_tmp_remove_nodes_array_input(graph_with_2_routes: BaseGraphModel) -> None:
175+
with graph_with_2_routes.tmp_remove_nodes(np.array([1, 2])): # type: ignore[arg-type]
176+
pass
177+
178+
# check that the external ids are still all integers instead of e.g. np.int
179+
assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids])
180+
181+
174182
def test_get_components(graph_with_2_routes: BaseGraphModel):
175183
graph = graph_with_2_routes
176184
graph.add_node(99)

0 commit comments

Comments
 (0)