- 
                Notifications
    
You must be signed in to change notification settings  - Fork 734
 
Description
Is your feature request related to a problem?
In analyses e.g. DistanceMatrix in diffusionmap, we need to run analysis of a single frame over the full trajectory---that is sliced and defined by start, stop, and step in run(). However, in the current parallel analysis implementation, information about the full trajectory is lost (only per-process slices are visible).
Describe the solution you’d like
Firstly, in _setup_frames (which only runs in the main process), store the sliced trajectory information in self._global_slicer. This makes it possible to retrieve global details such as the total number of frames.
Secondly, in self._compute, also store self._global_frame_index so the analysis can track the global frame index for each frame, rather than just per-process slices.
The following code will illustrate its usage
import MDAnalysis as mda 
from MDAnalysisTests.datafiles import PDB, XTC
import numpy as np
from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.analysis.results import ResultsGroup
u = mda.Universe(PDB, XTC)
class MyAnalysis(AnalysisBase):
    _analysis_algorithm_is_parallelizable = True
    @classmethod
    def get_supported_backends(cls):
        return ('serial', 'multiprocessing', 'dask',)
    def _prepare(self):
       self.results.frame_index = []
       self.results.global_frame_index = []
       self.results.n_frames = []
       self.results.global_n_frames = []
    def _single_frame(self):
       frame_index = self._frame_index
       global_frame_index = self._global_frame_index
       self.results.frame_index.append(frame_index)
       self.results.global_frame_index.append(global_frame_index)
       self.results.n_frames.append(self.n_frames)
       global_n_frames = len(self._trajectory[self._global_slicer])
       self.results.global_n_frames.append(global_n_frames)
    def _get_aggregator(self):
      return ResultsGroup(lookup={'frame_index': ResultsGroup.flatten_sequence,
                                  'global_frame_index': ResultsGroup.flatten_sequence,
                                  'n_frames': ResultsGroup.flatten_sequence,
                                  'global_n_frames': ResultsGroup.flatten_sequence})
# serial analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2)
ana.results
> {'frame_index': [0, 1, 2, 3, 4], 'global_frame_index': [0, 1, 2, 3, 4], 'n_frames': [5, 5, 5, 5, 5], 'global_n_frames': [5, 5, 5, 5, 5]}
# parallel analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2, backend='dask', n_workers=2)
ana.results
> {'frame_index': [0, 1, 2, 0, 1], 'global_frame_index': [0, 1, 2, 3, 4], 'n_frames': [3, 3, 3, 2, 2], 'global_n_frames': [5, 5, 5, 5, 5]}n_frames and frame_index are not the same for serial and parallel analysis