@@ -40,6 +40,7 @@ class ModelType(StringEnum):
40
40
INVALID = "INVALID"
41
41
visnet = "visnet"
42
42
43
+
43
44
class StrategyConfig (StringEnum ):
44
45
"""
45
46
Enum for possible MTENN Strategy classes.
@@ -708,7 +709,11 @@ def massage_irreps(cls, values):
708
709
709
710
# Combine Irreps into str
710
711
irreps = "+" .join (
711
- [f"{ num_irreps } x{ irrep } " for irrep , num_irreps in irreps .items ()]
712
+ [
713
+ f"{ num_irreps } x{ irrep } "
714
+ for irrep , num_irreps in irreps .items ()
715
+ if num_irreps > 0
716
+ ]
712
717
)
713
718
714
719
# Make sure this Irreps string is valid
@@ -781,28 +786,25 @@ class ViSNetModelConfig(ModelConfigBase):
781
786
max_z : int = Field (100 , description = "The maximum atomic numbers." )
782
787
cutoff : float = Field (5.0 , description = "The cutoff distance." )
783
788
max_num_neighbors : int = Field (
784
- 32 ,
785
- description = "The maximum number of neighbors considered for each atom."
786
- )
787
- vertex : bool = Field (
788
- False ,
789
- description = "Whether to use vertex geometric features."
789
+ 32 , description = "The maximum number of neighbors considered for each atom."
790
790
)
791
+ vertex : bool = Field (False , description = "Whether to use vertex geometric features." )
791
792
atomref : list [float ] | None = Field (
792
793
None ,
793
794
description = (
794
795
"Reference values for single-atom properties. Should have length max_z"
795
- )
796
+ ),
796
797
)
797
798
reduce_op : str = Field (
798
- "sum" ,
799
- description = "The type of reduction operation to apply. ['sum', 'mean']"
799
+ "sum" , description = "The type of reduction operation to apply. ['sum', 'mean']"
800
800
)
801
801
mean : float = Field (0.0 , description = "The mean of the output distribution." )
802
- std : float = Field (1.0 , description = "The standard deviation of the output distribution." )
802
+ std : float = Field (
803
+ 1.0 , description = "The standard deviation of the output distribution."
804
+ )
803
805
derivative : bool = Field (
804
- False ,
805
- description = "Whether to compute the derivative of the output with respect to the positions."
806
+ False ,
807
+ description = "Whether to compute the derivative of the output with respect to the positions." ,
806
808
)
807
809
808
810
@root_validator (pre = False )
@@ -813,11 +815,11 @@ def validate(cls, values):
813
815
# Make sure atomref length is correct (this is required by PyG)
814
816
atomref = values ["atomref" ]
815
817
if (atomref is not None ) and (len (atomref ) != values ["max_z" ]):
816
- raise ValueError (f"atomref length must match max_z. (Expected { values ['max_z' ]} , got { len (atomref )} )" )
818
+ raise ValueError (
819
+ f"atomref length must match max_z. (Expected { values ['max_z' ]} , got { len (atomref )} )"
820
+ )
817
821
818
822
return values
819
-
820
-
821
823
822
824
def _build (self , mtenn_params = {}):
823
825
"""
@@ -837,6 +839,7 @@ def _build(self, mtenn_params={}):
837
839
# Create an MTENN ViSNet model from PyG ViSNet model
838
840
839
841
from mtenn .conversion_utils .visnet import HAS_VISNET
842
+
840
843
if HAS_VISNET :
841
844
from mtenn .conversion_utils import ViSNet
842
845
@@ -874,5 +877,6 @@ def _build(self, mtenn_params={}):
874
877
)
875
878
876
879
else :
877
- raise ImportError ("ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42." )
878
-
880
+ raise ImportError (
881
+ "ViSNet not found. Is your PyG >=2.5.0? Refer to issue #42."
882
+ )
0 commit comments