Skip to content

Commit 3d0de92

Browse files
author
kaminow
committed
Remove BoltzmannCombination references.
1 parent e74d59c commit 3d0de92

File tree

2 files changed

+1
-37
lines changed

2 files changed

+1
-37
lines changed

mtenn/config.py

-6
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,10 @@ class CombinationConfig(StringEnum):
128128
* mean: :py:class:`MeanCombination <mtenn.combination.MeanCombination>`
129129
130130
* max: :py:class:`MaxCombination <mtenn.combination.MaxCombination>`
131-
132-
* boltzmann:
133-
:py:class:`BoltzmannCombination <mtenn.combination.BoltzmannCombination>`
134131
"""
135132

136133
mean = "mean"
137134
max = "max"
138-
boltzmann = "boltzmann"
139135

140136

141137
class ModelConfigBase(BaseModel):
@@ -273,8 +269,6 @@ def build(self) -> mtenn.model.Model:
273269
mtenn_combination = mtenn.combination.MaxCombination(
274270
negate_preds=self.max_comb_neg, pred_scale=self.max_comb_scale
275271
)
276-
case CombinationConfig.boltzmann:
277-
mtenn_combination = mtenn.combination.BoltzmannCombination()
278272
case None:
279273
mtenn_combination = None
280274

mtenn/tests/test_combination.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination
6+
from mtenn.combination import MeanCombination, MaxCombination
77
from mtenn.conversion_utils.schnet import SchNet
88

99

@@ -89,33 +89,3 @@ def test_max_combination(models_and_inputs):
8989
for n, p in model_test.named_parameters()
9090
]
9191
)
92-
93-
94-
def test_boltzmann_combination(models_and_inputs):
95-
model_test, model_ref, inp_list, target, loss_func = models_and_inputs
96-
97-
# Ref calc
98-
pred_list = torch.stack([model_ref(X)[0] for X in inp_list])
99-
w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0))
100-
pred_ref = torch.dot(w.flatten(), pred_list.flatten())
101-
loss = loss_func(pred_ref, target)
102-
loss.backward()
103-
104-
# Finish setting up GroupedModel
105-
model_test = SchNet.get_model(
106-
model_test, grouped=True, strategy="complex", combination=BoltzmannCombination()
107-
)
108-
109-
# Test GroupedModel
110-
pred_test, _ = model_test(inp_list)
111-
loss = loss_func(pred_test, target)
112-
loss.backward()
113-
114-
# Compare
115-
ref_param_dict = dict(model_ref.named_parameters())
116-
assert all(
117-
[
118-
np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7)
119-
for n, p in model_test.named_parameters()
120-
]
121-
)

0 commit comments

Comments
 (0)