|
| 1 | +As of v0.4.0 the `Combination` class has been reworked to be able to run on normal sized |
| 2 | +GPUs. Due to the size of the all-atom protein-ligand complex representation, storing all |
| 3 | +of the autograd computation graphs for every pose used all the GPU memory. By splitting |
| 4 | +the gradient math up into a function of the gradient from each pose, we can reduce the |
| 5 | +need to store more than one comp graph at a time. This document contains the derivation |
| 6 | +of the split up math. |
| 7 | + |
| 8 | +# `MSE Loss` |
| 9 | +```math |
| 10 | +L = (\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}})^2 |
| 11 | +``` |
| 12 | +```math |
| 13 | +\frac{\partial L}{\partial \theta} = 2(\Delta G_{\mathrm{pred}} \left ( \theta \right ) - \Delta G_{\mathrm{target}}) \frac{\partial \Delta G_{\mathrm{pred}} \left ( \theta \right )}{\partial \theta} |
| 14 | +``` |
| 15 | + |
| 16 | +# `MeanCombination` |
| 17 | +Just take the mean of all preds, so the gradient is straightforward: |
| 18 | +```math |
| 19 | +\Delta G(\theta) = \frac{1}{N} \sum_{n=1}^{N} \Delta G_n (\theta) |
| 20 | +``` |
| 21 | +```math |
| 22 | +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \Delta G_n (\theta)}{\partial \theta} |
| 23 | +``` |
| 24 | + |
| 25 | +# `MaxCombination` |
| 26 | +Combine according to a smooth max approximation using LSE: |
| 27 | +```math |
| 28 | +\Delta G(\theta) = \frac{-1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta)) |
| 29 | +``` |
| 30 | +```math |
| 31 | +Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta)) |
| 32 | +``` |
| 33 | +```math |
| 34 | +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-t \Delta G_n (\theta))} \sum_{n=1}^N \left[ \frac{\partial \Delta G_n (\theta)}{\partial \theta} \mathrm{exp} (-t \Delta G_n (\theta)) \right] |
| 35 | +``` |
| 36 | +```math |
| 37 | +\frac{\partial \Delta G(\theta)}{\partial \theta} = \frac{1}{\mathrm{exp}(Q)} \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] |
| 38 | +``` |
| 39 | +```math |
| 40 | +\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \mathrm{exp} \left( -t \Delta G_n (\theta) - Q \right) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] |
| 41 | +``` |
| 42 | +# `BoltzmannCombination` |
| 43 | +Combine according to Boltzmann weighting: |
| 44 | +```math |
| 45 | +\Delta G(\theta) = \sum_{n=1}^{N} w_n \Delta G_n (\theta) |
| 46 | +``` |
| 47 | + |
| 48 | +```math |
| 49 | +w_n = \mathrm{exp} \left[ -\Delta G_n (\theta) - \mathrm{ln} \sum_{i=1}^N \mathrm{exp} (-\Delta G_i (\theta)) \right] |
| 50 | +``` |
| 51 | + |
| 52 | +```math |
| 53 | +Q = \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta)) |
| 54 | +``` |
| 55 | + |
| 56 | +```math |
| 57 | +\frac{\partial \Delta G(\theta)}{\partial \theta} = \sum_{n=1}^N \left[ \frac{\partial w_n}{\partial \theta} \Delta G_n (\theta) + w_n \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] |
| 58 | +``` |
| 59 | + |
| 60 | +```math |
| 61 | +\frac{\partial w_n}{\partial \theta} = \mathrm{exp} \left[ -\Delta G_n (\theta) - Q \right] \left[ \frac{-\partial \Delta G_n (\theta)}{\partial \theta} - \frac{\partial Q}{\partial \theta} \right] |
| 62 | +``` |
| 63 | + |
| 64 | +```math |
| 65 | +\frac{\partial Q}{\partial \theta} = \frac{1}{\sum_{n=1}^N \mathrm{exp} (-\Delta G_n (\theta))} \sum_{i=1}^{N} \left[ \mathrm{exp} (-\Delta G_i (\theta)) \frac{-\partial \Delta G_i (\theta)}{\partial \theta} \right] |
| 66 | +``` |
| 67 | + |
| 68 | +```math |
| 69 | +\frac{\partial Q}{\partial \theta} = \frac{-1}{\mathrm{exp} (Q)} \sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta)) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] |
| 70 | +``` |
| 71 | + |
| 72 | +```math |
| 73 | +\frac{\partial Q}{\partial \theta} = -\sum_{n=1}^{N} \left[ \mathrm{exp} (-\Delta G_n (\theta) - Q) \frac{\partial \Delta G_n (\theta)}{\partial \theta} \right] |
| 74 | +``` |
0 commit comments