diff --git a/README.md b/README.md index 3a9917b..17dbaf3 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,25 @@ # Geometric Vector Perceptron -Implementation of equivariant GVP-GNNs as described in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror. +Implementation of equivariant GVP-GNNs as described +in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B +Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror. -**UPDATE:** Also includes equivariant GNNs with vector gating as described in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S Eismann, P Soni, and RO Dror. +**UPDATE:** Also includes equivariant GNNs with vector gating as described +in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S +Eismann, P Soni, and RO Dror. -Scripts for training / testing / sampling on protein design and training / testing on all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided. +Scripts for training / testing / sampling on protein design and training / testing on +all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided. -**Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be found [here](https://github.com/drorlab/gvp). +**Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be +found [here](https://github.com/drorlab/gvp).

## Requirements + * UNIX environment * python==3.6.13 * torch==1.8.1 @@ -29,26 +36,37 @@ While we have not tested with other versions, any reasonably recent versions of ## General usage We provide classes in three modules: + * `gvp`: core GVP modules and GVP-GNN layers * `gvp.data`: data pipelines for both general use and protein design * `gvp.models`: implementations of MQA and CPD models * `gvp.atom3d`: models and data pipelines for ATOM3D -The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` and `gvp.models` for your specific application, with the existing classes serving as examples. +The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` +and `gvp.models` for your specific application, with the existing classes serving as examples. -**Installation:** Download this repository and run `python setup.py develop` or `pip install . -e`. Be sure to manually install `torch_geometric` first! +**Installation:** Download this repository and run `python setup.py develop` or `pip install . -e`. Be sure to manually +install `torch_geometric` first! -**Tuple representation:** All inputs and outputs with both scalar and vector channels are represented as a tuple of two tensors `(s, V)`. Similarly, all dimensions should be specified as tuples `(n_scalar, n_vector)` where `n_scalar` and `n_vector` are the number of scalar and vector features, respectively. All `V` tensors must be shaped as `[..., n_vector, 3]`, not `[..., 3, n_vector]`. +**Tuple representation:** All inputs and outputs with both scalar and vector channels are represented as a tuple of two +tensors `(s, V)`. Similarly, all dimensions should be specified as tuples `(n_scalar, n_vector)` where `n_scalar` +and `n_vector` are the number of scalar and vector features, respectively. All `V` tensors must be shaped +as `[..., n_vector, 3]`, not `[..., 3, n_vector]`. -**Batching:** We adopt the `torch_geometric` convention of absorbing the batch dimension into the node dimension and keeping track of batch index in a separate tensor. +**Batching:** We adopt the `torch_geometric` convention of absorbing the batch dimension into the node dimension and +keeping track of batch index in a separate tensor. -**Amino acids:** Models view sequences as int tensors and are agnostic to aa-to-int mappings. Such mappings are specified as the `letter_to_num` attribute of `gvp.data.ProteinGraphDataset`. Currently, only the 20 standard amino acids are supported. +**Amino acids:** Models view sequences as int tensors and are agnostic to aa-to-int mappings. Such mappings are +specified as the `letter_to_num` attribute of `gvp.data.ProteinGraphDataset`. Currently, only the 20 standard amino +acids are supported. -For all classes, see the docstrings for more detailed usage. If you have any questions, please contact bjing@cs.stanford.edu. +For all classes, see the docstrings for more detailed usage. If you have any questions, please contact +bjing@cs.stanford.edu. ### Core GVP classes The class `gvp.GVP` implements a Geometric Vector Perceptron. + ``` import gvp @@ -56,17 +74,26 @@ in_dims = scalars_in, vectors_in out_dims = scalars_out, vectors_out gvp_ = gvp.GVP(in_dims, out_dims) ``` + To use vector gating, pass in `vector_gate=True` and the appropriate activations. + ``` gvp_ = gvp.GVP(in_dims, out_dims, activations=(F.relu, None), vector_gate=True) ``` -The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave like their scalar-valued counterparts if passed a single tensor. + +The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal +dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave +like their scalar-valued counterparts if passed a single tensor. + ``` dropout = gvp.Dropout(drop_rate=0.1) layernorm = gvp.LayerNorm(out_dims) ``` -The function `gvp.randn` returns tuples `(s, V)` drawn from a standard normal. Such tuples can be directly used in a forward pass. + +The function `gvp.randn` returns tuples `(s, V)` drawn from a standard normal. Such tuples can be directly used in a +forward pass. + ``` x = gvp.randn(n=5, dims=in_dims) # x = (s, V) with s.shape = [5, scalars_in] and V.shape = [5, vectors_in, 3] @@ -75,7 +102,9 @@ out = gvp_(x) out = drouput(out) out = layernorm(out) ``` + Finally, we provide utility functions for adding, concatenating, and indexing into such tuples. + ``` y = gvp.randn(n=5, dims=in_dims) z = gvp.tuple_sum(x, y) @@ -85,8 +114,12 @@ z = gvp.tuple_cat(x, y, dim=-2) # concat along node / batch axis node_mask = torch.rand(5) < 0.5 z = gvp.tuple_index(x, node_mask) # select half the nodes / batch at random ``` + ### GVP-GNN layers -The class `GVPConv` is a `torch_geometric.MessagePassing` module which forms messages and aggregates them at the destination node, returning new node embeddings. The original embeddings are not updated. + +The class `GVPConv` is a `torch_geometric.MessagePassing` module which forms messages and aggregates them at the +destination node, returning new node embeddings. The original embeddings are not updated. + ``` nodes = gvp.randn(n=5, in_dims) edges = gvp.randn(n=10, edge_dims) # 10 random edges @@ -95,22 +128,31 @@ edge_index = torch.randint(0, 5, (2, 10), device=device) conv = gvp.GVPConv(in_dims, out_dims, edge_dims) out = conv(nodes, edge_index, edges) ``` -The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed. + +The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as +described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed. + ``` layer = gvp.GVPConvLayer(node_dims, edge_dims) nodes = layer(nodes, edge_index, edges) ``` -The class also allows updates where incoming messages where src >= dst are computed using a different set of source embeddings, as in autoregressive models. + +The class also allows updates where incoming messages where src >= dst are computed using a different set of source +embeddings, as in autoregressive models. + ``` nodes_static = gvp.randn(n=5, in_dims) layer = gvp.GVPConvLayer(node_dims, edge_dims, autoregressive=True) nodes = layer(nodes, edge_index, edges, autoregressive_x=nodes_static) ``` + Both `GVPConv` and `GVPConvLayer` accept arguments `activations` and `vector_gate` to use vector gating. -### Loading data +### Loading data -The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a JSON/dictionary format to specify backbone structures: +The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. +Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a +JSON/dictionary format to specify backbone structures: ``` [ @@ -122,7 +164,10 @@ The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures ... ] ``` -For each structure, `coords` should be a `num_residues x 4 x 3` nested list of the positions of the backbone N, C-alpha, C, and O atoms of each residue (in that order). + +For each structure, `coords` should be a `num_residues x 4 x 3` nested list of the positions of the backbone N, C-alpha, +C, and O atoms of each residue (in that order). + ``` import gvp.data @@ -130,7 +175,9 @@ import gvp.data dataset = gvp.data.ProteinGraphDataset(structures) # dataset[i] is featurized graph corresponding to structures[i] ``` + The returned graphs are of type `torch_geometric.data.Data` with attributes + * `x`: alpha carbon coordinates * `seq`: sequence converted to int tensor according to attribute `self.letter_to_num` * `name`, `edge_index` @@ -138,13 +185,20 @@ The returned graphs are of type `torch_geometric.data.Data` with attributes * `edge_s`, `edge_v`: edge features as described in the paper with dims `(32, 1)` * `mask`: false for nodes with any nan coordinates -The `gvp.data.ProteinGraphDataset` can be used with a `torch.utils.data.DataLoader`. We supply a class `gvp.data.BatchSampler` which will form batches based on the number of total nodes in a batch. Use of this sampler is optional. +The `gvp.data.ProteinGraphDataset` can be used with a `torch.utils.data.DataLoader`. We supply a +class `gvp.data.BatchSampler` which will form batches based on the number of total nodes in a batch. Use of this sampler +is optional. + ``` node_counts = [len(s['seq']) for s in structures] sampler = gvp.data.BatchSampler(node_counts, max_nodes=3000) dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) ``` -The dataloader will return batched graphs of type `torch_geometric.data.Batch` with an additional `batch` attibute. The attributes of the `Batch` will then need to be formed into `(s, V)` tuples before passing into a GVP-GNN layer or network. + +The dataloader will return batched graphs of type `torch_geometric.data.Batch` with an additional `batch` attibute. The +attributes of the `Batch` will then need to be formed into `(s, V)` tuples before passing into a GVP-GNN layer or +network. + ``` for batch in dataloader: batch = batch.to(device) # optional @@ -155,7 +209,12 @@ for batch in dataloader: ``` ### Ready-to-use protein GNNs -We provide two fully specified networks which take in protein graphs and output a scalar prediction for each graph (`gvp.models.MQAModel`) or a 20-dimensional feature vector for each node (`gvp.models.CPDModel`), corresponding to the two tasks in our paper. Note that if you are using the unmodified `gvp.data.ProteinGraphDataset`, `node_in_dims` and `edge_in_dims` must be `(6, 3)` and `(32, 1)`, respectively. + +We provide two fully specified networks which take in protein graphs and output a scalar prediction for each +graph (`gvp.models.MQAModel`) or a 20-dimensional feature vector for each node (`gvp.models.CPDModel`), corresponding to +the two tasks in our paper. Note that if you are using the unmodified `gvp.data.ProteinGraphDataset`, `node_in_dims` +and `edge_in_dims` must be `(6, 3)` and `(32, 1)`, respectively. + ``` import gvp.models @@ -173,13 +232,21 @@ out = cpd_model(nodes, batch.edge_index, ``` ## Protein design -We provide a script `run_cpd.py` to train, validate, and test a `CPDModel` as specified in the paper using the CATH 4.2 dataset and TS50 dataset. If you want to use a trained model on new structures, see the section "Sampling" below. + +We provide a script `run_cpd.py` to train, validate, and test a `CPDModel` as specified in the paper using the CATH 4.2 +dataset and TS50 dataset. If you want to use a trained model on new structures, see the section "Sampling" below. ### Fetching data -Run `getCATH.sh` in `data/` to fetch the CATH 4.2 dataset. If you are interested in testing on the TS 50 test set, also run `grep -Fv -f ts50remove.txt chain_set.jsonl > chain_set_ts50.jsonl` to produce a training set without overlap with the TS 50 test set. + +Run `getCATH.sh` in `data/` to fetch the CATH 4.2 dataset. If you are interested in testing on the TS 50 test set, also +run `grep -Fv -f ts50remove.txt chain_set.jsonl > chain_set_ts50.jsonl` to produce a training set without overlap with +the TS 50 test set. ### Training / testing -To train a model, simply run `python run_cpd.py --train`. To test a trained model on both the CATH 4.2 test set and the TS50 test set, run `python run_cpd --test-r PATH` for perplexity or with `--test-p` for perplexity. Run `python run_cpd.py -h` for more detailed options. + +To train a model, simply run `python run_cpd.py --train`. To test a trained model on both the CATH 4.2 test set and the +TS50 test set, run `python run_cpd --test-r PATH` for perplexity or with `--test-p` for perplexity. +Run `python run_cpd.py -h` for more detailed options. ``` $ python run_cpd.py -h @@ -200,10 +267,14 @@ optional arguments: --test-p PATH evaluate a trained model on perplexity (without training) --n-samples N number of sequences to sample (if testing recovery), default=100 ``` -**Confusion matrices:** Note that the values are normalized such that each row (corresponding to true class) sums to 1000, with the actual number of residues in that class printed under the "Count" column. + +**Confusion matrices:** Note that the values are normalized such that each row (corresponding to true class) sums to +1000, with the actual number of residues in that class printed under the "Count" column. ### Sampling -To sample from a `CPDModel`, prepare a `ProteinGraphDataset`, but do NOT pass into a `DataLoader`. The sequences are not used, so placeholders can be used for the `seq` attributes of the original structures dicts. + +To sample from a `CPDModel`, prepare a `ProteinGraphDataset`, but do NOT pass into a `DataLoader`. The sequences are not +used, so placeholders can be used for the `seq` attributes of the original structures dicts. ``` protein = dataset[i] @@ -213,31 +284,52 @@ edges = (protein.edge_s, protein.edge_v) sample = model.sample(nodes, protein.edge_index, # shape = (n_samples, n_nodes) edges, n_samples=n_samples) ``` + The output will be an int tensor, with mappings corresponding to those used when training the model. ## ATOM3D -We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments. + +We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script +in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments. ### Models / data loaders -The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel`, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output corresponding to the task. Details vary based on the exact task---see the docstrings. + +The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel` +, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments +at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output +corresponding to the task. Details vary based on the exact task---see the docstrings. + ``` psr_model = gvp.atom3d.PSRModel() ``` -`gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the constructor of a `atom3d.datasets.LMDBDataset`: + +`gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an +underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable +transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the +constructor of a `atom3d.datasets.LMDBDataset`: + ``` psr_dataset = atom3d.datasets.LMDBDataset(path_to_dataset, transform=gvp.atom3d.PSRTransform()) ``` -On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around the `atom3d.datasets.LMDBDataset`: + +On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around +the `atom3d.datasets.LMDBDataset`: + ``` ppi_dataset = gvp.atom3d.PPIDataset(path_to_dataset) res_dataset = gvp.atom3d.RESDataset(path_to_dataset, path_to_split) # see docstring ``` + All datasets must be then wrapped in a `torch_geometric.data.DataLoader`: + ``` psr_dataloader = torch_geometric.data.DataLoader(psr_dataset, batch_size=batch_size) ``` -The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed into the models. + +The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed +into the models. + ``` for batch in psr_dataloader: pred = psr_model(batch) # pred.shape = (batch_size,) @@ -245,7 +337,9 @@ for batch in psr_dataloader: ### Training / testing -To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run: +To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the +function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run: + ``` $ python run_atom3d.py -h @@ -273,7 +367,9 @@ optional arguments: --lr RATE learning rate --load PATH initialize first 2 GNN layers with pretrained weights ``` + For example: + ``` # train a model python run_atom3d.py PSR @@ -286,9 +382,13 @@ python run_atom3d.py PSR --test PATH ``` ## Acknowledgements -Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni for portions of the implementation in PyTorch. + +Portions of the input data pipeline were adapted +from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni +for portions of the implementation in PyTorch. ## Citation + ``` @inproceedings{ jing2021learning, diff --git a/gvp/__init__.py b/gvp/__init__.py index eaf8435..4e6c690 100644 --- a/gvp/__init__.py +++ b/gvp/__init__.py @@ -1,15 +1,18 @@ -import torch, functools -from torch import nn +import functools +import torch import torch.nn.functional as F +from torch import nn from torch_geometric.nn import MessagePassing from torch_scatter import scatter_add + def tuple_sum(*args): ''' Sums any number of tuples (s, V) elementwise. ''' return tuple(map(sum, zip(*args))) + def tuple_cat(*args, dim=-1): ''' Concatenates any number of tuples (s, V) elementwise. @@ -23,6 +26,7 @@ def tuple_cat(*args, dim=-1): s_args, v_args = list(zip(*args)) return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) + def tuple_index(x, idx): ''' Indexes into a tuple (s, V) along the first dimension. @@ -31,7 +35,8 @@ def tuple_index(x, idx): ''' return x[0][idx], x[1][idx] -def randn(n, dims, device="cpu"): + +def randn(n, dims, d=3, device="cpu"): ''' Returns random tuples (s, V) drawn elementwise from a normal distribution. @@ -42,7 +47,8 @@ def randn(n, dims, device="cpu"): V.shape = (n, n_vector, 3) ''' return torch.randn(n, dims[0], device=device), \ - torch.randn(n, dims[1], 3, device=device) + torch.randn(n, dims[1], d, device=device) + def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): ''' @@ -53,7 +59,8 @@ def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) return torch.sqrt(out) if sqrt else out -def _split(x, nv): + +def _split(x, nv, vector_dim=3): ''' Splits a merged representation of (s, V) back into a tuple. Should be used only with `_merge(s, V)` and only if the tuple @@ -62,20 +69,22 @@ def _split(x, nv): :param x: the `torch.Tensor` returned from `_merge` :param nv: the number of vector channels in the input to `_merge` ''' - v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) - s = x[..., :-3*nv] + v = torch.reshape(x[..., -vector_dim * nv:], x.shape[:-1] + (nv, vector_dim)) + s = x[..., :-vector_dim * nv] return s, v -def _merge(s, v): + +def _merge(s, v, vector_dim=3): ''' Merges a tuple (s, V) into a single `torch.Tensor`, where the vector channels are flattened and appended to the scalar channels. Should be used only if the tuple representation cannot be used. Use `_split(x, nv)` to reverse. ''' - v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) + v = torch.reshape(v, v.shape[:-2] + (vector_dim * v.shape[-2],)) return torch.cat([s, v], -1) + class GVP(nn.Module): ''' Geometric Vector Perceptron. See manuscript and README.md @@ -88,14 +97,15 @@ class GVP(nn.Module): :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) ''' + def __init__(self, in_dims, out_dims, h_dim=None, activations=(F.relu, torch.sigmoid), vector_gate=False): super(GVP, self).__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims self.vector_gate = vector_gate - if self.vi: - self.h_dim = h_dim or max(self.vi, self.vo) + if self.vi: + self.h_dim = h_dim or max(self.vi, self.vo) self.wh = nn.Linear(self.vi, self.h_dim, bias=False) self.ws = nn.Linear(self.h_dim + self.si, self.so) if self.vo: @@ -103,10 +113,10 @@ def __init__(self, in_dims, out_dims, h_dim=None, if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) else: self.ws = nn.Linear(self.si, self.so) - + self.scalar_act, self.vector_act = activations self.dummy_param = nn.Parameter(torch.empty(0)) - + def forward(self, x): ''' :param x: tuple (s, V) of `torch.Tensor`, @@ -117,13 +127,13 @@ def forward(self, x): if self.vi: s, v = x v = torch.transpose(v, -1, -2) - vh = self.wh(v) + vh = self.wh(v) vn = _norm_no_nan(vh, axis=-2) s = self.ws(torch.cat([s, vn], -1)) - if self.vo: - v = self.wv(vh) + if self.vo: + v = self.wv(vh) v = torch.transpose(v, -1, -2) - if self.vector_gate: + if self.vector_gate: if self.vector_act: gate = self.wsv(self.vector_act(s)) else: @@ -139,14 +149,16 @@ def forward(self, x): device=self.dummy_param.device) if self.scalar_act: s = self.scalar_act(s) - + return (s, v) if self.vo else s + class _VDropout(nn.Module): ''' Vector channel dropout where the elements of each vector channel are dropped together. ''' + def __init__(self, drop_rate): super(_VDropout, self).__init__() self.drop_rate = drop_rate @@ -165,11 +177,13 @@ def forward(self, x): x = mask * x / (1 - self.drop_rate) return x + class Dropout(nn.Module): ''' Combined dropout for tuples (s, V). Takes tuples (s, V) as input and as output. ''' + def __init__(self, drop_rate): super(Dropout, self).__init__() self.sdropout = nn.Dropout(drop_rate) @@ -186,16 +200,18 @@ def forward(self, x): s, v = x return self.sdropout(s), self.vdropout(v) + class LayerNorm(nn.Module): ''' Combined LayerNorm for tuples (s, V). Takes tuples (s, V) as input and as output. ''' + def __init__(self, dims): super(LayerNorm, self).__init__() self.s, self.v = dims self.scalar_norm = nn.LayerNorm(self.s) - + def forward(self, x): ''' :param x: tuple (s, V) of `torch.Tensor`, @@ -209,6 +225,7 @@ def forward(self, x): vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) return self.scalar_norm(s), v / vn + class GVPConv(MessagePassing): ''' Graph convolution / message passing with Geometric Vector Perceptrons. @@ -229,31 +246,33 @@ class GVPConv(MessagePassing): :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) ''' - def __init__(self, in_dims, out_dims, edge_dims, - n_layers=3, module_list=None, aggr="mean", + + def __init__(self, in_dims, out_dims, edge_dims, vector_dim=3, + n_layers=3, module_list=None, aggr="mean", activations=(F.relu, torch.sigmoid), vector_gate=False): super(GVPConv, self).__init__(aggr=aggr) self.si, self.vi = in_dims self.so, self.vo = out_dims self.se, self.ve = edge_dims - - GVP_ = functools.partial(GVP, - activations=activations, vector_gate=vector_gate) - + self.vector_dim = vector_dim + + GVP_ = functools.partial(GVP, + activations=activations, vector_gate=vector_gate) + module_list = module_list or [] if not module_list: if n_layers == 1: module_list.append( - GVP_((2*self.si + self.se, 2*self.vi + self.ve), - (self.so, self.vo), activations=(None, None))) + GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), + (self.so, self.vo), activations=(None, None))) else: module_list.append( - GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims) + GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), out_dims) ) for i in range(n_layers - 2): module_list.append(GVP_(out_dims, out_dims)) module_list.append(GVP_(out_dims, out_dims, - activations=(None, None))) + activations=(None, None))) self.message_func = nn.Sequential(*module_list) def forward(self, x, edge_index, edge_attr): @@ -263,17 +282,17 @@ def forward(self, x, edge_index, edge_attr): :param edge_attr: tuple (s, V) of `torch.Tensor` ''' x_s, x_v = x - message = self.propagate(edge_index, - s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]), - edge_attr=edge_attr) - return _split(message, self.vo) + message = self.propagate(edge_index, + s=x_s, v=x_v.reshape(x_v.shape[0], x_v.shape[1] * x_v.shape[2]), + edge_attr=edge_attr) + return _split(message, self.vo, vector_dim=self.vector_dim) def message(self, s_i, v_i, s_j, v_j, edge_attr): - v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3) - v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3) + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // self.vector_dim, self.vector_dim) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // self.vector_dim, self.vector_dim) message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) message = self.message_func(message) - return _merge(*message) + return _merge(*message, vector_dim=self.vector_dim) class GVPConvLayer(nn.Module): @@ -297,17 +316,18 @@ class GVPConvLayer(nn.Module): :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) ''' - def __init__(self, node_dims, edge_dims, + + def __init__(self, node_dims, edge_dims, vector_dim=3, n_message=3, n_feedforward=2, drop_rate=.1, - autoregressive=False, + autoregressive=False, activations=(F.relu, torch.sigmoid), vector_gate=False): - + super(GVPConvLayer, self).__init__() - self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message, - aggr="add" if autoregressive else "mean", - activations=activations, vector_gate=vector_gate) - GVP_ = functools.partial(GVP, - activations=activations, vector_gate=vector_gate) + self.conv = GVPConv(node_dims, node_dims, edge_dims, vector_dim, n_message, + aggr="add" if autoregressive else "mean", + activations=activations, vector_gate=vector_gate) + GVP_ = functools.partial(GVP, + activations=activations, vector_gate=vector_gate) self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) @@ -315,9 +335,9 @@ def __init__(self, node_dims, edge_dims, if n_feedforward == 1: ff_func.append(GVP_(node_dims, node_dims, activations=(None, None))) else: - hid_dims = 4*node_dims[0], 2*node_dims[1] + hid_dims = 4 * node_dims[0], 2 * node_dims[1] ff_func.append(GVP_(node_dims, hid_dims)) - for i in range(n_feedforward-2): + for i in range(n_feedforward - 2): ff_func.append(GVP_(hid_dims, hid_dims)) ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) self.ff_func = nn.Sequential(*ff_func) @@ -337,7 +357,7 @@ def forward(self, x, edge_index, edge_attr, dim of node embeddings (s, V). If not `None`, only these nodes will be updated. ''' - + if autoregressive_x is not None: src, dst = edge_index mask = src < dst @@ -345,29 +365,29 @@ def forward(self, x, edge_index, edge_attr, edge_index_backward = edge_index[:, ~mask] edge_attr_forward = tuple_index(edge_attr, mask) edge_attr_backward = tuple_index(edge_attr, ~mask) - + dh = tuple_sum( self.conv(x, edge_index_forward, edge_attr_forward), self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) ) - + count = scatter_add(torch.ones_like(dst), dst, - dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) - + dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) + dh = dh[0] / count, dh[1] / count.unsqueeze(-1) else: dh = self.conv(x, edge_index, edge_attr) - + if node_mask is not None: x_ = x x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) - + x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) - + dh = self.ff_func(x) x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) - + if node_mask is not None: x_[0][node_mask], x_[1][node_mask] = x[0], x[1] x = x_ diff --git a/gvp/atom3d.py b/gvp/atom3d.py index e37e346..650dcd6 100644 --- a/gvp/atom3d.py +++ b/gvp/atom3d.py @@ -1,25 +1,31 @@ -import torch, random, scipy, math +import atom3d.datasets.ppi.neighbors as nb +import math +import numpy as np +import pandas as pd +import random +import scipy +import torch import torch.nn as nn import torch.nn.functional as F -import pandas as pd -import numpy as np +import torch_cluster +import torch_geometric +import torch_scatter from atom3d.datasets import LMDBDataset -import atom3d.datasets.ppi.neighbors as nb from torch.utils.data import IterableDataset + from . import GVP, GVPConvLayer, LayerNorm -import torch_cluster, torch_geometric, torch_scatter from .data import _normalize, _rbf _NUM_ATOM_TYPES = 9 _element_mapping = lambda x: { - 'H' : 0, - 'C' : 1, - 'N' : 2, - 'O' : 3, - 'F' : 4, - 'S' : 5, + 'H': 0, + 'C': 1, + 'N': 2, + 'O': 3, + 'F': 4, + 'S': 5, 'Cl': 6, 'CL': 6, - 'P' : 7 + 'P': 7 }.get(x, 8) _amino_acids = lambda x: { 'ALA': 0, @@ -46,20 +52,21 @@ _DEFAULT_V_DIM = (100, 16) _DEFAULT_E_DIM = (32, 1) + def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'): - E_vectors = coords[edge_index[0]] - coords[edge_index[1]] - rbf = _rbf(E_vectors.norm(dim=-1), + rbf = _rbf(E_vectors.norm(dim=-1), D_max=D_max, D_count=num_rbf, device=device) edge_s = rbf edge_v = _normalize(E_vectors).unsqueeze(-2) edge_s, edge_v = map(torch.nan_to_num, - (edge_s, edge_v)) + (edge_s, edge_v)) return edge_s, edge_v + class BaseTransform: ''' Implementation of an ATOM3D Transform which featurizes the atomic @@ -86,11 +93,12 @@ class BaseTransform: :param num_rbf: number of radial bases to encode the distance on each edge :device: if "cuda", will do preprocessing on the GPU ''' + def __init__(self, edge_cutoff=4.5, num_rbf=16, device='cpu'): self.edge_cutoff = edge_cutoff self.num_rbf = num_rbf self.device = device - + def __call__(self, df): ''' :param df: `pandas.DataFrame` of atomic coordinates @@ -102,15 +110,16 @@ def __call__(self, df): coords = torch.as_tensor(df[['x', 'y', 'z']].to_numpy(), dtype=torch.float32, device=self.device) atoms = torch.as_tensor(list(map(_element_mapping, df.element)), - dtype=torch.long, device=self.device) + dtype=torch.long, device=self.device) edge_index = torch_cluster.radius_graph(coords, r=self.edge_cutoff) - edge_s, edge_v = _edge_features(coords, edge_index, - D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device) + edge_s, edge_v = _edge_features(coords, edge_index, + D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device) return torch_geometric.data.Data(x=coords, atoms=atoms, - edge_index=edge_index, edge_s=edge_s, edge_v=edge_v) + edge_index=edge_index, edge_s=edge_s, edge_v=edge_v) + class BaseModel(nn.Module): ''' @@ -125,43 +134,44 @@ class BaseModel(nn.Module): :param num_rbf: number of radial bases to use in the edge embedding ''' + def __init__(self, num_rbf=16): - + super().__init__() activations = (F.relu, None) - + self.embed = nn.Embedding(_NUM_ATOM_TYPES, _NUM_ATOM_TYPES) - + self.W_e = nn.Sequential( LayerNorm((num_rbf, 1)), - GVP((num_rbf, 1), _DEFAULT_E_DIM, + GVP((num_rbf, 1), _DEFAULT_E_DIM, activations=(None, None), vector_gate=True) ) - + self.W_v = nn.Sequential( LayerNorm((_NUM_ATOM_TYPES, 0)), GVP((_NUM_ATOM_TYPES, 0), _DEFAULT_V_DIM, activations=(None, None), vector_gate=True) ) - + self.layers = nn.ModuleList( - GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, - activations=activations, vector_gate=True) + GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, + activations=activations, vector_gate=True) for _ in range(5)) - + ns, _ = _DEFAULT_V_DIM self.W_out = nn.Sequential( LayerNorm(_DEFAULT_V_DIM), - GVP(_DEFAULT_V_DIM, (ns, 0), + GVP(_DEFAULT_V_DIM, (ns, 0), activations=activations, vector_gate=True) ) - + self.dense = nn.Sequential( - nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), + nn.Linear(ns, 2 * ns), nn.ReLU(inplace=True), nn.Dropout(p=0.1), - nn.Linear(2*ns, 1) + nn.Linear(2 * ns, 1) ) - + def forward(self, batch, scatter_mean=True, dense=True): ''' Forward pass which can be adjusted based on task formulation. @@ -177,9 +187,9 @@ def forward(self, batch, scatter_mean=True, dense=True): h_E = (batch.edge_s, batch.edge_v) h_V = self.W_v(h_V) h_E = self.W_e(h_E) - + batch_id = batch.batch - + for layer in self.layers: h_V = layer(h_V, batch.edge_index, h_E) @@ -188,6 +198,7 @@ def forward(self, batch, scatter_mean=True, dense=True): if dense: out = self.dense(out).squeeze(-1) return out + ######################################################################## class SMPTransform(BaseTransform): @@ -199,15 +210,18 @@ class SMPTransform(BaseTransform): Includes hydrogen atoms. ''' + def __call__(self, elem): data = super().__call__(elem['atoms']) with torch.no_grad(): - data.label = torch.as_tensor(elem['labels'], - device=self.device, dtype=torch.float32) + data.label = torch.as_tensor(elem['labels'], + device=self.device, dtype=torch.float32) return data - + + SMPModel = BaseModel - + + ######################################################################## class PPIDataset(IterableDataset): @@ -231,15 +245,16 @@ class PPIDataset(IterableDataset): :param lmdb_dataset: path to ATOM3D dataset ''' + def __init__(self, lmdb_dataset): self.dataset = LMDBDataset(lmdb_dataset) self.transform = BaseTransform() - + def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: gen = self._dataset_generator(list(range(len(self.dataset))), shuffle=True) - else: + else: per_worker = int(math.ceil(len(self.dataset) / float(worker_info.num_workers))) worker_id = worker_info.id iter_start = worker_id * per_worker @@ -250,26 +265,26 @@ def __iter__(self): return gen def _df_to_graph(self, struct_df, chain_res, label): - + struct_df = struct_df[struct_df.element != 'H'].reset_index(drop=True) chain, resnum = chain_res res_df = struct_df[(struct_df.chain == chain) & (struct_df.residue == resnum)] if 'CA' not in res_df.name.tolist(): return None - ca_pos = res_df[res_df['name']=='CA'][['x', 'y', 'z']].astype(np.float32).to_numpy()[0] + ca_pos = res_df[res_df['name'] == 'CA'][['x', 'y', 'z']].astype(np.float32).to_numpy()[0] - kd_tree = scipy.spatial.KDTree(struct_df[['x','y','z']].to_numpy()) + kd_tree = scipy.spatial.KDTree(struct_df[['x', 'y', 'z']].to_numpy()) graph_pt_idx = kd_tree.query_ball_point(ca_pos, r=30.0, p=2.0) graph_df = struct_df.iloc[graph_pt_idx].reset_index(drop=True) - + ca_idx = np.where((graph_df.chain == chain) & (graph_df.residue == resnum) & (graph_df.name == 'CA'))[0] if len(ca_idx) != 1: return None - + data = self.transform(graph_df) data.label = label - + data.ca_idx = int(ca_idx) data.n_nodes = data.num_nodes @@ -283,16 +298,16 @@ def _dataset_generator(self, indices, shuffle=True): neighbors = data['atoms_neighbors'] pairs = data['atoms_pairs'] - + for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = self._create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) - + for index, row in labels.iterrows(): - + label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values @@ -311,6 +326,7 @@ def _create_labels(self, positives, negatives, num_pos, neg_pos_ratio): labels = pd.concat([positives, negatives])[['chain0', 'residue0', 'chain1', 'residue1', 'label']] return labels + class PPIModel(BaseModel): ''' GVP-GNN for the PPI task. @@ -326,26 +342,26 @@ class PPIModel(BaseModel): Returns a single scalar for each graph pair which can be used as a logit in binary classification. ''' + def __init__(self, **kwargs): - super().__init__(**kwargs) ns, _ = _DEFAULT_V_DIM self.dense = nn.Sequential( - nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), + nn.Linear(2 * ns, 4 * ns), nn.ReLU(inplace=True), nn.Dropout(p=0.1), - nn.Linear(4*ns, 1) + nn.Linear(4 * ns, 1) ) - def forward(self, batch): + def forward(self, batch): graph1, graph2 = batch out1, out2 = map(self._gnn_forward, (graph1, graph2)) out = torch.cat([out1, out2], dim=-1) out = self.dense(out) return torch.sigmoid(out).squeeze(-1) - + def _gnn_forward(self, graph): out = super().forward(graph, scatter_mean=False, dense=False) - return out[graph.ca_idx+graph.ptr[:-1]] + return out[graph.ca_idx + graph.ptr[:-1]] ######################################################################## @@ -362,10 +378,11 @@ class LBATransform(BaseTransform): Includes hydrogen atoms. ''' + def __call__(self, elem): pocket, ligand = elem['atoms_pocket'], elem['atoms_ligand'] df = pd.concat([pocket, ligand], ignore_index=True) - + data = super().__call__(df) with torch.no_grad(): data.label = elem['scores']['neglog_aff'] @@ -374,10 +391,12 @@ def __call__(self, elem): data.lig_flag = lig_flag return data + LBAModel = BaseModel - + + ######################################################################## - + class LEPTransform(BaseTransform): ''' Transforms dict-style entries from the ATOM3D LEP dataset @@ -392,16 +411,18 @@ class LEPTransform(BaseTransform): Excludes hydrogen atoms. ''' + def __call__(self, elem): active, inactive = elem['atoms_active'], elem['atoms_inactive'] with torch.no_grad(): active, inactive = map(self._to_graph, (active, inactive)) active.label = inactive.label = 1. if elem['label'] == 'A' else 0. return active, inactive - + def _to_graph(self, df): df = df[df.element != 'H'].reset_index(drop=True) - return super().__call__(df) + return super().__call__(df) + class LEPModel(BaseModel): ''' @@ -415,27 +436,29 @@ class LEPModel(BaseModel): Returns a single scalar for each graph pair which can be used as a logit in binary classification. ''' + def __init__(self, **kwargs): super().__init__(**kwargs) ns, _ = _DEFAULT_V_DIM self.dense = nn.Sequential( - nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), + nn.Linear(2 * ns, 4 * ns), nn.ReLU(inplace=True), nn.Dropout(p=0.1), - nn.Linear(4*ns, 1) + nn.Linear(4 * ns, 1) ) - - def forward(self, batch): + + def forward(self, batch): out1, out2 = map(self._gnn_forward, batch) out = torch.cat([out1, out2], dim=-1) out = self.dense(out) return torch.sigmoid(out).squeeze(-1) - + def _gnn_forward(self, graph): return super().forward(graph, dense=False) - + + ######################################################################## -class MSPTransform(BaseTransform): +class MSPTransform(BaseTransform): ''' Transforms dict-style entries from the ATOM3D MSP dataset to featurized graphs. Returns a tuple (original, mutated) of @@ -452,6 +475,7 @@ class MSPTransform(BaseTransform): Excludes hydrogen atoms. ''' + def __call__(self, elem): mutation = elem['id'].split('_')[-1] orig_df = elem['original_atoms'].reset_index(drop=True) @@ -461,21 +485,21 @@ def __call__(self, elem): self._transform(mut_df, mutation) original.label = mutated.label = 1. if elem['label'] == '1' else 0. return original, mutated - + def _transform(self, df, mutation): - df = df[df.element != 'H'].reset_index(drop=True) data = super().__call__(df) data.node_mask = self._extract_node_mask(df, mutation) return data - + def _extract_node_mask(self, df, mutation): chain, res = mutation[1], int(mutation[2:-1]) idx = df.index[(df.chain.values == chain) & (df.residue.values == res)].values mask = torch.zeros(len(df), dtype=torch.long, device=self.device) mask[idx] = 1 return mask - + + class MSPModel(BaseModel): ''' GVP-GNN for the MSP task. @@ -491,31 +515,33 @@ class MSPModel(BaseModel): Returns a single scalar for each graph pair which can be used as a logit in binary classification. ''' + def __init__(self, **kwargs): super().__init__(**kwargs) ns, _ = _DEFAULT_V_DIM self.dense = nn.Sequential( - nn.Linear(2*ns, 4*ns), nn.ReLU(inplace=True), + nn.Linear(2 * ns, 4 * ns), nn.ReLU(inplace=True), nn.Dropout(p=0.1), - nn.Linear(4*ns, 1) + nn.Linear(4 * ns, 1) ) - - def forward(self, batch): + + def forward(self, batch): out1, out2 = map(self._gnn_forward, batch) out = torch.cat([out1, out2], dim=-1) out = self.dense(out) return torch.sigmoid(out).squeeze(-1) - + def _gnn_forward(self, graph): out = super().forward(graph, scatter_mean=False, dense=False) out = out * graph.node_mask.unsqueeze(-1) out = torch_scatter.scatter_add(out, graph.batch, dim=0) count = torch_scatter.scatter_add(graph.node_mask, graph.batch) return out / count.unsqueeze(-1) - + + ######################################################################## - -class PSRTransform(BaseTransform): + +class PSRTransform(BaseTransform): ''' Transforms dict-style entries from the ATOM3D PSR dataset to featurized graphs. Returns a `torch_geometric.data.Data` @@ -525,6 +551,7 @@ class PSRTransform(BaseTransform): Includes hydrogen atoms. ''' + def __call__(self, elem): df = elem['atoms'] df = df[df.element != 'H'].reset_index(drop=True) @@ -533,11 +560,13 @@ def __call__(self, elem): data.id = eval(elem['id'])[0] return data + PSRModel = BaseModel + ######################################################################## - -class RSRTransform(BaseTransform): + +class RSRTransform(BaseTransform): ''' Transforms dict-style entries from the ATOM3D RSR dataset to featurized graphs. Returns a `torch_geometric.data.Data` @@ -547,6 +576,7 @@ class RSRTransform(BaseTransform): Includes hydrogen atoms. ''' + def __call__(self, elem): df = elem['atoms'] df = df[df.element != 'H'].reset_index(drop=True) @@ -555,8 +585,10 @@ def __call__(self, elem): data.id = eval(elem['id'])[0] return data + RSRModel = BaseModel + ######################################################################## class RESDataset(IterableDataset): @@ -574,25 +606,26 @@ class RESDataset(IterableDataset): :param lmdb_dataset: path to ATOM3D dataset :param split_path: path to the ATOM3D split file ''' + def __init__(self, lmdb_dataset, split_path): self.dataset = LMDBDataset(lmdb_dataset) self.idx = list(map(int, open(split_path).read().split())) self.transform = BaseTransform() - + def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: - gen = self._dataset_generator(list(range(len(self.idx))), - shuffle=True) - else: + gen = self._dataset_generator(list(range(len(self.idx))), + shuffle=True) + else: per_worker = int(math.ceil(len(self.idx) / float(worker_info.num_workers))) worker_id = worker_info.id iter_start = worker_id * per_worker iter_end = min(iter_start + per_worker, len(self.idx)) gen = self._dataset_generator(list(range(len(self.idx)))[iter_start:iter_end], - shuffle=True) + shuffle=True) return gen - + def _dataset_generator(self, indices, shuffle=True): if shuffle: random.shuffle(indices) with torch.no_grad(): @@ -606,13 +639,14 @@ def _dataset_generator(self, indices, shuffle=True): my_atoms = atoms.iloc[data['subunit_indices'][sub.Index]].reset_index(drop=True) ca_idx = np.where((my_atoms.residue == num) & (my_atoms.name == 'CA'))[0] if len(ca_idx) != 1: continue - + with torch.no_grad(): graph = self.transform(my_atoms) graph.label = aa graph.ca_idx = int(ca_idx) yield graph - + + class RESModel(BaseModel): ''' GVP-GNN for the RES task. @@ -624,14 +658,16 @@ class RESModel(BaseModel): As noted in the manuscript, RESModel uses the final alpha carbon embeddings instead of the graph mean embedding. ''' + def __init__(self, **kwargs): super().__init__(**kwargs) ns, _ = _DEFAULT_V_DIM self.dense = nn.Sequential( - nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), + nn.Linear(ns, 2 * ns), nn.ReLU(inplace=True), nn.Dropout(p=0.1), - nn.Linear(2*ns, 20) + nn.Linear(2 * ns, 20) ) + def forward(self, batch): out = super().forward(batch, scatter_mean=False) - return out[batch.ca_idx+batch.ptr[:-1]] \ No newline at end of file + return out[batch.ca_idx + batch.ptr[:-1]] diff --git a/gvp/data.py b/gvp/data.py index c115a68..cef8333 100644 --- a/gvp/data.py +++ b/gvp/data.py @@ -1,11 +1,15 @@ import json + +import math import numpy as np -import tqdm, random -import torch, math -import torch.utils.data as data +import random +import torch import torch.nn.functional as F -import torch_geometric +import torch.utils.data as data import torch_cluster +import torch_geometric +import tqdm + def _normalize(tensor, dim=-1): ''' @@ -43,33 +47,35 @@ class CATHDataset: :param path: path to chain_set.jsonl :param splits_path: path to chain_set_splits.json or equivalent. ''' + def __init__(self, path, splits_path): with open(splits_path) as f: dataset_splits = json.load(f) train_list, val_list, test_list = dataset_splits['train'], \ - dataset_splits['validation'], dataset_splits['test'] - + dataset_splits['validation'], dataset_splits['test'] + self.train, self.val, self.test = [], [], [] - + with open(path) as f: lines = f.readlines() - + for line in tqdm.tqdm(lines): entry = json.loads(line) name = entry['name'] coords = entry['coords'] - + entry['coords'] = list(zip( coords['N'], coords['CA'], coords['C'], coords['O'] )) - + if name in train_list: self.train.append(entry) elif name in val_list: self.val.append(entry) elif name in test_list: self.test.append(entry) - + + class BatchSampler(data.Sampler): ''' From https://github.com/jingraham/neurips19-graph-protein-design. @@ -82,15 +88,16 @@ class BatchSampler(data.Sampler): including batches of a single element :param shuffle: if `True`, batches in shuffled order ''' + def __init__(self, node_counts, max_nodes=3000, shuffle=True): - + self.node_counts = node_counts - self.idx = [i for i in range(len(node_counts)) - if node_counts[i] <= max_nodes] + self.idx = [i for i in range(len(node_counts)) + if node_counts[i] <= max_nodes] self.shuffle = shuffle self.max_nodes = max_nodes self._form_batches() - + def _form_batches(self): self.batches = [] if self.shuffle: random.shuffle(self.idx) @@ -103,15 +110,16 @@ def _form_batches(self): n_nodes += self.node_counts[next_idx] batch.append(next_idx) self.batches.append(batch) - - def __len__(self): + + def __len__(self): if not self.batches: self._form_batches() return len(self.batches) - + def __iter__(self): if not self.batches: self._form_batches() for batch in self.batches: yield batch + class ProteinGraphDataset(data.Dataset): ''' A map-syle `torch.utils.data.Dataset` which transforms JSON/dictionary-style @@ -136,69 +144,69 @@ class ProteinGraphDataset(data.Dataset): :param top_k: number of edges to draw per node (as destination node) :param device: if "cuda", will do preprocessing on the GPU ''' - def __init__(self, data_list, + + def __init__(self, data_list, num_positional_embeddings=16, top_k=30, num_rbf=16, device="cpu"): - super(ProteinGraphDataset, self).__init__() - + self.data_list = data_list self.top_k = top_k self.num_rbf = num_rbf self.num_positional_embeddings = num_positional_embeddings self.device = device self.node_counts = [len(e['seq']) for e in data_list] - + self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, - 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, - 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, - 'N': 2, 'Y': 18, 'M': 12} - self.num_to_letter = {v:k for k, v in self.letter_to_num.items()} - + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12} + self.num_to_letter = {v: k for k, v in self.letter_to_num.items()} + def __len__(self): return len(self.data_list) - + def __getitem__(self, i): return self._featurize_as_graph(self.data_list[i]) - + def _featurize_as_graph(self, protein): name = protein['name'] with torch.no_grad(): - coords = torch.as_tensor(protein['coords'], - device=self.device, dtype=torch.float32) + coords = torch.as_tensor(protein['coords'], + device=self.device, dtype=torch.float32) seq = torch.as_tensor([self.letter_to_num[a] for a in protein['seq']], device=self.device, dtype=torch.long) - - mask = torch.isfinite(coords.sum(dim=(1,2))) + + mask = torch.isfinite(coords.sum(dim=(1, 2))) coords[~mask] = np.inf - + X_ca = coords[:, 1] edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k) - + pos_embeddings = self._positional_embeddings(edge_index) E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device) - - dihedrals = self._dihedrals(coords) + + dihedrals = self._dihedrals(coords) orientations = self._orientations(X_ca) sidechains = self._sidechains(coords) - + node_s = dihedrals node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) edge_s = torch.cat([rbf, pos_embeddings], dim=-1) edge_v = _normalize(E_vectors).unsqueeze(-2) - + node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, - (node_s, node_v, edge_s, edge_v)) - + (node_s, node_v, edge_s, edge_v)) + data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name, node_s=node_s, node_v=node_v, edge_s=edge_s, edge_v=edge_v, edge_index=edge_index, mask=mask) return data - + def _dihedrals(self, X, eps=1e-7): # From https://github.com/jingraham/neurips19-graph-protein-design - - X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) + + X = torch.reshape(X[:, :3], [3 * X.shape[0], 3]) dX = X[1:] - X[:-1] U = _normalize(dX, dim=-1) u_2 = U[:-2] @@ -215,20 +223,19 @@ def _dihedrals(self, X, eps=1e-7): D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) # This scheme will remove phi[0], psi[-1], omega[-1] - D = F.pad(D, [1, 2]) + D = F.pad(D, [1, 2]) D = torch.reshape(D, [-1, 3]) # Lift angle representations to the circle D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) return D_features - - - def _positional_embeddings(self, edge_index, + + def _positional_embeddings(self, edge_index, num_embeddings=None, period_range=[2, 1000]): # From https://github.com/jingraham/neurips19-graph-protein-design num_embeddings = num_embeddings or self.num_positional_embeddings d = edge_index[0] - edge_index[1] - + frequency = torch.exp( torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device) * -(np.log(10000.0) / num_embeddings) @@ -250,4 +257,4 @@ def _sidechains(self, X): bisector = _normalize(c + n) perp = _normalize(torch.cross(c, n)) vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) - return vec \ No newline at end of file + return vec diff --git a/gvp/models.py b/gvp/models.py index 54433e7..932e235 100644 --- a/gvp/models.py +++ b/gvp/models.py @@ -1,10 +1,11 @@ -import numpy as np import torch import torch.nn as nn -from . import GVP, GVPConvLayer, LayerNorm, tuple_index from torch.distributions import Categorical from torch_scatter import scatter_mean +from . import GVP, GVPConvLayer, LayerNorm, tuple_index + + class CPDModel(torch.nn.Module): ''' GVP-GNN for structure-conditioned autoregressive @@ -33,12 +34,13 @@ class CPDModel(torch.nn.Module): and decoder modules :param drop_rate: rate to use in all dropout layers ''' - def __init__(self, node_in_dim, node_h_dim, - edge_in_dim, edge_h_dim, + + def __init__(self, node_in_dim, node_h_dim, + edge_in_dim, edge_h_dim, vector_dim=3, num_layers=3, drop_rate=0.1): - + super(CPDModel, self).__init__() - + self.W_v = nn.Sequential( GVP(node_in_dim, node_h_dim, activations=(None, None)), LayerNorm(node_h_dim) @@ -47,19 +49,19 @@ def __init__(self, node_in_dim, node_h_dim, GVP(edge_in_dim, edge_h_dim, activations=(None, None)), LayerNorm(edge_h_dim) ) - + self.encoder_layers = nn.ModuleList( - GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) + GVPConvLayer(node_h_dim, edge_h_dim, vector_dim=vector_dim, drop_rate=drop_rate) for _ in range(num_layers)) - + self.W_s = nn.Embedding(20, 20) edge_h_dim = (edge_h_dim[0] + 20, edge_h_dim[1]) - + self.decoder_layers = nn.ModuleList( - GVPConvLayer(node_h_dim, edge_h_dim, - drop_rate=drop_rate, autoregressive=True) + GVPConvLayer(node_h_dim, edge_h_dim, vector_dim=vector_dim, + drop_rate=drop_rate, autoregressive=True) for _ in range(num_layers)) - + self.W_out = GVP(node_h_dim, (20, 0), activations=(None, None)) def forward(self, h_V, edge_index, h_E, seq): @@ -73,24 +75,24 @@ def forward(self, h_V, edge_index, h_E, seq): ''' h_V = self.W_v(h_V) h_E = self.W_e(h_E) - + for layer in self.encoder_layers: h_V = layer(h_V, edge_index, h_E) - + encoder_embeddings = h_V - + h_S = self.W_s(seq) h_S = h_S[edge_index[0]] h_S[edge_index[0] >= edge_index[1]] = 0 h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1]) - + for layer in self.decoder_layers: - h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings) - + h_V = layer(h_V, edge_index, h_E, autoregressive_x=encoder_embeddings) + logits = self.W_out(h_V) - + return logits - + def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1): ''' Samples sequences autoregressively from the distribution @@ -106,61 +108,62 @@ def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1): :return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the residue-to-int mapping of the original training data ''' - + with torch.no_grad(): - + device = edge_index.device L = h_V[0].shape[0] - + h_V = self.W_v(h_V) h_E = self.W_e(h_E) - + for layer in self.encoder_layers: - h_V = layer(h_V, edge_index, h_E) - + h_V = layer(h_V, edge_index, h_E) + h_V = (h_V[0].repeat(n_samples, 1), h_V[1].repeat(n_samples, 1, 1)) - + h_E = (h_E[0].repeat(n_samples, 1), h_E[1].repeat(n_samples, 1, 1)) - + edge_index = edge_index.expand(n_samples, -1, -1) offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1) edge_index = torch.cat(tuple(edge_index + offset), dim=-1) - + seq = torch.zeros(n_samples * L, device=device, dtype=torch.int) h_S = torch.zeros(n_samples * L, 20, device=device) - + h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers] - + for i in range(L): - + h_S_ = h_S[edge_index[0]] h_S_[edge_index[0] >= edge_index[1]] = 0 h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1]) - + edge_mask = edge_index[1] % L == i edge_index_ = edge_index[:, edge_mask] h_E_ = tuple_index(h_E_, edge_mask) node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool) node_mask[i::L] = True - + for j, layer in enumerate(self.decoder_layers): out = layer(h_V_cache[j], edge_index_, h_E_, - autoregressive_x=h_V_cache[0], node_mask=node_mask) - + autoregressive_x=h_V_cache[0], node_mask=node_mask) + out = tuple_index(out, node_mask) - - if j < len(self.decoder_layers)-1: - h_V_cache[j+1][0][i::L] = out[0] - h_V_cache[j+1][1][i::L] = out[1] - + + if j < len(self.decoder_layers) - 1: + h_V_cache[j + 1][0][i::L] = out[0] + h_V_cache[j + 1][1][i::L] = out[1] + logits = self.W_out(out) seq[i::L] = Categorical(logits=logits / temperature).sample() h_S[i::L] = self.W_s(seq[i::L]) - + return seq.view(n_samples, L) - + + class MQAModel(nn.Module): ''' GVP-GNN for Model Quality Assessment as described in manuscript. @@ -185,16 +188,17 @@ class MQAModel(nn.Module): :param num_layers: number of GVP-GNN layers :param drop_rate: rate to use in all dropout layers ''' - def __init__(self, node_in_dim, node_h_dim, - edge_in_dim, edge_h_dim, + + def __init__(self, node_in_dim, node_h_dim, + edge_in_dim, edge_h_dim, vector_dim=3, seq_in=False, num_layers=3, drop_rate=0.1): - + super(MQAModel, self).__init__() - + if seq_in: self.W_s = nn.Embedding(20, 20) node_in_dim = (node_in_dim[0] + 20, node_in_dim[1]) - + self.W_v = nn.Sequential( LayerNorm(node_in_dim), GVP(node_in_dim, node_h_dim, activations=(None, None)) @@ -203,23 +207,23 @@ def __init__(self, node_in_dim, node_h_dim, LayerNorm(edge_in_dim), GVP(edge_in_dim, edge_h_dim, activations=(None, None)) ) - + self.layers = nn.ModuleList( - GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) + GVPConvLayer(node_h_dim, edge_h_dim, vector_dim=vector_dim, drop_rate=drop_rate) for _ in range(num_layers)) - + ns, _ = node_h_dim self.W_out = nn.Sequential( LayerNorm(node_h_dim), GVP(node_h_dim, (ns, 0))) - + self.dense = nn.Sequential( - nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), + nn.Linear(ns, 2 * ns), nn.ReLU(inplace=True), nn.Dropout(p=drop_rate), - nn.Linear(2*ns, 1) + nn.Linear(2 * ns, 1) ) - def forward(self, h_V, edge_index, h_E, seq=None, batch=None): + def forward(self, h_V, edge_index, h_E, seq=None, batch=None): ''' :param h_V: tuple (s, V) of node embeddings :param edge_index: `torch.Tensor` of shape [2, num_edges] @@ -235,8 +239,10 @@ def forward(self, h_V, edge_index, h_E, seq=None, batch=None): for layer in self.layers: h_V = layer(h_V, edge_index, h_E) out = self.W_out(h_V) - - if batch is None: out = out.mean(dim=0, keepdims=True) - else: out = scatter_mean(out, batch, dim=0) - - return self.dense(out).squeeze(-1) + 0.5 \ No newline at end of file + + if batch is None: + out = out.mean(dim=0, keepdims=True) + else: + out = scatter_mean(out, batch, dim=0) + + return self.dense(out).squeeze(-1) + 0.5 diff --git a/run_atom3d.py b/run_atom3d.py index eb6373b..79ce130 100644 --- a/run_atom3d.py +++ b/run_atom3d.py @@ -2,13 +2,13 @@ parser = argparse.ArgumentParser() parser.add_argument('task', metavar='TASK', choices=[ - 'PSR', 'RSR', 'PPI', 'RES', 'MSP', 'SMP', 'LBA', 'LEP' - ], help="{PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}") + 'PSR', 'RSR', 'PPI', 'RES', 'MSP', 'SMP', 'LBA', 'LEP' +], help="{PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}") parser.add_argument('--num-workers', metavar='N', type=int, default=4, - help='number of threads for loading data, default=4') + help='number of threads for loading data, default=4') parser.add_argument('--smp-idx', metavar='IDX', type=int, default=None, - choices=list(range(20)), - help='label index for SMP, in range 0-19') + choices=list(range(20)), + help='label index for SMP, in range 0-19') parser.add_argument('--lba-split', metavar='SPLIT', type=int, choices=[30, 60], help='identity cutoff for LBA, 30 (default) or 60', default=30) parser.add_argument('--batch', metavar='SIZE', type=int, default=8, @@ -23,7 +23,7 @@ help='evaluate a trained model') parser.add_argument('--lr', metavar='RATE', default=1e-4, type=float, help='learning rate') -parser.add_argument('--load', metavar='PATH', default=None, +parser.add_argument('--load', metavar='PATH', default=None, help='initialize first 2 GNN layers with pretrained weights') args = parser.parse_args() @@ -34,28 +34,29 @@ from functools import partial import gvp.atom3d import torch.nn as nn -import tqdm, torch, time, os +import tqdm, torch, time import numpy as np from atom3d.util import metrics import sklearn.metrics as sk_metrics from collections import defaultdict -import scipy.stats as stats + print = partial(print, flush=True) models_dir = 'models' device = 'cuda' if torch.cuda.is_available() else 'cpu' model_id = float(time.time()) + def main(): datasets = get_datasets(args.task, args.lba_split) - dataloader = partial(torch_geometric.data.DataLoader, - num_workers=args.num_workers, batch_size=args.batch) + dataloader = partial(torch_geometric.data.DataLoader, + num_workers=args.num_workers, batch_size=args.batch) if args.task not in ['PPI', 'RES']: dataloader = partial(dataloader, shuffle=True) - - trainset, valset, testset = map(dataloader, datasets) + + trainset, valset, testset = map(dataloader, datasets) model = get_model(args.task).to(device) - + if args.test: test(model, testset) @@ -63,7 +64,8 @@ def main(): if args.load: load(model, args.load) train(model, trainset, valset) - + + def test(model, testset): model.load_state_dict(torch.load(args.test)) model.eval() @@ -87,12 +89,12 @@ def test(model, testset): value = func(targets, predicts) print(f"{name}: {value}") + def train(model, trainset, valset): - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - + best_path, best_val = None, np.inf - + for epoch in range(args.epochs): model.train() loss = loop(trainset, model, optimizer=optimizer, max_time=args.train_time) @@ -106,54 +108,57 @@ def train(model, trainset, valset): if loss < best_val: best_path, best_val = path, loss print(f'BEST {best_path} VAL loss: {best_val:.8f}') - + + def loop(dataset, model, optimizer=None, max_time=None): start = time.time() - + loss_fn = get_loss(args.task) t = tqdm.tqdm(dataset) total_loss, total_count = 0, 0 - + for batch in t: - if max_time and (time.time() - start) > 60*max_time: break + if max_time and (time.time() - start) > 60 * max_time: break if optimizer: optimizer.zero_grad() try: out = forward(model, batch, device) except RuntimeError as e: - if "CUDA out of memory" not in str(e): raise(e) + if "CUDA out of memory" not in str(e): raise (e) torch.cuda.empty_cache() print('Skipped batch due to OOM', flush=True) continue - + label = get_label(batch, args.task, args.smp_idx) loss_value = loss_fn(out, label) total_loss += float(loss_value) total_count += 1 - + if optimizer: try: loss_value.backward() optimizer.step() except RuntimeError as e: - if "CUDA out of memory" not in str(e): raise(e) + if "CUDA out of memory" not in str(e): raise (e) torch.cuda.empty_cache() print('Skipped batch due to OOM', flush=True) continue - - t.set_description(f"{total_loss/total_count:.8f}") - + + t.set_description(f"{total_loss / total_count:.8f}") + return total_loss / total_count + def load(model, path): params = torch.load(path) state_dict = model.state_dict() for name, p in params.items(): if name in state_dict and \ - name[:8] in ['layers.0', 'layers.1'] and \ - state_dict[name].shape == p.shape: + name[:8] in ['layers.0', 'layers.1'] and \ + state_dict[name].shape == p.shape: print("Loading", name) model.state_dict()[name].copy_(p) - + + ####################################################################### def get_label(batch, task, smp_idx=None): @@ -163,6 +168,7 @@ def get_label(batch, task, smp_idx=None): return batch.label[smp_idx::20] return batch.label + def get_metrics(task): def _correlation(metric, targets, predict, ids=None, glob=True): if glob: return metric(targets, predict) @@ -171,31 +177,36 @@ def _correlation(metric, targets, predict, ids=None, glob=True): _targets[_id].append(_t) _predict[_id].append(_p) return np.mean([metric(_targets[_id], _predict[_id]) for _id in _targets]) - + correlations = { 'pearson': partial(_correlation, metrics.pearson), 'kendall': partial(_correlation, metrics.kendall), 'spearman': partial(_correlation, metrics.spearman) } - mean_correlations = {f'mean {k}' : partial(v, glob=False) \ - for k, v in correlations.items()} - - return { - 'RSR' : {**correlations, **mean_correlations}, - 'PSR' : {**correlations, **mean_correlations}, - 'PPI' : {'auroc': metrics.auroc}, - 'RES' : {'accuracy': metrics.accuracy}, - 'MSP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc}, - 'LEP' : {'auroc': metrics.auroc, 'auprc': metrics.auprc}, - 'LBA' : {**correlations, 'rmse': partial(sk_metrics.mean_squared_error, squared=False)}, - 'SMP' : {'mae': sk_metrics.mean_absolute_error} + mean_correlations = {f'mean {k}': partial(v, glob=False) \ + for k, v in correlations.items()} + + return { + 'RSR': {**correlations, **mean_correlations}, + 'PSR': {**correlations, **mean_correlations}, + 'PPI': {'auroc': metrics.auroc}, + 'RES': {'accuracy': metrics.accuracy}, + 'MSP': {'auroc': metrics.auroc, 'auprc': metrics.auprc}, + 'LEP': {'auroc': metrics.auroc, 'auprc': metrics.auprc}, + 'LBA': {**correlations, 'rmse': partial(sk_metrics.mean_squared_error, squared=False)}, + 'SMP': {'mae': sk_metrics.mean_absolute_error} }[task] - + + def get_loss(task): - if task in ['PSR', 'RSR', 'SMP', 'LBA']: return nn.MSELoss() # regression - elif task in ['PPI', 'MSP', 'LEP']: return nn.BCELoss() # binary classification - elif task in ['RES']: return nn.CrossEntropyLoss() # multiclass classification - + if task in ['PSR', 'RSR', 'SMP', 'LBA']: + return nn.MSELoss() # regression + elif task in ['PPI', 'MSP', 'LEP']: + return nn.BCELoss() # binary classification + elif task in ['RES']: + return nn.CrossEntropyLoss() # multiclass classification + + def forward(model, batch, device): if type(batch) in [list, tuple]: batch = batch[0].to(device), batch[1].to(device) @@ -203,57 +214,60 @@ def forward(model, batch, device): batch = batch.to(device) return model(batch) + def get_datasets(task, lba_split=30): data_path = { - 'RES' : 'atom3d-data/RES/raw/RES/data/', - 'PPI' : 'atom3d-data/PPI/splits/DIPS-split/data/', - 'RSR' : 'atom3d-data/RSR/splits/candidates-split-by-time/data/', - 'PSR' : 'atom3d-data/PSR/splits/split-by-year/data/', - 'MSP' : 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/', - 'LEP' : 'atom3d-data/LEP/splits/split-by-protein/data/', - 'LBA' : f'atom3d-data/LBA/splits/split-by-sequence-identity-{lba_split}/data/', - 'SMP' : 'atom3d-data/SMP/splits/random/data/' + 'RES': 'atom3d-data/RES/raw/RES/data/', + 'PPI': 'atom3d-data/PPI/splits/DIPS-split/data/', + 'RSR': 'atom3d-data/RSR/splits/candidates-split-by-time/data/', + 'PSR': 'atom3d-data/PSR/splits/split-by-year/data/', + 'MSP': 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/', + 'LEP': 'atom3d-data/LEP/splits/split-by-protein/data/', + 'LBA': f'atom3d-data/LBA/splits/split-by-sequence-identity-{lba_split}/data/', + 'SMP': 'atom3d-data/SMP/splits/random/data/' }[task] - + if task == 'RES': split_path = 'atom3d-data/RES/splits/split-by-cath-topology/indices/' - dataset = partial(gvp.atom3d.RESDataset, data_path) - trainset = dataset(split_path=split_path+'train_indices.txt') - valset = dataset(split_path=split_path+'val_indices.txt') - testset = dataset(split_path=split_path+'test_indices.txt') - + dataset = partial(gvp.atom3d.RESDataset, data_path) + trainset = dataset(split_path=split_path + 'train_indices.txt') + valset = dataset(split_path=split_path + 'val_indices.txt') + testset = dataset(split_path=split_path + 'test_indices.txt') + elif task == 'PPI': - trainset = gvp.atom3d.PPIDataset(data_path+'train') - valset = gvp.atom3d.PPIDataset(data_path+'val') - testset = gvp.atom3d.PPIDataset(data_path+'test') - + trainset = gvp.atom3d.PPIDataset(data_path + 'train') + valset = gvp.atom3d.PPIDataset(data_path + 'val') + testset = gvp.atom3d.PPIDataset(data_path + 'test') + else: - transform = { - 'RSR' : gvp.atom3d.RSRTransform, - 'PSR' : gvp.atom3d.PSRTransform, - 'MSP' : gvp.atom3d.MSPTransform, - 'LEP' : gvp.atom3d.LEPTransform, - 'LBA' : gvp.atom3d.LBATransform, - 'SMP' : gvp.atom3d.SMPTransform, + transform = { + 'RSR': gvp.atom3d.RSRTransform, + 'PSR': gvp.atom3d.PSRTransform, + 'MSP': gvp.atom3d.MSPTransform, + 'LEP': gvp.atom3d.LEPTransform, + 'LBA': gvp.atom3d.LBATransform, + 'SMP': gvp.atom3d.SMPTransform, }[task]() - - trainset = LMDBDataset(data_path+'train', transform=transform) - valset = LMDBDataset(data_path+'val', transform=transform) - testset = LMDBDataset(data_path+'test', transform=transform) - + + trainset = LMDBDataset(data_path + 'train', transform=transform) + valset = LMDBDataset(data_path + 'val', transform=transform) + testset = LMDBDataset(data_path + 'test', transform=transform) + return trainset, valset, testset + def get_model(task): return { - 'RES' : gvp.atom3d.RESModel, - 'PPI' : gvp.atom3d.PPIModel, - 'RSR' : gvp.atom3d.RSRModel, - 'PSR' : gvp.atom3d.PSRModel, - 'MSP' : gvp.atom3d.MSPModel, - 'LEP' : gvp.atom3d.LEPModel, - 'LBA' : gvp.atom3d.LBAModel, - 'SMP' : gvp.atom3d.SMPModel + 'RES': gvp.atom3d.RESModel, + 'PPI': gvp.atom3d.PPIModel, + 'RSR': gvp.atom3d.RSRModel, + 'PSR': gvp.atom3d.PSRModel, + 'MSP': gvp.atom3d.MSPModel, + 'LEP': gvp.atom3d.LEPModel, + 'LBA': gvp.atom3d.LBAModel, + 'SMP': gvp.atom3d.SMPModel }[task]() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/run_cpd.py b/run_cpd.py index f781a90..a516486 100644 --- a/run_cpd.py +++ b/run_cpd.py @@ -4,7 +4,7 @@ parser.add_argument('--models-dir', metavar='PATH', default='./models/', help='directory to save trained models, default=./models/') parser.add_argument('--num-workers', metavar='N', type=int, default=4, - help='number of threads for loading data, default=4') + help='number of threads for loading data, default=4') parser.add_argument('--max-nodes', metavar='N', type=int, default=3000, help='max number of nodes per batch, default=3000') parser.add_argument('--epochs', metavar='N', type=int, default=100, @@ -36,6 +36,7 @@ from sklearn.metrics import confusion_matrix import torch_geometric from functools import partial + print = partial(print, flush=True) node_dim = (100, 16) @@ -44,41 +45,45 @@ if not os.path.exists(args.models_dir): os.makedirs(args.models_dir) model_id = int(datetime.timestamp(datetime.now())) -dataloader = lambda x: torch_geometric.data.DataLoader(x, - num_workers=args.num_workers, - batch_sampler=gvp.data.BatchSampler( - x.node_counts, max_nodes=args.max_nodes)) +dataloader = lambda x: torch_geometric.data.DataLoader(x, + num_workers=args.num_workers, + batch_sampler=gvp.data.BatchSampler( + x.node_counts, max_nodes=args.max_nodes)) + def main(): - model = gvp.models.CPDModel((6, 3), node_dim, (32, 1), edge_dim).to(device) - + print("Loading CATH dataset") cath = gvp.data.CATHDataset(path="data/chain_set.jsonl", - splits_path="data/chain_set_splits.json") - + splits_path="data/chain_set_splits.json") + trainset, valset, testset = map(gvp.data.ProteinGraphDataset, (cath.train, cath.val, cath.test)) - + if args.test_r or args.test_p: ts50set = gvp.data.ProteinGraphDataset(json.load(open(args.ts50))) model.load_state_dict(torch.load(args.test_r or args.test_p)) - + if args.test_r: - print("Testing on CATH testset"); test_recovery(model, testset) - print("Testing on TS50 set"); test_recovery(model, ts50set) - + print("Testing on CATH testset"); + test_recovery(model, testset) + print("Testing on TS50 set"); + test_recovery(model, ts50set) + elif args.test_p: - print("Testing on CATH testset"); test_perplexity(model, testset) - print("Testing on TS50 set"); test_perplexity(model, ts50set) - + print("Testing on CATH testset"); + test_perplexity(model, testset) + print("Testing on TS50 set"); + test_perplexity(model, ts50set) + elif args.train: train(model, trainset, valset, testset) - - + + def train(model, trainset, valset, testset): train_loader, val_loader, test_loader = map(dataloader, - (trainset, valset, testset)) + (trainset, valset, testset)) optimizer = torch.optim.Adam(model.parameters()) best_path, best_val = None, np.inf lookup = train_loader.dataset.num_to_letter @@ -89,25 +94,26 @@ def train(model, trainset, valset, testset): torch.save(model.state_dict(), path) print(f'EPOCH {epoch} TRAIN loss: {loss:.4f} acc: {acc:.4f}') print_confusion(confusion, lookup=lookup) - + model.eval() with torch.no_grad(): - loss, acc, confusion = loop(model, val_loader) + loss, acc, confusion = loop(model, val_loader) print(f'EPOCH {epoch} VAL loss: {loss:.4f} acc: {acc:.4f}') print_confusion(confusion, lookup=lookup) - + if loss < best_val: best_path, best_val = path, loss print(f'BEST {best_path} VAL loss: {best_val:.4f}') - + print(f"TESTING: loading from {best_path}") model.load_state_dict(torch.load(best_path)) - + model.eval() with torch.no_grad(): loss, acc, confusion = loop(model, test_loader) print(f'TEST loss: {loss:.4f} acc: {acc:.4f}') - print_confusion(confusion,lookup=lookup) + print_confusion(confusion, lookup=lookup) + def test_perplexity(model, dataset): model.eval() @@ -116,37 +122,38 @@ def test_perplexity(model, dataset): print(f'TEST perplexity: {np.exp(loss):.4f}') print_confusion(confusion, lookup=dataset.num_to_letter) + def test_recovery(model, dataset): recovery = [] - + for protein in tqdm.tqdm(dataset): protein = protein.to(device) h_V = (protein.node_s, protein.node_v) - h_E = (protein.edge_s, protein.edge_v) - sample = model.sample(h_V, protein.edge_index, + h_E = (protein.edge_s, protein.edge_v) + sample = model.sample(h_V, protein.edge_index, h_E, n_samples=args.n_samples) - + recovery_ = sample.eq(protein.seq).float().mean().cpu().numpy() recovery.append(recovery_) print(protein.name, recovery_, flush=True) recovery = np.median(recovery) print(f'TEST recovery: {recovery:.4f}') - -def loop(model, dataloader, optimizer=None): + +def loop(model, dataloader, optimizer=None): confusion = np.zeros((20, 20)) t = tqdm.tqdm(dataloader) loss_fn = nn.CrossEntropyLoss() total_loss, total_correct, total_count = 0, 0, 0 - + for batch in t: if optimizer: optimizer.zero_grad() - + batch = batch.to(device) h_V = (batch.node_s, batch.node_v) h_E = (batch.edge_s, batch.edge_v) - + logits = model(h_V, batch.edge_index, h_E, seq=batch.seq) logits, seq = logits[batch.mask], batch.seq[batch.mask] loss_value = loss_fn(logits, seq) @@ -162,12 +169,13 @@ def loop(model, dataloader, optimizer=None): true = seq.detach().cpu().numpy() total_correct += (pred == true).sum() confusion += confusion_matrix(true, pred, labels=range(20)) - t.set_description("%.5f" % float(total_loss/total_count)) - + t.set_description("%.5f" % float(total_loss / total_count)) + torch.cuda.empty_cache() - + return total_loss / total_count, total_correct / total_count, confusion - + + def print_confusion(mat, lookup): counts = mat.astype(np.int32) mat = (counts.T / counts.sum(axis=-1, keepdims=True).T).T @@ -181,6 +189,7 @@ def print_confusion(mat, lookup): res += '\t'.join('{}'.format(n) for n in mat[i]) res += '\t{}\n'.format(sum(counts[i])) print(res) - -if __name__== "__main__": - main() \ No newline at end of file + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 7314149..de4e64b 100644 --- a/setup.py +++ b/setup.py @@ -25,4 +25,4 @@ 'sklearn', 'atom3d' ] -) \ No newline at end of file +) diff --git a/test_equivariance.py b/test_equivariance.py index a32610c..067e7ed 100644 --- a/test_equivariance.py +++ b/test_equivariance.py @@ -1,10 +1,12 @@ -import gvp -import gvp.models -import gvp.data +import math +import unittest + import torch from torch import nn -from scipy.spatial.transform import Rotation -import unittest + +import gvp +import gvp.data +import gvp.models device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -12,20 +14,21 @@ edge_dim = (32, 1) n_nodes = 300 n_edges = 10000 +d = 100 -nodes = gvp.randn(n_nodes, node_dim, device=device) -edges = gvp.randn(n_edges, edge_dim, device=device) +nodes = gvp.randn(n_nodes, node_dim, d=d, device=device) +edges = gvp.randn(n_edges, edge_dim, d=d, device=device) edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device) batch_idx = torch.randint(0, 5, (n_nodes,), device=device) seq = torch.randint(0, 20, (n_nodes,), device=device) -class EquivarianceTest(unittest.TestCase): +class EquivarianceTest(unittest.TestCase): def test_gvp(self): model = gvp.GVP(node_dim, node_dim).to(device).eval() model_fn = lambda h_V, h_E: model(h_V) test_equivariance(model_fn, nodes, edges) - + def test_gvp_vector_gate(self): model = gvp.GVP(node_dim, node_dim, vector_gate=True).to(device).eval() model_fn = lambda h_V, h_E: model(h_V) @@ -39,7 +42,7 @@ def test_gvp_sequence(self): ).to(device).eval() model_fn = lambda h_V, h_E: model(h_V) test_equivariance(model_fn, nodes, edges) - + def test_gvp_sequence_vector_gate(self): model = nn.Sequential( gvp.GVP(node_dim, node_dim, vector_gate=True), @@ -48,59 +51,81 @@ def test_gvp_sequence_vector_gate(self): ).to(device).eval() model_fn = lambda h_V, h_E: model(h_V) test_equivariance(model_fn, nodes, edges) - + def test_gvp_conv(self): - model = gvp.GVPConv(node_dim, node_dim, edge_dim).to(device).eval() + model = gvp.GVPConv(node_dim, node_dim, edge_dim, vector_dim=d).to(device).eval() model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E) test_equivariance(model_fn, nodes, edges) - + def test_gvp_conv_vector_gate(self): - model = gvp.GVPConv(node_dim, node_dim, edge_dim, vector_gate=True).to(device).eval() + model = gvp.GVPConv(node_dim, node_dim, edge_dim, vector_dim=d, vector_gate=True).to(device).eval() model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E) test_equivariance(model_fn, nodes, edges) - + def test_gvp_conv_layer(self): - model = gvp.GVPConvLayer(node_dim, edge_dim).to(device).eval() + model = gvp.GVPConvLayer(node_dim, edge_dim, vector_dim=d).to(device).eval() model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E, autoregressive_x=h_V) test_equivariance(model_fn, nodes, edges) - + def test_gvp_conv_layer_vector_gate(self): - model = gvp.GVPConvLayer(node_dim, edge_dim, vector_gate=True).to(device).eval() + model = gvp.GVPConvLayer(node_dim, edge_dim, vector_dim=d, vector_gate=True).to(device).eval() model_fn = lambda h_V, h_E: model(h_V, edge_index, h_E, autoregressive_x=h_V) test_equivariance(model_fn, nodes, edges) - + def test_mqa_model(self): - model = gvp.models.MQAModel(node_dim, node_dim, - edge_dim, edge_dim).to(device).eval() - model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, batch=batch_idx), \ + model = gvp.models.MQAModel(node_dim, node_dim, + edge_dim, edge_dim, vector_dim=d).to(device).eval() + model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, batch=batch_idx), torch.zeros_like(nodes[1])) test_equivariance(model_fn, nodes, edges) - + def test_cpd_model(self): - model = gvp.models.CPDModel(node_dim, node_dim, - edge_dim, edge_dim).to(device).eval() - model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, seq=seq), \ + model = gvp.models.CPDModel(node_dim, node_dim, + edge_dim, edge_dim, vector_dim=d).to(device).eval() + model_fn = lambda h_V, h_E: (model(h_V, edge_index, h_E, seq=seq), torch.zeros_like(nodes[1])) test_equivariance(model_fn, nodes, edges) - - + + def test_equivariance(model, nodes, edges): - - random = torch.as_tensor(Rotation.random().as_matrix(), - dtype=torch.float32, device=device) - + def rotation_matrix(theta: torch.Tensor, n_1: torch.Tensor, n_2: torch.Tensor) -> torch.Tensor: + """ + This method returns a rotation matrix which rotates any vector + in the 2 dimensional plane spanned by + @n1 and @n2 an angle @theta. The vectors @n1 and @n2 have to be orthogonal. + Inspired by + https://analyticphysics.com/Higher%20Dimensions/Rotations%20in%20Higher%20Dimensions.htm + :param @n1: first vector spanning 2-d rotation plane, needs to be orthogonal to @n2 + :param @n2: second vector spanning 2-d rotation plane, needs to be orthogonal to @n1 + :param @theta: rotation angle + :returns : rotation matrix + """ + dim = len(n_1) + assert len(n_1) == len(n_2) + assert (n_1.dot(n_2).abs() < 1e-4) + return (torch.eye(dim) + + (torch.outer(n_2, n_1) - torch.outer(n_1, n_2)) * torch.sin(theta) + + (torch.outer(n_1, n_1) + torch.outer(n_2, n_2)) * (torch.cos(theta) - 1) + ) + with torch.no_grad(): - out_s, out_v = model(nodes, edges) + theta = 2 * math.pi * torch.rand(1) + v1 = torch.rand(d) + v1 /= v1.norm() + v2 = torch.rand(d) + v2 -= v1.dot(v2) * v1 + v2 /= v2.norm() + random = rotation_matrix(theta, v1, v2) n_v_rot, e_v_rot = nodes[1] @ random, edges[1] @ random out_v_rot = out_v @ random out_s_prime, out_v_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot)) - + assert torch.allclose(out_s, out_s_prime, atol=1e-5, rtol=1e-4) assert torch.allclose(out_v_rot, out_v_prime, atol=1e-5, rtol=1e-4) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()