Skip to content

Commit e82560c

Browse files
authored
Merge pull request #61 from issp-center-dev/fix_60
Fix `policy.load` with MPI
2 parents cfad27f + 2a03337 commit e82560c

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

physbo/search/discrete/policy.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,13 @@ def load(self, file_history, file_training=None, file_predictor=None):
694694
self.predictor = pickle.load(f)
695695

696696
N = self.history.total_num_search
697-
self.actions = self._delete_actions(self.history.chosen_actions[:N])
697+
698+
visited = self.history.chosen_actions[:N]
699+
local_index = np.searchsorted(self.actions, visited)
700+
local_index = local_index[
701+
np.take(self.actions, local_index, mode="clip") == visited
702+
]
703+
self.actions = self._delete_actions(local_index)
698704

699705
def export_predictor(self):
700706
"""

physbo/search/discrete_multi/policy.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,13 @@ def load(self, file_history, file_training_list=None, file_predictor_list=None):
490490
self.load_predictor_list(file_predictor_list)
491491

492492
N = self.history.total_num_search
493-
self.actions = self._delete_actions(self.history.chosen_actions[:N])
493+
494+
visited = self.history.chosen_actions[:N]
495+
local_index = np.searchsorted(self.actions, visited)
496+
local_index = local_index[
497+
np.take(self.actions, local_index, mode="clip") == visited
498+
]
499+
self.actions = self._delete_actions(local_index)
494500

495501
def save_predictor_list(self, file_name):
496502
with open(file_name, "wb") as f:

0 commit comments

Comments
 (0)