Skip to content

Commit 74290e1

Browse files
wanghan-iapcmHan Wang
andauthored
show status of workflows (#71)
* show status of workflows. * fix bugs * change scheduler converged to complete. add complete for stage_scheduler * fix == None Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 0a27bf8 commit 74290e1

8 files changed

Lines changed: 507 additions & 35 deletions

File tree

dpgen2/entrypoint/main.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
submit_concurrent_learning,
1919
resubmit_concurrent_learning,
2020
)
21+
from .status import (
22+
status,
23+
)
2124
from dpgen2 import (
2225
__version__
2326
)
@@ -48,7 +51,7 @@ def main_parser() -> argparse.ArgumentParser:
4851
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
4952
)
5053
parser_run.add_argument(
51-
"INPUT", help="the input file in json format defining the workflow."
54+
"CONFIG", help="the config file in json format defining the workflow."
5255
)
5356

5457
parser_resubmit = subparsers.add_parser(
@@ -57,7 +60,7 @@ def main_parser() -> argparse.ArgumentParser:
5760
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
5861
)
5962
parser_resubmit.add_argument(
60-
"INPUT", help="the input file in json format defining the workflow."
63+
"CONFIG", help="the config file in json format defining the workflow."
6164
)
6265
parser_resubmit.add_argument(
6366
"ID", help="the ID of the existing workflow."
@@ -69,6 +72,18 @@ def main_parser() -> argparse.ArgumentParser:
6972
"--reuse", type=str, nargs='+', default=None, help="specify which Steps to reuse."
7073
)
7174

75+
parser_status = subparsers.add_parser(
76+
"status",
77+
help="Print the status (stage, iteration, convergence) of the DPGEN2 workflow",
78+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
79+
)
80+
parser_status.add_argument(
81+
"CONFIG", help="the config file in json format."
82+
)
83+
parser_status.add_argument(
84+
"ID", help="the ID of the existing workflow."
85+
)
86+
7287
# --version
7388
parser.add_argument(
7489
'--version',
@@ -102,16 +117,25 @@ def main():
102117
dict_args = vars(args)
103118

104119
if args.command == "submit":
105-
with open(args.INPUT) as fp:
120+
with open(args.CONFIG) as fp:
106121
config = json.load(fp)
107122
submit_concurrent_learning(config)
108123
elif args.command == "resubmit":
109-
with open(args.INPUT) as fp:
124+
with open(args.CONFIG) as fp:
110125
config = json.load(fp)
111126
wfid = args.ID
112127
resubmit_concurrent_learning(
113128
config, wfid, list_steps=args.list, reuse=args.reuse,
114129
)
130+
elif args.command == "status":
131+
with open(args.CONFIG) as fp:
132+
config = json.load(fp)
133+
wfid = args.ID
134+
status(
135+
wfid, config,
136+
)
137+
elif args.command is None:
138+
pass
115139
else:
116140
raise RuntimeError(f"unknown command {args.command}")
117141

dpgen2/entrypoint/status.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
from dflow import (
3+
Workflow,
4+
)
5+
from dpgen2.utils import (
6+
dflow_config,
7+
)
8+
from dpgen2.utils.dflow_query import (
9+
get_last_scheduler,
10+
)
11+
from typing import (
12+
Optional, Dict, Union, List,
13+
)
14+
15+
def status(
16+
workflow_id,
17+
wf_config : Optional[Dict] = {},
18+
):
19+
dflow_config_data = wf_config.get('dflow_config', None)
20+
dflow_config(dflow_config_data)
21+
22+
wf = Workflow(id=workflow_id)
23+
24+
wf_keys = wf.query_keys_of_steps()
25+
26+
scheduler = get_last_scheduler(wf, wf_keys)
27+
28+
if scheduler is not None:
29+
ptr_str = scheduler.print_convergence()
30+
print(ptr_str)
31+
else:
32+
logging.warn('no scheduler is finished')

dpgen2/exploration/scheduler/convergence_check_stage_scheduler.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,34 @@ def __init__(
2626
self.max_numb_iter = max_numb_iter
2727
self.fatal_at_max = fatal_at_max
2828
self.nxt_iter = 0
29+
self.conv = False
30+
self.reached_max_iter = False
31+
self.complete_ = False
32+
self.reports = []
33+
34+
def complete(self):
35+
return self.complete_
36+
37+
def converged(self):
38+
return self.conv
39+
40+
def reached_max_iteration(self):
41+
return self.reached_max_iter
2942

3043
def plan_next_iteration(
3144
self,
32-
hist_reports : List[ExplorationReport] = [],
3345
report : ExplorationReport = None,
3446
trajs : List[Path] = None,
3547
) -> Tuple[bool, ExplorationTaskGroup, ConfSelector] :
3648
if report is None:
37-
converged = False
49+
stg_complete = False
50+
self.conv = stg_complete
3851
lmp_task_grp = self.stage.make_task()
3952
ret_selector = self.selector
4053
else :
41-
converged = report.accurate_ratio() >= self.conv_accuracy
42-
if not converged:
54+
stg_complete = report.accurate_ratio() >= self.conv_accuracy
55+
self.conv = stg_complete
56+
if not stg_complete:
4357
# check if we have any candidate to improve the quality of the model
4458
if report.candidate_ratio() == 0.0:
4559
raise FatalError(
@@ -49,20 +63,23 @@ def plan_next_iteration(
4963
'improved and the iteraction would not end. '
5064
'Please try to increase the higher trust levels. '
5165
)
52-
# if not converged, check max iter
66+
# if not stg_complete, check max iter
5367
if self.max_numb_iter is not None and self.nxt_iter == self.max_numb_iter:
68+
self.reached_max_iter = True
5469
if self.fatal_at_max:
5570
raise FatalError('reached maximal number of iterations')
5671
else:
57-
converged = True
72+
stg_complete = True
5873
# make lmp tasks
59-
if converged:
60-
# if converged, no more lmp task
74+
if stg_complete:
75+
# if stg_complete, no more lmp task
6176
lmp_task_grp = None
6277
ret_selector = None
6378
else :
6479
lmp_task_grp = self.stage.make_task()
6580
ret_selector = self.selector
81+
self.reports.append(report)
6682
self.nxt_iter += 1
67-
return converged, lmp_task_grp, ret_selector
83+
self.complete_ = stg_complete
84+
return stg_complete, lmp_task_grp, ret_selector
6885

dpgen2/exploration/scheduler/scheduler.py

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from typing import (
24
List,
35
Tuple,
@@ -24,9 +26,9 @@ def __init__(
2426
self,
2527
):
2628
self.stage_schedulers = []
27-
self.stage_reports = [[]]
2829
self.cur_stage = 0
2930
self.iteration = -1
31+
self.complete_ = False
3032

3133
def add_stage_scheduler(
3234
self,
@@ -44,6 +46,7 @@ def add_stage_scheduler(
4446
4547
"""
4648
self.stage_schedulers.append(stage_scheduler)
49+
self.complete_ = False
4750
return self
4851

4952
def get_stage(self):
@@ -64,6 +67,13 @@ def get_iteration(self):
6467
"""
6568
return self.iteration
6669

70+
def complete(self):
71+
"""
72+
Tell if all stages are converged.
73+
74+
"""
75+
return self.complete_
76+
6777
def plan_next_iteration(
6878
self,
6979
report : ExplorationReport = None,
@@ -81,8 +91,8 @@ def plan_next_iteration(
8191
8292
Returns
8393
-------
84-
converged: bool
85-
If DPGEN converges.
94+
complete: bool
95+
If all the DPGEN stages complete.
8696
task: ExplorationTaskGroup
8797
A `ExplorationTaskGroup` defining the exploration of the next iteration. Should be `None` if converged.
8898
conf_selector: ConfSelector
@@ -91,26 +101,118 @@ def plan_next_iteration(
91101
"""
92102

93103
try:
94-
converged, lmp_task_grp, conf_selector = \
104+
stg_complete, lmp_task_grp, conf_selector = \
95105
self.stage_schedulers[self.cur_stage].plan_next_iteration(
96-
self.stage_reports[self.cur_stage],
97106
report,
98107
trajs,
99108
)
100-
self.stage_reports[self.cur_stage].append(report)
101109
except FatalError as e:
102110
raise FatalError(f'stage {self.cur_stage}: ' + str(e))
103111

104-
if converged:
112+
if stg_complete:
105113
self.cur_stage += 1
106-
self.stage_reports.append([])
107114
if self.cur_stage < len(self.stage_schedulers):
108115
# goes to next stage
109116
return self.plan_next_iteration()
110117
else:
111-
# all stages converged
118+
# all stages complete
119+
self.complete_ = True
112120
return True, None, None,
113121
else :
114122
self.iteration += 1
115-
return converged, lmp_task_grp, conf_selector
123+
return stg_complete, lmp_task_grp, conf_selector
124+
125+
126+
def get_stage_of_iterations(self):
127+
"""
128+
Get the stage index and the index in the stage of iterations.
116129
130+
"""
131+
stages = self.stage_schedulers
132+
n_stage_iters = []
133+
for ii in range(self.get_stage() + 1):
134+
if ii < len(stages) and len(stages[ii].reports) > 0:
135+
n_stage_iters.append(len(stages[ii].reports))
136+
cumsum_stage_iters = np.cumsum(n_stage_iters)
137+
138+
max_iter = self.get_iteration()
139+
if self.complete() or max_iter == -1:
140+
max_iter += 1
141+
stage_idx = []
142+
idx_in_stage = []
143+
iter_idx = []
144+
for ii in range(max_iter):
145+
idx = np.searchsorted(cumsum_stage_iters, ii+1)
146+
stage_idx.append(idx)
147+
if idx > 0:
148+
idx_in_stage.append(ii - cumsum_stage_iters[idx-1])
149+
else :
150+
idx_in_stage.append(ii)
151+
iter_idx.append(ii)
152+
assert( len(stage_idx) == max_iter)
153+
assert( len(idx_in_stage) == max_iter)
154+
assert( len(iter_idx) == max_iter)
155+
return stage_idx, idx_in_stage, iter_idx
156+
157+
158+
def get_convergence_ratio(self):
159+
"""
160+
Get the accurate, candidate and failed ratios of the iterations
161+
162+
Returns
163+
-------
164+
accu np.ndarray
165+
The accurate ratio. length of array the same as # iterations.
166+
cand np.ndarray
167+
The candidate ratio. length of array the same as # iterations.
168+
fail np.ndarray
169+
The failed ration. length of array the same as # iterations.
170+
"""
171+
stages = self.stage_schedulers
172+
stag_idx, idx_in_stag, iter_idx = self.get_stage_of_iterations()
173+
accu = []
174+
cand = []
175+
fail = []
176+
for ii in range(np.size(iter_idx)):
177+
accu.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].accurate_ratio())
178+
cand.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].candidate_ratio())
179+
fail.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].failed_ratio())
180+
return np.array(accu), np.array(cand), np.array(fail)
181+
182+
def _print_prev_summary(self, prev_stg_idx):
183+
if prev_stg_idx >= 0:
184+
yes = 'YES' if self.stage_schedulers[prev_stg_idx].converged() else 'NO '
185+
rmx = 'YES' if self.stage_schedulers[prev_stg_idx].reached_max_iteration() else 'NO '
186+
return f'# Stage {prev_stg_idx:4d} converged {yes} reached max numb iterations {rmx}'
187+
else:
188+
return None
189+
190+
def print_convergence(self):
191+
spaces = [8, 8, 8, 10, 10, 10]
192+
fmt_str = ' '.join([f'%{ii}s' for ii in spaces])
193+
fmt_flt = '%.4f'
194+
header_str = '#' + fmt_str % ('stage', 'id_stg.', 'iter.', 'accu.', 'cand.', 'fail.')
195+
ret = [header_str]
196+
197+
stage_idx, idx_in_stage, iter_idx = self.get_stage_of_iterations()
198+
accu, cand, fail = self.get_convergence_ratio()
199+
200+
iidx = 0
201+
prev_stg_idx = -1
202+
for iidx in range(len(accu)):
203+
if stage_idx[iidx] != prev_stg_idx:
204+
if prev_stg_idx >= 0:
205+
ret.append(self._print_prev_summary(prev_stg_idx))
206+
ret.append(f'# Stage {stage_idx[iidx]:4d} ' + '-'*20)
207+
prev_stg_idx = stage_idx[iidx]
208+
ret.append(' ' + fmt_str % (
209+
str(stage_idx[iidx]), str(idx_in_stage[iidx]), str(iidx),
210+
fmt_flt%(accu[iidx]*1),
211+
fmt_flt%(cand[iidx]*1),
212+
fmt_flt%(fail[iidx]*1),
213+
))
214+
if self.complete():
215+
if prev_stg_idx >= 0:
216+
ret.append(self._print_prev_summary(prev_stg_idx))
217+
ret.append(f'# All stages converged')
218+
return '\n'.join(ret + [''])

dpgen2/exploration/scheduler/stage_scheduler.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@ class StageScheduler(ABC):
1616
The scheduler for an exploration stage.
1717
"""
1818

19+
@abstractmethod
20+
def converged(self):
21+
"""
22+
Tell if the stage is converged
23+
24+
Returns
25+
-------
26+
converged bool
27+
the convergence
28+
"""
29+
pass
30+
1931
@abstractmethod
2032
def plan_next_iteration(
2133
self,
22-
hist_reports : List[ExplorationReport],
2334
report : ExplorationReport,
2435
trajs : List[Path],
2536
) -> Tuple[bool, ExplorationTaskGroup, ConfSelector] :
@@ -39,8 +50,10 @@ def plan_next_iteration(
3950
4051
Returns
4152
-------
42-
converged: bool
43-
If the stage converged.
53+
stg_complete: bool
54+
If the stage completed. Two cases may happen:
55+
1. converged.
56+
2. when not fatal_at_max, not converged but reached max number of iterations.
4457
task: ExplorationTaskGroup
4558
A `ExplorationTaskGroup` defining the exploration of the next iteration. Should be `None` if the stage is converged.
4659
conf_selector: ConfSelector

dpgen2/utils/dflow_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_last_scheduler(
2525
return None
2626
else:
2727
skey = sorted(scheduler_keys)[-1]
28-
step = wf.query_step(key=skey)
28+
step = wf.query_step(key=skey)[0]
2929
return step.outputs.parameters['exploration_scheduler'].value
3030

3131

0 commit comments

Comments
 (0)