15
15
16
16
import abc
17
17
from enum import Enum
18
- from pydantic import BaseModel , Field , root_validator
18
+ from pydantic import model_validator , ConfigDict , BaseModel , Field
19
19
import random
20
- from typing import Callable , ClassVar
20
+ from typing import Literal , Callable , ClassVar
21
21
import mtenn .combination
22
22
import mtenn .readout
23
23
import mtenn .model
@@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel):
140
140
to implement the ``_build`` method in order to be used.
141
141
"""
142
142
143
- model_type : ModelType = Field (ModelType .INVALID , const = True , allow_mutation = False )
143
+ model_type : Literal [ModelType .INVALID ] = ModelType .INVALID
144
+
144
145
145
146
# Random seed optional for reproducibility
146
147
rand_seed : int | None = Field (
@@ -240,9 +241,7 @@ class ModelConfigBase(BaseModel):
240
241
"``comb_substrate``."
241
242
),
242
243
)
243
-
244
- class Config :
245
- validate_assignment = True
244
+ model_config = ConfigDict (validate_assignment = True )
246
245
247
246
def build (self ) -> mtenn .model .Model :
248
247
"""
@@ -394,7 +393,7 @@ def _check_grouped(values):
394
393
Makes sure that a Combination method is passed if using a GroupedModel. Only
395
394
needs to be called for structure-based models.
396
395
"""
397
- if values [ " grouped" ] and ( not values [ " combination" ]) :
396
+ if values . grouped and not values . combination :
398
397
raise ValueError ("combination must be specified for a GroupedModel." )
399
398
400
399
@@ -436,7 +435,7 @@ class GATModelConfig(ModelConfigBase):
436
435
"biases" : bool ,
437
436
} #: :meta private:
438
437
439
- model_type : ModelType = Field ( ModelType .GAT , const = True )
438
+ model_type : Literal [ ModelType . GAT ] = ModelType .GAT
440
439
441
440
in_feats : int = Field (
442
441
_CanonicalAtomFeaturizer ().feat_size (),
@@ -527,14 +526,16 @@ class GATModelConfig(ModelConfigBase):
527
526
# num_layers
528
527
_from_num_layers = False
529
528
530
- @root_validator ( pre = False )
531
- def massage_into_lists (cls , values ) -> GATModelConfig :
529
+ @model_validator ( mode = "after" )
530
+ def massage_into_lists (self ) -> GATModelConfig :
532
531
"""
533
532
Validator to handle unifying all the values into the proper list forms based on
534
533
the rules described in the class docstring.
535
534
"""
535
+ values = self .dict ()
536
+
536
537
# First convert string lists to actual lists
537
- for param , param_type in cls .LIST_PARAMS .items ():
538
+ for param , param_type in self .LIST_PARAMS .items ():
538
539
param_val = values [param ]
539
540
if isinstance (param_val , str ):
540
541
try :
@@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig:
548
549
549
550
# Get sizes of all lists
550
551
list_lens = {}
551
- for p in cls .LIST_PARAMS :
552
+ for p in self .LIST_PARAMS :
552
553
param_val = values [p ]
553
554
if not isinstance (param_val , list ):
554
555
# Shouldn't be possible at this point but just in case
@@ -577,14 +578,17 @@ def massage_into_lists(cls, values) -> GATModelConfig:
577
578
# If we just want a model with one layer, can return early since we've already
578
579
# converted everything into lists
579
580
if num_layers == 1 :
580
- return values
581
+ # update self with the new values
582
+ self .__dict__ .update (values )
583
+
581
584
582
585
# Adjust any length 1 list to be the right length
583
586
for p , list_len in list_lens .items ():
584
587
if list_len == 1 :
585
588
values [p ] = values [p ] * num_layers
586
589
587
- return values
590
+ self .__dict__ .update (values )
591
+ return self
588
592
589
593
def _build (self , mtenn_params = {}):
590
594
"""
@@ -681,7 +685,7 @@ class SchNetModelConfig(ModelConfigBase):
681
685
given in PyG.
682
686
"""
683
687
684
- model_type : ModelType = Field ( ModelType .schnet , const = True )
688
+ model_type : Literal [ ModelType . schnet ] = ModelType .schnet
685
689
686
690
hidden_channels : int = Field (128 , description = "Hidden embedding size." )
687
691
num_filters : int = Field (
@@ -738,13 +742,14 @@ class SchNetModelConfig(ModelConfigBase):
738
742
),
739
743
)
740
744
741
- @root_validator (pre = False )
745
+ @model_validator (mode = "after" )
746
+ @classmethod
742
747
def validate (cls , values ):
743
748
# Make sure the grouped stuff is properly assigned
744
749
ModelConfigBase ._check_grouped (values )
745
750
746
751
# Make sure atomref length is correct (this is required by PyG)
747
- atomref = values [ " atomref" ]
752
+ atomref = values . atomref
748
753
if (atomref is not None ) and (len (atomref ) != 100 ):
749
754
raise ValueError (f"atomref must be length 100 (got { len (atomref )} )" )
750
755
@@ -816,7 +821,7 @@ class E3NNModelConfig(ModelConfigBase):
816
821
Class for constructing an e3nn ML model.
817
822
"""
818
823
819
- model_type : ModelType = Field ( ModelType .e3nn , const = True )
824
+ model_type : Literal [ ModelType . e3nn ] = ModelType .e3nn
820
825
821
826
num_atom_types : int = Field (
822
827
100 ,
@@ -862,7 +867,8 @@ class E3NNModelConfig(ModelConfigBase):
862
867
num_neighbors : float = Field (25 , description = "Typical number of neighbor nodes." )
863
868
num_nodes : float = Field (4700 , description = "Typical number of nodes in a graph." )
864
869
865
- @root_validator (pre = False )
870
+ @model_validator (mode = "after" )
871
+ @classmethod
866
872
def massage_irreps (cls , values ):
867
873
"""
868
874
Check that the value given for ``irreps_hidden`` can be converted into an Irreps
@@ -874,7 +880,7 @@ def massage_irreps(cls, values):
874
880
ModelConfigBase ._check_grouped (values )
875
881
876
882
# Now deal with irreps
877
- irreps = values [ " irreps_hidden" ]
883
+ irreps = values . irreps_hidden
878
884
# First see if this string should be converted into a dict
879
885
if isinstance (irreps , str ):
880
886
if ":" in irreps :
@@ -923,7 +929,7 @@ def massage_irreps(cls, values):
923
929
except ValueError :
924
930
raise ValueError (f"Couldn't parse irreps dict: { orig_irreps } " )
925
931
926
- values [ " irreps_hidden" ] = irreps
932
+ values . irreps_hidden = irreps
927
933
return values
928
934
929
935
def _build (self , mtenn_params = {}):
@@ -994,7 +1000,7 @@ class ViSNetModelConfig(ModelConfigBase):
994
1000
given in PyG.
995
1001
"""
996
1002
997
- model_type : ModelType = Field ( ModelType .visnet , const = True )
1003
+ model_type : Literal [ ModelType . visnet ] = ModelType .visnet
998
1004
lmax : int = Field (1 , description = "The maximum degree of the spherical harmonics." )
999
1005
vecnorm_type : str | None = Field (
1000
1006
None , description = "The type of normalization to apply to the vectors."
@@ -1041,7 +1047,8 @@ class ViSNetModelConfig(ModelConfigBase):
1041
1047
),
1042
1048
)
1043
1049
1044
- @root_validator (pre = False )
1050
+ @model_validator (mode = "after" )
1051
+ @classmethod
1045
1052
def validate (cls , values ):
1046
1053
"""
1047
1054
Check that ``atomref`` and ``max_z`` agree.
@@ -1050,10 +1057,10 @@ def validate(cls, values):
1050
1057
ModelConfigBase ._check_grouped (values )
1051
1058
1052
1059
# Make sure atomref length is correct (this is required by PyG)
1053
- atomref = values [ " atomref" ]
1054
- if (atomref is not None ) and (len (atomref ) != values [ " max_z" ] ):
1060
+ atomref = values . atomref
1061
+ if (atomref is not None ) and (len (atomref ) != values . max_z ):
1055
1062
raise ValueError (
1056
- f"atomref length must match max_z. (Expected { values [ ' max_z' ] } , got { len (atomref )} )"
1063
+ f"atomref length must match max_z. (Expected { values . max_z } , got { len (atomref )} )"
1057
1064
)
1058
1065
1059
1066
return values
0 commit comments