@@ -154,6 +154,16 @@ def add_bias(forces, biases):
154
154
def add_bias (forces , biases ):
155
155
forces [:, :3 ] += factor * biases
156
156
157
+ def restore_vm (view , snapshot , prev_snapshot ):
158
+ velocities = view (snapshot .vel_mass [0 ])
159
+ masses_types = snapshot .vel_mass [1 ]
160
+ masses = view (masses_types [0 ])
161
+ types = view (masses_types [1 ])
162
+ prev_masses_types = prev_snapshot .vel_mass [1 ]
163
+ velocities [:] = view (prev_snapshot .vel_mass [0 ])
164
+ masses [:] = view (prev_masses_types [0 ])
165
+ types [:] = view (prev_masses_types [1 ])
166
+
157
167
# TODO: check if this can be sped up. # pylint: disable=W0511
158
168
def bias (snapshot , state ):
159
169
"""Adds the computed bias to the forces."""
@@ -166,7 +176,7 @@ def bias(snapshot, state):
166
176
167
177
snapshot_methods = build_snapshot_methods (sampling_method , on_gpu )
168
178
flags = sampling_method .snapshot_flags
169
- restore = partial (restore_fn , view )
179
+ restore = partial (restore_fn , view , restore_vm = restore_vm )
170
180
helpers = HelperMethods (build_data_querier (snapshot_methods , flags ), lambda : dim )
171
181
172
182
return helpers , restore , bias
0 commit comments