4
4
5
5
import logging
6
6
from collections import OrderedDict
7
- from typing import Iterable , Sized
8
7
9
8
import numpy as np
10
9
from andes .core .model .modelcache import ModelCache
11
10
from andes .core .param import (BaseParam , DataParam , IdxParam , NumParam ,
12
11
TimerParam )
13
12
from andes .shared import pd
13
+ from andes .utils .func import validate_keys_values
14
14
15
15
logger = logging .getLogger (__name__ )
16
16
@@ -277,7 +277,7 @@ def find_param(self, prop):
277
277
278
278
return out
279
279
280
- def find_idx (self , keys , values , allow_none = False , default = False ):
280
+ def find_idx (self , keys , values , allow_none = False , default = False , allow_all = False ):
281
281
"""
282
282
Find `idx` of devices whose values match the given pattern.
283
283
@@ -288,49 +288,65 @@ def find_idx(self, keys, values, allow_none=False, default=False):
288
288
values : array, array of arrays, Sized
289
289
Values for the corresponding key to search for. If keys is a str, values should be an array of
290
290
elements. If keys is a list, values should be an array of arrays, each corresponds to the key.
291
- allow_none : bool, Sized
291
+ allow_none : bool, Sized, optional
292
292
Allow key, value to be not found. Used by groups.
293
- default : bool
293
+ default : bool, optional
294
294
Default idx to return if not found (missing)
295
+ allow_all : bool, optional
296
+ If True, returns a list of lists where each nested list contains all the matches for the
297
+ corresponding search criteria.
295
298
296
299
Returns
297
300
-------
298
301
list
299
302
indices of devices
300
- """
301
- if isinstance (keys , str ):
302
- keys = (keys ,)
303
- if not isinstance (values , (int , float , str , np .floating )) and not isinstance (values , Iterable ):
304
- raise ValueError (f"value must be a string, scalar or an iterable, got { values } " )
305
303
306
- if len (values ) > 0 and not isinstance (values [0 ], (list , tuple , np .ndarray )):
307
- values = (values ,)
304
+ Notes
305
+ -----
306
+ - Only the first match is returned by default.
307
+ - If all matches are needed, set `allow_all` to True.
308
+
309
+ Examples
310
+ --------
311
+ >>> # Use example case of IEEE 14-bus system with PVD1
312
+ >>> ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx'))
313
+
314
+ >>> # To find the idx of `PVD1` with `name` of 'PVD1_1' and 'PVD1_2'
315
+ >>> ss.PVD1.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'])
316
+ [1, 2]
317
+
318
+ >>> # To find the idx of `PVD1` connected to bus 4
319
+ >>> ss.PVD1.find_idx(keys='bus', values=[4])
320
+ [1]
308
321
309
- elif isinstance ( keys , Sized ):
310
- if not isinstance ( values , Iterable ):
311
- raise ValueError ( f"value must be an iterable, got { values } " )
322
+ >>> # To find ALL the idx of `PVD1` with `gammap` equals to 0.1
323
+ >>> ss.PVD1.find_idx(keys='gammap', values=[0.1], allow_all=True)
324
+ [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
312
325
313
- if len (values ) > 0 and not isinstance (values [0 ], Iterable ):
314
- raise ValueError (f"if keys is an iterable, values must be an iterable of iterables. got { values } " )
326
+ >>> # To find the idx of `PVD1` with `gammap` equals to 0.1 and `name` of 'PVD1_1'
327
+ >>> ss.PVD1.find_idx(keys=['gammap', 'name'], values=[[0.1], ['PVD1_1']])
328
+ [1]
329
+ """
315
330
316
- if len (keys ) != len (values ):
317
- raise ValueError ("keys and values must have the same length" )
331
+ keys , values = validate_keys_values (keys , values )
318
332
319
333
v_attrs = [self .__dict__ [key ].v for key in keys ]
320
334
321
335
idxes = []
322
336
for v_search in zip (* values ):
323
- v_idx = None
337
+ v_idx = []
324
338
for pos , v_attr in enumerate (zip (* v_attrs )):
325
339
if all ([i == j for i , j in zip (v_search , v_attr )]):
326
- v_idx = self .idx .v [pos ]
327
- break
328
- if v_idx is None :
340
+ v_idx .append (self .idx .v [pos ])
341
+ if not v_idx :
329
342
if allow_none is False :
330
343
raise IndexError (f'{ list (keys )} ={ v_search } not found in { self .class_name } ' )
331
344
else :
332
- v_idx = default
345
+ v_idx = [ default ]
333
346
334
- idxes .append (v_idx )
347
+ if allow_all :
348
+ idxes .append (v_idx )
349
+ else :
350
+ idxes .append (v_idx [0 ])
335
351
336
352
return idxes
0 commit comments