Skip to content

Commit a3672f2

Browse files
authored
Implementation of Parallelization to MDAnalysis.analysis.contacts (#4820)
* Fixes #4660 * summary of changes: - added backends and aggregators to Contacts in analysis.contacts - added private _get_box_func method because lambdas cannot be used for parallelization - added the client_Contacts in conftest.py - added client_Contacts in run() in test_contacts.py * Update CHANGELOG
1 parent 80b28c8 commit a3672f2

File tree

4 files changed

+105
-40
lines changed

4 files changed

+105
-40
lines changed

package/CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ Fixes
2525
the function to prevent shared state. (Issue #4655)
2626

2727
Enhancements
28+
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
2829
* Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670)
2930
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)
3031
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
3132

33+
3234
Changes
3335

3436
Deprecations

package/MDAnalysis/analysis/contacts.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def is_any_closer(r, r0, dist=2.5):
223223
from MDAnalysis.lib.util import openany
224224
from MDAnalysis.analysis.distances import distance_array
225225
from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup
226-
from .base import AnalysisBase
226+
from .base import AnalysisBase, ResultsGroup
227227

228228
logger = logging.getLogger("MDAnalysis.analysis.contacts")
229229

@@ -376,8 +376,22 @@ class Contacts(AnalysisBase):
376376
:class:`MDAnalysis.analysis.base.Results` instance.
377377
.. versionchanged:: 2.2.0
378378
:class:`Contacts` accepts both AtomGroup and string for `select`
379+
.. versionchanged:: 2.9.0
380+
Introduced :meth:`get_supported_backends` allowing
381+
for parallel execution on :mod:`multiprocessing`
382+
and :mod:`dask` backends.
379383
"""
380384

385+
_analysis_algorithm_is_parallelizable = True
386+
387+
@classmethod
388+
def get_supported_backends(cls):
389+
return (
390+
"serial",
391+
"multiprocessing",
392+
"dask",
393+
)
394+
381395
def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
382396
pbc=True, kwargs=None, **basekwargs):
383397
"""
@@ -444,11 +458,8 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
444458
self.r0 = []
445459
self.initial_contacts = []
446460

447-
#get dimension of box if pbc set to True
448-
if self.pbc:
449-
self._get_box = lambda ts: ts.dimensions
450-
else:
451-
self._get_box = lambda ts: None
461+
# get dimensions via partial for parallelization compatibility
462+
self._get_box = functools.partial(self._get_box_func, pbc=self.pbc)
452463

453464
if isinstance(refgroup[0], AtomGroup):
454465
refA, refB = refgroup
@@ -464,7 +475,6 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
464475

465476
self.n_initial_contacts = self.initial_contacts[0].sum()
466477

467-
468478
@staticmethod
469479
def _get_atomgroup(u, sel):
470480
select_error_message = ("selection must be either string or a "
@@ -480,6 +490,28 @@ def _get_atomgroup(u, sel):
480490
else:
481491
raise TypeError(select_error_message)
482492

493+
@staticmethod
494+
def _get_box_func(ts, pbc):
495+
"""Retrieve the dimensions of the simulation box based on PBC.
496+
497+
Parameters
498+
----------
499+
ts : Timestep
500+
The current timestep of the simulation, which contains the
501+
box dimensions.
502+
pbc : bool
503+
A flag indicating whether periodic boundary conditions (PBC)
504+
are enabled. If `True`, the box dimensions are returned,
505+
else returns `None`.
506+
507+
Returns
508+
-------
509+
box_dimensions : ndarray or None
510+
The dimensions of the simulation box as a NumPy array if PBC
511+
is True, else returns `None`.
512+
"""
513+
return ts.dimensions if pbc else None
514+
483515
def _prepare(self):
484516
self.results.timeseries = np.empty((self.n_frames, len(self.r0)+1))
485517

@@ -506,6 +538,8 @@ def timeseries(self):
506538
warnings.warn(wmsg, DeprecationWarning)
507539
return self.results.timeseries
508540

541+
def _get_aggregator(self):
542+
return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack})
509543

510544
def _new_selections(u_orig, selections, frame):
511545
"""create stand alone AGs from selections at frame"""

testsuite/MDAnalysisTests/analysis/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
HydrogenBondAnalysis,
1616
)
1717
from MDAnalysis.analysis.nucleicacids import NucPairDist
18+
from MDAnalysis.analysis.contacts import Contacts
1819
from MDAnalysis.lib.util import is_installed
1920

2021

@@ -149,3 +150,10 @@ def client_HydrogenBondAnalysis(request):
149150
@pytest.fixture(scope="module", params=params_for_cls(NucPairDist))
150151
def client_NucPairDist(request):
151152
return request.param
153+
154+
155+
# MDAnalysis.analysis.contacts
156+
157+
@pytest.fixture(scope="module", params=params_for_cls(Contacts))
158+
def client_Contacts(request):
159+
return request.param

testsuite/MDAnalysisTests/analysis/test_contacts.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def universe():
171171
return mda.Universe(PSF, DCD)
172172

173173
def _run_Contacts(
174-
self, universe,
175-
start=None, stop=None, step=None, **kwargs
174+
self, universe, client_Contacts, start=None,
175+
stop=None, step=None, **kwargs
176176
):
177177
acidic = universe.select_atoms(self.sel_acidic)
178178
basic = universe.select_atoms(self.sel_basic)
@@ -181,7 +181,8 @@ def _run_Contacts(
181181
select=(self.sel_acidic, self.sel_basic),
182182
refgroup=(acidic, basic),
183183
radius=6.0,
184-
**kwargs).run(start=start, stop=stop, step=step)
184+
**kwargs
185+
).run(**client_Contacts, start=start, stop=stop, step=step)
185186

186187
@pytest.mark.parametrize("seltxt", [sel_acidic, sel_basic])
187188
def test_select_valid_types(self, universe, seltxt):
@@ -195,7 +196,7 @@ def test_select_valid_types(self, universe, seltxt):
195196

196197
assert ag_from_string == ag_from_ag
197198

198-
def test_contacts_selections(self, universe):
199+
def test_contacts_selections(self, universe, client_Contacts):
199200
"""Test if Contacts can take both string and AtomGroup as selections.
200201
"""
201202
aga = universe.select_atoms(self.sel_acidic)
@@ -210,8 +211,8 @@ def test_contacts_selections(self, universe):
210211
refgroup=(aga, agb)
211212
)
212213

213-
cag.run()
214-
csel.run()
214+
cag.run(**client_Contacts)
215+
csel.run(**client_Contacts)
215216

216217
assert cag.grA == csel.grA
217218
assert cag.grB == csel.grB
@@ -228,26 +229,31 @@ def test_select_wrong_types(self, universe, ag):
228229
) as te:
229230
contacts.Contacts._get_atomgroup(universe, ag)
230231

231-
def test_startframe(self, universe):
232+
def test_startframe(self, universe, client_Contacts):
232233
"""test_startframe: TestContactAnalysis1: start frame set to 0 (resolution of
233234
Issue #624)
234235
235236
"""
236-
CA1 = self._run_Contacts(universe)
237+
CA1 = self._run_Contacts(universe, client_Contacts=client_Contacts)
237238
assert len(CA1.results.timeseries) == universe.trajectory.n_frames
238239

239-
def test_end_zero(self, universe):
240+
def test_end_zero(self, universe, client_Contacts):
240241
"""test_end_zero: TestContactAnalysis1: stop frame 0 is not ignored"""
241-
CA1 = self._run_Contacts(universe, stop=0)
242+
CA1 = self._run_Contacts(
243+
universe, client_Contacts=client_Contacts, stop=0
244+
)
242245
assert len(CA1.results.timeseries) == 0
243246

244-
def test_slicing(self, universe):
247+
def test_slicing(self, universe, client_Contacts):
245248
start, stop, step = 10, 30, 5
246-
CA1 = self._run_Contacts(universe, start=start, stop=stop, step=step)
249+
CA1 = self._run_Contacts(
250+
universe, client_Contacts=client_Contacts,
251+
start=start, stop=stop, step=step
252+
)
247253
frames = np.arange(universe.trajectory.n_frames)[start:stop:step]
248254
assert len(CA1.results.timeseries) == len(frames)
249255

250-
def test_villin_folded(self):
256+
def test_villin_folded(self, client_Contacts):
251257
# one folded, one unfolded
252258
f = mda.Universe(contacts_villin_folded)
253259
u = mda.Universe(contacts_villin_unfolded)
@@ -259,12 +265,12 @@ def test_villin_folded(self):
259265
select=(sel, sel),
260266
refgroup=(grF, grF),
261267
method="soft_cut")
262-
q.run()
268+
q.run(**client_Contacts)
263269

264270
results = soft_cut(f, u, sel, sel)
265271
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)
266272

267-
def test_villin_unfolded(self):
273+
def test_villin_unfolded(self, client_Contacts):
268274
# both folded
269275
f = mda.Universe(contacts_villin_folded)
270276
u = mda.Universe(contacts_villin_folded)
@@ -276,13 +282,13 @@ def test_villin_unfolded(self):
276282
select=(sel, sel),
277283
refgroup=(grF, grF),
278284
method="soft_cut")
279-
q.run()
285+
q.run(**client_Contacts)
280286

281287
results = soft_cut(f, u, sel, sel)
282288
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)
283289

284-
def test_hard_cut_method(self, universe):
285-
ca = self._run_Contacts(universe)
290+
def test_hard_cut_method(self, universe, client_Contacts):
291+
ca = self._run_Contacts(universe, client_Contacts=client_Contacts)
286292
expected = [1., 0.58252427, 0.52427184, 0.55339806, 0.54368932,
287293
0.54368932, 0.51456311, 0.46601942, 0.48543689, 0.52427184,
288294
0.46601942, 0.58252427, 0.51456311, 0.48543689, 0.48543689,
@@ -306,7 +312,7 @@ def test_hard_cut_method(self, universe):
306312
assert len(ca.results.timeseries) == len(expected)
307313
assert_allclose(ca.results.timeseries[:, 1], expected, rtol=0, atol=1.5e-7)
308314

309-
def test_radius_cut_method(self, universe):
315+
def test_radius_cut_method(self, universe, client_Contacts):
310316
acidic = universe.select_atoms(self.sel_acidic)
311317
basic = universe.select_atoms(self.sel_basic)
312318
r = contacts.distance_array(acidic.positions, basic.positions)
@@ -316,15 +322,20 @@ def test_radius_cut_method(self, universe):
316322
r = contacts.distance_array(acidic.positions, basic.positions)
317323
expected.append(contacts.radius_cut_q(r[initial_contacts], None, radius=6.0))
318324

319-
ca = self._run_Contacts(universe, method='radius_cut')
325+
ca = self._run_Contacts(
326+
universe, client_Contacts=client_Contacts, method="radius_cut"
327+
)
320328
assert_array_equal(ca.results.timeseries[:, 1], expected)
321329

322330
@staticmethod
323331
def _is_any_closer(r, r0, dist=2.5):
324332
return np.any(r < dist)
325333

326-
def test_own_method(self, universe):
327-
ca = self._run_Contacts(universe, method=self._is_any_closer)
334+
def test_own_method(self, universe, client_Contacts):
335+
ca = self._run_Contacts(
336+
universe, client_Contacts=client_Contacts,
337+
method=self._is_any_closer
338+
)
328339

329340
bound_expected = [1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0.,
330341
1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
@@ -340,21 +351,28 @@ def test_own_method(self, universe):
340351
def _weird_own_method(r, r0):
341352
return 'aaa'
342353

343-
def test_own_method_no_array_cast(self, universe):
354+
def test_own_method_no_array_cast(self, universe, client_Contacts):
344355
with pytest.raises(ValueError):
345-
self._run_Contacts(universe, method=self._weird_own_method, stop=2)
346-
347-
def test_non_callable_method(self, universe):
356+
self._run_Contacts(
357+
universe,
358+
client_Contacts=client_Contacts,
359+
method=self._weird_own_method,
360+
stop=2,
361+
)
362+
363+
def test_non_callable_method(self, universe, client_Contacts):
348364
with pytest.raises(ValueError):
349-
self._run_Contacts(universe, method=2, stop=2)
365+
self._run_Contacts(
366+
universe, client_Contacts=client_Contacts, method=2, stop=2
367+
)
350368

351369
@pytest.mark.parametrize("pbc,expected", [
352370
(True, [1., 0.43138152, 0.3989021, 0.43824337, 0.41948765,
353371
0.42223239, 0.41354071, 0.43641354, 0.41216834, 0.38334858]),
354372
(False, [1., 0.42327791, 0.39192399, 0.40950119, 0.40902613,
355373
0.42470309, 0.41140143, 0.42897862, 0.41472684, 0.38574822])
356374
])
357-
def test_distance_box(self, pbc, expected):
375+
def test_distance_box(self, pbc, expected, client_Contacts):
358376
u = mda.Universe(TPR, XTC)
359377
sel_basic = "(resname ARG LYS)"
360378
sel_acidic = "(resname ASP GLU)"
@@ -363,13 +381,15 @@ def test_distance_box(self, pbc, expected):
363381

364382
r = contacts.Contacts(u, select=(sel_acidic, sel_basic),
365383
refgroup=(acidic, basic), radius=6.0, pbc=pbc)
366-
r.run()
384+
r.run(**client_Contacts)
367385
assert_allclose(r.results.timeseries[:, 1], expected,rtol=0, atol=1.5e-7)
368386

369-
def test_warn_deprecated_attr(self, universe):
387+
def test_warn_deprecated_attr(self, universe, client_Contacts):
370388
"""Test for warning message emitted on using deprecated `timeseries`
371389
attribute"""
372-
CA1 = self._run_Contacts(universe, stop=1)
390+
CA1 = self._run_Contacts(
391+
universe, client_Contacts=client_Contacts, stop=1
392+
)
373393
wmsg = "The `timeseries` attribute was deprecated in MDAnalysis"
374394
with pytest.warns(DeprecationWarning, match=wmsg):
375395
assert_equal(CA1.timeseries, CA1.results.timeseries)
@@ -385,10 +405,11 @@ def test_n_initial_contacts(self, datafiles, expected):
385405
r = contacts.Contacts(u, select=select, refgroup=refgroup)
386406
assert_equal(r.n_initial_contacts, expected)
387407

388-
def test_q1q2():
408+
409+
def test_q1q2(client_Contacts):
389410
u = mda.Universe(PSF, DCD)
390411
q1q2 = contacts.q1q2(u, 'name CA', radius=8)
391-
q1q2.run()
412+
q1q2.run(**client_Contacts)
392413

393414
q1_expected = [1., 0.98092643, 0.97366031, 0.97275204, 0.97002725,
394415
0.97275204, 0.96276113, 0.96730245, 0.9582198, 0.96185286,

0 commit comments

Comments
 (0)