Skip to content

Commit 0a27bf8

Browse files
wanghan-iapcmHan Wang
andauthored
print step keys in nice format (#70)
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 2346d7a commit 0a27bf8

3 files changed

Lines changed: 331 additions & 13 deletions

File tree

dpgen2/entrypoint/submit.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,8 @@ def workflow_concurrent_learning(
367367
return dpgen_step
368368

369369

370-
def submit_concurrent_learning(
370+
def wf_global_workflow(
371371
wf_config,
372-
reuse_step = None,
373372
):
374373
dflow_config_data = wf_config.get('dflow_config', None)
375374
dflow_config(dflow_config_data)
@@ -384,9 +383,18 @@ def submit_concurrent_learning(
384383
else :
385384
lebesgue_context = None
386385

386+
return lebesgue_context
387+
388+
389+
def submit_concurrent_learning(
390+
wf_config,
391+
reuse_step = None,
392+
):
393+
context = wf_global_workflow(wf_config)
394+
387395
dpgen_step = workflow_concurrent_learning(wf_config)
388396

389-
wf = Workflow(name="dpgen", context=lebesgue_context)
397+
wf = Workflow(name="dpgen", context=context)
390398
wf.add(dpgen_step)
391399

392400
wf.submit(reuse_step=reuse_step)
@@ -441,20 +449,16 @@ def resubmit_concurrent_learning(
441449
list_steps = False,
442450
reuse = None,
443451
):
444-
# set global config
445-
from dflow import config, s3_config
446-
dflow_config = wf_config.get('dflow_config', None)
447-
if dflow_config :
448-
config["host"] = dflow_config.get('host', None)
449-
s3_config["endpoint"] = dflow_config.get('s3_endpoint', None)
450-
config["k8s_api_server"] = dflow_config.get('k8s_api_server', None)
451-
config["token"] = dflow_config.get('token', None)
452+
context = wf_global_workflow(wf_config)
452453

453454
old_wf = Workflow(id=wfid)
454455

455456
all_step_keys = successful_step_keys(old_wf)
457+
all_step_keys = sort_slice_ops(
458+
all_step_keys, ['run-train', 'run-lmp', 'run-fp'],)
456459
if list_steps:
457-
prt_str = print_list_steps(all_step_keys)
460+
prt_str = print_keys_in_nice_format(
461+
all_step_keys, ['run-train', 'run-lmp', 'run-fp'],)
458462
print(prt_str)
459463

460464
if reuse is None:
@@ -465,6 +469,10 @@ def resubmit_concurrent_learning(
465469
for ii in reuse_idx:
466470
reuse_step += old_wf_info.get_step(key=all_step_keys[ii])
467471

468-
wf = submit_concurrent_learning(wf_config, reuse_step=reuse_step)
472+
wf = submit_concurrent_learning(
473+
wf_config,
474+
context=context,
475+
reuse_step=reuse_step,
476+
)
469477

470478
return wf

dpgen2/utils/dflow_query.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import numpy as np
2+
import re
3+
from typing import (
4+
List, Optional, Any,
5+
)
6+
7+
def get_subkey(
8+
key : str,
9+
idx : Optional[int] = -1,
10+
):
11+
return key.split('--')[idx]
12+
13+
def get_last_scheduler(
14+
wf : Any,
15+
keys : List[str],
16+
):
17+
"""
18+
get the output Scheduler of the last successful iteration
19+
"""
20+
scheduler_keys = []
21+
for ii in keys:
22+
if get_subkey(ii) == 'scheduler':
23+
scheduler_keys.append(ii)
24+
if len(scheduler_keys) == 0:
25+
return None
26+
else:
27+
skey = sorted(scheduler_keys)[-1]
28+
step = wf.query_step(key=skey)
29+
return step.outputs.parameters['exploration_scheduler'].value
30+
31+
32+
def get_last_iteration(
33+
keys : List[str],
34+
):
35+
"""
36+
get the index of the last iteraction from a list of step keys.
37+
"""
38+
return int(sorted([get_subkey(ii,0) for ii in keys])[-1].split('-')[1])
39+
40+
41+
def find_slice_ranges(
42+
keys : List[str],
43+
sliced_subkey : str,
44+
):
45+
"""
46+
find range of sliced OPs that matches the pattern 'iter-[0-9]*--{sliced_subkey}-[0-9]*'
47+
"""
48+
found_range = []
49+
tmp_range = []
50+
status = 'not-found'
51+
for idx,ii in enumerate(keys):
52+
if status == 'not-found':
53+
if re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii):
54+
status = 'found'
55+
tmp_range.append(idx)
56+
elif status == 'found':
57+
if not re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii):
58+
status = 'not-found'
59+
tmp_range.append(idx)
60+
found_range.append(tmp_range)
61+
tmp_range = []
62+
else :
63+
raise RuntimeError(f'unknown status {status}, terrible error')
64+
return found_range
65+
66+
67+
def _sort_slice_ops(keys, sliced_subkey):
68+
found_range = find_slice_ranges(keys, sliced_subkey)
69+
for ii in found_range:
70+
keys[ii[0]:ii[1]] = sorted(keys[ii[0]:ii[1]])
71+
return keys
72+
73+
74+
def sort_slice_ops(
75+
keys : List[str],
76+
sliced_subkey : List[str],
77+
):
78+
"""
79+
sort the keys of the sliced ops. the keys of the sliced ops contains sliced_subkey
80+
"""
81+
if isinstance(sliced_subkey, str) :
82+
sliced_subkey = [sliced_subkey]
83+
for ii in sliced_subkey:
84+
keys = _sort_slice_ops(keys, ii)
85+
return keys
86+
87+
88+
def print_keys_in_nice_format(
89+
keys : List[str],
90+
sliced_subkey : List[str],
91+
idx_fmt_len : int = 8,
92+
):
93+
keys = sort_slice_ops(keys, sliced_subkey)
94+
slice_range = []
95+
for ii in sliced_subkey:
96+
found_range = find_slice_ranges(keys, ii)
97+
slice_range += found_range
98+
slice_0 = [ii[0] for ii in slice_range]
99+
slice_1 = [ii[1] for ii in slice_range]
100+
101+
normal_fmt = f'%{idx_fmt_len*2+4}d'
102+
range_fmt = f'%d -> %d'
103+
range_s_fmt = f'%{idx_fmt_len*2+4}s'
104+
105+
idx = 0
106+
ret = []
107+
while(True):
108+
if idx >= len(keys):
109+
break
110+
try:
111+
idx_in_slice = slice_0.index(idx)
112+
range_0 = slice_0[idx_in_slice]
113+
range_1 = slice_1[idx_in_slice] - 1
114+
idx = range_1
115+
range_str = range_fmt % (range_0, range_1)
116+
ret.append((range_s_fmt + ' : ' + '%s -> %s') % (
117+
range_str, keys[range_0], keys[range_1]))
118+
except ValueError:
119+
ret.append((normal_fmt + ' : ' + '%s') % (
120+
idx, keys[idx]))
121+
idx += 1
122+
return '\n'.join(ret + [''])
123+
124+

tests/utils/test_dflow_query.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import os, textwrap
2+
import numpy as np
3+
import unittest
4+
5+
from typing import Set, List
6+
from pathlib import Path
7+
try:
8+
from exploration.context import dpgen2
9+
except ModuleNotFoundError:
10+
# case of upload everything to argo, no context needed
11+
pass
12+
from dflow.python import (
13+
FatalError,
14+
)
15+
from dpgen2.exploration.scheduler import (
16+
ConvergenceCheckStageScheduler,
17+
ExplorationScheduler,
18+
)
19+
from dpgen2.exploration.report import ExplorationReport
20+
from dpgen2.exploration.task import ExplorationTaskGroup, ExplorationStage
21+
from dpgen2.exploration.selector import TrustLevel, ConfSelectorLammpsFrames
22+
from mocked_ops import (
23+
MockedExplorationReport,
24+
MockedExplorationTaskGroup,
25+
MockedExplorationTaskGroup1,
26+
MockedStage,
27+
MockedStage1,
28+
)
29+
from dpgen2.utils.dflow_query import (
30+
get_last_scheduler,
31+
get_subkey,
32+
get_last_iteration,
33+
find_slice_ranges,
34+
sort_slice_ops,
35+
print_keys_in_nice_format,
36+
)
37+
38+
dpgen_keys = [
39+
'init--scheduler',
40+
'init--id',
41+
'iter-000000--prep-train',
42+
'iter-000000--run-train-0002',
43+
'iter-000000--run-train-0000',
44+
'iter-000000--run-train-0001',
45+
'iter-000000--prep-run-train',
46+
'iter-000000--prep-lmp',
47+
'iter-000000--run-lmp-000001',
48+
'iter-000000--run-lmp-000004',
49+
'iter-000000--run-lmp-000005',
50+
'iter-000000--run-lmp-000002',
51+
'iter-000000--run-lmp-000003',
52+
'iter-000000--run-lmp-000000',
53+
'iter-000000--prep-run-lmp',
54+
'iter-000000--select-confs',
55+
'iter-000000--prep-fp',
56+
'iter-000000--run-fp-000001',
57+
'iter-000000--run-fp-000000',
58+
'iter-000000--prep-run-fp',
59+
'iter-000000--collect-data',
60+
'iter-000000--block',
61+
'iter-000000--scheduler',
62+
'iter-000000--id',
63+
'iter-000001--prep-train',
64+
'iter-000001--run-train-0000',
65+
'iter-000001--run-train-0001',
66+
'iter-000001--run-train-0002',
67+
'iter-000001--prep-run-train',
68+
'iter-000001--prep-lmp',
69+
'iter-000001--run-lmp-000003',
70+
'iter-000001--run-lmp-000000',
71+
'iter-000001--run-lmp-000001',
72+
'iter-000001--run-lmp-000005',
73+
'iter-000001--run-lmp-000002',
74+
'iter-000001--run-lmp-000004',
75+
'iter-000001--prep-run-lmp',
76+
'iter-000001--select-confs',
77+
'iter-000001--prep-fp',
78+
'iter-000001--run-fp-000001',
79+
'iter-000001--run-fp-000000',
80+
'iter-000001--prep-run-fp',
81+
'iter-000001--collect-data',
82+
'iter-000001--block',
83+
'iter-000001--scheduler',
84+
'iter-000001--id',
85+
'iter-000000--loop'
86+
]
87+
88+
class MockedTar:
89+
value = 10
90+
91+
class MockedFoo:
92+
parameters = {
93+
'exploration_scheduler' : MockedTar()
94+
}
95+
96+
class MockedBar:
97+
outputs = MockedFoo
98+
99+
class MockedWF:
100+
def query_step(self,key=None):
101+
assert(key == 'iter1--scheduler')
102+
return MockedBar()
103+
104+
class TestDflowQuery(unittest.TestCase):
105+
def test_get_subkey(self):
106+
self.assertEqual(get_subkey('aa--bb--cc', 0), 'aa')
107+
self.assertEqual(get_subkey('aa--bb--cc', 1), 'bb')
108+
self.assertEqual(get_subkey('aa--bb--cc', 2), 'cc')
109+
self.assertEqual(get_subkey('aa--bb--cc'), 'cc')
110+
self.assertEqual(get_subkey('aa'), 'aa')
111+
self.assertEqual(get_subkey('aa---bb'), '-bb')
112+
self.assertEqual(get_subkey('aa----bb', 1), '')
113+
self.assertEqual(get_subkey(''), '')
114+
115+
def test_get_last_scheduler(self):
116+
value = get_last_scheduler(
117+
MockedWF(),
118+
['iter1--scheduler', 'foo', 'bar', 'iter0--scheduler'],
119+
)
120+
self.assertEqual(value, 10)
121+
122+
def test_get_last_iteration(self):
123+
last = get_last_iteration(dpgen_keys)
124+
self.assertEqual(last, 1)
125+
126+
def test_sort_slice_ops(self):
127+
idxes = find_slice_ranges(dpgen_keys, 'run-lmp')
128+
self.assertEqual(idxes, [[8, 14], [30, 36]])
129+
130+
def test_sort_slice_ops(self):
131+
expected_output = [
132+
'init--scheduler',
133+
'init--id',
134+
'iter-000000--prep-train',
135+
'iter-000000--run-train-0000',
136+
'iter-000000--run-train-0001',
137+
'iter-000000--run-train-0002',
138+
'iter-000000--prep-run-train',
139+
'iter-000000--prep-lmp',
140+
'iter-000000--run-lmp-000000',
141+
'iter-000000--run-lmp-000001',
142+
'iter-000000--run-lmp-000002',
143+
'iter-000000--run-lmp-000003',
144+
'iter-000000--run-lmp-000004',
145+
'iter-000000--run-lmp-000005',
146+
'iter-000000--prep-run-lmp',
147+
'iter-000000--select-confs',
148+
'iter-000000--prep-fp',
149+
'iter-000000--run-fp-000000',
150+
'iter-000000--run-fp-000001',
151+
'iter-000000--prep-run-fp',
152+
'iter-000000--collect-data',
153+
'iter-000000--block',
154+
'iter-000000--scheduler',
155+
'iter-000000--id',
156+
'iter-000001--prep-train',
157+
'iter-000001--run-train-0000',
158+
'iter-000001--run-train-0001',
159+
'iter-000001--run-train-0002',
160+
'iter-000001--prep-run-train',
161+
]
162+
ncheck = len(expected_output)
163+
self.assertEqual(
164+
sort_slice_ops(dpgen_keys[:ncheck], ['run-train', 'run-lmp', 'run-fp']),
165+
expected_output,
166+
)
167+
168+
def test_print_keys(self):
169+
expected_output = [
170+
' 0 : init--scheduler',
171+
' 1 : init--id',
172+
' 2 : iter-000000--prep-train',
173+
' 3 -> 5 : iter-000000--run-train-0000 -> iter-000000--run-train-0002',
174+
' 6 : iter-000000--prep-run-train',
175+
]
176+
expected_output = '\n'.join(expected_output + [''])
177+
178+
ret = print_keys_in_nice_format(
179+
dpgen_keys[:7],
180+
['run-train', 'run-lmp', 'run-fp'],
181+
idx_fmt_len = 8,
182+
)
183+
184+
self.assertEqual(expected_output, ret)
185+
186+

0 commit comments

Comments
 (0)