Skip to content

Commit afe58a0

Browse files
ZeroKnightingXixian
and
Xixian
authored
fix: Potential load method (#64)
Co-authored-by: Xixian <[email protected]>
1 parent 92af875 commit afe58a0

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

docs/user_guide/getting_started.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ a list of structures using the ``Potential`` class.
5959
# load the model
6060
device = "cuda" if torch.cuda.is_available() else "cpu"
6161
print(f"Running MatterSim on {device}")
62-
potential = Potential.load(device=device)
62+
potential = Potential.from_checkpoint(device=device)
6363
6464
# build the dataloader that is compatible with MatterSim
6565
dataloader = build_dataloader(structures, only_inference=True)

src/mattersim/forcefield/potential.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -965,27 +965,39 @@ def load(
965965
if model_name.lower() != "m3gnet":
966966
raise NotImplementedError
967967

968-
current_dir = os.path.dirname(__file__)
968+
checkpoint_folder = os.path.expanduser("~/.local/mattersim/pretrained_models")
969+
os.makedirs(checkpoint_folder, exist_ok=True)
969970
if (
970971
load_path is None
971972
or load_path.lower() == "mattersim-v1.0.0-1m.pth"
972973
or load_path.lower() == "mattersim-v1.0.0-1m"
973974
):
974-
load_path = os.path.join(
975-
current_dir, "..", "pretrained_models/mattersim-v1.0.0-1M.pth"
976-
)
975+
load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-1M.pth")
976+
if not os.path.exists(load_path):
977+
logger.info(
978+
"The pre-trained model is not found locally, "
979+
"attempting to download it from the server."
980+
)
981+
download_checkpoint(
982+
"mattersim-v1.0.0-1M.pth", save_folder=checkpoint_folder
983+
)
977984
logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model")
978985
elif (
979986
load_path.lower() == "mattersim-v1.0.0-5m.pth"
980987
or load_path.lower() == "mattersim-v1.0.0-5m"
981988
):
982-
load_path = os.path.join(
983-
current_dir, "..", "pretrained_models/mattersim-v1.0.0-5M.pth"
984-
)
989+
load_path = os.path.join(checkpoint_folder, "mattersim-v1.0.0-5M.pth")
990+
if not os.path.exists(load_path):
991+
logger.info(
992+
"The pre-trained model is not found locally, "
993+
"attempting to download it from the server."
994+
)
995+
download_checkpoint(
996+
"mattersim-v1.0.0-5M.pth", save_folder=checkpoint_folder
997+
)
985998
logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model")
986999
else:
9871000
logger.info("Loading the model from %s" % load_path)
988-
9891001
assert os.path.exists(load_path), f"Model file {load_path} not found"
9901002

9911003
checkpoint = torch.load(load_path, map_location=device)

0 commit comments

Comments
 (0)