@@ -965,27 +965,39 @@ def load(
965
965
if model_name .lower () != "m3gnet" :
966
966
raise NotImplementedError
967
967
968
- current_dir = os .path .dirname (__file__ )
968
+ checkpoint_folder = os .path .expanduser ("~/.local/mattersim/pretrained_models" )
969
+ os .makedirs (checkpoint_folder , exist_ok = True )
969
970
if (
970
971
load_path is None
971
972
or load_path .lower () == "mattersim-v1.0.0-1m.pth"
972
973
or load_path .lower () == "mattersim-v1.0.0-1m"
973
974
):
974
- load_path = os .path .join (
975
- current_dir , ".." , "pretrained_models/mattersim-v1.0.0-1M.pth"
976
- )
975
+ load_path = os .path .join (checkpoint_folder , "mattersim-v1.0.0-1M.pth" )
976
+ if not os .path .exists (load_path ):
977
+ logger .info (
978
+ "The pre-trained model is not found locally, "
979
+ "attempting to download it from the server."
980
+ )
981
+ download_checkpoint (
982
+ "mattersim-v1.0.0-1M.pth" , save_folder = checkpoint_folder
983
+ )
977
984
logger .info (f"Loading the pre-trained { os .path .basename (load_path )} model" )
978
985
elif (
979
986
load_path .lower () == "mattersim-v1.0.0-5m.pth"
980
987
or load_path .lower () == "mattersim-v1.0.0-5m"
981
988
):
982
- load_path = os .path .join (
983
- current_dir , ".." , "pretrained_models/mattersim-v1.0.0-5M.pth"
984
- )
989
+ load_path = os .path .join (checkpoint_folder , "mattersim-v1.0.0-5M.pth" )
990
+ if not os .path .exists (load_path ):
991
+ logger .info (
992
+ "The pre-trained model is not found locally, "
993
+ "attempting to download it from the server."
994
+ )
995
+ download_checkpoint (
996
+ "mattersim-v1.0.0-5M.pth" , save_folder = checkpoint_folder
997
+ )
985
998
logger .info (f"Loading the pre-trained { os .path .basename (load_path )} model" )
986
999
else :
987
1000
logger .info ("Loading the model from %s" % load_path )
988
-
989
1001
assert os .path .exists (load_path ), f"Model file { load_path } not found"
990
1002
991
1003
checkpoint = torch .load (load_path , map_location = device )
0 commit comments