Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add PyG-based GAT implementation. #67

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ The input passed to this model should be a `dict` with the following keys (based
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)`
* `GAT`
* `g`: DGL graph object
* `x`: Tensor of input atom (node) features, shape of `(n,feats)`
* `edge_index`: Tensor giving source (first row) and dest (second row) atom indices, shape of `(2,n_bonds)`

The prediction can then be generated simply with:
```python
Expand Down
2 changes: 0 additions & 2 deletions devtools/conda-envs/mtenn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
2 changes: 0 additions & 2 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- fsspec
Expand Down
32 changes: 25 additions & 7 deletions docs/docs/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,37 @@ Below, we detail a basic example of building a default Graph Attention model and

.. code-block:: python

from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph
from mtenn.config import GATModelConfig
import rdkit.Chem as Chem
import torch

# Build model with GAT defaults
model = GATModelConfig().build()

# Build graph from SMILES
# Build mol
smiles = "CCCC"
g = SMILESToBigraph(
add_self_loop=True,
node_featurizer=CanonicalAtomFeaturizer(),
)(smiles)
mol = Chem.MolFromSmiles(smiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a convenience function to do this easily for user, easy to mess up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added one in asapdiscovery for us to use, but since there's no one right way to featurize a molecule I didn't want to add anything opinionated in here


# Get atomic numbers and bond indices (both directions)
atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()]
bond_idxs = [
atom_pair
for bond in mol.GetBonds()
for atom_pair in (
(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()),
(bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()),
)
]
# Add self bonds
bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()]

# Encode atomic numbers as one-hot, assume max num of 100
node_feats = torch.nn.functional.one_hot(
torch.tensor(atomic_nums), num_classes=100
).to(dtype=torch.float)
# Format bonds in correct shape
edge_index = torch.tensor(bond_idxs).t()

# Make a prediction
pred, _ = model({"g": g})
pred, _ = model({"x": node_feats, "edge_index": edge_index})

2 changes: 0 additions & 2 deletions docs/requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- pydantic >=1.10.8,<2.0.0a0
Expand Down
4 changes: 1 addition & 3 deletions environment-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- rdkit
- ase
4 changes: 1 addition & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- ase
244 changes: 18 additions & 226 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,192 +404,27 @@ def _check_grouped(values):

class GATModelConfig(ModelConfigBase):
"""
Class for constructing a graph attention ML model. Note that there are two methods
for defining the size of the model:

* If single values are passed for all parameters, the value of ``num_layers`` will
be used as the size of the model, and each layer will have the parameters given

* If a list of values is passed for any parameters, all parameters must be lists of
the same size, or single values. For parameters that are single values, that same
value will be used for each layer. For parameters that are lists, those lists will
be used

Parameters passed as strings are assumed to be comma-separated lists, and will first
be cast to lists of the appropriate type, and then processed as described above.

If lists of multiple different (non-1) sizes are found, an error will be raised.

Default values here are the default values given in DGL-LifeSci.
Class for constructing a GAT ML model. Default values here are based on the values
in DGL-LifeSci.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DGL-LifeSci gone now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the defaults are still based on the defaults in that package, even though we're not using their code anymore

"""

# Import as private, mainly so Sphinx doesn't autodoc it
from dgllife.utils import CanonicalAtomFeaturizer as _CanonicalAtomFeaturizer

# Dict of model params that can be passed as a list, and the type that each will be
# cast to
LIST_PARAMS: ClassVar[dict] = {
"hidden_feats": int,
"num_heads": int,
"feat_drops": float,
"attn_drops": float,
"alphas": float,
"residuals": bool,
"agg_modes": str,
"activations": None,
"biases": bool,
} #: :meta private:

model_type: ModelType = Field(ModelType.GAT, const=True)

in_feats: int = Field(
_CanonicalAtomFeaturizer().feat_size(),
description=(
"Input node feature size. Defaults to size of the "
"``CanonicalAtomFeaturizer``."
),
)
num_layers: int = Field(
2,
description=(
"Number of GAT layers. Ignored if a list of values is passed for any "
"other argument."
),
)
hidden_feats: str | int | list[int] = Field(
32,
description=(
"Output size of each GAT layer. If an ``int`` is passed, the value for "
"``num_layers`` will be used to determine the size of the model. If a list "
"of ``int`` s is passed, the size of the model will be inferred from the "
"length of the list."
),
)
num_heads: str | int | list[int] = Field(
4,
description=(
"Number of attention heads for each GAT layer. Passing an ``int`` or list "
"of ``int`` s functions similarly as for ``hidden_feats``."
),
)
feat_drops: str | float | list[float] = Field(
0,
description=(
"Dropout of input features for each GAT layer. Passing a ``float`` or "
"list of ``float`` s functions similarly as for ``hidden_feats``."
),
)
attn_drops: str | float | list[float] = Field(
0,
description=(
"Dropout of attention values for each GAT layer. Passing a ``float`` or "
"list of ``float`` s functions similarly as for ``hidden_feats``."
),
)
alphas: str | float | list[float] = Field(
0.2,
description=(
"Hyperparameter for ``LeakyReLU`` gate for each GAT layer. Passing a "
"``float`` or list of ``float`` s functions similarly as for "
"``hidden_feats``."
),
)
residuals: str | bool | list[bool] = Field(
True,
description=(
"Whether to use residual connection for each GAT layer. Passing a ``bool`` "
"or list of ``bool`` s functions similarly as for ``hidden_feats``."
),
)
agg_modes: str | list[str] = Field(
"flatten",
description=(
"Which aggregation mode [flatten, mean] to use for each GAT layer. "
"Passing a ``str`` or list of ``str`` s functions similarly as for "
"``hidden_feats``."
),
)
activations: Callable | list[Callable] | list[None] | None = Field(
None,
description=(
"Activation function for each GAT layer. Passing a function or "
"list of functions functions similarly as for ``hidden_feats``."
),
)
biases: str | bool | list[bool] = Field(
True,
in_channels: int = Field(
-1,
description=(
"Whether to use bias for each GAT layer. Passing a ``bool`` or "
"list of ``bool`` s functions similarly as for ``hidden_feats``."
"Input size. Can be left as -1 (default) to interpret based on "
"first forward call."
),
)
allow_zero_in_degree: bool = Field(
False, description="Allow zero in degree nodes for all graph layers."
hidden_channels: int = Field(32, description="Hidden embedding size.")
num_layers: int = Field(2, description="Number of GAT layers.")
dropout: float = Field(0, description="Dropout probability.")
heads: int = Field(4, description="Number of attention heads for each GAT layer.")
negative_slope: float = Field(
0.2, description="LeakyReLU angle of the negative slope."
)

# Internal tracker for if the parameters were originally built from lists or using
# num_layers
_from_num_layers = False

@root_validator(pre=False)
def massage_into_lists(cls, values) -> GATModelConfig:
"""
Validator to handle unifying all the values into the proper list forms based on
the rules described in the class docstring.
"""
# First convert string lists to actual lists
for param, param_type in cls.LIST_PARAMS.items():
param_val = values[param]
if isinstance(param_val, str):
try:
param_val = list(map(param_type, param_val.split(",")))
except ValueError:
raise ValueError(
f"Unable to parse value {param_val} for parameter {param}. "
f"Expected type of {param_type}."
)
values[param] = param_val

# Get sizes of all lists
list_lens = {}
for p in cls.LIST_PARAMS:
param_val = values[p]
if not isinstance(param_val, list):
# Shouldn't be possible at this point but just in case
param_val = [param_val]
values[p] = param_val
list_lens[p] = len(param_val)

# Check that there's only one length present
list_lens_set = set(list_lens.values())
# This could be 0 if lists of length 1 were passed, which is valid
if len(list_lens_set - {1}) > 1:
raise ValueError(
"All passed parameter lists must be the same value. "
f"Instead got list lengths of: {list_lens}"
)
elif list_lens_set == {1}:
# If all lists have only one value, we defer to the value passed to
# num_layers, as described in the class docstring
num_layers = values["num_layers"]
values["_from_num_layers"] = True
else:
num_layers = max(list_lens_set)
values["_from_num_layers"] = False

values["num_layers"] = num_layers
# If we just want a model with one layer, can return early since we've already
# converted everything into lists
if num_layers == 1:
return values

# Adjust any length 1 list to be the right length
for p, list_len in list_lens.items():
if list_len == 1:
values[p] = values[p] * num_layers

return values

def _build(self, mtenn_params={}):
"""
Build an ``mtenn`` GAT ``Model`` from this config.
Expand Down Expand Up @@ -624,60 +459,17 @@ def _build(self, mtenn_params={}):
from mtenn.conversion_utils.gat import GAT

model = GAT(
in_feats=self.in_feats,
hidden_feats=self.hidden_feats,
num_heads=self.num_heads,
feat_drops=self.feat_drops,
attn_drops=self.attn_drops,
alphas=self.alphas,
residuals=self.residuals,
agg_modes=self.agg_modes,
activations=self.activations,
biases=self.biases,
allow_zero_in_degree=self.allow_zero_in_degree,
in_channels=self.in_channels,
hidden_channels=self.hidden_channels,
num_layers=self.num_layers,
dropout=self.dropout,
heads=self.heads,
negative_slope=self.negative_slope,
)

pred_readout = mtenn_params.get("pred_readout", None)
return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True)

def _update(self, config_updates={}) -> GATModelConfig:
"""
GAT-specific implementation of updating logic. Need to handle stuff specially
to make sure that the original method of specifying parameters (either from a
passed value of ``num_layers`` or inferred from each parameter being a list) is
maintained.

:meta public:

Parameters
----------
config_updates : dict
Dictionary mapping from field names to new values

Returns
-------
GATModelConfig
New ``GATModelConfig`` object
"""
orig_config = self.dict()
if self._from_num_layers or ("num_layers" in config_updates):
# If originally generated from num_layers, want to pull out the first entry
# in each list param so it can be re-broadcast with (potentially) new
# num_layers
for param_name in GATModelConfig.LIST_PARAMS.keys():
orig_config[param_name] = orig_config[param_name][0]

# Get new config by overwriting old stuff with any new stuff
new_config = orig_config | config_updates

# A bit hacky, maybe try and change?
if isinstance(new_config["activations"], list) and (
new_config["activations"][0] is None
):
new_config["activations"] = None

return GATModelConfig(**new_config)


class SchNetModelConfig(ModelConfigBase):
"""
Expand Down
Loading
Loading