Skip to content

Commit 62a3463

Browse files
authored
feat: Add finetune method for MatterSim (#68)
* feat: Add finetune method for MatterSim * Update pyproject.toml
1 parent 79e48a1 commit 62a3463

File tree

5 files changed

+371
-21
lines changed

5 files changed

+371
-21
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"torchaudio>=2.2.0",
5050
"torchmetrics>=0.10.0",
5151
"torchvision>=0.17.0",
52+
"wandb",
5253
]
5354

5455
[project.optional-dependencies]

script/finetune_mattersim.py

+248
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# -*- coding: utf-8 -*-
2+
import argparse
3+
import os
4+
import pickle as pkl
5+
import random
6+
7+
import numpy as np
8+
import torch
9+
import torch.distributed
10+
import wandb
11+
from ase.units import GPa
12+
13+
from mattersim.datasets.utils.build import build_dataloader
14+
from mattersim.forcefield.m3gnet.scaling import AtomScaling
15+
from mattersim.forcefield.potential import Potential
16+
from mattersim.utils.atoms_utils import AtomsAdaptor
17+
from mattersim.utils.logger_utils import get_logger
18+
19+
logger = get_logger()
20+
torch.distributed.init_process_group(backend="nccl")
21+
local_rank = int(os.environ["LOCAL_RANK"])
22+
23+
24+
def main(args):
25+
args_dict = vars(args)
26+
if args.wandb and local_rank == 0:
27+
wandb_api_key = (
28+
args.wandb_api_key
29+
if args.wandb_api_key is not None
30+
else os.getenv("WANDB_API_KEY")
31+
)
32+
wandb.login(key=wandb_api_key)
33+
wandb.init(
34+
project=args.wandb_project,
35+
name=args.run_name,
36+
config=args,
37+
# id=args.run_name,
38+
# resume="allow",
39+
)
40+
41+
if args.wandb:
42+
args_dict["wandb"] = wandb
43+
44+
torch.distributed.barrier()
45+
46+
# set random seed
47+
random.seed(args.seed)
48+
np.random.seed(args.seed)
49+
torch.manual_seed(args.seed)
50+
51+
torch.cuda.set_device(local_rank)
52+
53+
if args.train_data_path.endswith(".pkl"):
54+
with open(args.train_data_path, "rb") as f:
55+
atoms_train = pkl.load(f)
56+
else:
57+
atoms_train = AtomsAdaptor.from_file(filename=args.train_data_path)
58+
energies = []
59+
forces = [] if args.include_forces else None
60+
stresses = [] if args.include_stresses else None
61+
logger.info("Processing training datasets...")
62+
for atoms in atoms_train:
63+
energies.append(atoms.get_potential_energy())
64+
if args.include_forces:
65+
forces.append(atoms.get_forces())
66+
if args.include_stresses:
67+
stresses.append(atoms.get_stress(voigt=False) / GPa) # convert to GPa
68+
69+
dataloader = build_dataloader(
70+
atoms_train,
71+
energies,
72+
forces,
73+
stresses,
74+
shuffle=True,
75+
pin_memory=True,
76+
is_distributed=True,
77+
**args_dict,
78+
)
79+
80+
device = "cuda" if torch.cuda.is_available() else "cpu"
81+
# build energy normalization module
82+
if args.re_normalize:
83+
scale = AtomScaling(
84+
atoms=atoms_train,
85+
total_energy=energies,
86+
forces=forces,
87+
verbose=True,
88+
**args_dict,
89+
).to(device)
90+
91+
if args.valid_data_path is not None:
92+
if args.valid_data_path.endswith(".pkl"):
93+
with open(args.valid_data_path, "rb") as f:
94+
atoms_val = pkl.load(f)
95+
else:
96+
atoms_val = AtomsAdaptor.from_file(filename=args.train_data_path)
97+
energies = []
98+
forces = [] if args.include_forces else None
99+
stresses = [] if args.include_stresses else None
100+
logger.info("Processing validation datasets...")
101+
for atoms in atoms_val:
102+
energies.append(atoms.get_potential_energy())
103+
if args.include_forces:
104+
forces.append(atoms.get_forces())
105+
if args.include_stresses:
106+
stresses.append(atoms.get_stress(voigt=False) / GPa) # convert to GPa
107+
val_dataloader = build_dataloader(
108+
atoms_val,
109+
energies,
110+
forces,
111+
stresses,
112+
pin_memory=True,
113+
is_distributed=True,
114+
**args_dict,
115+
)
116+
else:
117+
val_dataloader = None
118+
119+
potential = Potential.from_checkpoint(
120+
load_path=args.load_model_path,
121+
load_training_state=False,
122+
**args_dict,
123+
)
124+
125+
if args.re_normalize:
126+
potential.model.set_normalizer(scale)
127+
128+
potential.model = torch.nn.parallel.DistributedDataParallel(potential.model)
129+
torch.distributed.barrier()
130+
131+
potential.train_model(
132+
dataloader,
133+
val_dataloader,
134+
loss=torch.nn.HuberLoss(delta=0.01),
135+
is_distributed=True,
136+
**args_dict,
137+
)
138+
139+
if local_rank == 0 and args.save_checkpoint:
140+
wandb.save(os.path.join(args.save_path, "best_model.pth"))
141+
142+
143+
if __name__ == "__main__":
144+
# Some important arguments
145+
parser = argparse.ArgumentParser()
146+
147+
# path parameters
148+
parser.add_argument(
149+
"--run_name", type=str, default="example", help="name of the run"
150+
)
151+
parser.add_argument(
152+
"--train_data_path", type=str, default="./sample.xyz", help="train data path"
153+
)
154+
parser.add_argument(
155+
"--valid_data_path", type=str, default=None, help="valid data path"
156+
)
157+
parser.add_argument(
158+
"--load_model_path",
159+
type=str,
160+
default="mattersim-v1.0.0-1m",
161+
help="path to load the model",
162+
)
163+
parser.add_argument(
164+
"--save_path", type=str, default="./results", help="path to save the model"
165+
)
166+
parser.add_argument(
167+
"--save_checkpoint",
168+
type=bool,
169+
default=False,
170+
action=argparse.BooleanOptionalAction,
171+
)
172+
parser.add_argument(
173+
"--ckpt_interval",
174+
type=int,
175+
default=10,
176+
help="save checkpoint every ckpt_interval epochs",
177+
)
178+
parser.add_argument("--device", type=str, default="cuda")
179+
180+
# model parameters
181+
parser.add_argument("--cutoff", type=float, default=5.0, help="cutoff radius")
182+
parser.add_argument(
183+
"--threebody_cutoff",
184+
type=float,
185+
default=4.0,
186+
help="cutoff radius for three-body term, which should be smaller than cutoff (two-body)", # noqa: E501
187+
)
188+
189+
# training parameters
190+
parser.add_argument("--epochs", type=int, default=1000, help="number of epochs")
191+
parser.add_argument("--batch_size", type=int, default=16)
192+
parser.add_argument("--lr", type=float, default=2e-4)
193+
parser.add_argument(
194+
"--step_size",
195+
type=int,
196+
default=10,
197+
help="step epoch for learning rate scheduler",
198+
)
199+
parser.add_argument(
200+
"--include_forces",
201+
type=bool,
202+
default=True,
203+
action=argparse.BooleanOptionalAction,
204+
)
205+
parser.add_argument(
206+
"--include_stresses",
207+
type=bool,
208+
default=False,
209+
action=argparse.BooleanOptionalAction,
210+
)
211+
parser.add_argument("--force_loss_ratio", type=float, default=1.0)
212+
parser.add_argument("--stress_loss_ratio", type=float, default=0.1)
213+
parser.add_argument("--early_stop_patience", type=int, default=10)
214+
parser.add_argument("--seed", type=int, default=42)
215+
216+
# scaling parameters
217+
parser.add_argument(
218+
"--re_normalize",
219+
type=bool,
220+
default=False,
221+
action=argparse.BooleanOptionalAction,
222+
help="re-normalize the energy and forces according to the new data",
223+
)
224+
parser.add_argument("--scale_key", type=str, default="per_species_forces_rms")
225+
parser.add_argument(
226+
"--shift_key", type=str, default="per_species_energy_mean_linear_reg"
227+
)
228+
parser.add_argument("--init_scale", type=float, default=None)
229+
parser.add_argument("--init_shift", type=float, default=None)
230+
parser.add_argument(
231+
"--trainable_scale",
232+
type=bool,
233+
default=False,
234+
action=argparse.BooleanOptionalAction,
235+
)
236+
parser.add_argument(
237+
"--trainable_shift",
238+
type=bool,
239+
default=False,
240+
action=argparse.BooleanOptionalAction,
241+
)
242+
243+
# wandb parameters
244+
parser.add_argument("--wandb", action="store_true")
245+
parser.add_argument("--wandb_api_key", type=str, default=None)
246+
parser.add_argument("--wandb_project", type=str, default="wandb_test")
247+
args = parser.parse_args()
248+
main(args)

script/vasprun_to_xyz.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import random
4+
5+
from ase.io import write
6+
7+
from mattersim.utils.atoms_utils import AtomsAdaptor
8+
9+
vasp_files = [
10+
"work/data/H/vasp/vasprun.xml",
11+
"work/data/H/vasp_2/vasprun.xml",
12+
"work/data/H/vasp_3/vasprun.xml",
13+
"work/data/H/vasp_4/vasprun.xml",
14+
"work/data/H/vasp_5/vasprun.xml",
15+
"work/data/H/vasp_6/vasprun.xml",
16+
"work/data/H/vasp_7/vasprun.xml",
17+
"work/data/H/vasp_8/vasprun.xml",
18+
"work/data/H/vasp_9/vasprun.xml",
19+
"work/data/H/vasp_10/vasprun.xml",
20+
]
21+
train_ratio = 0.8
22+
validation_ratio = 0.1
23+
test_ratio = 0.1
24+
25+
save_dir = "./xyz_files"
26+
os.makedirs(save_dir, exist_ok=True)
27+
28+
29+
def main():
30+
atoms_train = []
31+
atoms_validation = []
32+
atoms_test = []
33+
34+
random.seed(42)
35+
36+
for vasp_file in vasp_files:
37+
atoms_list = AtomsAdaptor.from_file(filename=vasp_file)
38+
random.shuffle(atoms_list)
39+
num_atoms = len(atoms_list)
40+
num_train = int(num_atoms * train_ratio)
41+
num_validation = int(num_atoms * validation_ratio)
42+
43+
atoms_train.extend(atoms_list[:num_train])
44+
atoms_validation.extend(atoms_list[num_train : num_train + num_validation])
45+
atoms_test.extend(atoms_list[num_train + num_validation :])
46+
47+
print(
48+
f"Total number of atoms: {len(atoms_train) + len(atoms_validation) + len(atoms_test)}" # noqa: E501
49+
)
50+
51+
print(f"Number of atoms in the training set: {len(atoms_train)}")
52+
print(f"Number of atoms in the validation set: {len(atoms_validation)}")
53+
print(f"Number of atoms in the test set: {len(atoms_test)}")
54+
55+
# Save the training, validation, and test datasets to xyz files
56+
57+
write(f"{save_dir}/train.xyz", atoms_train)
58+
write(f"{save_dir}/valid.xyz", atoms_validation)
59+
write(f"{save_dir}/test.xyz", atoms_test)
60+
61+
62+
if __name__ == "__main__":
63+
main()

0 commit comments

Comments
 (0)