Skip to content

Commit 8b8d41f

Browse files
committed
reunited hamiltonian and reference compute interface
1 parent 0d33c73 commit 8b8d41f

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

psiflow/data/dataset.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def get_length(arg):
553553
def compute(
554554
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
555555
*apply_apps: Union[PythonApp, Callable],
556-
outputs_: Union[str, list[str], None] = None,
556+
outputs_: Union[str, list[str], tuple[str, ...], None] = None,
557557
reduce_func: Union[PythonApp, Callable] = aggregate_multiple,
558558
batch_size: Optional[int] = None,
559559
) -> Union[list[AppFuture], AppFuture]:
@@ -570,7 +570,7 @@ def compute(
570570
Returns:
571571
Union[list[AppFuture], AppFuture]: Future(s) representing computation results.
572572
"""
573-
if outputs_ is not None and not isinstance(outputs_, list):
573+
if type(outputs_) is str:
574574
outputs_ = [outputs_]
575575
if batch_size is not None:
576576
if isinstance(arg, Dataset):
@@ -627,7 +627,7 @@ class Computable:
627627
def compute(
628628
self,
629629
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
630-
outputs: Union[str, list[str], None] = None,
630+
*outputs: Optional[str],
631631
batch_size: Optional[int] = -1, # if -1: take class default
632632
) -> Union[list[AppFuture], AppFuture]:
633633
"""
@@ -641,13 +641,4 @@ def compute(
641641
Returns:
642642
Union[list[AppFuture], AppFuture]: Future(s) representing computation results.
643643
"""
644-
if outputs is None:
645-
outputs = list(self.__class__.outputs)
646-
if batch_size == -1:
647-
batch_size = self.__class__.batch_size
648-
return compute(
649-
arg,
650-
self.app,
651-
outputs_=outputs,
652-
batch_size=batch_size,
653-
)
644+
raise NotImplementedError

psiflow/hamiltonians.py

+17
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ class Hamiltonian(Computable):
3434
outputs: ClassVar[tuple] = ("energy", "forces", "stress")
3535
batch_size = 1000
3636

37+
def compute(
38+
self,
39+
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
40+
*outputs: Optional[str],
41+
batch_size: Optional[int] = -1, # if -1: take class default
42+
) -> Union[list[AppFuture], AppFuture]:
43+
if len(outputs) == 0:
44+
outputs = tuple(self.__class__.outputs)
45+
if batch_size == -1:
46+
batch_size = self.__class__.batch_size
47+
return compute(
48+
arg,
49+
self.app,
50+
outputs_=outputs,
51+
batch_size=batch_size,
52+
)
53+
3754
def __eq__(self, hamiltonian: Hamiltonian) -> bool:
3855
raise NotImplementedError
3956

tests/test_function.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_einstein_crystal(dataset):
4545
hamiltonian = EinsteinCrystal(dataset[0], force_constant=1.0)
4646

4747
forces_, stress_, energy_ = hamiltonian.compute(
48-
dataset[:4], outputs=["forces", "stress", "energy"]
48+
dataset[:4], "forces", "stress", "energy"
4949
)
5050
assert np.allclose(
5151
energy_.result(),
@@ -56,7 +56,7 @@ def test_einstein_crystal(dataset):
5656
forces,
5757
)
5858

59-
forces = hamiltonian.compute(dataset[:4], outputs=["forces"], batch_size=3)
59+
forces = hamiltonian.compute(dataset[:4], "forces", batch_size=3)
6060
assert np.allclose(
6161
forces.result(),
6262
forces_.result(),
@@ -173,7 +173,7 @@ def test_plumed_function(tmp_path, dataset, dataset_h2):
173173
distance = np.linalg.norm(positions[:, 0, :] - positions[:, 1, :], axis=1)
174174
distance = distance.reshape(-1, 1)
175175

176-
energy = hamiltonian.compute(dataset[:10], ["energy"]).result()
176+
energy = hamiltonian.compute(dataset[:10], "energy").result()
177177

178178
sigma = 2 * np.ones((1, 2))
179179
height = np.array([70, 70]).reshape(1, -1) * (kJ / mol) # unit consistency

0 commit comments

Comments
 (0)