diff --git a/dpgen2/entrypoint/args.py b/dpgen2/entrypoint/args.py index ea08afd9..df71d761 100644 --- a/dpgen2/entrypoint/args.py +++ b/dpgen2/entrypoint/args.py @@ -19,11 +19,13 @@ def dp_train_args(): doc_numb_models = "Number of models trained for evaluating the model deviation" doc_config = "Configuration of training" doc_template_script = "File names of the template training script. It can be a `List[Dict]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `Dict`, the models share the same template training script. " + doc_init_models_paths = "the paths to initial models" return [ Argument("config", dict, RunDPTrain.training_args(), optional=True, default=RunDPTrain.normalize_config({}), doc=doc_numb_models), Argument("numb_models", int, optional=True, default=4, doc=doc_numb_models), Argument("template_script", [list,str], optional=False, doc=doc_template_script), + Argument("init_models_paths", list, optional=True, doc=doc_init_models_paths, alias=['training_iter0_model_path']), ] def variant_train(): diff --git a/dpgen2/entrypoint/submit.py b/dpgen2/entrypoint/submit.py index 65483c6d..0f25334a 100644 --- a/dpgen2/entrypoint/submit.py +++ b/dpgen2/entrypoint/submit.py @@ -303,7 +303,7 @@ def workflow_concurrent_learning( collect_data_config = normalize_step_dict(config.get('collect_data_config', default_config)) if old_style else config['step_configs']['collect_data_config'] cl_step_config = normalize_step_dict(config.get('cl_step_config', default_config)) if old_style else config['step_configs']['cl_step_config'] upload_python_packages = config.get('upload_python_packages', None) - init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('training_iter0_model_path', None) + init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('init_models_paths', None) if upload_python_packages is not None and isinstance(upload_python_packages, str): upload_python_packages = [upload_python_packages] if upload_python_packages is not None: diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index 27512d02..aa389574 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -1,4 +1,4 @@ -import os, json, dpdata, glob +import os, json, dpdata, glob, shutil from pathlib import Path from dpgen2.utils.run_command import run_command from dpgen2.utils.chdir import set_directory @@ -125,6 +125,14 @@ def execute( train_dict = RunDPTrain.write_other_to_input_script( train_dict, config, do_init_model, major_version) + if RunDPTrain.skip_training(work_dir, train_dict, init_model, iter_data): + return OPIO({ + "script" : work_dir / train_script_name, + "model" : work_dir / "frozen_model.pb", + "lcurve" : work_dir / "lcurve.out", + "log" : work_dir / "train.log", + }) + with set_directory(work_dir): # open log fplog = open('train.log', 'w') @@ -224,6 +232,30 @@ def write_other_to_input_script( raise RuntimeError('unsupported DeePMD-kit major version', major_version) return odict + @staticmethod + def skip_training( + work_dir, + train_dict, + init_model, + iter_data, + ): + # we have init model and no iter data, skip training + if (init_model is not None) and \ + (iter_data is None or len(iter_data) == 0) : + with set_directory(work_dir): + with open(train_script_name, 'w') as fp: + json.dump(train_dict, fp, indent=4) + Path('train.log').write_text( + f'We have init model {init_model} and ' + f'no iteration training data. ' + f'The training is skipped.\n' + ) + Path('lcurve.out').touch() + shutil.copy(init_model, 'frozen_model.pb') + return True + else: + return False + @staticmethod def decide_init_model( config, diff --git a/tests/op/test_run_dp_train.py b/tests/op/test_run_dp_train.py index 49470ac8..ea053d39 100644 --- a/tests/op/test_run_dp_train.py +++ b/tests/op/test_run_dp_train.py @@ -577,7 +577,7 @@ def setUp(self): def tearDown(self): - for ii in ['init', self.task_path, self.task_name, 'foo' ]: + for ii in ['init', self.task_path, self.task_name, 'foo']: if Path(ii).exists(): shutil.rmtree(str(ii)) @@ -592,10 +592,7 @@ def test_update_input_dict_v2_empty_list(self): self.assertDictEqual(odict, self.expected_odict_v2) - @patch('dpgen2.op.run_dp_train.run_command') - def test_exec_v2_empty_list(self, mocked_run): - mocked_run.side_effect = [ (0, 'foo\n', ''), (0, 'bar\n', '') ] - + def test_exec_v2_empty_list(self): config = self.config.copy() config['init_model_policy'] = 'no' @@ -606,6 +603,9 @@ def test_exec_v2_empty_list(self, mocked_run): task_name = self.task_name work_dir = Path(task_name) + self.init_model = self.init_model.absolute() + self.init_model.write_text('this is init model') + ptrain = RunDPTrain() out = ptrain.execute( OPIO({ @@ -621,26 +621,20 @@ def test_exec_v2_empty_list(self, mocked_run): self.assertEqual(out['model'], work_dir/'frozen_model.pb') self.assertEqual(out['lcurve'], work_dir/'lcurve.out') self.assertEqual(out['log'], work_dir/'train.log') - - calls = [ - call(['dp', 'train', train_script_name]), - call(['dp', 'freeze', '-o', 'frozen_model.pb']), - ] - mocked_run.assert_has_calls(calls) - + self.assertTrue(work_dir.is_dir()) self.assertTrue(out['log'].is_file()) self.assertEqual(out['log'].read_text(), - '#=================== train std out ===================\n' - 'foo\n' - '#=================== train std err ===================\n' - '#=================== freeze std out ===================\n' - 'bar\n' - '#=================== freeze std err ===================\n' + f'We have init model {self.init_model} and ' + f'no iteration training data. ' + f'The training is skipped.\n' ) with open(out['script']) as fp: jdata = json.load(fp) self.assertDictEqual(jdata, self.expected_odict_v2) + self.assertEqual(Path(out['model']).read_text(), "this is init model") + + os.remove(self.init_model) @patch('dpgen2.op.run_dp_train.run_command')