Skip to content

Commit

Permalink
Simplifying and unifying dispatch a good bit across packages #24
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed May 9, 2020
1 parent c86157e commit 1700044
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 56 deletions.
30 changes: 27 additions & 3 deletions notebooks/platform/xarray/lib/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
from .core import *
from .core import (
GeneticDataset,
GenotypeAlleleCountDataset,
GenotypeCallDataset,
GenotypeCountDataset,
GenotypeDosageDataset,
GenotypeProbabilityDataset
)

from .config import config

from . import io
from .io import (
from .io.core import (
read_plink,
write_zarr
)
from .config import config

from . import stats
from .stats.core import (
ld_matrix
)

from . import method
from .method.core import (
ld_prune
)

from . import graph
from .graph.core import (
maximal_independent_set
)
57 changes: 40 additions & 17 deletions notebooks/platform/xarray/lib/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import functools
from typing import Hashable, Callable, Union, Sequence
from typing import Hashable, Callable, Union, Sequence, Dict, Type
from typing_extensions import Protocol
from .config import Configuration
from .config import config as global_config
Expand Down Expand Up @@ -47,12 +47,7 @@ class Dispatchable(Protocol):
def dispatch(self, fn: Callable, *args, **kwargs): ...


class Frontend(Dispatchable):
domain: Domain


class Backend(Dispatchable):
domain: Domain
id: Hashable

def requirements(self) -> Sequence[Requirement]: ...
Expand Down Expand Up @@ -93,8 +88,8 @@ def is_compatible(backend: Backend):
return True


class FrontendDispatcher(Frontend):
"""Base model for all Frontend instances with default logic for backend selection"""
class Dispatcher(Dispatchable):
"""Default dispatch model"""

def __init__(self, domain: Union[str, Domain], config: Configuration = None):
self.domain = Domain(domain)
Expand All @@ -107,12 +102,16 @@ def _update_config(self, default='auto'):
key = str(self.domain.append('backend'))
self.config.register(key, default, f'Options: {options}; default is {default}')

def register(self, backend: Backend) -> None:
# Only allow frontends to work with backends on the same domain
if not backend.domain == self.domain:
raise ValueError('Backend with domain {backend.domain} not compatible with frontend domain {self.domain}')
def register_function(self, fn: Callable) -> Callable:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return self.dispatch(fn, *args, **kwargs)
return wrapper

def register_backend(self, backend: Backend) -> Backend:
self.backends[backend.id] = backend
self._update_config()
return backend

def resolve(self, fn: Callable, *args, **kwargs) -> Backend:
# Passed parameters get highest priority
Expand All @@ -137,8 +136,32 @@ def dispatch(self, fn: Callable, *args, **kwargs):
kwargs.pop('backend', None) # Pop this off since implementations cannot expect it
return backend.dispatch(fn, *args, **kwargs)

def add(self, fn: Callable):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return self.dispatch(fn, *args, **kwargs)
return wrapper

# ----------------------------------------------------------------------
# Registry
#
# These functions define the sole interaction points for all
# frontend/backend coordination across the project

dispatchers: Dict[Domain, Dispatchable] = dict()

def register_function(domain):
domain = Domain(domain)
if domain not in dispatchers:
dispatchers[domain] = Dispatcher(domain)

def register(fn: Callable):
return dispatchers[domain].register_function(fn)
return register


def register_backend(domain):
domain = Domain(domain)
if domain not in dispatchers:
raise NotImplementedError('Dispatcher for domain {domain} not implemented')

def register(backend: Union[Backend, Type[Backend]]):
instance = backend() if isinstance(backend, type) else backend
dispatchers[domain].register_backend(instance)
return backend
return register
9 changes: 9 additions & 0 deletions notebooks/platform/xarray/lib/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .core import *

from . import numba_backend

# TODO: Add when ready
# try:
# from . import networkx_backend
# except ImportError:
# pass
12 changes: 12 additions & 0 deletions notebooks/platform/xarray/lib/graph/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Graph API"""
from ..dispatch import register_function


DOMAIN = 'graph'


@register_function(DOMAIN)
def maximal_independent_set(df):
"""Maximal Independent Set"""
pass

8 changes: 5 additions & 3 deletions notebooks/platform/xarray/lib/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from lib.io.core import *
from .core import *

# Choose the backends to register
import lib.io.pysnptools_backend
try:
from . import pysnptools_backend
except ImportError:
pass
37 changes: 4 additions & 33 deletions notebooks/platform/xarray/lib/io/core.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
"""I/O backend implementations and configuration"""
from ..dispatch import FrontendDispatcher, ClassBackend, Domain
"""I/O API"""
from ..dispatch import register_function
from ..core import isdstype, GenotypeCountDataset
from xarray import Dataset

DOMAIN = Domain('io')
PLINK_DOMAIN = DOMAIN.append('plink')

# ----------------------------------------------------------------------
# IO API
PLINK_DOMAIN = 'io.plink'


class IOBackend(ClassBackend):
domain = DOMAIN


class PLINKBackend(IOBackend):
domain = PLINK_DOMAIN


dispatchers = dict()


def dispatch(domain):
if domain not in dispatchers:
dispatchers[domain] = FrontendDispatcher(DOMAIN.append(domain))

def decorator(fn):
return dispatchers[domain].add(fn)
return decorator


@dispatch('plink')
@register_function(PLINK_DOMAIN)
def read_plink(path, backend=None, **kwargs):
"""Import PLINK dataset"""
pass
Expand All @@ -47,9 +24,3 @@ def write_zarr(ds: Dataset, path, **kwargs):

return ds.to_zarr(path, **kwargs)

# ----------------------------------------------------------------------
# IO Backend Registration


def register_backend(backend: IOBackend):
dispatchers[backend.domain[-1]].register(backend)
6 changes: 6 additions & 0 deletions notebooks/platform/xarray/lib/method/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .core import *

try:
from .ld_prune import dask_backend
except ImportError:
pass
34 changes: 34 additions & 0 deletions notebooks/platform/xarray/lib/method/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from ..dispatch import register_function
from xarray import Dataset
from typing import Optional

DOMAIN = 'method'

@register_function(DOMAIN)
def ld_prune(
ldm: 'DataFrame',
use_cmp: bool = True
):
"""LD Prune
Prune variants within a dataset using a sparse LD matrix to find a
maximally independent set (MIS).
Note: This result is not a true MIS if `use_cmp` is True and was based on MAF scores
(or anything else) provided during pair-wise LD matrix evaluation, or if those scores
were not all identical (it is otherwise).
Parameters
----------
ldm : DataFrame
LD matrix from `ld_matrix`
use_cmp : bool
Whether or not to use precomputed score-based comparisons
TODO: wire this up in MIS
Returns
-------
array
Array with indexes of rows that should be removed
"""
pass
6 changes: 6 additions & 0 deletions notebooks/platform/xarray/lib/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .core import *

try:
from .ld_matrix import dask_backend
except ImportError:
pass
70 changes: 70 additions & 0 deletions notebooks/platform/xarray/lib/stats/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from ..dispatch import register_function
from xarray import Dataset
from typing import Optional

DOMAIN = 'stats'

@register_function(DOMAIN)
def ld_matrix(
ds: Dataset,
window: int = 1_000_000,
threshold: float = 0.2,
step: Optional[int] = None,
groups='contig',
positions='pos',
scores=None,
**kwargs
):
"""Compute Sparse LD Matrix (for tall-skinny matrices only at the moment)
This method works by first computing all overlapping variant ranges
and then breaking up necessary pariwise comparisons into approximately
equal sized chunks (along dimension 0). These chunks are exactly equal
in size when windows are of fixed size but vary based on variant density
when using base pair ranges. For each pair of variants in a window,
R2 is calculated and only those exceeding the provided threshold are returned.
Parameters
----------
window : int
Size of window for LD comparisons (between rows). This is either in base pairs
if `positions` is not None or is a fixed window size otherwise. By default this
is 1,000,000 (1,000 kbp)
threshold : float
R2 threshold below which no variant pairs will be returned. This should almost
always be something at least slightly above 0 to avoid the large density very
near zero LD present in most datasets. Defaults to 0.2
step : optional
Fixed step size to move each window by. Has no effect when `positions` is provided
and must be provided when base pair ranges are not being used.
groups : str or array-like, optional
Name of field to use to represent disconnected components (typically contigs). Will
be used directly if provided as an array and otherwise fetched from
`ds` if given as a variable name.
positions : str or array-like, optional
Name of field to use to represent base pair positions. Will
be used directly if provided as an array and otherwise fetched from
`ds` if given as a variable name.
scores : [type], optional
Name of field to use to prioritize variant selection (e.g. MAF). Will
be used directly if provided as an array and otherwise fetched from
`ds` if given as a variable name.
return_intervals : bool
Whether or not to also return the variant interval calculations (which can be
useful for analyzing variant density), by default False
**kwargs
Backend-specific options
Returns
-------
DataFrame or (DataFrame, (DataFrame, DataFrame))
Upper triangle (including diagonal) of LD matrix as COO in dataframe. Fields:
`i`: Row (variant) index 1
`j`: Row (variant) index 2
`value`: R2 value
`cmp`: If scores are provided, this is 1, 0, or -1 indicating whether or not
i > j (1), i < j (-1), or i == j (0)
When `return_intervals` is True, the second tuple contains the results from
`axis_intervals`
"""
pass

0 comments on commit 1700044

Please sign in to comment.