Skip to content

Commit aafd56b

Browse files
authored
Merge pull request #25 from choderalab/fix-boltzmann-mean
Fix Boltzmann mean
2 parents 07ea82b + 6a17958 commit aafd56b

File tree

1 file changed

+77
-31
lines changed

1 file changed

+77
-31
lines changed

mtenn/model.py

+77-31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import permutations
33
import os
44
import torch
5+
from typing import Optional
56

67

78
class Model(torch.nn.Module):
@@ -328,55 +329,98 @@ def forward(self, predictions: torch.Tensor):
328329
return torch.mean(predictions)
329330

330331

331-
class BoltzmannCombination(Combination):
332+
class MaxCombination(Combination):
332333
"""
333-
Combine a list of deltaG predictions according to their Boltzmann weight.
334+
Approximate max/min of the predictions using the LogSumExp function for smoothness.
334335
"""
335336

336-
def __init__(self):
337-
super(BoltzmannCombination, self).__init__()
337+
def __init__(self, neg=True, scale=1000.0):
338+
"""
339+
Parameters
340+
----------
341+
neg : bool, default=True
342+
Negate the predictions before calculating the LSE, effectively finding
343+
the min. Preds are negated again before being returned
344+
scale : float, default=1000.0
345+
Fixed positive value to scale predictions by before taking the LSE. This
346+
tightens the bounds of the LSE approximation
347+
"""
348+
super(MaxCombination, self).__init__()
349+
350+
self.neg = -1 * neg
351+
self.scale = scale
338352

339-
from simtk.unit import (
340-
BOLTZMANN_CONSTANT_kB as kB,
341-
elementary_charge,
342-
coulomb,
353+
def forward(self, predictions: torch.Tensor):
354+
return (
355+
self.neg
356+
* torch.logsumexp(self.neg * self.scale * predictions, dim=0)
357+
/ self.scale
343358
)
344359

345-
## Convert kB to eV (calibrate to SchNet predictions)
346-
electron_volt = elementary_charge.conversion_factor_to(coulomb)
347360

348-
self.kT = (kB / electron_volt * 298.0)._value
361+
class BoltzmannCombination(Combination):
362+
"""
363+
Combine a list of deltaG predictions according to their Boltzmann weight. Use LSE
364+
approximation of min energy to improve numerical stability. Treat energy in implicit
365+
kT units.
366+
"""
367+
368+
def __init__(self):
369+
super(BoltzmannCombination, self).__init__()
349370

350371
def forward(self, predictions: torch.Tensor):
351-
return -self.kT * torch.logsumexp(-predictions, dim=0)
372+
# First calculate LSE (no scale here bc math)
373+
lse = torch.logsumexp(-predictions, dim=0)
374+
# Calculate Boltzmann weights for each prediction
375+
w = torch.exp(-predictions - lse)
376+
377+
return torch.dot(w, predictions)
352378

353379

354380
class PIC50Readout(Readout):
355381
"""
356-
Readout implementation to convert delta G values to pIC50 values.
382+
Readout implementation to convert delta G values to pIC50 values. This new
383+
implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED
384+
PRIOR TO v0.3.0.
385+
Assuming implicit energy units:
386+
deltaG = ln(Ki)
387+
Ki = exp(deltaG)
388+
Using the Cheng-Prusoff equation:
389+
Ki = IC50 / (1 + [S]/Km)
390+
exp(deltaG) = IC50 / (1 + [S]/Km)
391+
IC50 = exp(deltaG) * (1 + [S]/Km)
392+
pIC50 = -log10(exp(deltaG) * (1 + [S]/Km))
393+
pIC50 = -log10(exp(deltaG)) - log10(1 + [S]/Km)
394+
pIC50 = -ln(exp(deltaG))/ln(10) - log10(1 + [S]/Km)
395+
pIC50 = -deltaG/ln(10) - log10(1 + [S]/Km)
396+
Estimating Ki as the IC50 value:
397+
Ki = IC50
398+
IC50 = exp(deltaG)
399+
pIC50 = -log10(exp(deltaG))
400+
pIC50 = -ln(exp(deltaG))/ln(10)
401+
pIC50 = -deltaG/ln(10)
357402
"""
358403

359-
def __init__(self, T=298.0):
404+
def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None):
360405
"""
361-
Initialize conversion with specified T (assume 298 K).
406+
Initialize conversion with specified substrate concentration and Km. If either
407+
is left blank, the IC50 approximation will be used.
362408
363409
Parameters
364410
----------
365-
T : float, default=298
366-
Temperature for conversion.
411+
substrate : float, optional
412+
Substrate concentration for use in the Cheng-Prusoff equation. Assumed to be
413+
in the same units as Km
414+
Km : float, optional
415+
Km value for use in the Cheng-Prusoff equation. Assumed to be in the same
416+
units as substrate
367417
"""
368418
super(PIC50Readout, self).__init__()
369419

370-
from simtk.unit import (
371-
BOLTZMANN_CONSTANT_kB as kB,
372-
elementary_charge,
373-
coulomb,
374-
)
375-
376-
## Convert kB to eV (calibrate to SchNet predictions)
377-
electron_volt = elementary_charge.conversion_factor_to(coulomb)
378-
379-
self.kT = (kB / electron_volt * T)._value
420+
if substrate and Km:
421+
self.cp_val = 1 + substrate / Km
422+
else:
423+
self.cp_val = None
380424

381425
def forward(self, delta_g):
382426
"""
@@ -392,7 +436,9 @@ def forward(self, delta_g):
392436
float
393437
Calculated pIC50 value.
394438
"""
395-
## IC50 value = exp(dG/kT) => pic50 = -log10(exp(dg/kT))
396-
## Rearrange a bit more to avoid disappearing floats:
397-
## pic50 = -dg/kT / ln(10)
398-
return -delta_g / self.kT / torch.log(torch.tensor(10, dtype=delta_g.dtype))
439+
pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype))
440+
# Using Cheng-Prusoff
441+
if self.cp_val:
442+
pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype))
443+
444+
return pic50

0 commit comments

Comments
 (0)