Skip to content

Commit 9681561

Browse files
committed
optimize order parameter evaluation, allow hamiltonians to be evaluated on single states
1 parent c0df3a5 commit 9681561

File tree

7 files changed

+116
-34
lines changed

7 files changed

+116
-34
lines changed

psiflow/data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def _batch_frames(
812812
@join_app
813813
@typeguard.typechecked
814814
def batch_apply(
815-
func: Callable,
815+
funcs: list[Callable],
816816
batch_size: int,
817817
length: int,
818818
inputs: list = [],
@@ -821,6 +821,9 @@ def batch_apply(
821821
nbatches = math.ceil(length / batch_size)
822822
batches = [psiflow.context().new_file("data_", ".xyz") for _ in range(nbatches)]
823823
future = batch_frames(batch_size, inputs=[inputs[0]], outputs=batches)
824-
evaluated = [func(Dataset(None, extxyz=e)) for e in future.outputs]
825-
f = join_frames(inputs=[e.extxyz for e in evaluated], outputs=[outputs[0]])
824+
datasets = [Dataset(None, extxyz=e) for e in future.outputs]
825+
for func in funcs:
826+
datasets = [func(d) for d in datasets]
827+
# evaluated = [func(Dataset(None, extxyz=e)) for e in future.outputs]
828+
f = join_frames(inputs=[d.extxyz for d in datasets], outputs=[outputs[0]])
826829
return f

psiflow/hamiltonians/_plumed.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def try_manual_plumed_linking() -> str:
4242
def remove_comments_printflush(plumed_input: str) -> str:
4343
new_input = []
4444
for line in list(plumed_input.split("\n")):
45-
if line.strip().startswith("#"):
45+
pre_comment = line.strip().split("#")[0].strip()
46+
if len(pre_comment) == 0:
4647
continue
47-
if line.strip().startswith("PRINT"):
48+
if pre_comment.startswith("PRINT"):
4849
continue
49-
if line.strip().startswith("FLUSH"):
50+
if pre_comment.startswith("FLUSH"):
5051
continue
51-
new_input.append(line)
52+
new_input.append(pre_comment)
5253
return "\n".join(new_input)
5354

5455

psiflow/hamiltonians/hamiltonian.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations # necessary for type-guarding class methods
22

33
import logging
4-
from typing import Callable, Optional
4+
from typing import Callable, Optional, Union
55

66
import typeguard
77
from parsl.app.app import python_app
88
from parsl.app.futures import DataFuture
9+
from parsl.dataflow.futures import AppFuture
910
from parsl.data_provider.files import File
1011

1112
import psiflow
@@ -22,16 +23,20 @@ def evaluate_function(
2223
outputs: list = [],
2324
parsl_resource_specification: dict = {},
2425
**parameters, # dict values can be futures, so app must wait for those
25-
) -> None:
26+
) -> Optional[Geometry]:
2627
import numpy as np
2728
from ase import Atoms
2829

2930
from psiflow.data import _read_frames, _write_frames
3031
from psiflow.geometry import NullState
3132

3233
assert len(inputs) >= 1
33-
assert len(outputs) == 1
34-
states = _read_frames(inputs=[inputs[0]])
34+
if isinstance(inputs[0], Geometry):
35+
assert len(outputs) == 0
36+
states = [inputs[0]]
37+
else:
38+
assert len(outputs) == 1
39+
states = _read_frames(inputs=[inputs[0]])
3540
calculators, index_mapping = load_calculators(states, inputs[1], **parameters)
3641
for i, state in enumerate(states):
3742
if state == NullState:
@@ -54,25 +59,41 @@ def evaluate_function(
5459
print(e)
5560
stress = np.zeros((3, 3))
5661
state.stress = stress
57-
_write_frames(*states, outputs=[outputs[0]])
62+
if isinstance(inputs[0], Geometry):
63+
return states[0]
64+
else:
65+
_write_frames(*states, outputs=[outputs[0]])
5866

5967

6068
@typeguard.typechecked
6169
@psiflow.serializable # otherwise MixtureHamiltonian.hamiltonians is not serialized
6270
class Hamiltonian:
6371
external: Optional[psiflow._DataFuture]
6472

65-
def evaluate(self, dataset: Dataset, batch_size: Optional[int] = 100) -> Dataset:
66-
future = batch_apply(
67-
self.single_evaluate,
68-
batch_size,
69-
dataset.length(),
70-
inputs=[dataset.extxyz],
71-
outputs=[
72-
psiflow.context().new_file("data_", ".xyz")
73-
], # join_app needs outputs kwarg here!
74-
)
75-
return Dataset(None, future.outputs[0])
73+
def evaluate(
74+
self,
75+
arg: Union[Dataset, Geometry, AppFuture[Geometry]],
76+
batch_size: Optional[int] = 100,
77+
) -> Union[AppFuture, Dataset]:
78+
if isinstance(arg, Dataset):
79+
future = batch_apply(
80+
[self.single_evaluate],
81+
batch_size,
82+
arg.length(),
83+
inputs=[arg.extxyz],
84+
outputs=[
85+
psiflow.context().new_file("data_", ".xyz")
86+
], # join_app needs outputs kwarg here!
87+
)
88+
return Dataset(None, future.outputs[0])
89+
else:
90+
future = self.evaluate_app(
91+
self.load_calculators,
92+
inputs=[arg, self.external],
93+
outputs=[],
94+
**self.parameters,
95+
)
96+
return future
7697

7798
# mostly for internal use
7899
def single_evaluate(self, dataset: Dataset) -> Dataset:

psiflow/sampling/order.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from __future__ import annotations # necessary for type-guarding class methods
22

3-
from typing import Union
3+
from functools import partial
4+
from typing import Union, Optional
45

56
import typeguard
67
from ase.units import kJ, mol
78
from parsl.app.app import python_app
89
from parsl.dataflow.futures import AppFuture
910

1011
import psiflow
11-
from psiflow.data import Dataset
12+
from psiflow.data import Dataset, batch_apply
1213
from psiflow.geometry import Geometry
1314
from psiflow.hamiltonians._plumed import PlumedHamiltonian
1415
from psiflow.hamiltonians.hamiltonian import Hamiltonian
1516

1617

17-
def _insert_in_state(
18+
@typeguard.typechecked
19+
def insert_in_state(
1820
state: Geometry,
1921
name: str,
2022
) -> Geometry:
@@ -24,7 +26,32 @@ def _insert_in_state(
2426
return state
2527

2628

27-
insert_in_state = python_app(_insert_in_state, executors=["default_threads"])
29+
@typeguard.typechecked
30+
def _insert(
31+
state_or_states: Union[Geometry, list[Geometry]],
32+
name: str,
33+
) -> Union[list[Geometry], Geometry]:
34+
if not isinstance(state_or_states, list):
35+
return insert_in_state(state_or_states, name)
36+
else:
37+
for state in state_or_states:
38+
insert_in_state(state, name) # modify list in place
39+
return state_or_states
40+
41+
42+
insert = python_app(_insert, executors=["default_threads"])
43+
44+
45+
@typeguard.typechecked
46+
def insert_in_dataset(
47+
data: Dataset,
48+
name: str,
49+
) -> Dataset:
50+
geometries = insert(
51+
data.geometries(),
52+
name,
53+
)
54+
return Dataset(geometries)
2855

2956

3057
@typeguard.typechecked
@@ -51,11 +78,29 @@ def __init__(self, name: str, hamiltonian: Hamiltonian):
5178
super().__init__(name)
5279
self.hamiltonian = hamiltonian
5380

54-
def evaluate(self, state: Union[Geometry, AppFuture]) -> AppFuture:
55-
return insert_in_state(
56-
self.hamiltonian.evaluate(Dataset([state]))[0],
57-
self.name,
58-
)
81+
def evaluate(
82+
self,
83+
arg: Union[Dataset, Geometry, AppFuture[Geometry]],
84+
batch_size: Optional[int] = 100,
85+
) -> Union[Dataset, AppFuture]:
86+
if isinstance(arg, Dataset):
87+
# avoid batching the dataset twice:
88+
# apply hamiltonian in batched sense and put insert afterwards
89+
funcs = [
90+
self.hamiltonian.single_evaluate,
91+
partial(insert_in_dataset, name=self.name),
92+
]
93+
future = batch_apply(
94+
funcs,
95+
batch_size,
96+
arg.length(),
97+
inputs=[arg.extxyz],
98+
outputs=[psiflow.context().new_file("data_", ".xyz")],
99+
)
100+
return Dataset(None, future.outputs[0])
101+
else:
102+
state = self.hamiltonian.evaluate(arg)
103+
return insert(state, self.name)
59104

60105
def __eq__(self, other):
61106
if type(other) is not HamiltonianOrderParameter:

psiflow/sampling/walker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def quench(walkers: list[Walker], dataset: Dataset) -> None:
187187
coefficients = []
188188
for walker in walkers:
189189
c = all_hamiltonians.get_coefficients(1.0 * walker.hamiltonian)
190+
assert c is not None
190191
coefficients.append(c)
191192
coefficients = np.array(coefficients)
192193

tests/test_hamiltonian.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test_get_filename_hills():
3232
RESTART
3333
UNITS LENGTH=A ENERGY=kj/mol TIME=fs
3434
CV: VOLUME
35-
CV0: CV
35+
CV0: CV #lkasdjf
3636
METAD ARG=CV0 SIGMA=100 HEIGHT=2 PACE=50 LABEL=metad FILE=test_hills sdld
37-
METADD ARG=CV SIGMA=100 HEIGHT=2 PACE=50 LABEL=metad sdld
37+
METADD ARG=CV SIGMA=100 HEIGHT=2 PACE=50 LABEL=metad sdld #fjalsdkfj
3838
PRINT ARG=CV,metad.bias STRIDE=10 FILE=COLVAR
3939
FLUSH STRIDE=10
4040
"""
@@ -62,6 +62,10 @@ def test_einstein(dataset, dataset_h2):
6262
for i in range(1, 10):
6363
assert evaluated[i].result().energy > 0.0
6464
assert not np.allclose(evaluated[i].result().stress, 0.0)
65+
assert np.allclose(
66+
evaluated[i].result().energy,
67+
hamiltonian.evaluate(evaluated[i]).result().energy,
68+
)
6569

6670
# test evaluation with NullState in data
6771
data = hamiltonian.evaluate(dataset[:5] + Dataset([NullState]) + dataset[5:10])

tests/test_sampling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,13 @@ def test_order_parameter(dataset):
450450
assert state.energy is None
451451
assert np.allclose(CV, np.linalg.det(dataset[3].result().cell))
452452

453+
# test batch evaluation of order parameter
454+
data = order.evaluate(dataset[:10], batch_size=5)
455+
volumes = data.get("CV").result()
456+
for i in range(10):
457+
volume = np.linalg.det(dataset[i].result().cell)
458+
assert np.allclose(volume, volumes[i])
459+
453460

454461
def test_walker_serialization(dataset, tmp_path):
455462
einstein = EinsteinCrystal(dataset[0], force_constant=0.1)

0 commit comments

Comments
 (0)