|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 | 5 |
|
6 |
| -from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination |
| 6 | +from mtenn.combination import MeanCombination, MaxCombination |
7 | 7 | from mtenn.conversion_utils.schnet import SchNet
|
8 | 8 |
|
9 | 9 |
|
@@ -89,33 +89,3 @@ def test_max_combination(models_and_inputs):
|
89 | 89 | for n, p in model_test.named_parameters()
|
90 | 90 | ]
|
91 | 91 | )
|
92 |
| - |
93 |
| - |
94 |
| -def test_boltzmann_combination(models_and_inputs): |
95 |
| - model_test, model_ref, inp_list, target, loss_func = models_and_inputs |
96 |
| - |
97 |
| - # Ref calc |
98 |
| - pred_list = torch.stack([model_ref(X)[0] for X in inp_list]) |
99 |
| - w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0)) |
100 |
| - pred_ref = torch.dot(w.flatten(), pred_list.flatten()) |
101 |
| - loss = loss_func(pred_ref, target) |
102 |
| - loss.backward() |
103 |
| - |
104 |
| - # Finish setting up GroupedModel |
105 |
| - model_test = SchNet.get_model( |
106 |
| - model_test, grouped=True, strategy="complex", combination=BoltzmannCombination() |
107 |
| - ) |
108 |
| - |
109 |
| - # Test GroupedModel |
110 |
| - pred_test, _ = model_test(inp_list) |
111 |
| - loss = loss_func(pred_test, target) |
112 |
| - loss.backward() |
113 |
| - |
114 |
| - # Compare |
115 |
| - ref_param_dict = dict(model_ref.named_parameters()) |
116 |
| - assert all( |
117 |
| - [ |
118 |
| - np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) |
119 |
| - for n, p in model_test.named_parameters() |
120 |
| - ] |
121 |
| - ) |
0 commit comments