@@ -1180,8 +1180,23 @@ def calculate(
11801180
11811181 self .args_dict ["batch_size" ] = 1
11821182 self .args_dict ["only_inference" ] = 1
1183+ cutoff = (
1184+ self .potential .model .model_args ["cutoff" ]
1185+ if self .potential .model_name == "m3gnet"
1186+ else 5.0
1187+ )
1188+ threebody_cutoff = (
1189+ self .potential .model .model_args ["threebody_cutoff" ]
1190+ if self .potential .model_name == "m3gnet"
1191+ else 4.0
1192+ )
1193+
11831194 dataloader = build_dataloader (
1184- [atoms ], model_type = self .potential .model_name , ** self .args_dict
1195+ [atoms ],
1196+ model_type = self .potential .model_name ,
1197+ cutoff = cutoff ,
1198+ threebody_cutoff = threebody_cutoff ,
1199+ ** self .args_dict ,
11851200 )
11861201 for graph_batch in dataloader :
11871202 # Resemble input dictionary
@@ -1323,8 +1338,23 @@ def calculate(
13231338
13241339 self .args_dict ["batch_size" ] = 1
13251340 self .args_dict ["only_inference" ] = 1
1341+ cutoff = (
1342+ self .potential .model .model_args ["cutoff" ]
1343+ if self .potential .model_name == "m3gnet"
1344+ else 5.0
1345+ )
1346+ threebody_cutoff = (
1347+ self .potential .model .model_args ["threebody_cutoff" ]
1348+ if self .potential .model_name == "m3gnet"
1349+ else 4.0
1350+ )
1351+
13261352 dataloader = build_dataloader (
1327- [atoms ], model_type = self .potential .model_name , ** self .args_dict
1353+ [atoms ],
1354+ model_type = self .potential .model_name ,
1355+ cutoff = cutoff ,
1356+ threebody_cutoff = threebody_cutoff ,
1357+ ** self .args_dict ,
13281358 )
13291359 for graph_batch in dataloader :
13301360 # Resemble input dictionary
0 commit comments