From 1dca8159e9f492547f1a7d69f4e690a81a7ae721 Mon Sep 17 00:00:00 2001 From: Sander Vandenhaute Date: Wed, 12 Jun 2024 13:04:47 -0400 Subject: [PATCH] linting + update to latest mace version --- psiflow/__init__.py | 2 +- psiflow/config.py | 1 - psiflow/execution.py | 3 +- psiflow/hamiltonians/hamiltonian.py | 2 +- psiflow/models/_mace.py | 40 ++--- psiflow/models/mace_utils.py | 253 ++++++++++++++++++---------- psiflow/sampling/order.py | 2 +- psiflow/tools/server.py | 2 +- tests/test_hamiltonian.py | 24 ++- tests/test_models.py | 7 +- 10 files changed, 212 insertions(+), 124 deletions(-) diff --git a/psiflow/__init__.py b/psiflow/__init__.py index 28cfec35..f4e7d955 100644 --- a/psiflow/__init__.py +++ b/psiflow/__init__.py @@ -1,3 +1,4 @@ +from .config import setup_slurm_config # noqa: F401 from .execution import ExecutionContextLoader from .serialization import ( # noqa: F401 _DataFuture, @@ -5,7 +6,6 @@ serializable, serialize, ) -from .config import setup_slurm_config # noqa: F401 load = ExecutionContextLoader.load context = ExecutionContextLoader.context diff --git a/psiflow/config.py b/psiflow/config.py index 2ea65de5..fd4c85ca 100644 --- a/psiflow/config.py +++ b/psiflow/config.py @@ -22,7 +22,6 @@ def get_partitions(): partition_info[partition_name] = partition_dict scontrol_output = subprocess.check_output(["scontrol", "show", "node"], text=True) - node_info = {} nodes = scontrol_output.strip().split("\n\n") for node in nodes: diff --git a/psiflow/execution.py b/psiflow/execution.py index dfa166df..9859de85 100644 --- a/psiflow/execution.py +++ b/psiflow/execution.py @@ -17,10 +17,9 @@ from parsl.addresses import address_by_hostname from parsl.config import Config from parsl.data_provider.files import File -from parsl.executors import ( +from parsl.executors import ( # WorkQueueExecutor, HighThroughputExecutor, ThreadPoolExecutor, - # WorkQueueExecutor, ) from parsl.executors.base import ParslExecutor from parsl.launchers import SimpleLauncher, WrappedLauncher diff --git a/psiflow/hamiltonians/hamiltonian.py b/psiflow/hamiltonians/hamiltonian.py index e6c734c4..a993399d 100644 --- a/psiflow/hamiltonians/hamiltonian.py +++ b/psiflow/hamiltonians/hamiltonian.py @@ -6,8 +6,8 @@ import typeguard from parsl.app.app import python_app from parsl.app.futures import DataFuture -from parsl.dataflow.futures import AppFuture from parsl.data_provider.files import File +from parsl.dataflow.futures import AppFuture import psiflow from psiflow.data import Dataset, batch_apply diff --git a/psiflow/models/_mace.py b/psiflow/models/_mace.py index 793a95cc..598b3616 100644 --- a/psiflow/models/_mace.py +++ b/psiflow/models/_mace.py @@ -28,7 +28,8 @@ class MACEConfig: results_dir: str = "" downloads_dir: str = "" device: str = "cuda" - default_dtype: str = "float32" + default_dtype: str = "float32" # default: float64 + # distributed: bool = False # this is automatically set based on execution config log_level: str = "INFO" error_table: str = "PerAtomRMSE" model: str = "MACE" @@ -36,6 +37,8 @@ class MACEConfig: radial_type: str = "bessel" num_radial_basis: int = 8 num_cutoff_basis: int = 5 + pair_repulsion: bool = False + distance_transform: Optional[str] = None interaction: str = "RealAgnosticResidualInteractionBlock" interaction_first: str = "RealAgnosticResidualInteractionBlock" max_ell: int = 3 @@ -43,19 +46,20 @@ class MACEConfig: num_interactions: int = 2 MLP_irreps: str = "16x0e" radial_MLP: str = "[64, 64, 64]" - num_channels: int = 16 # hidden_irreps is determined by num_channels and max_L + num_channels: int = 16 # default: 128 channels max_L: int = 1 gate: str = "silu" scaling: str = "rms_forces_scaling" avg_num_neighbors: Optional[float] = None compute_avg_num_neighbors: bool = True - compute_stress: bool = True + compute_stress: bool = True # default: False compute_forces: bool = True train_file: Optional[str] = None valid_file: Optional[str] = None - # model_dtype: str = "float32" valid_fraction: float = 1e-12 # never split training set test_file: Optional[str] = None + num_workers: int = 0 + pin_memory: bool = True E0s: Optional[str] = "average" energy_key: str = "energy" forces_key: str = "forces" @@ -64,21 +68,19 @@ class MACEConfig: dipole_key: str = "dipole" charges_key: str = "charges" loss: str = "weighted" - forces_weight: float = 1 + forces_weight: float = 1 # default: 100 swa_forces_weight: float = 1 - energy_weight: float = 10 + energy_weight: float = 10 # default: 1 swa_energy_weight: float = 100 - virials_weight: float = 0 - swa_virials_weight: float = 0 - stress_weight: float = 0 - swa_stress_weight: float = 0 - dipole_weight: float = 0 - swa_dipole_weight: float = 0 + virials_weight: float = 0 # default: 1 + swa_virials_weight: float = 0 # default: 10 + stress_weight: float = 0 # default: 1 + swa_stress_weight: float = 0 # default: 10 config_type_weights: str = '{"Default":1.0}' huber_delta: float = 0.01 optimizer: str = "adam" - batch_size: int = 1 - valid_batch_size: int = 8 + batch_size: int = 10 + valid_batch_size: int = 10 lr: float = 0.01 swa_lr: float = 0.001 weight_decay: float = 5e-7 @@ -87,16 +89,16 @@ class MACEConfig: lr_factor: float = 0.8 scheduler_patience: int = 50 lr_scheduler_gamma: float = 0.9993 - swa: bool = False - start_swa: int = int(1e12) # never start swa + swa: bool = True # default: False + start_swa: Optional[int] = None # never start swa ema: bool = False ema_decay: float = 0.99 - max_num_epochs: int = int(1e6) + max_num_epochs: int = 2048 patience: int = 2048 eval_interval: int = 2 keep_checkpoints: bool = False restart_latest: bool = False - save_cpu: bool = True + save_cpu: bool = True # default: False clip_grad: Optional[float] = 10 wandb: bool = False wandb_project: Optional[str] = "psiflow" @@ -113,6 +115,7 @@ def serialize(config: dict): "restart_latest", "save_cpu", "wandb", + "pair_repulsion", ] config_str = "" for key, value in config.items(): @@ -198,7 +201,6 @@ class MACE(Model): def __init__(self, **config) -> None: config = MACEConfig(**config) # validate input - assert not config.swa, "usage of SWA is currently not supported" config.save_cpu = True # assert model is saved to CPU after training config.device = "cpu" self._config = asdict(config) diff --git a/psiflow/models/mace_utils.py b/psiflow/models/mace_utils.py index a6cb4af2..315e85fc 100644 --- a/psiflow/models/mace_utils.py +++ b/psiflow/models/mace_utils.py @@ -6,7 +6,7 @@ """ -MACE utils for use in psiflow -- initially copied from mace@d520aba +MACE utils for use in psiflow -- copied from mace@dee204f The following changes were made: - use signal module to wrap tools.train() call with timeout such that @@ -18,6 +18,7 @@ """ +import argparse import ast import json import logging @@ -38,8 +39,11 @@ from mace.tools.scripts_utils import ( LRScheduler, create_error_table, + get_atomic_energies, + get_config_type_weights, get_dataset_from_xyz, ) +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.swa_utils import SWALR, AveragedModel from torch_ema import ExponentialMovingAverage @@ -175,17 +179,9 @@ def timeout_handler(signum, frame): raise TimeoutException -def main() -> None: - # extend MACE arg parser with ability to pass initialized model; set tmpdirs - parser = tools.build_default_arg_parser() - parser.add_argument( - "--initialized_model", - help="path to initialized model", - default=None, - type=str, - ) - args = parser.parse_args() +def run(rank: int, args: argparse.Namespace, world_size: int) -> None: + # extend MACE arg parser with ability to pass initialized model; set tmpdirs args.log_dir = os.path.join(os.getcwd(), "log") args.model_dir = os.path.join(os.getcwd()) args.results_dir = os.path.join(os.getcwd(), "results") @@ -193,6 +189,18 @@ def main() -> None: args.checkpoints_dir = os.path.join(os.getcwd(), "checkpoints") tag = tools.get_tag(name=args.name, seed=args.seed) + if args.distributed: + local_rank = rank + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + ) + torch.cuda.set_device(rank) + else: + pass # Setup tools.set_seeds(args.seed) @@ -205,16 +213,11 @@ def main() -> None: device = tools.init_device(args.device) tools.set_default_dtype(args.default_dtype) - try: - config_type_weights = ast.literal_eval(args.config_type_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} + assert args.foundation_model is None + assert args.statistics_file is None # Data preparation + config_type_weights = get_config_type_weights(args.config_type_weights) collections, atomic_energies_dict = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, @@ -237,14 +240,36 @@ def main() -> None: # Atomic number table # yapf: disable - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) + if args.atomic_numbers is None: + assert args.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + if args.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) # yapf: enable logging.info(z_table) + + if atomic_energies_dict is None or len(atomic_energies_dict) == 0: + if args.E0s.lower() == "foundation": + raise NotImplementedError + else: + if args.train_file.endswith(".xyz"): + atomic_energies_dict = get_atomic_energies( + args.E0s, collections.train, z_table + ) + else: + atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table) + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True @@ -264,30 +289,6 @@ def main() -> None: else: compute_energy = True compute_dipole = False - if atomic_energies_dict is None or len(atomic_energies_dict) == 0: - if args.E0s is not None: - logging.info( - "Atomic Energies not in training file, using command line argument E0s" - ) - if args.E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - atomic_energies_dict = data.compute_average_E0s( - collections.train, z_table - ) - else: - try: - atomic_energies_dict = ast.literal_eval(args.E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occurred" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) @@ -295,23 +296,51 @@ def main() -> None: args.batch_size = min(len(collections.train), args.batch_size) print("actual batch size: {}".format(args.batch_size)) + train_set = [ + data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + for config in collections.train + ] + valid_set = [ + data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + for config in collections.valid + ] + train_sampler, valid_sampler = None, None + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) train_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.train - ], + dataset=train_set, batch_size=args.batch_size, - shuffle=True, - drop_last=True, + sampler=train_sampler, + shuffle=(train_sampler is None), + drop_last=(train_sampler is None), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), ) valid_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.valid - ], + dataset=valid_set, batch_size=args.valid_batch_size, + sampler=valid_sampler, shuffle=False, drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), ) loss_fn: torch.nn.Module @@ -360,7 +389,17 @@ def main() -> None: logging.info(loss_fn) if args.compute_avg_num_neighbors: - args.avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + args.avg_num_neighbors = (num_neighbors / num_graphs).item() + else: + args.avg_num_neighbors = avg_num_neighbors logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") # Selecting outputs @@ -565,12 +604,17 @@ def main() -> None: assert dipole_only is False, "swa for dipole fitting not implemented" swas.append(True) if args.start_swa is None: - args.start_swa = ( - args.max_num_epochs // 4 * 3 - ) # if not set start swa at 75% of training + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa > args.max_num_epochs: + logging.info( + f"Start swa must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + logging.info(f"Setting start swa to {args.start_swa}") if args.loss == "forces_only": - logging.info("Can not select swa with forces only loss.") - elif args.loss == "virials": + raise ValueError("Can not select swa with forces only loss.") + if args.loss == "virials": loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight, @@ -638,6 +682,9 @@ def main() -> None: ema: Optional[ExponentialMovingAverage] = None if args.ema: ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) + else: + for group in optimizer.param_groups: + group["lr"] = args.lr logging.info(model) logging.info(f"Number of parameters: {tools.count_parameters(model)}") @@ -660,6 +707,11 @@ def main() -> None: ) wandb.run.summary["params"] = args_dict_json + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + else: + distributed_model = None + try: tools.train( model=model, @@ -674,6 +726,7 @@ def main() -> None: max_num_epochs=args.max_num_epochs, logger=logger, patience=args.patience, + save_all_checkpoints=args.save_all_checkpoints, output_args=output_args, device=device, swa=swa, @@ -681,6 +734,10 @@ def main() -> None: max_grad_norm=args.clip_grad, log_errors=args.error_table, log_wandb=args.wandb, + distributed=args.distributed, + distributed_model=distributed_model, + train_sampler=train_sampler, + rank=rank, ) except TimeoutException: logging.info("received SIGTERM!") @@ -689,10 +746,10 @@ def main() -> None: # Evaluation on test datasets logging.info("Computing metrics for training, validation, and test sets") - all_collections = [ - ("train", collections.train), - ("valid", collections.valid), - ] + collections.tests + all_data_loaders = { + "train": train_loader, + "valid": valid_loader, + } for swa_eval in swas: epoch = checkpoint_handler.load_latest( @@ -701,42 +758,60 @@ def main() -> None: device=device, ) model.to(device) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + model_to_evaluate = model if not args.distributed else distributed_model logging.info(f"Loaded model from epoch {epoch}") for param in model.parameters(): param.requires_grad = False table = create_error_table( table_type=args.error_table, - all_collections=all_collections, - z_table=z_table, - r_max=args.r_max, - valid_batch_size=args.valid_batch_size, - model=model, + all_data_loaders=all_data_loaders, + model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, log_wandb=args.wandb, device=device, + distributed=args.distributed, ) logging.info("\n" + str(table)) - # Save entire model - if swa_eval: - model_path = Path(args.checkpoints_dir) / (tag + "_swa.model") - else: - model_path = Path(args.checkpoints_dir) / (tag + ".model") - logging.info(f"Saving model to {model_path}") - if args.save_cpu: - model = model.to("cpu") - torch.save(model, "model.pth") - - if swa_eval: - torch.save(model, Path(args.model_dir) / (args.name + "_swa.model")) - else: - torch.save(model, Path(args.model_dir) / (args.name + ".model")) + if rank == 0: + # Save entire model + # if swa_eval: + # model_path = Path.cwd() / 'model_swa.pth' + # else: + model_path = Path.cwd() / "model.pth" + logging.info("swa: {}".format(swa_eval)) + logging.info(f"Saving model to {model_path}") + if args.save_cpu: + model = model.to("cpu") + torch.save(model, model_path) + + if args.distributed: + torch.distributed.barrier() logging.info("Done") + if args.distributed: + torch.distributed.destroy_process_group() if __name__ == "__main__": signal.signal(signal.SIGTERM, timeout_handler) - main() + # main() + parser = tools.build_default_arg_parser() + parser.add_argument( + "--initialized_model", + help="path to initialized model", + default=None, + type=str, + ) + args = parser.parse_args() + if args.distributed: + world_size = torch.cuda.device_count() + import torch.multiprocessing as mp + + mp.spawn(run, args=(args, world_size), nprocs=world_size) + else: + run(0, args, 1) diff --git a/psiflow/sampling/order.py b/psiflow/sampling/order.py index 0a7126b7..b1c1f924 100644 --- a/psiflow/sampling/order.py +++ b/psiflow/sampling/order.py @@ -1,7 +1,7 @@ from __future__ import annotations # necessary for type-guarding class methods from functools import partial -from typing import Union, Optional +from typing import Optional, Union import typeguard from ase.units import kJ, mol diff --git a/psiflow/tools/server.py b/psiflow/tools/server.py index 1e62c125..b6bf4a3e 100644 --- a/psiflow/tools/server.py +++ b/psiflow/tools/server.py @@ -420,7 +420,7 @@ def cleanup(args): else: try: cleanup(args) - except BaseException: + except BaseException: # noqa: B036 print("i-PI cleanup failed!") print("files in directory:") for filepath in Path.cwd().glob("*"): diff --git a/tests/test_hamiltonian.py b/tests/test_hamiltonian.py index a49ba539..8ca924f4 100644 --- a/tests/test_hamiltonian.py +++ b/tests/test_hamiltonian.py @@ -211,7 +211,9 @@ def test_plumed_evaluate(dataset, tmp_path): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: DISTANCE ATOMS=1,2 NOPBC RESTRAINT ARG=CV AT={center} KAPPA={kappa} -""".format(center=center, kappa=kappa / (kJ / mol)) +""".format( + center=center, kappa=kappa / (kJ / mol) + ) calculator = PlumedCalculator(plumed_input) calculator.calculate(atoms) energy = calculator.results["energy"] @@ -249,7 +251,9 @@ def test_plumed_evaluate(dataset, tmp_path): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: DISTANCE ATOMS=1,2 NOPBC METAD ARG=CV PACE=1 SIGMA=3 HEIGHT=342 FILE={} -""".format(path_hills) +""".format( + path_hills + ) calculator = PlumedCalculator(plumed_input, path_hills) for _ in range(30): @@ -277,7 +281,9 @@ def test_plumed_hamiltonian(dataset, tmp_path): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: VOLUME RESTRAINT ARG=CV AT={center} KAPPA={kappa} -""".format(center=center, kappa=kappa / (kJ / mol)) +""".format( + center=center, kappa=kappa / (kJ / mol) + ) hamiltonian = PlumedHamiltonian(plumed_input) evaluated = hamiltonian.evaluate(dataset).geometries().result() for geometry in evaluated: @@ -306,7 +312,9 @@ def test_plumed_hamiltonian(dataset, tmp_path): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: DISTANCE ATOMS=1,2 NOPBC METAD ARG=CV PACE=1 SIGMA=3 HEIGHT=342 FILE={} -""".format(data_future.filepath) +""".format( + data_future.filepath + ) hamiltonian = PlumedHamiltonian(plumed_input, data_future) data = hamiltonian.evaluate(dataset) assert np.all(data.get("energy")[0].result() > 0) @@ -385,7 +393,9 @@ def test_serialization(dataset, tmp_path, mace_model): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: DISTANCE ATOMS=1,2 NOPBC METAD ARG=CV PACE=1 SIGMA=3 HEIGHT=342 FILE={} -""".format(data_future.filepath) +""".format( + data_future.filepath + ) hamiltonian = PlumedHamiltonian(plumed_input, data_future) evaluated = hamiltonian.evaluate(dataset[:3]) @@ -460,7 +470,9 @@ def test_hamiltonian_serialize(dataset): UNITS LENGTH=A ENERGY=kj/mol TIME=fs CV: VOLUME RESTRAINT ARG=CV AT={center} KAPPA={kappa} -""".format(center=center, kappa=kappa / (kJ / mol)) +""".format( + center=center, kappa=kappa / (kJ / mol) + ) plumed = PlumedHamiltonian(plumed_input) data = json.loads(psiflow.serialize(einstein).result()) assert "EinsteinCrystal" in data diff --git a/tests/test_models.py b/tests/test_models.py index fb0b4ba1..6935aaca 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,9 +27,9 @@ def test_mace_init(mace_config, dataset): assert value == _config_[key] config = copy.deepcopy(mace_config) - config[ - "batch_size" - ] = 100000 # bigger than ntrain --> should get reduced internally + config["batch_size"] = ( + 100000 # bigger than ntrain --> should get reduced internally + ) model = MACE(**config) model.seed = 1 model.initialize(dataset[:3]) @@ -83,6 +83,7 @@ def test_mace_train(gpu, mace_config, dataset, tmp_path): # it with the manually computed value training = dataset[:-5] validation = dataset[-5:] + mace_config["start_swa"] = 1000 model = MACE(**mace_config) model.initialize(training) hamiltonian0 = model.create_hamiltonian()