Skip to content

Commit 5ffb31f

Browse files
Merging 0f1d781 into trunk-temp/pr-344/816e4f01-336c-419f-8b5f-a6adf97a4b6f
2 parents 5f5bfc7 + 0f1d781 commit 5ffb31f

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pysages/backends/lammps.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ def add_bias(forces, biases):
154154
def add_bias(forces, biases):
155155
forces[:, :3] += factor * biases
156156

157+
def restore_vm(view, snapshot, prev_snapshot):
158+
velocities = view(snapshot.vel_mass[0])
159+
masses_types = snapshot.vel_mass[1]
160+
masses = view(masses_types[0])
161+
types = view(masses_types[1])
162+
prev_masses_types = prev_snapshot.vel_mass[1]
163+
velocities[:] = view(prev_snapshot.vel_mass[0])
164+
masses[:] = view(prev_masses_types[0])
165+
types[:] = view(prev_masses_types[1])
166+
157167
# TODO: check if this can be sped up. # pylint: disable=W0511
158168
def bias(snapshot, state):
159169
"""Adds the computed bias to the forces."""
@@ -166,7 +176,7 @@ def bias(snapshot, state):
166176

167177
snapshot_methods = build_snapshot_methods(sampling_method, on_gpu)
168178
flags = sampling_method.snapshot_flags
169-
restore = partial(restore_fn, view)
179+
restore = partial(restore_fn, view, restore_vm=restore_vm)
170180
helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: dim)
171181

172182
return helpers, restore, bias

0 commit comments

Comments
 (0)