Skip to content

Commit 36b9f8b

Browse files
authored
Migrate to pydantic 2.0+
Migrate to pydantic 2.0+
2 parents 8e157d8 + 965f0b7 commit 36b9f8b

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

devtools/conda-envs/mtenn.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ channels:
33
- conda-forge
44
dependencies:
55
- pytorch
6-
- pytorch_geometric
6+
- pytorch_geometric >=2.5.0
77
- pytorch_cluster
88
- pytorch_scatter
99
- pytorch_sparse
10+
- pydantic >=2.0.0a0
1011
- numpy
1112
- h5py
1213
- e3nn
1314
- dgllife
1415
- dgl
1516
- rdkit
1617
- ase
18+
- fsspec

devtools/conda-envs/test_env.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- pytorch_cluster
88
- pytorch_scatter
99
- pytorch_sparse
10+
- pydantic >=2.0.0a0
1011
- numpy
1112
- h5py
1213
- e3nn
@@ -19,5 +20,4 @@ dependencies:
1920
- pytest
2021
- pytest-cov
2122
- codecov
22-
- pydantic >=1.10.8,<2.0.0a0
2323
- fsspec

mtenn/config.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import abc
1717
from enum import Enum
18-
from pydantic import BaseModel, Field, root_validator
18+
from pydantic import model_validator, ConfigDict, BaseModel, Field
1919
import random
20-
from typing import Callable, ClassVar
20+
from typing import Literal, Callable, ClassVar
2121
import mtenn.combination
2222
import mtenn.readout
2323
import mtenn.model
@@ -140,7 +140,8 @@ class ModelConfigBase(BaseModel):
140140
to implement the ``_build`` method in order to be used.
141141
"""
142142

143-
model_type: ModelType = Field(ModelType.INVALID, const=True, allow_mutation=False)
143+
model_type: Literal[ModelType.INVALID] = ModelType.INVALID
144+
144145

145146
# Random seed optional for reproducibility
146147
rand_seed: int | None = Field(
@@ -240,9 +241,7 @@ class ModelConfigBase(BaseModel):
240241
"``comb_substrate``."
241242
),
242243
)
243-
244-
class Config:
245-
validate_assignment = True
244+
model_config = ConfigDict(validate_assignment=True)
246245

247246
def build(self) -> mtenn.model.Model:
248247
"""
@@ -394,7 +393,7 @@ def _check_grouped(values):
394393
Makes sure that a Combination method is passed if using a GroupedModel. Only
395394
needs to be called for structure-based models.
396395
"""
397-
if values["grouped"] and (not values["combination"]):
396+
if values.grouped and not values.combination:
398397
raise ValueError("combination must be specified for a GroupedModel.")
399398

400399

@@ -436,7 +435,7 @@ class GATModelConfig(ModelConfigBase):
436435
"biases": bool,
437436
} #: :meta private:
438437

439-
model_type: ModelType = Field(ModelType.GAT, const=True)
438+
model_type: Literal[ModelType.GAT] = ModelType.GAT
440439

441440
in_feats: int = Field(
442441
_CanonicalAtomFeaturizer().feat_size(),
@@ -527,14 +526,16 @@ class GATModelConfig(ModelConfigBase):
527526
# num_layers
528527
_from_num_layers = False
529528

530-
@root_validator(pre=False)
531-
def massage_into_lists(cls, values) -> GATModelConfig:
529+
@model_validator(mode="after")
530+
def massage_into_lists(self) -> GATModelConfig:
532531
"""
533532
Validator to handle unifying all the values into the proper list forms based on
534533
the rules described in the class docstring.
535534
"""
535+
values = self.dict()
536+
536537
# First convert string lists to actual lists
537-
for param, param_type in cls.LIST_PARAMS.items():
538+
for param, param_type in self.LIST_PARAMS.items():
538539
param_val = values[param]
539540
if isinstance(param_val, str):
540541
try:
@@ -548,7 +549,7 @@ def massage_into_lists(cls, values) -> GATModelConfig:
548549

549550
# Get sizes of all lists
550551
list_lens = {}
551-
for p in cls.LIST_PARAMS:
552+
for p in self.LIST_PARAMS:
552553
param_val = values[p]
553554
if not isinstance(param_val, list):
554555
# Shouldn't be possible at this point but just in case
@@ -577,14 +578,17 @@ def massage_into_lists(cls, values) -> GATModelConfig:
577578
# If we just want a model with one layer, can return early since we've already
578579
# converted everything into lists
579580
if num_layers == 1:
580-
return values
581+
# update self with the new values
582+
self.__dict__.update(values)
583+
581584

582585
# Adjust any length 1 list to be the right length
583586
for p, list_len in list_lens.items():
584587
if list_len == 1:
585588
values[p] = values[p] * num_layers
586589

587-
return values
590+
self.__dict__.update(values)
591+
return self
588592

589593
def _build(self, mtenn_params={}):
590594
"""
@@ -681,7 +685,7 @@ class SchNetModelConfig(ModelConfigBase):
681685
given in PyG.
682686
"""
683687

684-
model_type: ModelType = Field(ModelType.schnet, const=True)
688+
model_type: Literal[ModelType.schnet] = ModelType.schnet
685689

686690
hidden_channels: int = Field(128, description="Hidden embedding size.")
687691
num_filters: int = Field(
@@ -738,13 +742,14 @@ class SchNetModelConfig(ModelConfigBase):
738742
),
739743
)
740744

741-
@root_validator(pre=False)
745+
@model_validator(mode="after")
746+
@classmethod
742747
def validate(cls, values):
743748
# Make sure the grouped stuff is properly assigned
744749
ModelConfigBase._check_grouped(values)
745750

746751
# Make sure atomref length is correct (this is required by PyG)
747-
atomref = values["atomref"]
752+
atomref = values.atomref
748753
if (atomref is not None) and (len(atomref) != 100):
749754
raise ValueError(f"atomref must be length 100 (got {len(atomref)})")
750755

@@ -816,7 +821,7 @@ class E3NNModelConfig(ModelConfigBase):
816821
Class for constructing an e3nn ML model.
817822
"""
818823

819-
model_type: ModelType = Field(ModelType.e3nn, const=True)
824+
model_type: Literal[ModelType.e3nn] = ModelType.e3nn
820825

821826
num_atom_types: int = Field(
822827
100,
@@ -862,7 +867,8 @@ class E3NNModelConfig(ModelConfigBase):
862867
num_neighbors: float = Field(25, description="Typical number of neighbor nodes.")
863868
num_nodes: float = Field(4700, description="Typical number of nodes in a graph.")
864869

865-
@root_validator(pre=False)
870+
@model_validator(mode="after")
871+
@classmethod
866872
def massage_irreps(cls, values):
867873
"""
868874
Check that the value given for ``irreps_hidden`` can be converted into an Irreps
@@ -874,7 +880,7 @@ def massage_irreps(cls, values):
874880
ModelConfigBase._check_grouped(values)
875881

876882
# Now deal with irreps
877-
irreps = values["irreps_hidden"]
883+
irreps = values.irreps_hidden
878884
# First see if this string should be converted into a dict
879885
if isinstance(irreps, str):
880886
if ":" in irreps:
@@ -923,7 +929,7 @@ def massage_irreps(cls, values):
923929
except ValueError:
924930
raise ValueError(f"Couldn't parse irreps dict: {orig_irreps}")
925931

926-
values["irreps_hidden"] = irreps
932+
values.irreps_hidden = irreps
927933
return values
928934

929935
def _build(self, mtenn_params={}):
@@ -994,7 +1000,7 @@ class ViSNetModelConfig(ModelConfigBase):
9941000
given in PyG.
9951001
"""
9961002

997-
model_type: ModelType = Field(ModelType.visnet, const=True)
1003+
model_type: Literal[ModelType.visnet] = ModelType.visnet
9981004
lmax: int = Field(1, description="The maximum degree of the spherical harmonics.")
9991005
vecnorm_type: str | None = Field(
10001006
None, description="The type of normalization to apply to the vectors."
@@ -1041,7 +1047,8 @@ class ViSNetModelConfig(ModelConfigBase):
10411047
),
10421048
)
10431049

1044-
@root_validator(pre=False)
1050+
@model_validator(mode="after")
1051+
@classmethod
10451052
def validate(cls, values):
10461053
"""
10471054
Check that ``atomref`` and ``max_z`` agree.
@@ -1050,10 +1057,10 @@ def validate(cls, values):
10501057
ModelConfigBase._check_grouped(values)
10511058

10521059
# Make sure atomref length is correct (this is required by PyG)
1053-
atomref = values["atomref"]
1054-
if (atomref is not None) and (len(atomref) != values["max_z"]):
1060+
atomref = values.atomref
1061+
if (atomref is not None) and (len(atomref) != values.max_z):
10551062
raise ValueError(
1056-
f"atomref length must match max_z. (Expected {values['max_z']}, got {len(atomref)})"
1063+
f"atomref length must match max_z. (Expected {values.max_z}, got {len(atomref)})"
10571064
)
10581065

10591066
return values

0 commit comments

Comments
 (0)