6969 dump_object_to_file ,
7070 load_object_from_file ,
7171)
72+ from dpgen2 .utils .step_config import normalize as normalize_step_dict
73+ default_config = normalize_step_dict (
74+ {
75+ "template_config" : {
76+ "image" : default_image ,
77+ }
78+ }
79+ )
7280
7381def make_concurrent_learning_op (
7482 train_style : str = 'dp' ,
7583 explore_style : str = 'lmp' ,
7684 fp_style : str = 'vasp' ,
77- prep_train_image : str = default_image ,
78- run_train_image : str = default_image ,
79- prep_explore_image : str = default_image ,
80- run_explore_image : str = default_image ,
81- prep_fp_image : str = default_image ,
82- run_fp_image : str = default_image ,
83- select_confs_image : str = default_image ,
84- collect_data_image : str = default_image ,
85+ prep_train_config : str = default_config ,
86+ run_train_config : str = default_config ,
87+ prep_explore_config : str = default_config ,
88+ run_explore_config : str = default_config ,
89+ prep_fp_config : str = default_config ,
90+ run_fp_config : str = default_config ,
91+ select_confs_config : str = default_config ,
92+ collect_data_config : str = default_config ,
8593 upload_python_package : bool = None ,
8694):
8795 if train_style == 'dp' :
8896 prep_run_train_op = PrepRunDPTrain (
8997 "prep-run-dp-train" ,
9098 PrepDPTrain ,
9199 RunDPTrain ,
92- prep_image = prep_train_image ,
93- run_image = run_train_image ,
100+ prep_config = prep_train_config ,
101+ run_config = run_train_config ,
94102 upload_python_package = upload_python_package ,
95103 )
96104 else :
@@ -100,8 +108,8 @@ def make_concurrent_learning_op (
100108 "prep-run-lmp" ,
101109 PrepLmp ,
102110 RunLmp ,
103- prep_image = prep_explore_image ,
104- run_image = run_explore_image ,
111+ prep_config = prep_explore_config ,
112+ run_config = run_explore_config ,
105113 upload_python_package = upload_python_package ,
106114 )
107115 else :
@@ -111,8 +119,8 @@ def make_concurrent_learning_op (
111119 "prep-run-vasp" ,
112120 PrepVasp ,
113121 RunVasp ,
114- prep_image = prep_fp_image ,
115- run_image = run_fp_image ,
122+ prep_config = prep_fp_config ,
123+ run_config = run_fp_config ,
116124 upload_python_package = upload_python_package ,
117125 )
118126 else :
@@ -126,16 +134,16 @@ def make_concurrent_learning_op (
126134 SelectConfs ,
127135 prep_run_fp_op ,
128136 CollectData ,
129- select_confs_image = select_confs_image ,
130- collect_data_image = collect_data_image ,
137+ select_confs_config = select_confs_config ,
138+ collect_data_config = collect_data_config ,
131139 upload_python_package = upload_python_package ,
132140 )
133141 # dpgen
134142 dpgen_op = ConcurrentLearning (
135143 "concurrent-learning" ,
136144 block_cl_op ,
137145 upload_python_package = upload_python_package ,
138- image = default_image ,
146+ step_config = default_config ,
139147 )
140148
141149 return dpgen_op
@@ -242,25 +250,25 @@ def workflow_concurrent_learning(
242250 train_style = config ['train_style' ]
243251 explore_style = config ['explore_style' ]
244252 fp_style = config ['fp_style' ]
245- prep_train_image = config [ 'prep_train_image' ]
246- run_train_image = config [ 'run_train_image' ]
247- prep_explore_image = config [ 'prep_explore_image' ]
248- run_explore_image = config [ 'run_explore_image' ]
249- prep_fp_image = config [ 'prep_fp_image' ]
250- run_fp_image = config [ 'run_fp_image' ]
253+ prep_train_config = normalize_step_dict ( config . get ( 'prep_train_config' , {}))
254+ run_train_config = normalize_step_dict ( config . get ( 'run_train_config' , {}))
255+ prep_explore_config = normalize_step_dict ( config . get ( 'prep_explore_config' , {}))
256+ run_explore_config = normalize_step_dict ( config . get ( 'run_explore_config' , {}))
257+ prep_fp_config = normalize_step_dict ( config . get ( 'prep_fp_config' , {}))
258+ run_fp_config = normalize_step_dict ( config . get ( 'run_fp_config' , {}))
251259 upload_python_package = config .get ('upload_python_package' , None )
252260 init_models_paths = config .get ('training_iter0_model_path' )
253261
254262 concurrent_learning_op = make_concurrent_learning_op (
255263 train_style ,
256264 explore_style ,
257265 fp_style ,
258- prep_train_image = prep_train_image ,
259- run_train_image = run_train_image ,
260- prep_explore_image = prep_explore_image ,
261- run_explore_image = run_explore_image ,
262- prep_fp_image = prep_fp_image ,
263- run_fp_image = run_fp_image ,
266+ prep_train_config = prep_train_config ,
267+ run_train_config = run_train_config ,
268+ prep_explore_config = prep_explore_config ,
269+ run_explore_config = run_explore_config ,
270+ prep_fp_config = prep_fp_config ,
271+ run_fp_config = run_fp_config ,
264272 upload_python_package = upload_python_package ,
265273 )
266274 scheduler = make_naive_exploration_scheduler (config )
@@ -315,15 +323,39 @@ def workflow_concurrent_learning(
315323 "iter_data" : iter_data ,
316324 },
317325 )
318-
319- wf = Workflow (name = "dpgen" , host = default_host )
320- wf .add (dpgen_step )
321-
322- return wf
326+ return dpgen_step
323327
324328def submit_concurrent_learning (
325- config ,
329+ wf_config ,
326330):
327- wf = workflow_concurrent_learning (config )
331+ # set global config
332+ from dflow import config , s3_config
333+ dflow_config = wf_config .get ('dflow_config' , None )
334+ if dflow_config :
335+ config ["host" ] = dflow_config .get ('host' , None )
336+ s3_config ["endpoint" ] = dflow_config .get ('s3_endpoint' , None )
337+ config ["k8s_api_server" ] = dflow_config .get ('k8s_api_server' , None )
338+ config ["token" ] = dflow_config .get ('token' , None )
339+
340+ # lebesque context
341+ from dflow .plugins .lebesgue import LebesgueContext
342+ lb_context_config = wf_config .get ("lebesque_context_config" , None )
343+ if lb_context_config :
344+ lebesgue_context = LebesgueContext (
345+ ** lb_context_config ,
346+ )
347+ else :
348+ lebesgue_context = None
349+
350+ # print('config:', config)
351+ # print('s3_config:',s3_config)
352+ # print('lebsque context:', lb_context_config)
353+
354+ dpgen_step = workflow_concurrent_learning (wf_config )
355+
356+ wf = Workflow (name = "dpgen" , context = lebesgue_context )
357+ wf .add (dpgen_step )
358+
328359 wf .submit ()
360+
329361 return wf
0 commit comments