@@ -1180,8 +1180,23 @@ def calculate(
1180
1180
1181
1181
self .args_dict ["batch_size" ] = 1
1182
1182
self .args_dict ["only_inference" ] = 1
1183
+ cutoff = (
1184
+ self .potential .model .model_args ["cutoff" ]
1185
+ if self .potential .model_name == "m3gnet"
1186
+ else 5.0
1187
+ )
1188
+ threebody_cutoff = (
1189
+ self .potential .model .model_args ["threebody_cutoff" ]
1190
+ if self .potential .model_name == "m3gnet"
1191
+ else 4.0
1192
+ )
1193
+
1183
1194
dataloader = build_dataloader (
1184
- [atoms ], model_type = self .potential .model_name , ** self .args_dict
1195
+ [atoms ],
1196
+ model_type = self .potential .model_name ,
1197
+ cutoff = cutoff ,
1198
+ threebody_cutoff = threebody_cutoff ,
1199
+ ** self .args_dict ,
1185
1200
)
1186
1201
for graph_batch in dataloader :
1187
1202
# Resemble input dictionary
@@ -1323,8 +1338,23 @@ def calculate(
1323
1338
1324
1339
self .args_dict ["batch_size" ] = 1
1325
1340
self .args_dict ["only_inference" ] = 1
1341
+ cutoff = (
1342
+ self .potential .model .model_args ["cutoff" ]
1343
+ if self .potential .model_name == "m3gnet"
1344
+ else 5.0
1345
+ )
1346
+ threebody_cutoff = (
1347
+ self .potential .model .model_args ["threebody_cutoff" ]
1348
+ if self .potential .model_name == "m3gnet"
1349
+ else 4.0
1350
+ )
1351
+
1326
1352
dataloader = build_dataloader (
1327
- [atoms ], model_type = self .potential .model_name , ** self .args_dict
1353
+ [atoms ],
1354
+ model_type = self .potential .model_name ,
1355
+ cutoff = cutoff ,
1356
+ threebody_cutoff = threebody_cutoff ,
1357
+ ** self .args_dict ,
1328
1358
)
1329
1359
for graph_batch in dataloader :
1330
1360
# Resemble input dictionary
0 commit comments