Skip to content

Commit

Permalink
linting + update to latest mace version
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jun 12, 2024
1 parent dfb71bd commit 1dca815
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 124 deletions.
2 changes: 1 addition & 1 deletion psiflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .config import setup_slurm_config # noqa: F401
from .execution import ExecutionContextLoader
from .serialization import ( # noqa: F401
_DataFuture,
deserialize,
serializable,
serialize,
)
from .config import setup_slurm_config # noqa: F401

load = ExecutionContextLoader.load
context = ExecutionContextLoader.context
Expand Down
1 change: 0 additions & 1 deletion psiflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def get_partitions():
partition_info[partition_name] = partition_dict

scontrol_output = subprocess.check_output(["scontrol", "show", "node"], text=True)
node_info = {}

nodes = scontrol_output.strip().split("\n\n")
for node in nodes:
Expand Down
3 changes: 1 addition & 2 deletions psiflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
from parsl.addresses import address_by_hostname
from parsl.config import Config
from parsl.data_provider.files import File
from parsl.executors import (
from parsl.executors import ( # WorkQueueExecutor,
HighThroughputExecutor,
ThreadPoolExecutor,
# WorkQueueExecutor,
)
from parsl.executors.base import ParslExecutor
from parsl.launchers import SimpleLauncher, WrappedLauncher
Expand Down
2 changes: 1 addition & 1 deletion psiflow/hamiltonians/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import typeguard
from parsl.app.app import python_app
from parsl.app.futures import DataFuture
from parsl.dataflow.futures import AppFuture
from parsl.data_provider.files import File
from parsl.dataflow.futures import AppFuture

import psiflow
from psiflow.data import Dataset, batch_apply
Expand Down
40 changes: 21 additions & 19 deletions psiflow/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,38 @@ class MACEConfig:
results_dir: str = ""
downloads_dir: str = ""
device: str = "cuda"
default_dtype: str = "float32"
default_dtype: str = "float32" # default: float64
# distributed: bool = False # this is automatically set based on execution config
log_level: str = "INFO"
error_table: str = "PerAtomRMSE"
model: str = "MACE"
r_max: float = 5.0
radial_type: str = "bessel"
num_radial_basis: int = 8
num_cutoff_basis: int = 5
pair_repulsion: bool = False
distance_transform: Optional[str] = None
interaction: str = "RealAgnosticResidualInteractionBlock"
interaction_first: str = "RealAgnosticResidualInteractionBlock"
max_ell: int = 3
correlation: int = 3
num_interactions: int = 2
MLP_irreps: str = "16x0e"
radial_MLP: str = "[64, 64, 64]"
num_channels: int = 16 # hidden_irreps is determined by num_channels and max_L
num_channels: int = 16 # default: 128 channels
max_L: int = 1
gate: str = "silu"
scaling: str = "rms_forces_scaling"
avg_num_neighbors: Optional[float] = None
compute_avg_num_neighbors: bool = True
compute_stress: bool = True
compute_stress: bool = True # default: False
compute_forces: bool = True
train_file: Optional[str] = None
valid_file: Optional[str] = None
# model_dtype: str = "float32"
valid_fraction: float = 1e-12 # never split training set
test_file: Optional[str] = None
num_workers: int = 0
pin_memory: bool = True
E0s: Optional[str] = "average"
energy_key: str = "energy"
forces_key: str = "forces"
Expand All @@ -64,21 +68,19 @@ class MACEConfig:
dipole_key: str = "dipole"
charges_key: str = "charges"
loss: str = "weighted"
forces_weight: float = 1
forces_weight: float = 1 # default: 100
swa_forces_weight: float = 1
energy_weight: float = 10
energy_weight: float = 10 # default: 1
swa_energy_weight: float = 100
virials_weight: float = 0
swa_virials_weight: float = 0
stress_weight: float = 0
swa_stress_weight: float = 0
dipole_weight: float = 0
swa_dipole_weight: float = 0
virials_weight: float = 0 # default: 1
swa_virials_weight: float = 0 # default: 10
stress_weight: float = 0 # default: 1
swa_stress_weight: float = 0 # default: 10
config_type_weights: str = '{"Default":1.0}'
huber_delta: float = 0.01
optimizer: str = "adam"
batch_size: int = 1
valid_batch_size: int = 8
batch_size: int = 10
valid_batch_size: int = 10
lr: float = 0.01
swa_lr: float = 0.001
weight_decay: float = 5e-7
Expand All @@ -87,16 +89,16 @@ class MACEConfig:
lr_factor: float = 0.8
scheduler_patience: int = 50
lr_scheduler_gamma: float = 0.9993
swa: bool = False
start_swa: int = int(1e12) # never start swa
swa: bool = True # default: False
start_swa: Optional[int] = None # never start swa
ema: bool = False
ema_decay: float = 0.99
max_num_epochs: int = int(1e6)
max_num_epochs: int = 2048
patience: int = 2048
eval_interval: int = 2
keep_checkpoints: bool = False
restart_latest: bool = False
save_cpu: bool = True
save_cpu: bool = True # default: False
clip_grad: Optional[float] = 10
wandb: bool = False
wandb_project: Optional[str] = "psiflow"
Expand All @@ -113,6 +115,7 @@ def serialize(config: dict):
"restart_latest",
"save_cpu",
"wandb",
"pair_repulsion",
]
config_str = ""
for key, value in config.items():
Expand Down Expand Up @@ -198,7 +201,6 @@ class MACE(Model):

def __init__(self, **config) -> None:
config = MACEConfig(**config) # validate input
assert not config.swa, "usage of SWA is currently not supported"
config.save_cpu = True # assert model is saved to CPU after training
config.device = "cpu"
self._config = asdict(config)
Expand Down
Loading

0 comments on commit 1dca815

Please sign in to comment.