Skip to content

Commit d534ced

Browse files
authored
Merge pull request #567 from jinningwang/findidx
2 parents c7e7304 + 3adc110 commit d534ced

File tree

7 files changed

+218
-42
lines changed

7 files changed

+218
-42
lines changed

andes/core/model/modeldata.py

+40-24
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import logging
66
from collections import OrderedDict
7-
from typing import Iterable, Sized
87

98
import numpy as np
109
from andes.core.model.modelcache import ModelCache
1110
from andes.core.param import (BaseParam, DataParam, IdxParam, NumParam,
1211
TimerParam)
1312
from andes.shared import pd
13+
from andes.utils.func import validate_keys_values
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -277,7 +277,7 @@ def find_param(self, prop):
277277

278278
return out
279279

280-
def find_idx(self, keys, values, allow_none=False, default=False):
280+
def find_idx(self, keys, values, allow_none=False, default=False, allow_all=False):
281281
"""
282282
Find `idx` of devices whose values match the given pattern.
283283
@@ -288,49 +288,65 @@ def find_idx(self, keys, values, allow_none=False, default=False):
288288
values : array, array of arrays, Sized
289289
Values for the corresponding key to search for. If keys is a str, values should be an array of
290290
elements. If keys is a list, values should be an array of arrays, each corresponds to the key.
291-
allow_none : bool, Sized
291+
allow_none : bool, Sized, optional
292292
Allow key, value to be not found. Used by groups.
293-
default : bool
293+
default : bool, optional
294294
Default idx to return if not found (missing)
295+
allow_all : bool, optional
296+
If True, returns a list of lists where each nested list contains all the matches for the
297+
corresponding search criteria.
295298
296299
Returns
297300
-------
298301
list
299302
indices of devices
300-
"""
301-
if isinstance(keys, str):
302-
keys = (keys,)
303-
if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable):
304-
raise ValueError(f"value must be a string, scalar or an iterable, got {values}")
305303
306-
if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)):
307-
values = (values,)
304+
Notes
305+
-----
306+
- Only the first match is returned by default.
307+
- If all matches are needed, set `allow_all` to True.
308+
309+
Examples
310+
--------
311+
>>> # Use example case of IEEE 14-bus system with PVD1
312+
>>> ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx'))
313+
314+
>>> # To find the idx of `PVD1` with `name` of 'PVD1_1' and 'PVD1_2'
315+
>>> ss.PVD1.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'])
316+
[1, 2]
317+
318+
>>> # To find the idx of `PVD1` connected to bus 4
319+
>>> ss.PVD1.find_idx(keys='bus', values=[4])
320+
[1]
308321
309-
elif isinstance(keys, Sized):
310-
if not isinstance(values, Iterable):
311-
raise ValueError(f"value must be an iterable, got {values}")
322+
>>> # To find ALL the idx of `PVD1` with `gammap` equals to 0.1
323+
>>> ss.PVD1.find_idx(keys='gammap', values=[0.1], allow_all=True)
324+
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
312325
313-
if len(values) > 0 and not isinstance(values[0], Iterable):
314-
raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}")
326+
>>> # To find the idx of `PVD1` with `gammap` equals to 0.1 and `name` of 'PVD1_1'
327+
>>> ss.PVD1.find_idx(keys=['gammap', 'name'], values=[[0.1], ['PVD1_1']])
328+
[1]
329+
"""
315330

316-
if len(keys) != len(values):
317-
raise ValueError("keys and values must have the same length")
331+
keys, values = validate_keys_values(keys, values)
318332

319333
v_attrs = [self.__dict__[key].v for key in keys]
320334

321335
idxes = []
322336
for v_search in zip(*values):
323-
v_idx = None
337+
v_idx = []
324338
for pos, v_attr in enumerate(zip(*v_attrs)):
325339
if all([i == j for i, j in zip(v_search, v_attr)]):
326-
v_idx = self.idx.v[pos]
327-
break
328-
if v_idx is None:
340+
v_idx.append(self.idx.v[pos])
341+
if not v_idx:
329342
if allow_none is False:
330343
raise IndexError(f'{list(keys)}={v_search} not found in {self.class_name}')
331344
else:
332-
v_idx = default
345+
v_idx = [default]
333346

334-
idxes.append(v_idx)
347+
if allow_all:
348+
idxes.append(v_idx)
349+
else:
350+
idxes.append(v_idx[0])
335351

336352
return idxes

andes/models/group.py

+57-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from andes.core.service import BackRef
8-
from andes.utils.func import list_flatten
8+
from andes.utils.func import list_flatten, validate_keys_values
99

1010
logger = logging.getLogger(__name__)
1111

@@ -243,30 +243,71 @@ def set(self, src: str, idx, attr, value):
243243

244244
return True
245245

246-
def find_idx(self, keys, values, allow_none=False, default=None):
246+
def find_idx(self, keys, values, allow_none=False, default=None, allow_all=False):
247247
"""
248248
Find indices of devices that satisfy the given `key=value` condition.
249249
250250
This method iterates over all models in this group.
251+
252+
Parameters
253+
----------
254+
keys : str, array-like, Sized
255+
A string or an array-like of strings containing the names of parameters for the search criteria.
256+
values : array, array of arrays, Sized
257+
Values for the corresponding key to search for. If keys is a str, values should be an array of
258+
elements. If keys is a list, values should be an array of arrays, each corresponding to the key.
259+
allow_none : bool, optional
260+
Allow key, value to be not found. Used by groups. Default is False.
261+
default : bool, optional
262+
Default idx to return if not found (missing). Default is None.
263+
allow_all : bool, optional
264+
Return all matches if set to True. Default is False.
265+
266+
Returns
267+
-------
268+
list
269+
Indices of devices.
251270
"""
271+
272+
keys, values = validate_keys_values(keys, values)
273+
274+
n_mdl, n_pair = len(self.models), len(values[0])
275+
252276
indices_found = []
253277
# `indices_found` contains found indices returned from all models of this group
254278
for model in self.models.values():
255-
indices_found.append(model.find_idx(keys, values, allow_none=True, default=default))
256-
257-
out = []
258-
for idx, idx_found in enumerate(zip(*indices_found)):
259-
if not allow_none:
260-
if idx_found.count(None) == len(idx_found):
261-
missing_values = [item[idx] for item in values]
262-
raise IndexError(f'{list(keys)} = {missing_values} not found in {self.class_name}')
263-
264-
real_idx = default
265-
for item in idx_found:
266-
if item is not None:
267-
real_idx = item
279+
indices_found.append(model.find_idx(keys, values, allow_none=True, default=default, allow_all=True))
280+
281+
# --- find missing pairs ---
282+
i_val_miss = []
283+
for i in range(n_pair):
284+
idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)]
285+
if all(item == [default] for item in idx_cross_mdls):
286+
i_val_miss.append(i)
287+
288+
if (not allow_none) and i_val_miss:
289+
miss_pairs = []
290+
for i in i_val_miss:
291+
miss_pairs.append([values[j][i] for j in range(len(keys))])
292+
raise IndexError(f'{keys} = {miss_pairs} not found in {self.class_name}')
293+
294+
# --- output ---
295+
out_pre = []
296+
for i in range(n_pair):
297+
idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)]
298+
if all(item == [default] for item in idx_cross_mdls):
299+
out_pre.append([default])
300+
continue
301+
for item in idx_cross_mdls:
302+
if item != [default]:
303+
out_pre.append(item)
268304
break
269-
out.append(real_idx)
305+
306+
if allow_all:
307+
out = out_pre
308+
else:
309+
out = [item[0] for item in out_pre]
310+
270311
return out
271312

272313
def _check_src(self, src: str):

andes/models/misc/output.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def in1d(self, addr, v_code):
5050
"""
5151

5252
if v_code == 'x':
53-
return np.in1d(self.xidx, addr)
53+
return np.isin(self.xidx, addr)
5454
if v_code == 'y':
55-
return np.in1d(self.yidx, addr)
55+
return np.isin(self.yidx, addr)
5656

5757
raise NotImplementedError("v_code <%s> not recognized" % v_code)
5858

andes/utils/func.py

+48
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import operator
3+
from typing import Iterable, Sized
34

45
from andes.shared import np
56

@@ -36,3 +37,50 @@ def interp_n2(t, x, y):
3637
"""
3738

3839
return y[:, 0] + (t - x[0]) * (y[:, 1] - y[:, 0]) / (x[1] - x[0])
40+
41+
42+
def validate_keys_values(keys, values):
43+
"""
44+
Validate the inputs for the func `find_idx`.
45+
46+
Parameters
47+
----------
48+
keys : str, array-like, Sized
49+
A string or an array-like of strings containing the names of parameters for the search criteria.
50+
values : array, array of arrays, Sized
51+
Values for the corresponding key to search for. If keys is a str, values should be an array of
52+
elements. If keys is a list, values should be an array of arrays, each corresponds to the key.
53+
54+
Returns
55+
-------
56+
tuple
57+
Sanitized keys and values
58+
59+
Raises
60+
------
61+
ValueError
62+
If the inputs are not valid.
63+
"""
64+
if isinstance(keys, str):
65+
keys = (keys,)
66+
if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable):
67+
raise ValueError(f"value must be a string, scalar or an iterable, got {values}")
68+
69+
if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)):
70+
values = (values,)
71+
72+
elif isinstance(keys, Sized):
73+
if not isinstance(values, Iterable):
74+
raise ValueError(f"value must be an iterable, got {values}")
75+
76+
if len(values) > 0 and not isinstance(values[0], Iterable):
77+
raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}")
78+
79+
if len(keys) != len(values):
80+
raise ValueError("keys and values must have the same length")
81+
82+
if isinstance(values[0], Iterable):
83+
if not all([len(val) == len(values[0]) for val in values]):
84+
raise ValueError("All items in values must have the same length")
85+
86+
return keys, values

docs/source/release-notes.rst

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ v1.9.3 (2024-04-XX)
1919
- Adjust `BusFreq.Tw.default` to 0.1.
2020
- Add parameter from_csv=None in TDS.run() to allow loading data from CSV files at TDS begining.
2121
- Fix `TDS.init()` and `TDS._csv_step()` to fit loading from CSV when `Output` exists.
22+
- Add parameter `allow_all=False` to `ModelData.find_idx()` `GroupBase.find_idx()` to allow searching all matches.
2223

2324
v1.9.2 (2024-03-25)
2425
-------------------

tests/test_group.py

+17
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_group_access(self):
7171
[6, 7, 8, 1])
7272

7373
# --- find_idx ---
74+
# same Model
7475
self.assertListEqual(ss.DG.find_idx('name', ['PVD1_1', 'PVD1_2']),
7576
ss.PVD1.find_idx('name', ['PVD1_1', 'PVD1_2']),
7677
)
@@ -82,6 +83,22 @@ def test_group_access(self):
8283
[('PVD1_1', 'PVD1_2'),
8384
(1.0, 1.0)]))
8485

86+
# cross Model, given results
87+
self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
88+
values=[1, 2, 3, 4]),
89+
[1, 2, 3, 6])
90+
self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
91+
values=[1, 2, 3, 4],
92+
allow_all=True),
93+
[[1], [2], [3], [6]])
94+
95+
self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
96+
values=[1, 2, 3, 4, 2024],
97+
allow_none=True,
98+
default=2011,
99+
allow_all=True),
100+
[[1], [2], [3], [6], [2011]])
101+
85102
# --- get_field ---
86103
ff = ss.DG.get_field('f', list(ss.DG._idx2model.keys()), 'v_code')
87104
self.assertTrue(any([item == 'y' for item in ff]))

tests/test_model_set.py

+53
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,56 @@ def test_model_set(self):
5454
ss.GENROU.set("M", np.array(["GENROU_4"]), "v", 6.0)
5555
np.testing.assert_equal(ss.GENROU.M.v[3], 6.0)
5656
self.assertEqual(ss.TDS.Teye[omega_addr[3], omega_addr[3]], 6.0)
57+
58+
def test_find_idx(self):
59+
ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx'))
60+
mdl = ss.PVD1
61+
62+
# not allow all matches
63+
self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=False),
64+
[1])
65+
66+
# allow all matches
67+
self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=True),
68+
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
69+
70+
# multiple values
71+
self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'],
72+
allow_none=False, default=False),
73+
[1, 2])
74+
# non-existing value
75+
self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_999'],
76+
allow_none=True, default=False),
77+
[False])
78+
79+
# non-existing value is not allowed
80+
with self.assertRaises(IndexError):
81+
mdl.find_idx(keys='name', values=['PVD1_999'],
82+
allow_none=False, default=False)
83+
84+
# multiple keys
85+
self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'],
86+
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_2']]),
87+
[1, 2])
88+
89+
# multiple keys, with non-existing values
90+
self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'],
91+
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']],
92+
allow_none=True, default='CURENT'),
93+
[1, 'CURENT'])
94+
95+
# multiple keys, with non-existing values not allowed
96+
with self.assertRaises(IndexError):
97+
mdl.find_idx(keys=['gammap', 'name'],
98+
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']],
99+
allow_none=False, default=999)
100+
101+
# multiple keys, values are not iterable
102+
with self.assertRaises(ValueError):
103+
mdl.find_idx(keys=['gammap', 'name'],
104+
values=[0.1, 0.1])
105+
106+
# multiple keys, items length are inconsistent in values
107+
with self.assertRaises(ValueError):
108+
mdl.find_idx(keys=['gammap', 'name'],
109+
values=[[0.1, 0.1], ['PVD1_1']])

0 commit comments

Comments
 (0)