|
1 | 1 | import glob
|
2 | 2 | import json
|
3 | 3 | import logging
|
| 4 | +import math |
4 | 5 | import os
|
| 6 | +import random |
5 | 7 | import shutil
|
6 | 8 | from pathlib import (
|
7 | 9 | Path,
|
@@ -197,6 +199,12 @@ def execute(
|
197 | 199 | valid_data = ip["valid_data"]
|
198 | 200 | iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1])
|
199 | 201 | iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
|
| 202 | + if config["split_last_iter_valid_ratio"] is not None: |
| 203 | + train_systems, valid_systems = split_valid( |
| 204 | + iter_data_new_exp, config["split_last_iter_valid_ratio"] |
| 205 | + ) |
| 206 | + iter_data_new_exp = train_systems |
| 207 | + valid_data = append_valid_data(config, valid_data, valid_systems) |
200 | 208 | iter_data_exp = iter_data_old_exp + iter_data_new_exp
|
201 | 209 | work_dir = Path(task_name)
|
202 | 210 | init_model_with_finetune = config["init_model_with_finetune"]
|
@@ -517,6 +525,9 @@ def training_args():
|
517 | 525 | doc_head = "Head to use in the multitask training"
|
518 | 526 | doc_init_model_with_finetune = "Use finetune for init model"
|
519 | 527 | doc_train_args = "Extra arguments for dp train"
|
| 528 | + doc_split_last_iter_valid_ratio = ( |
| 529 | + "Ratio of valid data if split data of last iter" |
| 530 | + ) |
520 | 531 | return [
|
521 | 532 | Argument(
|
522 | 533 | "command",
|
@@ -618,6 +629,13 @@ def training_args():
|
618 | 629 | default="",
|
619 | 630 | doc=doc_train_args,
|
620 | 631 | ),
|
| 632 | + Argument( |
| 633 | + "split_last_iter_valid_ratio", |
| 634 | + float, |
| 635 | + optional=True, |
| 636 | + default=None, |
| 637 | + doc=doc_split_last_iter_valid_ratio, |
| 638 | + ), |
621 | 639 | ]
|
622 | 640 |
|
623 | 641 | @staticmethod
|
@@ -672,4 +690,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys):
|
672 | 690 | return all_sys_dirs
|
673 | 691 |
|
674 | 692 |
|
| 693 | +def split_valid(systems: List[str], valid_ratio: float): |
| 694 | + train_systems = [] |
| 695 | + valid_systems = [] |
| 696 | + for system in systems: |
| 697 | + d = dpdata.MultiSystems() |
| 698 | + mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0 |
| 699 | + if mixed_type: |
| 700 | + d.load_systems_from_file(system, fmt="deepmd/npy/mixed") |
| 701 | + else: |
| 702 | + k = dpdata.LabeledSystem(system, fmt="deepmd/npy") |
| 703 | + d.append(k) |
| 704 | + |
| 705 | + train_multi_systems = dpdata.MultiSystems() |
| 706 | + valid_multi_systems = dpdata.MultiSystems() |
| 707 | + for s in d: |
| 708 | + nvalid = math.floor(len(s) * valid_ratio) |
| 709 | + if random.random() < len(s) * valid_ratio - nvalid: |
| 710 | + nvalid += 1 |
| 711 | + valid_indices = random.sample(range(len(s)), nvalid) |
| 712 | + train_indices = list(set(range(len(s))).difference(valid_indices)) |
| 713 | + if len(valid_indices) > 0: |
| 714 | + valid_multi_systems.append(s.sub_system(valid_indices)) |
| 715 | + if len(train_indices) > 0: |
| 716 | + train_multi_systems.append(s.sub_system(train_indices)) |
| 717 | + |
| 718 | + if len(train_multi_systems) > 0: |
| 719 | + target = "train_data/" + system |
| 720 | + if mixed_type: |
| 721 | + # The multisystem is loaded from one dir, thus we can safely keep one dir |
| 722 | + train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) # type: ignore |
| 723 | + fs = os.listdir("%s.tmp" % target) |
| 724 | + assert len(fs) == 1 |
| 725 | + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) |
| 726 | + os.rmdir("%s.tmp" % target) |
| 727 | + else: |
| 728 | + train_multi_systems[0].to_deepmd_npy(target) # type: ignore |
| 729 | + train_systems.append(os.path.abspath(target)) |
| 730 | + |
| 731 | + if len(valid_multi_systems) > 0: |
| 732 | + target = "valid_data/" + system |
| 733 | + if mixed_type: |
| 734 | + # The multisystem is loaded from one dir, thus we can safely keep one dir |
| 735 | + valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) # type: ignore |
| 736 | + fs = os.listdir("%s.tmp" % target) |
| 737 | + assert len(fs) == 1 |
| 738 | + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) |
| 739 | + os.rmdir("%s.tmp" % target) |
| 740 | + else: |
| 741 | + valid_multi_systems[0].to_deepmd_npy(target) # type: ignore |
| 742 | + valid_systems.append(os.path.abspath(target)) |
| 743 | + |
| 744 | + return train_systems, valid_systems |
| 745 | + |
| 746 | + |
| 747 | +def append_valid_data(config, valid_data, valid_systems): |
| 748 | + if not valid_systems: |
| 749 | + return valid_data |
| 750 | + if config["multitask"]: |
| 751 | + head = config["head"] |
| 752 | + if not valid_data: |
| 753 | + valid_data = {} |
| 754 | + if head not in valid_data: |
| 755 | + valid_data[head] = [] |
| 756 | + valid_data[head] += valid_systems |
| 757 | + else: |
| 758 | + if not valid_data: |
| 759 | + valid_data = [] |
| 760 | + valid_data += valid_systems |
| 761 | + return valid_data |
| 762 | + |
| 763 | + |
675 | 764 | config_args = RunDPTrain.training_args
|
0 commit comments