Skip to content

Commit 2b295e2

Browse files
committed
fix bug in refernce.compute
1 parent 669f50a commit 2b295e2

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

psiflow/reference/reference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def compute(self, dataset: Dataset, *outputs: Optional[Union[str, tuple]]):
115115
for output in outputs_:
116116
if output not in self.outputs:
117117
raise ValueError("output {} not in {}".format(output, self.outputs))
118-
index = outputs_.index(output)
118+
index = self.outputs.index(output)
119119
to_return.append(compute_outputs[index])
120120
if len(outputs_) == 1:
121121
return to_return[0]

tests/test_reference.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ def test_reference_d3(context, dataset, tmp_path):
181181
assert state.energy is not None
182182
assert state.energy < 0.0 # dispersion is attractive
183183

184-
data = dataset[:3].evaluate(reference)
185-
energy = reference.compute(dataset[:3], "energy")
184+
subset = dataset[:3]
185+
data = subset.evaluate(reference)
186+
energy = reference.compute(subset, "energy")
187+
forces = reference.compute(subset, "forces")
186188

187189
assert np.allclose(
188190
data.get("energy").result(),
@@ -193,6 +195,7 @@ def test_reference_d3(context, dataset, tmp_path):
193195
0.0,
194196
)
195197

198+
assert len(forces.result().shape) == 3
196199

197200
@pytest.mark.filterwarnings("ignore:Original input file not found")
198201
def test_cp2k_success(context, simple_cp2k_input):

0 commit comments

Comments
 (0)