Skip to content

Automatic parallelization for dask arrays in apply_ufunc #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 9, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 131 additions & 65 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,13 @@ def all_core_dims(self):
return self._all_core_dims

@property
def n_inputs(self):
def num_inputs(self):
return len(self.input_core_dims)

@property
def n_outputs(self):
def num_outputs(self):
return len(self.output_core_dims)

@classmethod
def default(cls, n_inputs):
return cls([()] * n_inputs, [()])

@classmethod
def from_sequence(cls, nested):
if (not isinstance(nested, collections.Sequence) or
not len(nested) == 2 or
any(not isinstance(arg_list, collections.Sequence)
for arg_list in nested) or
any(isinstance(arg, basestring) or
not isinstance(arg, collections.Sequence)
for arg_list in nested for arg in arg_list)):
raise TypeError('functions signatures not provided as a string '
'must be a triply nested sequence providing the '
'list of core dimensions for each variable, for '
'both input and output.')
return cls(*nested)

def __eq__(self, other):
try:
return (self.input_core_dims == other.input_core_dims and
Expand Down Expand Up @@ -190,7 +171,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
data_vars = [getattr(a, 'variable', a) for a in args]
result_var = func(*data_vars)

if signature.n_outputs > 1:
if signature.num_outputs > 1:
out = tuple(DataArray(variable, coords, name=name, fastpath=True)
for variable, coords in zip(result_var, result_coords))
else:
Expand Down Expand Up @@ -269,10 +250,10 @@ def _as_variables_or_variable(arg):

def _unpack_dict_tuples(
result_vars, # type: Mapping[Any, Tuple[Variable]]
n_outputs, # type: int
num_outputs, # type: int
):
# type: (...) -> Tuple[Dict[Any, Variable]]
out = tuple(OrderedDict() for _ in range(n_outputs))
out = tuple(OrderedDict() for _ in range(num_outputs))
for name, values in result_vars.items():
for value, results_dict in zip(values, out):
results_dict[name] = value
Expand All @@ -298,8 +279,8 @@ def apply_dict_of_variables_ufunc(func, *args, **kwargs):
for name, variable_args in zip(names, grouped_by_name):
result_vars[name] = func(*variable_args)

if signature.n_outputs > 1:
return _unpack_dict_tuples(result_vars, signature.n_outputs)
if signature.num_outputs > 1:
return _unpack_dict_tuples(result_vars, signature.num_outputs)
else:
return result_vars

Expand Down Expand Up @@ -335,8 +316,8 @@ def apply_dataset_ufunc(func, *args, **kwargs):

if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and
fill_value is _DEFAULT_FILL_VALUE):
raise TypeError('To apply an operation to datasets with different ',
'data variables, you must supply the ',
raise TypeError('to apply an operation to datasets with different '
'data variables with apply_ufunc, you must supply the '
'dataset_fill_value argument.')

if kwargs:
Expand All @@ -353,7 +334,7 @@ def apply_dataset_ufunc(func, *args, **kwargs):
func, *args, signature=signature, join=dataset_join,
fill_value=fill_value)

if signature.n_outputs > 1:
if signature.num_outputs > 1:
out = tuple(_fast_dataset(*args)
for args in zip(result_vars, list_of_coords))
else:
Expand Down Expand Up @@ -388,12 +369,12 @@ def apply_groupby_ufunc(func, *args):
from .variable import Variable

groupbys = [arg for arg in args if isinstance(arg, GroupBy)]
if not groupbys:
raise ValueError('must have at least one groupby to iterate over')
assert groupbys, 'must have at least one groupby to iterate over'
first_groupby = groupbys[0]
if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]):
raise ValueError('can only perform operations over multiple groupbys '
'at once if they are all grouped the same way')
raise ValueError('apply_ufunc can only perform operations over '
'multiple GroupBy objets at once if they are all '
'grouped the same way')

grouped_dim = first_groupby._group.name
unique_values = first_groupby._unique_coord.values
Expand Down Expand Up @@ -430,7 +411,7 @@ def unified_dim_sizes(variables, exclude_dims=frozenset()):
for var in variables:
if len(set(var.dims)) < len(var.dims):
raise ValueError('broadcasting cannot handle duplicate '
'dimensions: %r' % list(var.dims))
'dimensions on a variable: %r' % list(var.dims))
for dim, size in zip(var.dims, var.shape):
if dim not in exclude_dims:
if dim not in dim_sizes:
Expand Down Expand Up @@ -462,15 +443,17 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
set_old_dims = set(old_dims)
missing_core_dims = [d for d in core_dims if d not in set_old_dims]
if missing_core_dims:
raise ValueError('operation requires dimensions missing on input '
'variable: %r' % missing_core_dims)
raise ValueError('operand to apply_ufunc has required core dimensions '
'%r, but some of these are missing on the input '
'variable: %r' % (list(core_dims), missing_core_dims))

set_new_dims = set(new_dims)
unexpected_dims = [d for d in old_dims if d not in set_new_dims]
if unexpected_dims:
raise ValueError('operation encountered unexpected dimensions %r '
'on input variable: these are core dimensions on '
'other input or output variables' % unexpected_dims)
raise ValueError('operand to apply_ufunc encountered unexpected '
'dimensions %r on an input variable: these are core '
'dimensions on other input or output variables'
% unexpected_dims)

# for consistency with numpy, keep broadcast dimensions to the left
old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims)
Expand Down Expand Up @@ -500,6 +483,9 @@ def apply_variable_ufunc(func, *args, **kwargs):

signature = kwargs.pop('signature')
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET)
dask = kwargs.pop('dask', 'forbidden')
output_dtypes = kwargs.pop('output_dtypes', None)
output_sizes = kwargs.pop('output_sizes', None)
if kwargs:
raise TypeError('apply_variable_ufunc() got unexpected keyword '
'arguments: %s' % list(kwargs))
Expand All @@ -515,9 +501,28 @@ def apply_variable_ufunc(func, *args, **kwargs):
else arg
for arg, core_dims in zip(args, signature.input_core_dims)]

if any(isinstance(array, dask_array_type) for array in input_data):
if dask == 'forbidden':
raise ValueError('apply_ufunc encountered a dask array on an '
'argument, but handling for dask arrays has not '
'been enabled. Either set the ``dask`` argument '
'or load your data into memory first with '
'``.load()`` or ``.compute()``')
elif dask == 'parallelized':
input_dims = [broadcast_dims + input_dims
for input_dims in signature.input_core_dims]
numpy_func = func
func = lambda *arrays: _apply_with_dask_atop(
numpy_func, arrays, input_dims, output_dims, signature,
output_dtypes, output_sizes)
elif dask == 'allowed':
pass
else:
raise ValueError('unknown setting for dask array handling in '
'apply_ufunc: {}'.format(dask))
result_data = func(*input_data)

if signature.n_outputs > 1:
if signature.num_outputs > 1:
output = []
for dims, data in zip(output_dims, result_data):
output.append(Variable(dims, data))
Expand All @@ -527,24 +532,71 @@ def apply_variable_ufunc(func, *args, **kwargs):
return Variable(dims, result_data)


def _apply_with_dask_atop(func, args, input_dims, output_dims, signature,
output_dtypes, output_sizes=None):
import dask.array as da

if signature.num_outputs > 1:
raise NotImplementedError('multiple outputs from apply_ufunc not yet '
"supported with dask='parallelized'")

if output_dtypes is None:
raise ValueError('output dtypes (output_dtypes) must be supplied to '
"apply_func when using dask='parallelized'")
if len(output_dtypes) != signature.num_outputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make sure output_dtypes is an iterable before calling len.

raise ValueError('apply_ufunc arguments output_dtypes and '
'output_core_dims must have the same length: {} vs {}'
.format(len(output_dtypes), signature.num_outputs))
(dtype,) = output_dtypes

if output_sizes is None:
output_sizes = {}

new_dims = signature.all_output_core_dims - signature.all_input_core_dims
if any(dim not in output_sizes for dim in new_dims):
raise ValueError("when using dask='parallelized' with apply_ufunc, "
'output core dimensions not found on inputs must have '
'explicitly set sizes with ``output_sizes``: {}'
.format(new_dims))

args2 = []
for data, core_dims in zip(args, signature.input_core_dims):
if isinstance(data, dask_array_type):
# core dimensions cannot span multiple chunks
chunks = {axis: (data.shape[axis],)
for axis, dim in enumerate(core_dims, -len(core_dims))}
data = data.rechunk(chunks)
args2.append(data)

(out_ind,) = output_dims
atop_args = [ai for a in zip(args2, input_dims) for ai in a]
return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
new_axes=output_sizes)


def apply_array_ufunc(func, *args, **kwargs):
"""apply_variable_ufunc(func, *args, dask_array='forbidden')
"""apply_array_ufunc(func, *args, dask='forbidden')
"""
dask_array = kwargs.pop('dask_array', 'forbidden')
dask = kwargs.pop('dask', 'forbidden')
if kwargs:
raise TypeError('apply_array_ufunc() got unexpected keyword '
'arguments: %s' % list(kwargs))

if any(isinstance(arg, dask_array_type) for arg in args):
# TODO: add a mode dask_array='auto' when dask.array gets a function
# for applying arbitrary gufuncs
if dask_array == 'forbidden':
raise ValueError('encountered dask array, but did not set '
"dask_array='allowed'")
elif dask_array != 'allowed':
raise ValueError('unknown setting for dask array handling: %r'
% dask_array)
# fall through
if dask == 'forbidden':
raise ValueError('apply_ufunc encountered a dask array on an '
'argument, but handling for dask arrays has not '
'been enabled. Either set the ``dask`` argument '
'or load your data into memory first with '
'``.load()`` or ``.compute()``')
elif dask == 'parallelized':
raise ValueError("cannot use dask='parallelized' for apply_ufunc "
'unless at least one input is an xarray object')
elif dask == 'allowed':
pass
else:
raise ValueError('unknown setting for dask array handling: {}'
.format(dask))
return func(*args)


Expand All @@ -559,7 +611,9 @@ def apply_ufunc(func, *args, **kwargs):
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
keep_attrs : bool = False,
kwargs : Mapping = None,
dask_array : str = 'forbidden')
dask : str = 'forbidden',
output_dtypes : Optional[Sequence] = None,
output_sizes : Optional[Mapping[Any, int]] = None)

Apply a vectorized function for unlabeled arrays on xarray objects.

Expand Down Expand Up @@ -630,10 +684,20 @@ def apply_ufunc(func, *args, **kwargs):
Whether to copy attributes from the first argument to the output.
kwargs: dict, optional
Optional keyword arguments passed directly on to call ``func``.
dask_array: 'forbidden' or 'allowed', optional
Whether or not to allow applying the ufunc to objects containing lazy
data in the form of dask arrays. By default, this is forbidden, to
avoid implicitly converting lazy data.
dask: 'forbidden', 'allowed' or 'parallelized', optional
How to handle applying to objects containing lazy data in the form of
dask arrays:
- 'forbidden' (default): raise an error if a dask array is encountered.
- 'allowed': pass dask arrays directly on to ``func``.
- 'parallelized': automatically parallelize ``func`` if any of the
inputs are a dask array. If used, the ``output_dtypes`` argument must
also be provided. Multiple output arguments are not yet supported.
output_dtypes : list of dtypes, optional
Optional list of output dtypes. Only used if dask='parallelized'.
output_sizes : dict, optional
Optional mapping from dimension names to sizes for outputs. Only used if
dask='parallelized' and new dimensions (not found on inputs) appear on
outputs.

Returns
-------
Expand Down Expand Up @@ -710,7 +774,9 @@ def stack(objects, dim, new_coord):
exclude_dims = kwargs.pop('exclude_dims', frozenset())
dataset_fill_value = kwargs.pop('dataset_fill_value', _DEFAULT_FILL_VALUE)
kwargs_ = kwargs.pop('kwargs', None)
dask_array = kwargs.pop('dask_array', 'forbidden')
dask = kwargs.pop('dask', 'forbidden')
output_dtypes = kwargs.pop('output_dtypes', None)
output_sizes = kwargs.pop('output_sizes', None)
if kwargs:
raise TypeError('apply_ufunc() got unexpected keyword arguments: %s'
% list(kwargs))
Expand All @@ -727,12 +793,12 @@ def stack(objects, dim, new_coord):
if kwargs_:
func = functools.partial(func, **kwargs_)

array_ufunc = functools.partial(
apply_array_ufunc, func, dask_array=dask_array)

variables_ufunc = functools.partial(apply_variable_ufunc, array_ufunc,
variables_ufunc = functools.partial(apply_variable_ufunc, func,
signature=signature,
exclude_dims=exclude_dims)
exclude_dims=exclude_dims,
dask=dask,
output_dtypes=output_dtypes,
output_sizes=output_sizes)

if any(isinstance(a, GroupBy) for a in args):
# kwargs has already been added into func
Expand All @@ -744,7 +810,7 @@ def stack(objects, dim, new_coord):
dataset_join=dataset_join,
dataset_fill_value=dataset_fill_value,
keep_attrs=keep_attrs,
dask_array=dask_array)
dask=dask)
return apply_groupby_ufunc(this_apply, *args)
elif any(is_dict_like(a) for a in args):
return apply_dataset_ufunc(variables_ufunc, *args,
Expand All @@ -763,7 +829,7 @@ def stack(objects, dim, new_coord):
elif any(isinstance(a, Variable) for a in args):
return variables_ufunc(*args)
else:
return array_ufunc(*args)
return apply_array_ufunc(func, *args, dask=dask)


def where(cond, x, y):
Expand Down Expand Up @@ -805,4 +871,4 @@ def where(cond, x, y):
cond, x, y,
join='exact',
dataset_join='exact',
dask_array='allowed')
dask='allowed')
4 changes: 2 additions & 2 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def fillna(data, other, join="left", dataset_join="left"):

return apply_ufunc(duck_array_ops.fillna, data, other,
join=join,
dask_array="allowed",
dask="allowed",
dataset_join=dataset_join,
dataset_fill_value=np.nan,
keep_attrs=True)
Expand Down Expand Up @@ -176,7 +176,7 @@ def where_method(self, cond, other=dtypes.NA):
self, cond, other,
join=join,
dataset_join=join,
dask_array='allowed',
dask='allowed',
keep_attrs=True)


Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def raises_regex(error, pattern):
with pytest.raises(error) as excinfo:
yield
message = str(excinfo.value)
if not re.match(pattern, message):
raise AssertionError('exception %r did not match pattern %s'
if not re.search(pattern, message):
raise AssertionError('exception %r did not match pattern %r'
% (excinfo.value, pattern))


Expand Down
Loading