@@ -154,6 +154,16 @@ def add_bias(forces, biases):
154154 def add_bias (forces , biases ):
155155 forces [:, :3 ] += factor * biases
156156
157+ def restore_vm (view , snapshot , prev_snapshot ):
158+ velocities = view (snapshot .vel_mass [0 ])
159+ masses_types = snapshot .vel_mass [1 ]
160+ masses = view (masses_types [0 ])
161+ types = view (masses_types [1 ])
162+ prev_masses_types = prev_snapshot .vel_mass [1 ]
163+ velocities [:] = view (prev_snapshot .vel_mass [0 ])
164+ masses [:] = view (prev_masses_types [0 ])
165+ types [:] = view (prev_masses_types [1 ])
166+
157167 # TODO: check if this can be sped up. # pylint: disable=W0511
158168 def bias (snapshot , state ):
159169 """Adds the computed bias to the forces."""
@@ -166,7 +176,7 @@ def bias(snapshot, state):
166176
167177 snapshot_methods = build_snapshot_methods (sampling_method , on_gpu )
168178 flags = sampling_method .snapshot_flags
169- restore = partial (restore_fn , view )
179+ restore = partial (restore_fn , view , restore_vm = restore_vm )
170180 helpers = HelperMethods (build_data_querier (snapshot_methods , flags ), lambda : dim )
171181
172182 return helpers , restore , bias
0 commit comments