Skip to content

Commit 0a89b72

Browse files
zjgemipre-commit-ci[bot]wanghan-iapcmcoderabbitai[bot]
authored
Split data of last iteration to training data and validation data (#293)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added an option to automatically split the last iteration's training data into separate training and validation sets using a configurable ratio. - **Bug Fixes** - Fixed an issue ensuring temperature settings are correctly applied when preparing VASP input files. - Improved validation data loading logic for multitask and single-task workflows. - **Tests** - Added tests to verify the correct splitting of datasets into training and validation subsets. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: zjgemi <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 62e1661 commit 0a89b72

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

dpgen2/entrypoint/submit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def workflow_concurrent_learning(
533533
else:
534534
if config["inputs"]["valid_data_uri"] is not None:
535535
valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"])
536-
elif config["inputs"]["valid_data_prefix"] is not None:
536+
elif config["inputs"]["valid_data_sys"] is not None:
537537
valid_data_prefix = config["inputs"]["valid_data_prefix"]
538538
valid_data = config["inputs"]["valid_data_sys"]
539539
valid_data = get_systems_from_data(valid_data, valid_data_prefix)

dpgen2/fp/vasp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def prep_task(
141141

142142
conf_frame.to("vasp/poscar", vasp_conf_name)
143143
incar = vasp_inputs.incar_template
144-
self.set_ele_temp(conf_frame, incar)
144+
incar = self.set_ele_temp(conf_frame, incar)
145145

146146
Path(vasp_input_name).write_text(incar)
147147
# fix the case when some element have 0 atom, e.g. H0O2

dpgen2/op/run_dp_train.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import glob
22
import json
33
import logging
4+
import math
45
import os
6+
import random
57
import shutil
68
from pathlib import (
79
Path,
@@ -197,6 +199,12 @@ def execute(
197199
valid_data = ip["valid_data"]
198200
iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1])
199201
iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
202+
if config["split_last_iter_valid_ratio"] is not None:
203+
train_systems, valid_systems = split_valid(
204+
iter_data_new_exp, config["split_last_iter_valid_ratio"]
205+
)
206+
iter_data_new_exp = train_systems
207+
valid_data = append_valid_data(config, valid_data, valid_systems)
200208
iter_data_exp = iter_data_old_exp + iter_data_new_exp
201209
work_dir = Path(task_name)
202210
init_model_with_finetune = config["init_model_with_finetune"]
@@ -517,6 +525,9 @@ def training_args():
517525
doc_head = "Head to use in the multitask training"
518526
doc_init_model_with_finetune = "Use finetune for init model"
519527
doc_train_args = "Extra arguments for dp train"
528+
doc_split_last_iter_valid_ratio = (
529+
"Ratio of valid data if split data of last iter"
530+
)
520531
return [
521532
Argument(
522533
"command",
@@ -618,6 +629,13 @@ def training_args():
618629
default="",
619630
doc=doc_train_args,
620631
),
632+
Argument(
633+
"split_last_iter_valid_ratio",
634+
float,
635+
optional=True,
636+
default=None,
637+
doc=doc_split_last_iter_valid_ratio,
638+
),
621639
]
622640

623641
@staticmethod
@@ -672,4 +690,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys):
672690
return all_sys_dirs
673691

674692

693+
def split_valid(systems: List[str], valid_ratio: float):
694+
train_systems = []
695+
valid_systems = []
696+
for system in systems:
697+
d = dpdata.MultiSystems()
698+
mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0
699+
if mixed_type:
700+
d.load_systems_from_file(system, fmt="deepmd/npy/mixed")
701+
else:
702+
k = dpdata.LabeledSystem(system, fmt="deepmd/npy")
703+
d.append(k)
704+
705+
train_multi_systems = dpdata.MultiSystems()
706+
valid_multi_systems = dpdata.MultiSystems()
707+
for s in d:
708+
nvalid = math.floor(len(s) * valid_ratio)
709+
if random.random() < len(s) * valid_ratio - nvalid:
710+
nvalid += 1
711+
valid_indices = random.sample(range(len(s)), nvalid)
712+
train_indices = list(set(range(len(s))).difference(valid_indices))
713+
if len(valid_indices) > 0:
714+
valid_multi_systems.append(s.sub_system(valid_indices))
715+
if len(train_indices) > 0:
716+
train_multi_systems.append(s.sub_system(train_indices))
717+
718+
if len(train_multi_systems) > 0:
719+
target = "train_data/" + system
720+
if mixed_type:
721+
# The multisystem is loaded from one dir, thus we can safely keep one dir
722+
train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) # type: ignore
723+
fs = os.listdir("%s.tmp" % target)
724+
assert len(fs) == 1
725+
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
726+
os.rmdir("%s.tmp" % target)
727+
else:
728+
train_multi_systems[0].to_deepmd_npy(target) # type: ignore
729+
train_systems.append(os.path.abspath(target))
730+
731+
if len(valid_multi_systems) > 0:
732+
target = "valid_data/" + system
733+
if mixed_type:
734+
# The multisystem is loaded from one dir, thus we can safely keep one dir
735+
valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) # type: ignore
736+
fs = os.listdir("%s.tmp" % target)
737+
assert len(fs) == 1
738+
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
739+
os.rmdir("%s.tmp" % target)
740+
else:
741+
valid_multi_systems[0].to_deepmd_npy(target) # type: ignore
742+
valid_systems.append(os.path.abspath(target))
743+
744+
return train_systems, valid_systems
745+
746+
747+
def append_valid_data(config, valid_data, valid_systems):
748+
if not valid_systems:
749+
return valid_data
750+
if config["multitask"]:
751+
head = config["head"]
752+
if not valid_data:
753+
valid_data = {}
754+
if head not in valid_data:
755+
valid_data[head] = []
756+
valid_data[head] += valid_systems
757+
else:
758+
if not valid_data:
759+
valid_data = []
760+
valid_data += valid_systems
761+
return valid_data
762+
763+
675764
config_args = RunDPTrain.training_args

tests/op/test_run_dp_train.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Path,
88
)
99

10+
import dpdata
1011
import numpy as np
1112
from dflow.python import (
1213
OP,
@@ -37,6 +38,7 @@
3738
RunDPTrain,
3839
_get_data_size_of_all_mult_sys,
3940
_make_train_command,
41+
split_valid,
4042
)
4143

4244
# isort: on
@@ -942,3 +944,44 @@ def test_exec_v2_empty_dir(self, mocked_run):
942944
with open(out["script"]) as fp:
943945
jdata = json.load(fp)
944946
self.assertDictEqual(jdata, self.expected_odict_v2)
947+
948+
949+
class TestSplitValid(unittest.TestCase):
950+
def setUp(self):
951+
s = fake_system(10, 1)
952+
s.to_deepmd_npy("fake_data")
953+
ms = fake_multi_sys([10, 20], [1, 2])
954+
ms.to_deepmd_npy_mixed("fake_mixed_data")
955+
956+
def test_split_valid(self):
957+
train_systems, valid_systems = split_valid(["fake_data"], 0.1)
958+
self.assertEqual(len(train_systems), 1)
959+
s = dpdata.LabeledSystem(train_systems[0], fmt="deepmd/npy")
960+
self.assertEqual(len(s), 9)
961+
self.assertEqual(len(valid_systems), 1)
962+
s = dpdata.LabeledSystem(valid_systems[0], fmt="deepmd/npy")
963+
self.assertEqual(len(s), 1)
964+
965+
def test_split_valid_mixed(self):
966+
train_systems, valid_systems = split_valid(
967+
["fake_mixed_data/1", "fake_mixed_data/2"], 0.1
968+
)
969+
self.assertEqual(len(train_systems), 2)
970+
ms = dpdata.MultiSystems()
971+
ms.load_systems_from_file(train_systems[0], fmt="deepmd/npy/mixed")
972+
self.assertEqual(len(ms[0]), 9)
973+
ms = dpdata.MultiSystems()
974+
ms.load_systems_from_file(train_systems[1], fmt="deepmd/npy/mixed")
975+
self.assertEqual(len(ms[0]), 18)
976+
self.assertEqual(len(valid_systems), 2)
977+
ms = dpdata.MultiSystems()
978+
ms.load_systems_from_file(valid_systems[0], fmt="deepmd/npy/mixed")
979+
self.assertEqual(len(ms[0]), 1)
980+
ms = dpdata.MultiSystems()
981+
ms.load_systems_from_file(valid_systems[1], fmt="deepmd/npy/mixed")
982+
self.assertEqual(len(ms[0]), 2)
983+
984+
def tearDown(self):
985+
for f in ["fake_data", "fake_mixed_data", "train_data", "valid_data"]:
986+
if os.path.exists(f):
987+
shutil.rmtree(f)

0 commit comments

Comments
 (0)