Skip to content

Commit d60b4e9

Browse files
Optimize space usage of ExplorationReport before saving (#279)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added `no_candidate()` method across multiple exploration report classes to check candidate availability. - Enhanced `get_candidate_ids()` method with optional `clear` parameter for memory management. - **Improvements** - Optimized ratio calculations in exploration report classes. - Introduced more efficient state tracking for candidate configurations. - **Technical Updates** - Updated method signatures in exploration report classes to include new parameters. - Refined candidate selection and reporting mechanisms. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: zjgemi <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fc72c85 commit d60b4e9

File tree

6 files changed

+58
-16
lines changed

6 files changed

+58
-16
lines changed

dpgen2/exploration/report/report.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ def converged(
6060
"""
6161
pass
6262

63+
@abstractmethod
6364
def no_candidate(self) -> bool:
6465
r"""If no candidate configuration is found"""
65-
return all([len(ii) == 0 for ii in self.get_candidate_ids()])
66+
pass
6667

6768
@abstractmethod
6869
def get_candidate_ids(

dpgen2/exploration/report/report_adaptive_lower.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def __init__(
127127
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
128128
self.fmt_flt = "%.4f"
129129
self.header_str = "#" + self.fmt_str % print_tuple
130+
self._no_candidate = False
131+
self._failed_ratio = None
132+
self._accurate_ratio = None
133+
self._candidate_ratio = None
130134

131135
@staticmethod
132136
def doc() -> str:
@@ -274,6 +278,10 @@ def record(
274278
# accurate set is substracted by the candidate set
275279
self.accur = self.accur - self.candi
276280
self.model_devi = model_devi
281+
self._no_candidate = len(self.candi) == 0
282+
self._failed_ratio = float(len(self.failed)) / float(self.nframes)
283+
self._accurate_ratio = float(len(self.accur)) / float(self.nframes)
284+
self._candidate_ratio = float(len(self.candi)) / float(self.nframes)
277285

278286
def _record_one_traj(
279287
self,
@@ -346,29 +354,36 @@ def failed_ratio(
346354
self,
347355
tag=None,
348356
):
349-
return float(len(self.failed)) / float(self.nframes)
357+
return self._failed_ratio
350358

351359
def accurate_ratio(
352360
self,
353361
tag=None,
354362
):
355-
return float(len(self.accur)) / float(self.nframes)
363+
return self._accurate_ratio
356364

357365
def candidate_ratio(
358366
self,
359367
tag=None,
360368
):
361-
return float(len(self.candi)) / float(self.nframes)
369+
return self._candidate_ratio
370+
371+
def no_candidate(self) -> bool:
372+
return self._no_candidate
362373

363374
def get_candidate_ids(
364375
self,
365376
max_nframes: Optional[int] = None,
377+
clear: bool = True,
366378
) -> List[List[int]]:
367379
ntraj = self.ntraj
368380
id_cand = self._get_candidates(max_nframes)
369381
id_cand_list = [[] for ii in range(ntraj)]
370382
for ii in id_cand:
371383
id_cand_list[ii[0]].append(ii[1])
384+
# free the memory, this method should only be called once
385+
if clear:
386+
self.clear()
372387
return id_cand_list
373388

374389
def _get_candidates(

dpgen2/exploration/report/report_trust_levels_base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__(
6464
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
6565
self.fmt_flt = "%.4f"
6666
self.header_str = "#" + self.fmt_str % print_tuple
67+
self._no_candidate = False
68+
self._failed_ratio = None
69+
self._accurate_ratio = None
70+
self._candidate_ratio = None
6771

6872
@staticmethod
6973
def args() -> List[Argument]:
@@ -133,6 +137,16 @@ def record(
133137
assert len(self.traj_accu) == ntraj
134138
assert len(self.traj_fail) == ntraj
135139
self.model_devi = model_devi
140+
self._no_candidate = sum([len(ii) for ii in self.traj_cand]) == 0
141+
self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float(
142+
sum(self.traj_nframes)
143+
)
144+
self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float(
145+
sum(self.traj_nframes)
146+
)
147+
self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float(
148+
sum(self.traj_nframes)
149+
)
136150

137151
def _get_indexes(
138152
self,
@@ -205,22 +219,22 @@ def failed_ratio(
205219
self,
206220
tag=None,
207221
):
208-
traj_nf = [len(ii) for ii in self.traj_fail]
209-
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
222+
return self._failed_ratio
210223

211224
def accurate_ratio(
212225
self,
213226
tag=None,
214227
):
215-
traj_nf = [len(ii) for ii in self.traj_accu]
216-
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
228+
return self._accurate_ratio
217229

218230
def candidate_ratio(
219231
self,
220232
tag=None,
221233
):
222-
traj_nf = [len(ii) for ii in self.traj_cand]
223-
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
234+
return self._candidate_ratio
235+
236+
def no_candidate(self) -> bool:
237+
return self._no_candidate
224238

225239
@abstractmethod
226240
def get_candidate_ids(

dpgen2/exploration/report/report_trust_levels_max.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,23 @@ def converged(
4141
converged bool
4242
If the exploration is converged.
4343
"""
44-
return self.accurate_ratio() >= self.conv_accuracy
44+
accurate_ratio = self.accurate_ratio()
45+
assert isinstance(accurate_ratio, float)
46+
return accurate_ratio >= self.conv_accuracy
4547

4648
def get_candidate_ids(
4749
self,
4850
max_nframes: Optional[int] = None,
51+
clear: bool = True,
4952
) -> List[List[int]]:
5053
ntraj = len(self.traj_nframes)
5154
id_cand = self._get_candidates(max_nframes)
5255
id_cand_list = [[] for ii in range(ntraj)]
5356
for ii in id_cand:
5457
id_cand_list[ii[0]].append(ii[1])
58+
# free the memory, this method should only be called once
59+
if clear:
60+
self.clear()
5561
return id_cand_list
5662

5763
def _get_candidates(

dpgen2/exploration/report/report_trust_levels_random.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,23 @@ def converged(
4141
converged bool
4242
If the exploration is converged.
4343
"""
44-
return self.accurate_ratio() >= self.conv_accuracy
44+
accurate_ratio = self.accurate_ratio()
45+
assert isinstance(accurate_ratio, float)
46+
return accurate_ratio >= self.conv_accuracy
4547

4648
def get_candidate_ids(
4749
self,
4850
max_nframes: Optional[int] = None,
51+
clear: bool = True,
4952
) -> List[List[int]]:
5053
ntraj = len(self.traj_nframes)
5154
id_cand = self._get_candidates(max_nframes)
5255
id_cand_list = [[] for ii in range(ntraj)]
5356
for ii in id_cand:
5457
id_cand_list[ii[0]].append(ii[1])
58+
# free the memory, this method should only be called once
59+
if clear:
60+
self.clear()
5561
return id_cand_list
5662

5763
def _get_candidates(

tests/exploration/test_report_adaptive_lower.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class MockedReport:
8888
self.assertFalse(ter.converged([mr, mr1, mr]))
8989
self.assertTrue(ter.converged([mr1, mr, mr]))
9090

91-
picked = ter.get_candidate_ids(2)
91+
picked = ter.get_candidate_ids(2, clear=False)
9292
npicked = 0
9393
self.assertEqual(len(picked), 2)
9494
for ii in range(2):
@@ -218,12 +218,12 @@ def faked_choices(
218218
return ret
219219

220220
ter.record(model_devi)
221-
with mock.patch("random.choices", faked_choices):
222-
picked = ter.get_candidate_ids(11)
223-
self.assertFalse(ter.converged([]))
224221
self.assertEqual(ter.candi, expected_cand)
225222
self.assertEqual(ter.accur, expected_accu)
226223
self.assertEqual(set(ter.failed), expected_fail)
224+
with mock.patch("random.choices", faked_choices):
225+
picked = ter.get_candidate_ids(11)
226+
self.assertFalse(ter.converged([]))
227227
self.assertEqual(len(picked), 2)
228228
self.assertEqual(sorted(picked[0]), [1, 3])
229229
self.assertEqual(sorted(picked[1]), [1, 5, 7])

0 commit comments

Comments
 (0)