Skip to content

Commit c18b034

Browse files
authored
Merge pull request #66 from choderalab/connect-pred-grads
Connect the individual prediction gradients into the Combination workflow
2 parents 54e24f1 + 3d0de92 commit c18b034

File tree

5 files changed

+372
-340
lines changed

5 files changed

+372
-340
lines changed

docs/docs/combination.rst

+123-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The prediction for each pose is generated by the same single-pose model (:math:`
2020

2121
.. math::
2222
23-
\hat{y}_i = f( \mathrm{X}_i, \theta )
23+
\hat{y}_i = f( \text{X}_i, \theta )
2424
2525
and the final prediction for this compound is found by applying the combination function (:math:`h`) to this set of individual predictions:
2626

@@ -32,13 +32,13 @@ We then calculate the loss of our prediction compared to a target value
3232

3333
.. math::
3434
35-
\mathrm{loss} = L ( \hat{y}(\theta), y )
35+
\text{loss} = L ( \hat{y}(\theta), y )
3636
3737
and backprop is performed by calcuation the gradient of that loss wrt the model parameters:
3838

3939
.. math::
4040
41-
\frac{\partial \mathrm{loss}}{\partial \theta} = \frac{\partial L}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial \theta}
41+
\frac{\partial \text{loss}}{\partial \theta} = \frac{\partial L}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial \theta}
4242
4343
The :math:`\frac{\partial L}{\partial \hat{y}}` term can be calculated automatically using the ``pytorch.autograd`` capabilities.
4444
However, because we've decoupled the single-pose model predictions from the overall multi-pose prediction, we must manually account for the relation between the :math:`\frac{\partial \hat{y}}{\partial \theta}` term and the individual gradients that we calculated during the forward pass (:math:`\frac{\partial \hat{y}_i}{\partial \theta}`).
@@ -50,3 +50,123 @@ Arbitrarily, this will be some function (:math:`g`) that depends on the individu
5050
g( \hat{y}_1, ..., \hat{y}_n, \frac{\partial \hat{y}_1}{\partial \theta}, ..., \frac{\partial \hat{y}_n}{\partial \theta} )
5151
5252
In practice, this function :math:`g` will need to be analytically determined and manually implemented within the ``Combination`` block (see :ref:`the guide <new-combination-guide>` for more practical information).
53+
54+
.. _implemented-combs:
55+
56+
Math for Implemented Combinations
57+
----------------------------------
58+
59+
Below, we detail the math required for appropriately combining gradients.
60+
This math is used in the ``backward`` pass in the various ``Combination`` classes.
61+
62+
.. _imp-comb-loss-fn:
63+
64+
Loss Functions
65+
^^^^^^^^^^^^^^
66+
67+
We anticipate these ``Combination`` methods being used with a linear combination of two types of loss functions:
68+
69+
* Loss based on the final combined prediction (ie :math:`L = f(\Delta \text{G} (\theta))`)
70+
71+
* Loss based on a linear combination of the per-pose predictions (ie :math:`L = f(\Delta \text{G}_1 (\theta), \Delta \text{G}_2 (\theta), ...)`)
72+
73+
Ultimately for backprop we need to return the gradients of the loss wrt each model parameter.
74+
The gradients for each of these types of losses is given below.
75+
76+
Combined Prediction
77+
"""""""""""""""""""
78+
79+
.. math::
80+
:label: comb-grad
81+
82+
\frac{\partial L}{\partial \theta} =
83+
\frac{\partial L}{\partial \Delta \text{G}}
84+
\frac{\partial \Delta \text{G}}{\partial \theta}
85+
86+
The :math:`\frac{\partial L}{\partial \Delta \text{G}}` part of this equation will be a scalar that is calculated automatically by ``pytorch`` and fed to our ``Combination`` class.
87+
The :math:`\frac{\partial \Delta \text{G}}{\partial \theta}` parts will be computed internally.
88+
89+
Per-Pose Prediction
90+
"""""""""""""""""""
91+
92+
Because we assume this loss is based on a linear combination of the individual :math:`\Delta \text{G}_i` predictions, we can decompose the loss as:
93+
94+
.. math::
95+
:label: pose-grad
96+
97+
\frac{\partial L}{\partial \theta} =
98+
\sum_{i=1}^N
99+
\frac{\partial L}{\partial \Delta \text{G}_i}
100+
\frac{\partial \Delta \text{G}_i}{\partial \theta}
101+
102+
As before, the :math:`\frac{\partial L}{\partial \Delta \text{G}_i}` parts of this equation will be scalars calculated automatically by ``pytorch`` and fed to our ``Combination`` class, and the :math:`\frac{\partial \Delta \text{G}}{\partial \theta}` parts will be computed internally.
103+
104+
.. _mean-comb-imp:
105+
106+
Mean Combination
107+
^^^^^^^^^^^^^^^^
108+
109+
This is mostly included as an example, but it can be illustrative.
110+
111+
.. math::
112+
:label: mean-comb-pred
113+
114+
\Delta \text{G}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \Delta \text{G}_i (\theta)
115+
116+
.. math::
117+
:label: mean-comb-grad
118+
119+
\frac{\partial \Delta \text{G}(\theta)}{\partial \theta} = \frac{1}{N} \sum_{i=1}^{N} \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta}
120+
121+
.. _max-comb-imp:
122+
123+
Max Combination
124+
^^^^^^^^^^^^^^^
125+
126+
This will likely be the more useful of the currently implemented ``Combination`` implementations.
127+
In the below equations, we define the following variables:
128+
129+
* :math:`n` : A sign multiplier taking the value of :math:`-1` if we are taking the min value (generally the case if the inputs are :math:`\Delta \text{G}` values) or :math:`1` if we are taking the max
130+
* :math:`t` : A scaling value that will bring the final combined value closer to the actual value of the max/min of the input values (see `here <https://en.wikipedia.org/wiki/LogSumExp#Properties>`_ for more details).
131+
Setting :math:`t = 1` reduces this operation to the LogSumExp operation
132+
133+
.. math::
134+
:label: max-comb-pred
135+
136+
\Delta \text{G}(\theta) = n \frac{1}{t} \text{ln} \sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta))
137+
138+
We define a a constant :math:`Q` for simplicity as well as for numerical stability:
139+
140+
.. math::
141+
:label: max-comb-q
142+
143+
Q = \text{ln} \sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta))
144+
145+
.. math::
146+
:label: max-comb-grad-initial
147+
148+
\frac{\partial \Delta \text{G}(\theta)}{\partial \theta} =
149+
n^2
150+
\frac{1}{\sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta))}
151+
\sum_{i=1}^N \left[
152+
\frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta} \text{exp} (n t \Delta \text{G}_i (\theta))
153+
\right]
154+
155+
Substituting in :math:`Q`:
156+
157+
.. math::
158+
:label: max-comb-grad-sub
159+
160+
\frac{\partial \Delta \text{G}(\theta)}{\partial \theta} =
161+
\frac{1}{\text{exp}(Q)}
162+
\sum_{i=1}^N \left[
163+
\text{exp} \left( n t \Delta \text{G}_i (\theta) \right) \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta}
164+
\right]
165+
166+
.. math::
167+
:label: max-comb-grad-final
168+
169+
\frac{\partial \Delta \text{G}(\theta)}{\partial \theta} =
170+
\sum_{i=1}^N \left[
171+
\text{exp} \left( n t \Delta \text{G}_i (\theta) - Q \right) \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta}
172+
\right]

0 commit comments

Comments
 (0)