Skip to content

Commit 02efa73

Browse files
ZeroKnightingXixian
and
Xixian
authored
fix: cutoff auto detect in calculate (#62)
Co-authored-by: Xixian <[email protected]>
1 parent afe58a0 commit 02efa73

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

src/mattersim/forcefield/potential.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1180,8 +1180,23 @@ def calculate(
11801180

11811181
self.args_dict["batch_size"] = 1
11821182
self.args_dict["only_inference"] = 1
1183+
cutoff = (
1184+
self.potential.model.model_args["cutoff"]
1185+
if self.potential.model_name == "m3gnet"
1186+
else 5.0
1187+
)
1188+
threebody_cutoff = (
1189+
self.potential.model.model_args["threebody_cutoff"]
1190+
if self.potential.model_name == "m3gnet"
1191+
else 4.0
1192+
)
1193+
11831194
dataloader = build_dataloader(
1184-
[atoms], model_type=self.potential.model_name, **self.args_dict
1195+
[atoms],
1196+
model_type=self.potential.model_name,
1197+
cutoff=cutoff,
1198+
threebody_cutoff=threebody_cutoff,
1199+
**self.args_dict,
11851200
)
11861201
for graph_batch in dataloader:
11871202
# Resemble input dictionary
@@ -1323,8 +1338,23 @@ def calculate(
13231338

13241339
self.args_dict["batch_size"] = 1
13251340
self.args_dict["only_inference"] = 1
1341+
cutoff = (
1342+
self.potential.model.model_args["cutoff"]
1343+
if self.potential.model_name == "m3gnet"
1344+
else 5.0
1345+
)
1346+
threebody_cutoff = (
1347+
self.potential.model.model_args["threebody_cutoff"]
1348+
if self.potential.model_name == "m3gnet"
1349+
else 4.0
1350+
)
1351+
13261352
dataloader = build_dataloader(
1327-
[atoms], model_type=self.potential.model_name, **self.args_dict
1353+
[atoms],
1354+
model_type=self.potential.model_name,
1355+
cutoff=cutoff,
1356+
threebody_cutoff=threebody_cutoff,
1357+
**self.args_dict,
13281358
)
13291359
for graph_batch in dataloader:
13301360
# Resemble input dictionary

0 commit comments

Comments
 (0)