2
2
from itertools import permutations
3
3
import os
4
4
import torch
5
+ from typing import Optional
5
6
6
7
7
8
class Model (torch .nn .Module ):
@@ -328,55 +329,98 @@ def forward(self, predictions: torch.Tensor):
328
329
return torch .mean (predictions )
329
330
330
331
331
- class BoltzmannCombination (Combination ):
332
+ class MaxCombination (Combination ):
332
333
"""
333
- Combine a list of deltaG predictions according to their Boltzmann weight .
334
+ Approximate max/min of the predictions using the LogSumExp function for smoothness .
334
335
"""
335
336
336
- def __init__ (self ):
337
- super (BoltzmannCombination , self ).__init__ ()
337
+ def __init__ (self , neg = True , scale = 1000.0 ):
338
+ """
339
+ Parameters
340
+ ----------
341
+ neg : bool, default=True
342
+ Negate the predictions before calculating the LSE, effectively finding
343
+ the min. Preds are negated again before being returned
344
+ scale : float, default=1000.0
345
+ Fixed positive value to scale predictions by before taking the LSE. This
346
+ tightens the bounds of the LSE approximation
347
+ """
348
+ super (MaxCombination , self ).__init__ ()
349
+
350
+ self .neg = - 1 * neg
351
+ self .scale = scale
338
352
339
- from simtk .unit import (
340
- BOLTZMANN_CONSTANT_kB as kB ,
341
- elementary_charge ,
342
- coulomb ,
353
+ def forward (self , predictions : torch .Tensor ):
354
+ return (
355
+ self .neg
356
+ * torch .logsumexp (self .neg * self .scale * predictions , dim = 0 )
357
+ / self .scale
343
358
)
344
359
345
- ## Convert kB to eV (calibrate to SchNet predictions)
346
- electron_volt = elementary_charge .conversion_factor_to (coulomb )
347
360
348
- self .kT = (kB / electron_volt * 298.0 )._value
361
+ class BoltzmannCombination (Combination ):
362
+ """
363
+ Combine a list of deltaG predictions according to their Boltzmann weight. Use LSE
364
+ approximation of min energy to improve numerical stability. Treat energy in implicit
365
+ kT units.
366
+ """
367
+
368
+ def __init__ (self ):
369
+ super (BoltzmannCombination , self ).__init__ ()
349
370
350
371
def forward (self , predictions : torch .Tensor ):
351
- return - self .kT * torch .logsumexp (- predictions , dim = 0 )
372
+ # First calculate LSE (no scale here bc math)
373
+ lse = torch .logsumexp (- predictions , dim = 0 )
374
+ # Calculate Boltzmann weights for each prediction
375
+ w = torch .exp (- predictions - lse )
376
+
377
+ return torch .dot (w , predictions )
352
378
353
379
354
380
class PIC50Readout (Readout ):
355
381
"""
356
- Readout implementation to convert delta G values to pIC50 values.
382
+ Readout implementation to convert delta G values to pIC50 values. This new
383
+ implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED
384
+ PRIOR TO v0.3.0.
385
+ Assuming implicit energy units:
386
+ deltaG = ln(Ki)
387
+ Ki = exp(deltaG)
388
+ Using the Cheng-Prusoff equation:
389
+ Ki = IC50 / (1 + [S]/Km)
390
+ exp(deltaG) = IC50 / (1 + [S]/Km)
391
+ IC50 = exp(deltaG) * (1 + [S]/Km)
392
+ pIC50 = -log10(exp(deltaG) * (1 + [S]/Km))
393
+ pIC50 = -log10(exp(deltaG)) - log10(1 + [S]/Km)
394
+ pIC50 = -ln(exp(deltaG))/ln(10) - log10(1 + [S]/Km)
395
+ pIC50 = -deltaG/ln(10) - log10(1 + [S]/Km)
396
+ Estimating Ki as the IC50 value:
397
+ Ki = IC50
398
+ IC50 = exp(deltaG)
399
+ pIC50 = -log10(exp(deltaG))
400
+ pIC50 = -ln(exp(deltaG))/ln(10)
401
+ pIC50 = -deltaG/ln(10)
357
402
"""
358
403
359
- def __init__ (self , T = 298.0 ):
404
+ def __init__ (self , substrate : Optional [ float ] = None , Km : Optional [ float ] = None ):
360
405
"""
361
- Initialize conversion with specified T (assume 298 K).
406
+ Initialize conversion with specified substrate concentration and Km. If either
407
+ is left blank, the IC50 approximation will be used.
362
408
363
409
Parameters
364
410
----------
365
- T : float, default=298
366
- Temperature for conversion.
411
+ substrate : float, optional
412
+ Substrate concentration for use in the Cheng-Prusoff equation. Assumed to be
413
+ in the same units as Km
414
+ Km : float, optional
415
+ Km value for use in the Cheng-Prusoff equation. Assumed to be in the same
416
+ units as substrate
367
417
"""
368
418
super (PIC50Readout , self ).__init__ ()
369
419
370
- from simtk .unit import (
371
- BOLTZMANN_CONSTANT_kB as kB ,
372
- elementary_charge ,
373
- coulomb ,
374
- )
375
-
376
- ## Convert kB to eV (calibrate to SchNet predictions)
377
- electron_volt = elementary_charge .conversion_factor_to (coulomb )
378
-
379
- self .kT = (kB / electron_volt * T )._value
420
+ if substrate and Km :
421
+ self .cp_val = 1 + substrate / Km
422
+ else :
423
+ self .cp_val = None
380
424
381
425
def forward (self , delta_g ):
382
426
"""
@@ -392,7 +436,9 @@ def forward(self, delta_g):
392
436
float
393
437
Calculated pIC50 value.
394
438
"""
395
- ## IC50 value = exp(dG/kT) => pic50 = -log10(exp(dg/kT))
396
- ## Rearrange a bit more to avoid disappearing floats:
397
- ## pic50 = -dg/kT / ln(10)
398
- return - delta_g / self .kT / torch .log (torch .tensor (10 , dtype = delta_g .dtype ))
439
+ pic50 = - delta_g / torch .log (torch .tensor (10 , dtype = delta_g .dtype ))
440
+ # Using Cheng-Prusoff
441
+ if self .cp_val :
442
+ pic50 -= torch .log10 (torch .tensor (self .cp_val , dtype = delta_g .dtype ))
443
+
444
+ return pic50
0 commit comments