Skip to content

Commit fe0fa83

Browse files
wanghan-iapcmHan Wang
and
Han Wang
authored
Replace the scheduler in the old workflow by the new one (#114)
When resubmit, update the scheduler of the old workflow. Otherwise the workflow will exactly follow the old schedule and one has no opportunity to update the schedule. Also update the report printing: print the trust levels and if the iteration is converged. Co-authored-by: Han Wang <[email protected]>
1 parent 1261939 commit fe0fa83

13 files changed

+720
-56
lines changed

dpgen2/entrypoint/main.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def main_parser() -> argparse.ArgumentParser:
9191
"-l", "--list", action='store_true', help="list the Steps of the existing workflow."
9292
)
9393
parser_resubmit.add_argument(
94-
"--reuse", type=str, nargs='+', default=None, help="specify which Steps to reuse."
94+
"-u", "--reuse", type=str, nargs='+', default=None, help="specify which Steps to reuse."
95+
)
96+
parser_resubmit.add_argument(
97+
"-k", "--keep-schedule", action='store_true', help="if set then keep schedule of the old workflow. otherwise use the schedule defined in the input file"
9598
)
9699
parser_resubmit.add_argument(
97100
"-o", "--old-compatible", action='store_true', help="compatible with old-style input script used in dpgen2 < 0.0.6."
@@ -241,6 +244,7 @@ def main():
241244
list_steps=args.list,
242245
reuse=args.reuse,
243246
old_style=args.old_compatible,
247+
replace_scheduler=(not args.keep_schedule),
244248
)
245249
elif args.command == "status":
246250
with open(args.CONFIG) as fp:

dpgen2/entrypoint/status.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from dpgen2.utils.dflow_query import (
66
get_last_scheduler,
7+
get_all_schedulers,
78
)
89
from typing import (
910
Optional, Dict, Union, List,

dpgen2/entrypoint/submit.py

+87-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import glob, dpdata, os, pickle
1+
import glob, dpdata, os, pickle, logging, copy
22
from pathlib import Path
33
from dflow import (
44
InputParameter,
@@ -82,6 +82,7 @@
8282
workflow_config_from_dict,
8383
matched_step_key,
8484
bohrium_config_from_dict,
85+
get_subkey,
8586
)
8687
from dpgen2.utils.step_config import normalize as normalize_step_dict
8788
from dpgen2.entrypoint.common import (
@@ -388,10 +389,76 @@ def workflow_concurrent_learning(
388389
return dpgen_step
389390

390391

392+
def get_scheduler_ids(
393+
reuse_step,
394+
):
395+
scheduler_ids = []
396+
for idx,ii in enumerate(reuse_step):
397+
if get_subkey(ii.key, 1) == "scheduler":
398+
scheduler_ids.append(idx)
399+
scheduler_keys = [reuse_step[ii].key for ii in scheduler_ids]
400+
assert(sorted(scheduler_keys) == scheduler_keys),\
401+
"The scheduler keys are not properly sorted"
402+
403+
if len(scheduler_ids) == 0:
404+
logging.warning("No scheduler found in the workflow, "
405+
"does not do any replacement."
406+
)
407+
return scheduler_ids
408+
409+
def update_reuse_step_scheduler(
410+
reuse_step,
411+
scheduler_new,
412+
):
413+
scheduler_ids = get_scheduler_ids(reuse_step)
414+
if len(scheduler_ids) == 0:
415+
return reuse_step
416+
417+
# do replacement
418+
reuse_step[scheduler_ids[-1]].modify_output_parameter(
419+
"exploration_scheduler", scheduler_new)
420+
421+
return reuse_step
422+
423+
def copy_scheduler_plans(
424+
scheduler_new,
425+
scheduler_old,
426+
):
427+
if len(scheduler_old.stage_schedulers) == 0:
428+
return scheduler_new
429+
if len(scheduler_new.stage_schedulers) < len(scheduler_old.stage_schedulers):
430+
raise RuntimeError(
431+
'The new scheduler has less stages than the old scheduler, '
432+
'scheduler copy is not supported.'
433+
)
434+
# the scheduler_old is planned. minic the init call of the scheduler
435+
if scheduler_old.get_iteration() > -1:
436+
scheduler_new.plan_next_iteration()
437+
for ii in range(len(scheduler_old.stage_schedulers)):
438+
old_stage = scheduler_old.stage_schedulers[ii]
439+
old_reports = old_stage.get_reports()
440+
if old_stage.next_iteration() > 0:
441+
if ii != scheduler_new.get_stage():
442+
raise RuntimeError(
443+
f'The stage {scheduler_new.get_stage()} of the new '
444+
f'scheduler does not match'
445+
f'the stage {ii} of the old scheduler. '
446+
f'scheduler, which should not happen'
447+
)
448+
for report in old_reports:
449+
scheduler_new.plan_next_iteration(report)
450+
if old_stage.complete() and \
451+
(not scheduler_new.stage_schedulers[ii].complete()):
452+
scheduler_new.force_stage_complete()
453+
else:
454+
break
455+
return scheduler_new
456+
391457
def submit_concurrent_learning(
392458
wf_config,
393-
reuse_step = None,
394-
old_style = False,
459+
reuse_step : Optional[List[Step]] = None,
460+
old_style : bool = False,
461+
replace_scheduler: bool = False,
395462
):
396463
# normalize args
397464
wf_config = normalize_args(wf_config)
@@ -401,6 +468,21 @@ def submit_concurrent_learning(
401468
context = global_config_workflow(wf_config, do_lebesgue=do_lebesgue)
402469

403470
dpgen_step = workflow_concurrent_learning(wf_config, old_style=old_style)
471+
472+
if reuse_step is not None and replace_scheduler:
473+
scheduler_new = copy.deepcopy(dpgen_step.inputs.parameters['exploration_scheduler'].value)
474+
idx_old = get_scheduler_ids(reuse_step)[-1]
475+
scheduler_old = reuse_step[idx_old].inputs.parameters['exploration_scheduler'].value
476+
scheduler_new = copy_scheduler_plans(scheduler_new, scheduler_old)
477+
exploration_report = reuse_step[idx_old].inputs.parameters['exploration_report'].value
478+
# plan next
479+
# hack! trajs is set to None...
480+
conv, lmp_task_grp, selector = scheduler_new.plan_next_iteration(exploration_report, trajs=None)
481+
# update output of the scheduler step
482+
reuse_step[idx_old].modify_output_parameter("converged", conv,)
483+
reuse_step[idx_old].modify_output_parameter("exploration_scheduler", scheduler_new,)
484+
reuse_step[idx_old].modify_output_parameter("lmp_task_grp", lmp_task_grp,)
485+
reuse_step[idx_old].modify_output_parameter("conf_selector", selector,)
404486

405487
wf = Workflow(name="dpgen", context=context)
406488
wf.add(dpgen_step)
@@ -449,6 +531,7 @@ def resubmit_concurrent_learning(
449531
list_steps = False,
450532
reuse = None,
451533
old_style = False,
534+
replace_scheduler = False,
452535
):
453536
wf_config = normalize_args(wf_config)
454537

@@ -474,6 +557,7 @@ def resubmit_concurrent_learning(
474557
wf_config,
475558
reuse_step=reuse_step,
476559
old_style=old_style,
560+
replace_scheduler=replace_scheduler,
477561
)
478562

479563
return wf

dpgen2/exploration/report/report_trust_levels.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
from dflow.python import FatalError
1010

1111
class ExplorationReportTrustLevels(ExplorationReport):
12-
# class attrs
13-
spaces = [8, 8, 8, 10, 10, 10]
14-
fmt_str = ' '.join([f'%{ii}s' for ii in spaces])
15-
fmt_flt = '%.4f'
16-
header_str = '#' + fmt_str % ('stage', 'id_stg.', 'iter.', 'accu.', 'cand.', 'fail.')
17-
1812
def __init__(
1913
self,
2014
trust_level,
@@ -26,6 +20,21 @@ def __init__(
2620
self.v_level = ( (self.trust_level.level_v_lo is not None) and \
2721
(self.trust_level.level_v_hi is not None) )
2822

23+
print_tuple = ('stage', 'id_stg.', 'iter.',
24+
'accu.', 'cand.', 'fail.',
25+
'lvl_f_lo', 'lvl_f_hi',
26+
)
27+
spaces = [8, 8, 8, 10, 10, 10, 10, 10]
28+
if self.v_level:
29+
print_tuple += ('v_lo', 'v_hi',)
30+
spaces += [10, 10]
31+
print_tuple += ('cvged',)
32+
spaces += [8]
33+
self.fmt_str = ' '.join([f'%{ii}s' for ii in spaces])
34+
self.fmt_flt = '%.4f'
35+
self.header_str = '#' + self.fmt_str % print_tuple
36+
37+
2938
def clear(
3039
self,
3140
):
@@ -186,7 +195,7 @@ def _get_candidates(
186195

187196
def print_header(self) -> str:
188197
r"""Print the header of report"""
189-
return ExplorationReportTrustLevels.header_str
198+
return self.header_str
190199

191200
def print(
192201
self,
@@ -195,12 +204,23 @@ def print(
195204
iter_idx : int,
196205
) -> str:
197206
r"""Print the report"""
198-
fmt_str = ExplorationReportTrustLevels.fmt_str
199-
fmt_flt = ExplorationReportTrustLevels.fmt_flt
200-
ret = ' ' + fmt_str % (
201-
str(stage_idx), str(idx_in_stage), str(iter_idx),
202-
fmt_flt%(self.accurate_ratio()),
203-
fmt_flt%(self.candidate_ratio()),
204-
fmt_flt%(self.failed_ratio()),
207+
fmt_str = self.fmt_str
208+
fmt_flt = self.fmt_flt
209+
print_tuple = (
210+
str(stage_idx), str(idx_in_stage), str(iter_idx),
211+
fmt_flt%(self.accurate_ratio()),
212+
fmt_flt%(self.candidate_ratio()),
213+
fmt_flt%(self.failed_ratio()),
214+
fmt_flt%(self.trust_level.level_f_lo),
215+
fmt_flt%(self.trust_level.level_f_hi),
216+
)
217+
if self.v_level:
218+
print_tuple += (
219+
fmt_flt%(self.trust_level.level_v_lo),
220+
fmt_flt%(self.trust_level.level_v_hi),
221+
)
222+
print_tuple += (
223+
str(self.converged()),
205224
)
225+
ret = ' ' + fmt_str % print_tuple
206226
return ret

dpgen2/exploration/scheduler/convergence_check_stage_scheduler.py

+13
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ def __init__(
3030
self.complete_ = False
3131
self.reports = []
3232

33+
def get_reports(self):
34+
return self.reports
35+
3336
def complete(self):
3437
return self.complete_
3538

39+
def force_complete(self):
40+
self.complete_ = True
41+
42+
def next_iteration(self):
43+
return self.nxt_iter
44+
3645
def converged(self):
3746
return self.conv
3847

@@ -44,6 +53,10 @@ def plan_next_iteration(
4453
report : Optional[ExplorationReport] = None,
4554
trajs : Optional[List[Path]] = None,
4655
) -> Tuple[bool, Optional[ExplorationTaskGroup], Optional[ConfSelector]] :
56+
if self.complete():
57+
raise FatalError(
58+
'Cannot plan because the stage has completed.'
59+
)
4760
if report is None:
4861
stg_complete = False
4962
self.conv = stg_complete

dpgen2/exploration/scheduler/scheduler.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def __init__(
2828
):
2929
self.stage_schedulers = []
3030
self.cur_stage = 0
31-
self.iteration = -1
3231
self.complete_ = False
3332

3433
def add_stage_scheduler(
@@ -66,7 +65,15 @@ def get_iteration(self):
6665
Iteration index increase when `self.plan_next_iteration` returns valid `lmp_task_grp` and `conf_selector` for the next iteration.
6766
6867
"""
69-
return self.iteration
68+
tot_iter = -1
69+
for idx,ii in enumerate(self.stage_schedulers):
70+
if ii.complete():
71+
# the last plan is not used because the stage
72+
# is found converged
73+
tot_iter += ii.next_iteration() - 1
74+
else:
75+
tot_iter += ii.next_iteration()
76+
return tot_iter
7077

7178
def complete(self):
7279
"""
@@ -75,6 +82,20 @@ def complete(self):
7582
"""
7683
return self.complete_
7784

85+
def force_stage_complete(self):
86+
"""
87+
Force complete the current stage
88+
89+
"""
90+
self.stage_schedulers[self.cur_stage].force_complete()
91+
self.cur_stage += 1
92+
if self.cur_stage < len(self.stage_schedulers):
93+
# goes to next stage
94+
self.plan_next_iteration()
95+
else:
96+
# all stages complete
97+
self.complete_ = True
98+
7899
def plan_next_iteration(
79100
self,
80101
report : Optional[ExplorationReport] = None,
@@ -109,7 +130,7 @@ def plan_next_iteration(
109130
)
110131
except FatalError as e:
111132
raise FatalError(f'stage {self.cur_stage}: ' + str(e))
112-
133+
113134
if stg_complete:
114135
self.cur_stage += 1
115136
if self.cur_stage < len(self.stage_schedulers):
@@ -120,7 +141,6 @@ def plan_next_iteration(
120141
self.complete_ = True
121142
return True, None, None,
122143
else :
123-
self.iteration += 1
124144
return stg_complete, lmp_task_grp, conf_selector
125145

126146

@@ -188,6 +208,32 @@ def _print_prev_summary(self, prev_stg_idx):
188208
else:
189209
return None
190210

211+
212+
def print_last_iteration(self, print_header=False):
213+
stages = self.stage_schedulers
214+
215+
stage_idx, idx_in_stage, iter_idx = self.get_stage_of_iterations()
216+
217+
if np.size(iter_idx) == 0:
218+
return "No finished iteration found\n"
219+
220+
iidx = np.size(iter_idx)-1
221+
222+
ret = []
223+
if print_header:
224+
ret.append(
225+
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]].print_header())
226+
ret.append(
227+
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]]\
228+
.print(stage_idx[iidx], idx_in_stage[iidx], iidx)
229+
)
230+
231+
if self.complete():
232+
ret.append(f'# All stages converged')
233+
return '\n'.join(ret + [''])
234+
235+
236+
191237
def print_convergence(self):
192238
ret = []
193239
stages = self.stage_schedulers

0 commit comments

Comments
 (0)