Skip to content

Commit d4989ce

Browse files
authored
Merge pull request #55 from choderalab/fix-issue-54
Add Irreps string filter
2 parents 69d1399 + 06aea6d commit d4989ce

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

mtenn/config.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ModelType(StringEnum):
4040
INVALID = "INVALID"
4141
visnet = "visnet"
4242

43+
4344
class StrategyConfig(StringEnum):
4445
"""
4546
Enum for possible MTENN Strategy classes.
@@ -708,7 +709,11 @@ def massage_irreps(cls, values):
708709

709710
# Combine Irreps into str
710711
irreps = "+".join(
711-
[f"{num_irreps}x{irrep}" for irrep, num_irreps in irreps.items()]
712+
[
713+
f"{num_irreps}x{irrep}"
714+
for irrep, num_irreps in irreps.items()
715+
if num_irreps > 0
716+
]
712717
)
713718

714719
# Make sure this Irreps string is valid
@@ -781,28 +786,25 @@ class ViSNetModelConfig(ModelConfigBase):
781786
max_z: int = Field(100, description="The maximum atomic numbers.")
782787
cutoff: float = Field(5.0, description="The cutoff distance.")
783788
max_num_neighbors: int = Field(
784-
32,
785-
description="The maximum number of neighbors considered for each atom."
786-
)
787-
vertex: bool = Field(
788-
False,
789-
description="Whether to use vertex geometric features."
789+
32, description="The maximum number of neighbors considered for each atom."
790790
)
791+
vertex: bool = Field(False, description="Whether to use vertex geometric features.")
791792
atomref: list[float] | None = Field(
792793
None,
793794
description=(
794795
"Reference values for single-atom properties. Should have length max_z"
795-
)
796+
),
796797
)
797798
reduce_op: str = Field(
798-
"sum",
799-
description="The type of reduction operation to apply. ['sum', 'mean']"
799+
"sum", description="The type of reduction operation to apply. ['sum', 'mean']"
800800
)
801801
mean: float = Field(0.0, description="The mean of the output distribution.")
802-
std: float = Field(1.0, description="The standard deviation of the output distribution.")
802+
std: float = Field(
803+
1.0, description="The standard deviation of the output distribution."
804+
)
803805
derivative: bool = Field(
804-
False,
805-
description="Whether to compute the derivative of the output with respect to the positions."
806+
False,
807+
description="Whether to compute the derivative of the output with respect to the positions.",
806808
)
807809

808810
@root_validator(pre=False)
@@ -813,11 +815,11 @@ def validate(cls, values):
813815
# Make sure atomref length is correct (this is required by PyG)
814816
atomref = values["atomref"]
815817
if (atomref is not None) and (len(atomref) != values["max_z"]):
816-
raise ValueError(f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})")
818+
raise ValueError(
819+
f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})"
820+
)
817821

818822
return values
819-
820-
821823

822824
def _build(self, mtenn_params={}):
823825
"""
@@ -837,6 +839,7 @@ def _build(self, mtenn_params={}):
837839
# Create an MTENN ViSNet model from PyG ViSNet model
838840

839841
from mtenn.conversion_utils.visnet import HAS_VISNET
842+
840843
if HAS_VISNET:
841844
from mtenn.conversion_utils import ViSNet
842845

@@ -874,5 +877,6 @@ def _build(self, mtenn_params={}):
874877
)
875878

876879
else:
877-
raise ImportError("ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42.")
878-
880+
raise ImportError(
881+
"ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42."
882+
)

0 commit comments

Comments
 (0)