Skip to content

Commit 6f6d8e8

Browse files
authored
Merge pull request #26 from choderalab/split-comb-calcs
Rework Combination class
2 parents 54c94b0 + 58c43f2 commit 6f6d8e8

13 files changed

+873
-270
lines changed

README_COMBINATION.md

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
As of v0.4.0 the `Combination` class has been reworked to be able to run on normal sized
2+
GPUs. Due to the size of the all-atom protein-ligand complex representation, storing all
3+
of the autograd computation graphs for every pose used all the GPU memory. By splitting
4+
the gradient math up into a function of the gradient from each pose, we can reduce the
5+
need to store more than one comp graph at a time. This document contains the derivation
6+
of the split up math.
7+
8+
# `MSE Loss`
9+
```math
10+
L = (\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}})^2
11+
```
12+
```math
13+
\frac{\partial L}{\partial \theta} = 2(\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}}) \frac{\partial \Delta G_{\mathrm{pred}} \left ( \theta \right )}{\partial \theta}
14+
```
15+
16+
# `MeanCombination`
17+
Just take the mean of all preds, so the gradient is straightforward:
18+
```math
19+
\Delta G(\theta) = \frac{1}{N} \sum_{n=1}^{N} \Delta G_n (\theta)
20+
```
21+
```math
22+
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \Delta G_n (\theta)}{\partial \theta}
23+
```
24+
25+
# `MaxCombination`
26+
Combine according to a smooth max approximation using LSE:
27+
```math
28+
\Delta G(\theta) = \frac{-1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))
29+
```
30+
```math
31+
Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))
32+
```
33+
```math
34+
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))} \sum_{n=1}^N \left[ \frac{\partial \Delta G_n (\theta)}{\partial \theta} \mathrm{exp} (-t \Delta G_n (\theta)) \right]
35+
```
36+
```math
37+
\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\mathrm{exp}(Q)} \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
38+
```
39+
```math
40+
\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) - Q \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
41+
```
42+
# `BoltzmannCombination`
43+
Combine according to Boltzmann weighting:
44+
```math
45+
\Delta G(\theta) = \sum_{n=1}^{N} w_n \Delta G_n (\theta)
46+
```
47+
48+
```math
49+
w_n = \mathrm{exp} \left[ -\Delta G_n (\theta) - \mathrm{ln} \sum_{i=1}^N \mathrm{exp} (-\Delta G_i (\theta)) \right]
50+
```
51+
52+
```math
53+
Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))
54+
```
55+
56+
```math
57+
\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \frac{\partial w_n}{\partial \theta} \Delta G_n (\theta) + w_n \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
58+
```
59+
60+
```math
61+
\frac{\partial w_n}{\partial \theta} = \mathrm{exp} \left[ -\Delta G_n (\theta) - Q \right] \left[ \frac{-\partial \Delta G_n (\theta)}{\partial \theta} - \frac{\partial Q}{\partial \theta} \right]
62+
```
63+
64+
```math
65+
\frac{\partial Q}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))} \sum_{i=1}^{N} \left[ \mathrm{exp} (-\Delta G_i (\theta)) \frac{-\partial \Delta G_i (\theta)}{\partial \theta} \right]
66+
```
67+
68+
```math
69+
\frac{\partial Q}{\partial \theta} = \frac{-1}{\mathrm{exp} (Q)} \sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta)) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
70+
```
71+
72+
```math
73+
\frac{\partial Q}{\partial \theta} = -\sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta) - Q) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right]
74+
```

devtools/conda-envs/test_env.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies:
1313
- dgllife
1414
- dgl
1515
- rdkit
16+
- ase
1617
# testing dependencies
1718
- pytest
1819
- pytest-cov

environment-gpu.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
name: mtenn-gpu
22
channels:
33
- conda-forge
4-
- dglteam
54
dependencies:
65
- pytorch
76
- pytorch-gpu
@@ -14,3 +13,5 @@ dependencies:
1413
- e3nn
1514
- dgllife
1615
- dgl
16+
- rdkit
17+
- ase

environment.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
name: mtenn
22
channels:
33
- conda-forge
4-
- dglteam
54
dependencies:
65
- pytorch
76
- pytorch_geometric
@@ -13,3 +12,5 @@ dependencies:
1312
- e3nn
1413
- dgllife
1514
- dgl
15+
- rdkit
16+
- ase

0 commit comments

Comments
 (0)