Skip to content

Commit e4c2cf1

Browse files
authored
Filter elements by country (#36)
Signed-off-by: Geoffroy Jamgotchian <[email protected]>
1 parent 13c4707 commit e4c2cf1

File tree

12 files changed

+183
-51
lines changed

12 files changed

+183
-51
lines changed

Diff for: cpp/src/bindings.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ PYBIND11_MODULE(_gridpy, m) {
4545
.value("LINE", element_type::LINE)
4646
.value("TWO_WINDINGS_TRANSFORMER", element_type::TWO_WINDINGS_TRANSFORMER)
4747
.value("GENERATOR", element_type::GENERATOR)
48+
.value("LOAD", element_type::LOAD)
4849
.export_values();
4950

5051
m.def("get_network_elements_ids", &gridpy::getNetworkElementsIds, "Get network elements ids for a given element type",
51-
py::arg("network"), py::arg("element_type"), py::arg("nominal_voltage"),
52-
py::arg("main_connected_component"));
52+
py::arg("network"), py::arg("element_type"), py::arg("nominal_voltages"),
53+
py::arg("countries"), py::arg("main_connected_component"));
5354

5455
m.def("load_network", &gridpy::loadNetwork, "Load a network from a file");
5556

Diff for: cpp/src/gridpy.cpp

+58-30
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,57 @@ Array<limit_violation>::~Array() {
8080
// already freed by contingency_result
8181
}
8282

83+
template<typename T>
84+
class ToPtr {
85+
public:
86+
~ToPtr() {
87+
delete[] ptr_;
88+
}
89+
90+
T* get() const {
91+
return ptr_;
92+
}
93+
94+
protected:
95+
explicit ToPtr(size_t size)
96+
: ptr_(new T[size])
97+
{}
98+
99+
T* ptr_;
100+
};
101+
102+
class ToCharPtrPtr : public ToPtr<char*> {
103+
public:
104+
explicit ToCharPtrPtr(const std::vector<std::string>& strings)
105+
: ToPtr<char*>(strings.size())
106+
{
107+
for (int i = 0; i < strings.size(); i++) {
108+
ptr_[i] = (char *) strings[i].data();
109+
}
110+
}
111+
};
112+
113+
class ToDoublePtr : public ToPtr<double> {
114+
public:
115+
explicit ToDoublePtr(const std::vector<double>& doubles)
116+
: ToPtr<double>(doubles.size())
117+
{
118+
for (int i = 0; i < doubles.size(); i++) {
119+
ptr_[i] = doubles[i];
120+
}
121+
}
122+
};
123+
124+
std::vector<std::string> fromCharPtrPtr(array* arrayPtr) {
125+
std::vector<std::string> strings;
126+
strings.reserve(arrayPtr->length);
127+
for (int i = 0; i < arrayPtr->length; i++) {
128+
std::string str = *((char**) arrayPtr->ptr + i);
129+
strings.emplace_back(str);
130+
}
131+
return strings;
132+
}
133+
83134
void setDebugMode(bool debug) {
84135
GraalVmGuard guard;
85136
setDebugMode(guard.thread(), debug);
@@ -125,15 +176,14 @@ bool updateConnectableStatus(void* network, const std::string& id, bool connecte
125176
return updateConnectableStatus(guard.thread(), network, (char*) id.data(), connected);
126177
}
127178

128-
std::vector<std::string> getNetworkElementsIds(void* network, element_type elementType, double nominalVoltage, bool mainCc) {
179+
std::vector<std::string> getNetworkElementsIds(void* network, element_type elementType, const std::vector<double>& nominalVoltages,
180+
const std::vector<std::string>& countries, bool mainCc) {
129181
GraalVmGuard guard;
130-
array* elementsIdsArrayPtr = getNetworkElementsIds(guard.thread(), network, elementType, nominalVoltage, mainCc);
131-
std::vector<std::string> elementsIds;
132-
elementsIds.reserve(elementsIdsArrayPtr->length);
133-
for (int i = 0; i < elementsIdsArrayPtr->length; i++) {
134-
std::string elementId = *((char**) elementsIdsArrayPtr->ptr + i);
135-
elementsIds.emplace_back(elementId);
136-
}
182+
ToDoublePtr nominalVoltagePtr(nominalVoltages);
183+
ToCharPtrPtr countryPtr(countries);
184+
array* elementsIdsArrayPtr = getNetworkElementsIds(guard.thread(), network, elementType, nominalVoltagePtr.get(), nominalVoltages.size(),
185+
countryPtr.get(), countries.size(), mainCc);
186+
std::vector<std::string> elementsIds = fromCharPtrPtr(elementsIdsArrayPtr);
137187
freeNetworkElementsIds(guard.thread(), elementsIdsArrayPtr);
138188
return elementsIds;
139189
}
@@ -168,28 +218,6 @@ void* createSecurityAnalysis() {
168218
return createSecurityAnalysis(guard.thread());
169219
}
170220

171-
class ToCharPtrPtr {
172-
public:
173-
explicit ToCharPtrPtr(const std::vector<std::string>& strings)
174-
: charPtrPtr_(new char*[strings.size()])
175-
{
176-
for (int i = 0; i < strings.size(); i++) {
177-
charPtrPtr_[i] = (char *) strings[i].data();
178-
}
179-
}
180-
181-
~ToCharPtrPtr() {
182-
delete[] charPtrPtr_;
183-
}
184-
185-
char** get() const {
186-
return charPtrPtr_;
187-
}
188-
189-
private:
190-
char** charPtrPtr_;
191-
};
192-
193221
void addContingency(void* analysisContext, const std::string& contingencyId, const std::vector<std::string>& elementsIds) {
194222
GraalVmGuard guard;
195223
ToCharPtrPtr elementIdPtr(elementsIds);

Diff for: cpp/src/gridpy.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ bool updateSwitchPosition(void* network, const std::string& id, bool open);
8787

8888
bool updateConnectableStatus(void* network, const std::string& id, bool connected);
8989

90-
std::vector<std::string> getNetworkElementsIds(void* network, element_type elementType, double nominalVoltage, bool mainCc);
90+
std::vector<std::string> getNetworkElementsIds(void* network, element_type elementType, const std::vector<double>& nominalVoltages,
91+
const std::vector<std::string>& countries, bool mainCc);
9192

9293
void* loadNetwork(const std::string& file);
9394

Diff for: gridpy/network.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from _gridpy import ElementType
1212
from gridpy.util import ObjectHandle
1313
from typing import List
14+
from typing import Set
1415

1516

1617
Bus.__repr__ = lambda self: f"{self.__class__.__name__}("\
@@ -73,8 +74,10 @@ def dump(self, file: str, format: str = 'XIIDM'):
7374
def write_single_line_diagram_svg(self, container_id: str, svg_file: str):
7475
_gridpy.write_single_line_diagram_svg(self.ptr, container_id, svg_file)
7576

76-
def get_elements_ids(self, element_type: _gridpy.ElementType, nominal_voltage: float = float('NaN'), main_connected_component: bool = True) -> List[str]:
77-
return _gridpy.get_network_elements_ids(self.ptr, element_type, nominal_voltage, main_connected_component)
77+
def get_elements_ids(self, element_type: _gridpy.ElementType, nominal_voltages: Set[float] = None, countries: Set[str] = None,
78+
main_connected_component: bool = True) -> List[str]:
79+
return _gridpy.get_network_elements_ids(self.ptr, element_type, [] if nominal_voltages is None else list(nominal_voltages),
80+
[] if countries is None else list(countries), main_connected_component)
7881

7982

8083
def create_empty(id: str = "Default") -> Network:

Diff for: java/pom.xml

+9
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
<graalvm.version>20.3.0</graalvm.version>
2828
<gridsuite-dependencies.version>4</gridsuite-dependencies.version>
2929
<janino.version>3.1.0</janino.version>
30+
<junit-jupiter.version>5.5.2</junit-jupiter.version>
3031
<logback.version>1.2.3</logback.version>
3132
<mapdb.version>3.0.8</mapdb.version>
3233
<maven-shade-plugin.version>3.2.4</maven-shade-plugin.version>
@@ -204,6 +205,14 @@
204205
<artifactId>powsybl-open-loadflow</artifactId>
205206
<version>${open-load-flow.version}</version>
206207
</dependency>
208+
209+
<!-- test -->
210+
<dependency>
211+
<groupId>org.junit.jupiter</groupId>
212+
<artifactId>junit-jupiter-engine</artifactId>
213+
<version>${junit-jupiter.version}</version>
214+
<scope>test</scope>
215+
</dependency>
207216
</dependencies>
208217

209218
</project>

Diff for: java/src/main/java/org/gridsuite/gridpy/CTypeUtil.java

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.graalvm.nativeimage.c.type.CCharPointer;
1010
import org.graalvm.nativeimage.c.type.CCharPointerPointer;
11+
import org.graalvm.nativeimage.c.type.CDoublePointer;
1112
import org.graalvm.nativeimage.c.type.CTypeConversion;
1213

1314
import java.util.ArrayList;
@@ -30,4 +31,13 @@ static List<String> createStringList(CCharPointerPointer charPtrPtr, int length)
3031
}
3132
return stringList;
3233
}
34+
35+
static List<Double> createDoubleList(CDoublePointer doublePtr, int length) {
36+
List<Double> doubleList = new ArrayList<>(length);
37+
for (int i = 0; i < length; i++) {
38+
double d = doublePtr.read(i);
39+
doubleList.add(d);
40+
}
41+
return doubleList;
42+
}
3343
}

Diff for: java/src/main/java/org/gridsuite/gridpy/GridPyApi.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@
3939
import org.slf4j.LoggerFactory;
4040

4141
import java.nio.file.Paths;
42-
import java.util.Collection;
43-
import java.util.List;
44-
import java.util.Objects;
42+
import java.util.*;
4543
import java.util.stream.Collectors;
4644

4745
import static org.gridsuite.gridpy.GridPyApiHeader.*;
@@ -270,9 +268,12 @@ public static boolean updateConnectableStatus(IsolateThread thread, ObjectHandle
270268

271269
@CEntryPoint(name = "getNetworkElementsIds")
272270
public static ArrayPointer<CCharPointerPointer> getNetworkElementsIds(IsolateThread thread, ObjectHandle networkHandle, ElementType elementType,
273-
double nominalVoltage, boolean mainCc) {
271+
CDoublePointer nominalVoltagePtr, int nominalVoltageCount,
272+
CCharPointerPointer countryPtr, int countryCount, boolean mainCc) {
274273
Network network = ObjectHandles.getGlobal().get(networkHandle);
275-
List<String> elementsIds = NetworkUtil.getElementsIds(network, elementType, nominalVoltage, mainCc);
274+
Set<Double> nominalVoltages = new HashSet<>(CTypeUtil.createDoubleList(nominalVoltagePtr, nominalVoltageCount));
275+
Set<String> countries = new HashSet<>(CTypeUtil.createStringList(countryPtr, countryCount));
276+
List<String> elementsIds = NetworkUtil.getElementsIds(network, elementType, nominalVoltages, countries, mainCc);
276277
CCharPointerPointer elementsIdsPtr = UnmanagedMemory.calloc(elementsIds.size() * SizeOf.get(CCharPointerPointer.class));
277278
for (int i = 0; i < elementsIds.size(); i++) {
278279
elementsIdsPtr.addressOf(i).write(CTypeConversion.toCString(elementsIds.get(i)).get());

Diff for: java/src/main/java/org/gridsuite/gridpy/GridPyApiHeader.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ interface ContingencyResultPointer extends PointerBase {
338338
enum ElementType {
339339
LINE,
340340
TWO_WINDINGS_TRANSFORMER,
341-
GENERATOR;
341+
GENERATOR,
342+
LOAD;
342343

343344
@CEnumValue
344345
public native int getCValue();

Diff for: java/src/main/java/org/gridsuite/gridpy/NetworkUtil.java

+51-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.gridsuite.gridpy.GridPyApiHeader.ElementType;
1212

1313
import java.util.List;
14+
import java.util.Set;
1415
import java.util.stream.Collectors;
1516

1617
/**
@@ -71,31 +72,72 @@ private static boolean isInMainCc(Terminal t) {
7172
return bus != null && bus.getConnectedComponent().getNum() == ComponentConstants.MAIN_NUM;
7273
}
7374

74-
static List<String> getElementsIds(Network network, ElementType elementType, double nominalVoltage, boolean mainCc) {
75+
private static boolean filter(Branch branch, Set<Double> nominalVoltages, Set<String> countries, boolean mainCc) {
76+
Terminal terminal1 = branch.getTerminal1();
77+
Terminal terminal2 = branch.getTerminal2();
78+
VoltageLevel voltageLevel1 = terminal1.getVoltageLevel();
79+
VoltageLevel voltageLevel2 = terminal2.getVoltageLevel();
80+
if (!(nominalVoltages.isEmpty()
81+
|| nominalVoltages.contains(voltageLevel1.getNominalV())
82+
|| nominalVoltages.contains(voltageLevel2.getNominalV()))) {
83+
return false;
84+
}
85+
if (!(countries.isEmpty()
86+
|| countries.contains(voltageLevel1.getSubstation().getCountry().map(Country::name).orElse(null))
87+
|| countries.contains(voltageLevel2.getSubstation().getCountry().map(Country::name).orElse(null)))) {
88+
return false;
89+
}
90+
if (mainCc && !(isInMainCc(terminal1) && isInMainCc(terminal2))) {
91+
return false;
92+
}
93+
return true;
94+
}
95+
96+
private static boolean filter(Injection injection, Set<Double> nominalVoltages, Set<String> countries, boolean mainCc) {
97+
Terminal terminal = injection.getTerminal();
98+
VoltageLevel voltageLevel = terminal.getVoltageLevel();
99+
if (!(nominalVoltages.isEmpty()
100+
|| nominalVoltages.contains(voltageLevel.getNominalV()))) {
101+
return false;
102+
}
103+
if (!(countries.isEmpty()
104+
|| countries.contains(voltageLevel.getSubstation().getCountry().map(Country::name).orElse(null)))) {
105+
return false;
106+
}
107+
if (mainCc && !isInMainCc(terminal)) {
108+
return false;
109+
}
110+
return true;
111+
}
112+
113+
static List<String> getElementsIds(Network network, ElementType elementType, Set<Double> nominalVoltages,
114+
Set<String> countries, boolean mainCc) {
75115
List<String> elementsIds;
76116
switch (elementType) {
77117
case LINE:
78118
elementsIds = network.getLineStream()
79-
.filter(l -> Double.isNaN(nominalVoltage) || l.getTerminal1().getVoltageLevel().getNominalV() == nominalVoltage)
80-
.filter(l -> !mainCc || (isInMainCc(l.getTerminal1()) && isInMainCc(l.getTerminal2())))
119+
.filter(l -> filter(l, nominalVoltages, countries, mainCc))
81120
.map(Identifiable::getId)
82121
.collect(Collectors.toList());
83122
break;
84123

85124
case TWO_WINDINGS_TRANSFORMER:
86125
elementsIds = network.getTwoWindingsTransformerStream()
87-
.filter(twt -> Double.isNaN(nominalVoltage)
88-
|| twt.getTerminal1().getVoltageLevel().getNominalV() == nominalVoltage
89-
|| twt.getTerminal2().getVoltageLevel().getNominalV() == nominalVoltage)
90-
.filter(l -> !mainCc || (isInMainCc(l.getTerminal1()) && isInMainCc(l.getTerminal2())))
126+
.filter(twt -> filter(twt, nominalVoltages, countries, mainCc))
91127
.map(Identifiable::getId)
92128
.collect(Collectors.toList());
93129
break;
94130

95131
case GENERATOR:
96132
elementsIds = network.getGeneratorStream()
97-
.filter(g -> Double.isNaN(nominalVoltage) || g.getTerminal().getVoltageLevel().getNominalV() == nominalVoltage)
98-
.filter(g -> !mainCc || isInMainCc(g.getTerminal()))
133+
.filter(g -> filter(g, nominalVoltages, countries, mainCc))
134+
.map(Identifiable::getId)
135+
.collect(Collectors.toList());
136+
break;
137+
138+
case LOAD:
139+
elementsIds = network.getLoadStream()
140+
.filter(g -> filter(g, nominalVoltages, countries, mainCc))
99141
.map(Identifiable::getId)
100142
.collect(Collectors.toList());
101143
break;

Diff for: java/src/main/resources/gridpy-api.h

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ typedef enum {
7171
LINE = 0,
7272
TWO_WINDINGS_TRANSFORMER,
7373
GENERATOR,
74+
LOAD,
7475
} element_type;
7576

7677
typedef struct matrix_struct {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* Copyright (c) 2021, RTE (http://www.rte-france.com)
3+
* This Source Code Form is subject to the terms of the Mozilla Public
4+
* License, v. 2.0. If a copy of the MPL was not distributed with this
5+
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
*/
7+
package org.gridsuite.gridpy;
8+
9+
import com.powsybl.iidm.network.Network;
10+
import com.powsybl.iidm.network.test.EurostagTutorialExample1Factory;
11+
import org.junit.jupiter.api.Test;
12+
13+
import java.util.Collections;
14+
import java.util.List;
15+
16+
import static org.junit.jupiter.api.Assertions.assertEquals;
17+
18+
/**
19+
* @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
20+
*/
21+
class NetworkUtilTest {
22+
23+
@Test
24+
void test() {
25+
Network network = EurostagTutorialExample1Factory.create();
26+
List<String> elementsIds = NetworkUtil.getElementsIds(network, GridPyApiHeader.ElementType.TWO_WINDINGS_TRANSFORMER, Collections.singleton(24.0), Collections.singleton("FR"), true);
27+
assertEquals(Collections.singletonList("NGEN_NHV1"), elementsIds);
28+
}
29+
}

Diff for: tests/test.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,13 @@ def test_security_analysis(self):
8484
def test_get_network_element_ids(self):
8585
n = gp.network.create_eurostag_tutorial_example1_network()
8686
self.assertEqual(['NGEN_NHV1', 'NHV2_NLOAD'], n.get_elements_ids(gp.network.ElementType.TWO_WINDINGS_TRANSFORMER))
87-
self.assertEqual(['NGEN_NHV1'], n.get_elements_ids(gp.network.ElementType.TWO_WINDINGS_TRANSFORMER, 24))
87+
self.assertEqual(['NGEN_NHV1'], n.get_elements_ids(element_type=gp.network.ElementType.TWO_WINDINGS_TRANSFORMER, nominal_voltages={24}))
88+
self.assertEqual(['NGEN_NHV1', 'NHV2_NLOAD'], n.get_elements_ids(element_type=gp.network.ElementType.TWO_WINDINGS_TRANSFORMER, nominal_voltages={24, 150}))
89+
self.assertEqual(['LOAD'], n.get_elements_ids(element_type=gp.network.ElementType.LOAD, nominal_voltages={150}))
90+
self.assertEqual(['LOAD'], n.get_elements_ids(element_type=gp.network.ElementType.LOAD, nominal_voltages={150}, countries={'FR'}))
91+
self.assertEqual([], n.get_elements_ids(element_type=gp.network.ElementType.LOAD, nominal_voltages={150}, countries={'BE'}))
92+
self.assertEqual(['NGEN_NHV1'], n.get_elements_ids(element_type=gp.network.ElementType.TWO_WINDINGS_TRANSFORMER, nominal_voltages={24}, countries={'FR'}))
93+
self.assertEqual([], n.get_elements_ids(element_type=gp.network.ElementType.TWO_WINDINGS_TRANSFORMER, nominal_voltages={24}, countries={'BE'}))
8894

8995
def test_sensitivity_analysis(self):
9096
n = gp.network.create_ieee14()

0 commit comments

Comments
 (0)