7070 load_object_from_file ,
7171 normalize_alloy_conf_dict ,
7272 generate_alloy_conf_file_content ,
73- dflow_config ,
7473 sort_slice_ops ,
7574 print_keys_in_nice_format ,
75+ workflow_config_from_dict ,
7676)
7777from dpgen2 .utils .step_config import normalize as normalize_step_dict
78+ from dpgen2 .entrypoint .submit_args import normalize as normalize_submit_args
7879from typing import (
79- Union , List ,
80+ Union , List , Dict , Optional ,
8081)
8182
8283default_config = normalize_step_dict (
@@ -196,24 +197,30 @@ def make_conf_list(
196197
197198def make_naive_exploration_scheduler (
198199 config ,
200+ old_style = False ,
199201):
200202 # use npt task group
201- model_devi_jobs = config ['model_devi_jobs' ]
202- sys_configs = config ['sys_configs' ]
203- mass_map = config ['mass_map' ]
204- type_map = config ['type_map' ]
205- numb_models = config ['numb_models' ]
206- fp_task_max = config ['fp_task_max' ]
207- conv_accuracy = config ['conv_accuracy' ]
208- max_numb_iter = config ['max_numb_iter' ]
209- fatal_at_max = config .get ('fatal_at_max' , True )
203+ model_devi_jobs = config ['model_devi_jobs' ] if old_style else config ['explore' ]['stages' ]
204+ sys_configs = config ['sys_configs' ] if old_style else config ['explore' ]['configurations' ]
205+ sys_prefix = config .get ('sys_prefix' )
206+ if sys_prefix is not None :
207+ for ii in range (len (sys_configs )):
208+ if isinstance (sys_configs [ii ], list ):
209+ sys_configs [ii ] = [os .path .join (sys_prefix , jj ) for jj in sys_prefix [ii ]]
210+ mass_map = config ['mass_map' ] if old_style else config ['inputs' ]['mass_map' ]
211+ type_map = config ['type_map' ] if old_style else config ['inputs' ]['type_map' ]
212+ numb_models = config ['numb_models' ] if old_style else config ['train' ]['numb_models' ]
213+ fp_task_max = config ['fp_task_max' ] if old_style else config ['fp' ]['task_max' ]
214+ conv_accuracy = config ['conv_accuracy' ] if old_style else config ['explore' ]['conv_accuracy' ]
215+ max_numb_iter = config ['max_numb_iter' ] if old_style else config ['explore' ]['max_numb_iter' ]
216+ fatal_at_max = config .get ('fatal_at_max' , True ) if old_style else config ['explore' ]['fatal_at_max' ]
210217 scheduler = ExplorationScheduler ()
211218
212219 for job in model_devi_jobs :
213220 # task group
214221 tgroup = NPTTaskGroup ()
215222 ## ignore the expansion of sys_idx
216- # get all file names of md initial configuraitons
223+ # get all file names of md initial configurations
217224 sys_idx = job ['sys_idx' ]
218225 conf_list = []
219226 for ii in sys_idx :
@@ -244,8 +251,10 @@ def make_naive_exploration_scheduler(
244251 stage .add_task_group (tasks )
245252 # trust level
246253 trust_level = TrustLevel (
247- config ['model_devi_f_trust_lo' ],
248- config ['model_devi_f_trust_hi' ],
254+ config ['model_devi_f_trust_lo' ] if old_style else config ['explore' ]['f_trust_lo' ],
255+ config ['model_devi_f_trust_hi' ] if old_style else config ['explore' ]['f_trust_hi' ],
256+ level_v_lo = config .get ('model_devi_v_trust_lo' ) if old_style else config ['explore' ]['v_trust_lo' ],
257+ level_v_hi = config .get ('model_devi_v_trust_hi' ) if old_style else config ['explore' ]['v_trust_hi' ],
249258 )
250259 # selector
251260 selector = ConfSelectorLammpsFrames (
@@ -285,22 +294,23 @@ def get_kspacing_kgamma_from_incar(
285294
286295
287296def workflow_concurrent_learning (
288- config ,
297+ config : Dict ,
298+ old_style : Optional [bool ] = False ,
289299):
290- default_config = normalize_step_dict (config .get ('default_config' , {}))
291-
292- train_style = config [ 'train_style' ]
293- explore_style = config [ 'explore_style' ]
294- fp_style = config [ 'fp_style' ]
295- prep_train_config = normalize_step_dict (config .get ('prep_train_config' , default_config ))
296- run_train_config = normalize_step_dict (config .get ('run_train_config' , default_config ))
297- prep_explore_config = normalize_step_dict (config .get ('prep_explore_config' , default_config ))
298- run_explore_config = normalize_step_dict (config .get ('run_explore_config' , default_config ))
299- prep_fp_config = normalize_step_dict (config .get ('prep_fp_config' , default_config ))
300- run_fp_config = normalize_step_dict (config .get ('run_fp_config' , default_config ))
301- select_confs_config = normalize_step_dict (config .get ('select_confs_config' , default_config ))
302- collect_data_config = normalize_step_dict (config .get ('collect_data_config' , default_config ))
303- cl_step_config = normalize_step_dict (config .get ('cl_step_config' , default_config ))
300+ default_config = normalize_step_dict (config .get ('default_config' , {})) if old_style else config [ 'default_step_config' ]
301+
302+ train_style = config . get ( 'train_style' , 'dp' ) if old_style else config [ 'train' ][ 'type ' ]
303+ explore_style = config . get ( 'explore_style' , 'lmp' ) if old_style else config [ 'explore' ][ 'type ' ]
304+ fp_style = config . get ( 'fp_style' , 'vasp' ) if old_style else config [ 'fp' ][ 'type ' ]
305+ prep_train_config = normalize_step_dict (config .get ('prep_train_config' , default_config )) if old_style else config [ 'step_configs' ][ 'prep_train_config' ]
306+ run_train_config = normalize_step_dict (config .get ('run_train_config' , default_config )) if old_style else config [ 'step_configs' ][ 'run_train_config' ]
307+ prep_explore_config = normalize_step_dict (config .get ('prep_explore_config' , default_config )) if old_style else config [ 'step_configs' ][ 'prep_explore_config' ]
308+ run_explore_config = normalize_step_dict (config .get ('run_explore_config' , default_config )) if old_style else config [ 'step_configs' ][ 'run_explore_config' ]
309+ prep_fp_config = normalize_step_dict (config .get ('prep_fp_config' , default_config )) if old_style else config [ 'step_configs' ][ 'prep_fp_config' ]
310+ run_fp_config = normalize_step_dict (config .get ('run_fp_config' , default_config )) if old_style else config [ 'step_configs' ][ 'run_fp_config' ]
311+ select_confs_config = normalize_step_dict (config .get ('select_confs_config' , default_config )) if old_style else config [ 'step_configs' ][ 'select_confs_config' ]
312+ collect_data_config = normalize_step_dict (config .get ('collect_data_config' , default_config )) if old_style else config [ 'step_configs' ][ 'collect_data_config' ]
313+ cl_step_config = normalize_step_dict (config .get ('cl_step_config' , default_config )) if old_style else config [ 'step_configs' ][ 'cl_step_config' ]
304314 upload_python_package = config .get ('upload_python_package' , None )
305315 init_models_paths = config .get ('training_iter0_model_path' )
306316
@@ -319,24 +329,27 @@ def workflow_concurrent_learning(
319329 cl_step_config = cl_step_config ,
320330 upload_python_package = upload_python_package ,
321331 )
322- scheduler = make_naive_exploration_scheduler (config )
323-
324- type_map = config ['type_map' ]
325- numb_models = config ['numb_models' ]
326- template_script = config ['default_training_param' ]
327- train_config = {}
328- lmp_config = config .get ('lmp_config' , {})
329- fp_config = config .get ('fp_config' , {})
330- kspacing , kgamma = get_kspacing_kgamma_from_incar (config ['fp_incar' ])
331- fp_pp_files = config ['fp_pp_files' ]
332- incar_file = config ['fp_incar' ]
332+ scheduler = make_naive_exploration_scheduler (config , old_style = old_style )
333+
334+ type_map = config ['type_map' ] if old_style else config [ 'inputs' ][ 'type_map' ]
335+ numb_models = config ['numb_models' ] if old_style else config [ 'train' ][ 'numb_models' ]
336+ template_script = config ['default_training_param' ] if old_style else config [ 'train' ][ 'template_script' ]
337+ train_config = {} if old_style else config [ 'train' ][ 'config' ]
338+ lmp_config = config .get ('lmp_config' , {}) if old_style else config [ 'explore' ][ 'config' ]
339+ fp_config = config .get ('fp_config' , {}) if old_style else config [ 'fp' ][ 'config' ]
340+ kspacing , kgamma = get_kspacing_kgamma_from_incar (config ['fp_incar' ] if old_style else config [ 'fp' ][ 'incar' ] )
341+ fp_pp_files = config ['fp_pp_files' ] if old_style else config [ 'fp' ][ 'pp_files' ]
342+ incar_file = config ['fp_incar' ] if old_style else config [ 'fp' ][ 'incar' ]
333343 fp_inputs = VaspInputs (
334344 kspacing = kspacing ,
335345 kgamma = kgamma ,
336346 incar_template_name = incar_file ,
337347 potcar_names = fp_pp_files ,
338348 )
339- init_data = config ['init_data_sys' ]
349+ init_data_prefix = config .get ('init_data_prefix' ) if old_style else config ['inputs' ]['init_data_prefix' ]
350+ init_data = config ['init_data_sys' ] if old_style else config ['inputs' ]['init_data_sys' ]
351+ if init_data_prefix is not None :
352+ init_data = [os .path .join (init_data_prefix , ii ) for ii in init_data_sys ]
340353 if isinstance (init_data ,str ):
341354 init_data = expand_sys_str (init_data )
342355 init_data = upload_artifact (init_data )
@@ -372,8 +385,7 @@ def workflow_concurrent_learning(
372385def wf_global_workflow (
373386 wf_config ,
374387):
375- dflow_config_data = wf_config .get ('dflow_config' , None )
376- dflow_config (dflow_config_data )
388+ workflow_config_from_dict (wf_config )
377389
378390 # lebesgue context
379391 from dflow .plugins .lebesgue import LebesgueContext
@@ -391,10 +403,13 @@ def wf_global_workflow(
391403def submit_concurrent_learning (
392404 wf_config ,
393405 reuse_step = None ,
406+ old_style = False ,
394407):
408+ wf_config = normalize_submit_args (wf_config )
409+
395410 context = wf_global_workflow (wf_config )
396411
397- dpgen_step = workflow_concurrent_learning (wf_config )
412+ dpgen_step = workflow_concurrent_learning (wf_config , old_style = old_style )
398413
399414 wf = Workflow (name = "dpgen" , context = context )
400415 wf .add (dpgen_step )
@@ -450,7 +465,10 @@ def resubmit_concurrent_learning(
450465 wfid ,
451466 list_steps = False ,
452467 reuse = None ,
468+ old_style = False ,
453469):
470+ wf_config = normalize_submit_args (wf_config )
471+
454472 context = wf_global_workflow (wf_config )
455473
456474 old_wf = Workflow (id = wfid )
@@ -474,6 +492,7 @@ def resubmit_concurrent_learning(
474492 wf = submit_concurrent_learning (
475493 wf_config ,
476494 reuse_step = reuse_step ,
495+ old_style = old_style ,
477496 )
478497
479498 return wf
0 commit comments