Skip to content

Commit 35e0b97

Browse files
Support parallelization of conf filter (#268)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced batch processing capabilities for configuration checks, improving efficiency when handling multiple frames. - Added new filter classes (`BarFilter`, `BazFilter`) with specific checks for frame coordinates. - **Bug Fixes** - Enhanced clarity and efficiency in the configuration filtering process, streamlining logic and reducing complexity. - **Tests** - Updated test cases to reflect new filter logic and ensure accurate validation of frame counts and coordinate values. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: zjgemi <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 08d8d6e commit 35e0b97

File tree

4 files changed

+152
-89
lines changed

4 files changed

+152
-89
lines changed

dpgen2/exploration/render/traj_render_lammps.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,5 @@ def get_confs(
130130
ss = ss.sub_system(id_selected[ii])
131131
ms.append(ss)
132132
if conf_filters is not None:
133-
ms2 = dpdata.MultiSystems(type_map=type_map)
134-
for s in ms:
135-
s2 = conf_filters.check(s)
136-
if len(s2) > 0:
137-
ms2.append(s2)
138-
ms = ms2
133+
ms = conf_filters.check(ms)
139134
return ms

dpgen2/exploration/selector/conf_filter.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
ABC,
77
abstractmethod,
88
)
9+
from typing import (
10+
List,
11+
)
912

1013
import dpdata
1114
import numpy as np
@@ -32,6 +35,25 @@ def check(
3235
"""
3336
pass
3437

38+
def batched_check(
39+
self,
40+
frames: List[dpdata.System],
41+
) -> List[bool]:
42+
"""Check if a list of configurations are valid.
43+
44+
Parameters
45+
----------
46+
frames : List[dpdata.System]
47+
A list of dpdata.System each containing a single frame
48+
49+
Returns
50+
-------
51+
valid : List[bool]
52+
`True` if the configuration is a valid configuration, else `False`.
53+
54+
"""
55+
return list(map(self.check, frames))
56+
3557

3658
class ConfFilters:
3759
def __init__(
@@ -48,11 +70,20 @@ def add(
4870

4971
def check(
5072
self,
51-
conf: dpdata.System,
52-
) -> dpdata.System:
53-
natoms = sum(conf["atom_numbs"]) # type: ignore
54-
selected_idx = np.arange(conf.get_nframes())
73+
ms: dpdata.MultiSystems,
74+
) -> dpdata.MultiSystems:
75+
selected_idx = []
76+
for i in range(len(ms)):
77+
for j in range(ms[i].get_nframes()):
78+
selected_idx.append((i, j))
5579
for ff in self._filters:
56-
fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0]
57-
selected_idx = np.intersect1d(selected_idx, fsel)
58-
return conf.sub_system(selected_idx)
80+
res = ff.batched_check([ms[i][j] for i, j in selected_idx])
81+
selected_idx = [idx for i, idx in enumerate(selected_idx) if res[i]]
82+
selected_idx_list = [[] for _ in range(len(ms))]
83+
for i, j in selected_idx:
84+
selected_idx_list[i].append(j)
85+
ms2 = dpdata.MultiSystems(type_map=ms.atom_names)
86+
for i in range(len(ms)):
87+
if len(selected_idx_list[i]) > 0:
88+
ms2.append(ms[i].sub_system(selected_idx_list[i]))
89+
return ms2

dpgen2/exploration/selector/distance_conf_filter.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import logging
2+
from concurrent.futures import (
3+
ProcessPoolExecutor,
4+
)
25
from copy import (
36
deepcopy,
47
)
@@ -133,7 +136,8 @@ def check_multiples(a, b, c, multiple):
133136

134137

135138
class DistanceConfFilter(ConfFilter):
136-
def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0):
139+
def __init__(self, max_workers=None, custom_safe_dist=None, safe_dist_ratio=1.0):
140+
self.max_workers = max_workers
137141
self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {}
138142
self.safe_dist_ratio = safe_dist_ratio
139143

@@ -187,6 +191,16 @@ def check(
187191

188192
return True
189193

194+
def batched_check(
195+
self,
196+
frames: List[dpdata.System],
197+
):
198+
if self.max_workers == 1:
199+
return list(map(self.check, frames))
200+
else:
201+
with ProcessPoolExecutor(self.max_workers) as executor:
202+
return list(executor.map(self.check, frames))
203+
190204
@staticmethod
191205
def args() -> List[dargs.Argument]:
192206
r"""The argument definition of the `ConfFilter`.
@@ -197,9 +211,20 @@ def args() -> List[dargs.Argument]:
197211
List of dargs.Argument defines the arguments of the `ConfFilter`.
198212
"""
199213

214+
doc_max_workers = (
215+
"The maximum number of processes used to filter configurations, "
216+
+ "None represents as many as the processors of the machine, and 1 for serial"
217+
)
200218
doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element"
201219
doc_safe_dist_ratio = "The ratio multiplied to the safe distance"
202220
return [
221+
Argument(
222+
"max_workers",
223+
int,
224+
optional=True,
225+
default=None,
226+
doc=doc_max_workers,
227+
),
203228
Argument(
204229
"custom_safe_dist",
205230
dict,
@@ -218,7 +243,8 @@ def args() -> List[dargs.Argument]:
218243

219244

220245
class BoxSkewnessConfFilter(ConfFilter):
221-
def __init__(self, theta=60.0):
246+
def __init__(self, max_workers=None, theta=60.0):
247+
self.max_workers = max_workers
222248
self.theta = theta
223249

224250
def check(
@@ -251,6 +277,16 @@ def check(
251277
return False
252278
return True
253279

280+
def batched_check(
281+
self,
282+
frames: List[dpdata.System],
283+
):
284+
if self.max_workers == 1:
285+
return list(map(self.check, frames))
286+
else:
287+
with ProcessPoolExecutor(self.max_workers) as executor:
288+
return list(executor.map(self.check, frames))
289+
254290
@staticmethod
255291
def args() -> List[dargs.Argument]:
256292
r"""The argument definition of the `ConfFilter`.
@@ -261,8 +297,19 @@ def args() -> List[dargs.Argument]:
261297
List of dargs.Argument defines the arguments of the `ConfFilter`.
262298
"""
263299

300+
doc_max_workers = (
301+
"The maximum number of processes used to filter configurations, "
302+
+ "None represents as many as the processors of the machine, and 1 for serial"
303+
)
264304
doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed"
265305
return [
306+
Argument(
307+
"max_workers",
308+
int,
309+
optional=True,
310+
default=None,
311+
doc=doc_max_workers,
312+
),
266313
Argument(
267314
"theta",
268315
float,
@@ -274,7 +321,8 @@ def args() -> List[dargs.Argument]:
274321

275322

276323
class BoxLengthFilter(ConfFilter):
277-
def __init__(self, length_ratio=5.0):
324+
def __init__(self, max_workers=None, length_ratio=5.0):
325+
self.max_workers = max_workers
278326
self.length_ratio = length_ratio
279327

280328
def check(
@@ -307,6 +355,16 @@ def check(
307355
return False
308356
return True
309357

358+
def batched_check(
359+
self,
360+
frames: List[dpdata.System],
361+
):
362+
if self.max_workers == 1:
363+
return list(map(self.check, frames))
364+
else:
365+
with ProcessPoolExecutor(self.max_workers) as executor:
366+
return list(executor.map(self.check, frames))
367+
310368
@staticmethod
311369
def args() -> List[dargs.Argument]:
312370
r"""The argument definition of the `ConfFilter`.
@@ -317,8 +375,19 @@ def args() -> List[dargs.Argument]:
317375
List of dargs.Argument defines the arguments of the `ConfFilter`.
318376
"""
319377

378+
doc_max_workers = (
379+
"The maximum number of processes used to filter configurations, "
380+
+ "None represents as many as the processors of the machine, and 1 for serial"
381+
)
320382
doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed"
321383
return [
384+
Argument(
385+
"max_workers",
386+
int,
387+
optional=True,
388+
default=None,
389+
doc=doc_max_workers,
390+
),
322391
Argument(
323392
"length_ratio",
324393
float,

tests/exploration/test_conf_filter.py

Lines changed: 41 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -27,110 +27,78 @@ def check(
2727
self,
2828
frame: dpdata.System,
2929
) -> bool:
30-
return True
30+
return frame["coords"][0][0][0] > 0.0
3131

3232

33-
class faked_filter:
34-
myiter = -1
35-
myret = [True]
33+
class BarFilter(ConfFilter):
34+
def check(
35+
self,
36+
frame: dpdata.System,
37+
) -> bool:
38+
return frame["coords"][0][0][1] > 0.0
39+
3640

37-
@classmethod
38-
def faked_check(cls, frame):
39-
cls.myiter += 1
40-
cls.myiter = cls.myiter % len(cls.myret)
41-
return cls.myret[cls.myiter]
41+
class BazFilter(ConfFilter):
42+
def check(
43+
self,
44+
frame: dpdata.System,
45+
) -> bool:
46+
return frame["coords"][0][0][2] > 0.0
4247

4348

4449
class TestConfFilter(unittest.TestCase):
45-
@patch.object(FooFilter, "check", faked_filter.faked_check)
4650
def test_filter_0(self):
47-
faked_filter.myiter = -1
48-
faked_filter.myret = [
49-
True,
50-
True,
51-
False,
52-
True,
53-
False,
54-
True,
55-
True,
56-
False,
57-
True,
58-
True,
59-
False,
60-
False,
61-
]
6251
faked_sys = fake_system(4, 3)
6352
# expected only frame 1 is preseved.
64-
faked_sys["coords"][1][0][0] = 1.0
53+
faked_sys["coords"][1][0] = 1.0
54+
faked_sys["coords"][0][0][0] = 2.0
55+
faked_sys["coords"][2][0][1] = 3.0
56+
faked_sys["coords"][3][0][2] = 4.0
6557
filters = ConfFilters()
66-
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
67-
sel_sys = filters.check(faked_sys)
58+
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
59+
ms = dpdata.MultiSystems()
60+
ms.append(faked_sys)
61+
sel_sys = filters.check(ms)[0]
6862
self.assertEqual(sel_sys.get_nframes(), 1)
6963
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1)
7064

71-
@patch.object(FooFilter, "check", faked_filter.faked_check)
7265
def test_filter_1(self):
73-
faked_filter.myiter = -1
74-
faked_filter.myret = [
75-
True,
76-
True,
77-
False,
78-
True,
79-
False,
80-
True,
81-
True,
82-
True,
83-
True,
84-
True,
85-
False,
86-
True,
87-
]
8866
faked_sys = fake_system(4, 3)
8967
# expected frame 1 and 3 are preseved.
90-
faked_sys["coords"][1][0][0] = 1.0
91-
faked_sys["coords"][3][0][0] = 3.0
68+
faked_sys["coords"][1][0] = 1.0
69+
faked_sys["coords"][3][0] = 3.0
9270
filters = ConfFilters()
93-
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
94-
sel_sys = filters.check(faked_sys)
71+
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
72+
ms = dpdata.MultiSystems()
73+
ms.append(faked_sys)
74+
sel_sys = filters.check(ms)[0]
9575
self.assertEqual(sel_sys.get_nframes(), 2)
9676
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 1)
9777
self.assertAlmostEqual(sel_sys["coords"][1][0][0], 3)
9878

99-
@patch.object(FooFilter, "check", faked_filter.faked_check)
10079
def test_filter_all(self):
101-
faked_filter.myiter = -1
102-
faked_filter.myret = [
103-
True,
104-
True,
105-
True,
106-
True,
107-
]
10880
faked_sys = fake_system(4, 3)
10981
# expected all frames are preseved.
110-
faked_sys["coords"][0][0][0] = 0.5
111-
faked_sys["coords"][1][0][0] = 1.0
112-
faked_sys["coords"][2][0][0] = 2.0
113-
faked_sys["coords"][3][0][0] = 3.0
82+
faked_sys["coords"][0][0] = 0.5
83+
faked_sys["coords"][1][0] = 1.0
84+
faked_sys["coords"][2][0] = 2.0
85+
faked_sys["coords"][3][0] = 3.0
11486
filters = ConfFilters()
115-
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
116-
sel_sys = filters.check(faked_sys)
87+
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
88+
ms = dpdata.MultiSystems()
89+
ms.append(faked_sys)
90+
sel_sys = filters.check(ms)[0]
11791
self.assertEqual(sel_sys.get_nframes(), 4)
11892
self.assertAlmostEqual(sel_sys["coords"][0][0][0], 0.5)
11993
self.assertAlmostEqual(sel_sys["coords"][1][0][0], 1)
12094
self.assertAlmostEqual(sel_sys["coords"][2][0][0], 2)
12195
self.assertAlmostEqual(sel_sys["coords"][3][0][0], 3)
12296

123-
@patch.object(FooFilter, "check", faked_filter.faked_check)
12497
def test_filter_none(self):
125-
faked_filter.myiter = -1
126-
faked_filter.myret = [
127-
False,
128-
False,
129-
False,
130-
False,
131-
]
13298
faked_sys = fake_system(4, 3)
13399
filters = ConfFilters()
134-
filters.add(FooFilter()).add(FooFilter()).add(FooFilter())
135-
sel_sys = filters.check(faked_sys)
136-
self.assertEqual(sel_sys.get_nframes(), 0)
100+
filters.add(FooFilter()).add(BarFilter()).add(BazFilter())
101+
ms = dpdata.MultiSystems()
102+
ms.append(faked_sys)
103+
sel_ms = filters.check(ms)
104+
self.assertEqual(sel_ms.get_nframes(), 0)

0 commit comments

Comments
 (0)