Skip to content

Commit b46fcd6

Browse files
authored
Automatic parallelization for dask arrays in apply_ufunc (#1517)
* dask='parallelized' for apply_ufunc' * Fix dask_array -> dask * Delete unused code, better error messages * fix tests * really fix tests * wip: raise errors * Raise an error instead of automatic rechunking
1 parent 14b5f1c commit b46fcd6

File tree

4 files changed

+255
-87
lines changed

4 files changed

+255
-87
lines changed

xarray/core/computation.py

+145-67
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,13 @@ def all_core_dims(self):
6464
return self._all_core_dims
6565

6666
@property
67-
def n_inputs(self):
67+
def num_inputs(self):
6868
return len(self.input_core_dims)
6969

7070
@property
71-
def n_outputs(self):
71+
def num_outputs(self):
7272
return len(self.output_core_dims)
7373

74-
@classmethod
75-
def default(cls, n_inputs):
76-
return cls([()] * n_inputs, [()])
77-
78-
@classmethod
79-
def from_sequence(cls, nested):
80-
if (not isinstance(nested, collections.Sequence) or
81-
not len(nested) == 2 or
82-
any(not isinstance(arg_list, collections.Sequence)
83-
for arg_list in nested) or
84-
any(isinstance(arg, basestring) or
85-
not isinstance(arg, collections.Sequence)
86-
for arg_list in nested for arg in arg_list)):
87-
raise TypeError('functions signatures not provided as a string '
88-
'must be a triply nested sequence providing the '
89-
'list of core dimensions for each variable, for '
90-
'both input and output.')
91-
return cls(*nested)
92-
9374
def __eq__(self, other):
9475
try:
9576
return (self.input_core_dims == other.input_core_dims and
@@ -190,7 +171,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
190171
data_vars = [getattr(a, 'variable', a) for a in args]
191172
result_var = func(*data_vars)
192173

193-
if signature.n_outputs > 1:
174+
if signature.num_outputs > 1:
194175
out = tuple(DataArray(variable, coords, name=name, fastpath=True)
195176
for variable, coords in zip(result_var, result_coords))
196177
else:
@@ -269,10 +250,10 @@ def _as_variables_or_variable(arg):
269250

270251
def _unpack_dict_tuples(
271252
result_vars, # type: Mapping[Any, Tuple[Variable]]
272-
n_outputs, # type: int
253+
num_outputs, # type: int
273254
):
274255
# type: (...) -> Tuple[Dict[Any, Variable]]
275-
out = tuple(OrderedDict() for _ in range(n_outputs))
256+
out = tuple(OrderedDict() for _ in range(num_outputs))
276257
for name, values in result_vars.items():
277258
for value, results_dict in zip(values, out):
278259
results_dict[name] = value
@@ -298,8 +279,8 @@ def apply_dict_of_variables_ufunc(func, *args, **kwargs):
298279
for name, variable_args in zip(names, grouped_by_name):
299280
result_vars[name] = func(*variable_args)
300281

301-
if signature.n_outputs > 1:
302-
return _unpack_dict_tuples(result_vars, signature.n_outputs)
282+
if signature.num_outputs > 1:
283+
return _unpack_dict_tuples(result_vars, signature.num_outputs)
303284
else:
304285
return result_vars
305286

@@ -335,8 +316,8 @@ def apply_dataset_ufunc(func, *args, **kwargs):
335316

336317
if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and
337318
fill_value is _DEFAULT_FILL_VALUE):
338-
raise TypeError('To apply an operation to datasets with different ',
339-
'data variables, you must supply the ',
319+
raise TypeError('to apply an operation to datasets with different '
320+
'data variables with apply_ufunc, you must supply the '
340321
'dataset_fill_value argument.')
341322

342323
if kwargs:
@@ -353,7 +334,7 @@ def apply_dataset_ufunc(func, *args, **kwargs):
353334
func, *args, signature=signature, join=dataset_join,
354335
fill_value=fill_value)
355336

356-
if signature.n_outputs > 1:
337+
if signature.num_outputs > 1:
357338
out = tuple(_fast_dataset(*args)
358339
for args in zip(result_vars, list_of_coords))
359340
else:
@@ -388,12 +369,12 @@ def apply_groupby_ufunc(func, *args):
388369
from .variable import Variable
389370

390371
groupbys = [arg for arg in args if isinstance(arg, GroupBy)]
391-
if not groupbys:
392-
raise ValueError('must have at least one groupby to iterate over')
372+
assert groupbys, 'must have at least one groupby to iterate over'
393373
first_groupby = groupbys[0]
394374
if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]):
395-
raise ValueError('can only perform operations over multiple groupbys '
396-
'at once if they are all grouped the same way')
375+
raise ValueError('apply_ufunc can only perform operations over '
376+
'multiple GroupBy objets at once if they are all '
377+
'grouped the same way')
397378

398379
grouped_dim = first_groupby._group.name
399380
unique_values = first_groupby._unique_coord.values
@@ -430,7 +411,7 @@ def unified_dim_sizes(variables, exclude_dims=frozenset()):
430411
for var in variables:
431412
if len(set(var.dims)) < len(var.dims):
432413
raise ValueError('broadcasting cannot handle duplicate '
433-
'dimensions: %r' % list(var.dims))
414+
'dimensions on a variable: %r' % list(var.dims))
434415
for dim, size in zip(var.dims, var.shape):
435416
if dim not in exclude_dims:
436417
if dim not in dim_sizes:
@@ -462,15 +443,17 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
462443
set_old_dims = set(old_dims)
463444
missing_core_dims = [d for d in core_dims if d not in set_old_dims]
464445
if missing_core_dims:
465-
raise ValueError('operation requires dimensions missing on input '
466-
'variable: %r' % missing_core_dims)
446+
raise ValueError('operand to apply_ufunc has required core dimensions '
447+
'%r, but some of these are missing on the input '
448+
'variable: %r' % (list(core_dims), missing_core_dims))
467449

468450
set_new_dims = set(new_dims)
469451
unexpected_dims = [d for d in old_dims if d not in set_new_dims]
470452
if unexpected_dims:
471-
raise ValueError('operation encountered unexpected dimensions %r '
472-
'on input variable: these are core dimensions on '
473-
'other input or output variables' % unexpected_dims)
453+
raise ValueError('operand to apply_ufunc encountered unexpected '
454+
'dimensions %r on an input variable: these are core '
455+
'dimensions on other input or output variables'
456+
% unexpected_dims)
474457

475458
# for consistency with numpy, keep broadcast dimensions to the left
476459
old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims)
@@ -500,6 +483,9 @@ def apply_variable_ufunc(func, *args, **kwargs):
500483

501484
signature = kwargs.pop('signature')
502485
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET)
486+
dask = kwargs.pop('dask', 'forbidden')
487+
output_dtypes = kwargs.pop('output_dtypes', None)
488+
output_sizes = kwargs.pop('output_sizes', None)
503489
if kwargs:
504490
raise TypeError('apply_variable_ufunc() got unexpected keyword '
505491
'arguments: %s' % list(kwargs))
@@ -515,9 +501,28 @@ def apply_variable_ufunc(func, *args, **kwargs):
515501
else arg
516502
for arg, core_dims in zip(args, signature.input_core_dims)]
517503

504+
if any(isinstance(array, dask_array_type) for array in input_data):
505+
if dask == 'forbidden':
506+
raise ValueError('apply_ufunc encountered a dask array on an '
507+
'argument, but handling for dask arrays has not '
508+
'been enabled. Either set the ``dask`` argument '
509+
'or load your data into memory first with '
510+
'``.load()`` or ``.compute()``')
511+
elif dask == 'parallelized':
512+
input_dims = [broadcast_dims + input_dims
513+
for input_dims in signature.input_core_dims]
514+
numpy_func = func
515+
func = lambda *arrays: _apply_with_dask_atop(
516+
numpy_func, arrays, input_dims, output_dims, signature,
517+
output_dtypes, output_sizes)
518+
elif dask == 'allowed':
519+
pass
520+
else:
521+
raise ValueError('unknown setting for dask array handling in '
522+
'apply_ufunc: {}'.format(dask))
518523
result_data = func(*input_data)
519524

520-
if signature.n_outputs > 1:
525+
if signature.num_outputs > 1:
521526
output = []
522527
for dims, data in zip(output_dims, result_data):
523528
output.append(Variable(dims, data))
@@ -527,24 +532,83 @@ def apply_variable_ufunc(func, *args, **kwargs):
527532
return Variable(dims, result_data)
528533

529534

535+
def _apply_with_dask_atop(func, args, input_dims, output_dims, signature,
536+
output_dtypes, output_sizes=None):
537+
import dask.array as da
538+
539+
if signature.num_outputs > 1:
540+
raise NotImplementedError('multiple outputs from apply_ufunc not yet '
541+
"supported with dask='parallelized'")
542+
543+
if output_dtypes is None:
544+
raise ValueError('output dtypes (output_dtypes) must be supplied to '
545+
"apply_func when using dask='parallelized'")
546+
if not isinstance(output_dtypes, list):
547+
raise TypeError('output_dtypes must be a list of objects coercible to '
548+
'numpy dtypes, got {}'.format(output_dtypes))
549+
if len(output_dtypes) != signature.num_outputs:
550+
raise ValueError('apply_ufunc arguments output_dtypes and '
551+
'output_core_dims must have the same length: {} vs {}'
552+
.format(len(output_dtypes), signature.num_outputs))
553+
(dtype,) = output_dtypes
554+
555+
if output_sizes is None:
556+
output_sizes = {}
557+
558+
new_dims = signature.all_output_core_dims - signature.all_input_core_dims
559+
if any(dim not in output_sizes for dim in new_dims):
560+
raise ValueError("when using dask='parallelized' with apply_ufunc, "
561+
'output core dimensions not found on inputs must have '
562+
'explicitly set sizes with ``output_sizes``: {}'
563+
.format(new_dims))
564+
565+
for n, (data, core_dims) in enumerate(
566+
zip(args, signature.input_core_dims)):
567+
if isinstance(data, dask_array_type):
568+
# core dimensions cannot span multiple chunks
569+
for axis, dim in enumerate(core_dims, start=-len(core_dims)):
570+
if len(data.chunks[axis]) != 1:
571+
raise ValueError(
572+
'dimension {!r} on {}th function argument to '
573+
"apply_ufunc with dask='parallelized' consists of "
574+
'multiple chunks, but is also a core dimension. To '
575+
'fix, rechunk into a single dask array chunk along '
576+
'this dimension, i.e., ``.rechunk({})``, but beware '
577+
'that this may significantly increase memory usage.'
578+
.format(dim, n, {dim: -1}))
579+
580+
(out_ind,) = output_dims
581+
# skip leading dimensions that we did not insert with broadcast_compat_data
582+
atop_args = [element
583+
for (arg, dims) in zip(args, input_dims)
584+
for element in (arg, dims[-getattr(arg, 'ndim', 0):])]
585+
return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
586+
new_axes=output_sizes)
587+
588+
530589
def apply_array_ufunc(func, *args, **kwargs):
531-
"""apply_variable_ufunc(func, *args, dask_array='forbidden')
590+
"""apply_array_ufunc(func, *args, dask='forbidden')
532591
"""
533-
dask_array = kwargs.pop('dask_array', 'forbidden')
592+
dask = kwargs.pop('dask', 'forbidden')
534593
if kwargs:
535594
raise TypeError('apply_array_ufunc() got unexpected keyword '
536595
'arguments: %s' % list(kwargs))
537596

538597
if any(isinstance(arg, dask_array_type) for arg in args):
539-
# TODO: add a mode dask_array='auto' when dask.array gets a function
540-
# for applying arbitrary gufuncs
541-
if dask_array == 'forbidden':
542-
raise ValueError('encountered dask array, but did not set '
543-
"dask_array='allowed'")
544-
elif dask_array != 'allowed':
545-
raise ValueError('unknown setting for dask array handling: %r'
546-
% dask_array)
547-
# fall through
598+
if dask == 'forbidden':
599+
raise ValueError('apply_ufunc encountered a dask array on an '
600+
'argument, but handling for dask arrays has not '
601+
'been enabled. Either set the ``dask`` argument '
602+
'or load your data into memory first with '
603+
'``.load()`` or ``.compute()``')
604+
elif dask == 'parallelized':
605+
raise ValueError("cannot use dask='parallelized' for apply_ufunc "
606+
'unless at least one input is an xarray object')
607+
elif dask == 'allowed':
608+
pass
609+
else:
610+
raise ValueError('unknown setting for dask array handling: {}'
611+
.format(dask))
548612
return func(*args)
549613

550614

@@ -559,7 +623,9 @@ def apply_ufunc(func, *args, **kwargs):
559623
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
560624
keep_attrs : bool = False,
561625
kwargs : Mapping = None,
562-
dask_array : str = 'forbidden')
626+
dask : str = 'forbidden',
627+
output_dtypes : Optional[Sequence] = None,
628+
output_sizes : Optional[Mapping[Any, int]] = None)
563629
564630
Apply a vectorized function for unlabeled arrays on xarray objects.
565631
@@ -581,8 +647,8 @@ def apply_ufunc(func, *args, **kwargs):
581647
Mix of labeled and/or unlabeled arrays to which to apply the function.
582648
input_core_dims : Sequence[Sequence], optional
583649
List of the same length as ``args`` giving the list of core dimensions
584-
on each input argument that should be broadcast. By default, we assume
585-
there are no core dimensions on any input arguments.
650+
on each input argument that should not be broadcast. By default, we
651+
assume there are no core dimensions on any input arguments.
586652
587653
For example ,``input_core_dims=[[], ['time']]`` indicates that all
588654
dimensions on the first argument and all dimensions other than 'time'
@@ -630,10 +696,20 @@ def apply_ufunc(func, *args, **kwargs):
630696
Whether to copy attributes from the first argument to the output.
631697
kwargs: dict, optional
632698
Optional keyword arguments passed directly on to call ``func``.
633-
dask_array: 'forbidden' or 'allowed', optional
634-
Whether or not to allow applying the ufunc to objects containing lazy
635-
data in the form of dask arrays. By default, this is forbidden, to
636-
avoid implicitly converting lazy data.
699+
dask: 'forbidden', 'allowed' or 'parallelized', optional
700+
How to handle applying to objects containing lazy data in the form of
701+
dask arrays:
702+
- 'forbidden' (default): raise an error if a dask array is encountered.
703+
- 'allowed': pass dask arrays directly on to ``func``.
704+
- 'parallelized': automatically parallelize ``func`` if any of the
705+
inputs are a dask array. If used, the ``output_dtypes`` argument must
706+
also be provided. Multiple output arguments are not yet supported.
707+
output_dtypes : list of dtypes, optional
708+
Optional list of output dtypes. Only used if dask='parallelized'.
709+
output_sizes : dict, optional
710+
Optional mapping from dimension names to sizes for outputs. Only used if
711+
dask='parallelized' and new dimensions (not found on inputs) appear on
712+
outputs.
637713
638714
Returns
639715
-------
@@ -710,7 +786,9 @@ def stack(objects, dim, new_coord):
710786
exclude_dims = kwargs.pop('exclude_dims', frozenset())
711787
dataset_fill_value = kwargs.pop('dataset_fill_value', _DEFAULT_FILL_VALUE)
712788
kwargs_ = kwargs.pop('kwargs', None)
713-
dask_array = kwargs.pop('dask_array', 'forbidden')
789+
dask = kwargs.pop('dask', 'forbidden')
790+
output_dtypes = kwargs.pop('output_dtypes', None)
791+
output_sizes = kwargs.pop('output_sizes', None)
714792
if kwargs:
715793
raise TypeError('apply_ufunc() got unexpected keyword arguments: %s'
716794
% list(kwargs))
@@ -727,12 +805,12 @@ def stack(objects, dim, new_coord):
727805
if kwargs_:
728806
func = functools.partial(func, **kwargs_)
729807

730-
array_ufunc = functools.partial(
731-
apply_array_ufunc, func, dask_array=dask_array)
732-
733-
variables_ufunc = functools.partial(apply_variable_ufunc, array_ufunc,
808+
variables_ufunc = functools.partial(apply_variable_ufunc, func,
734809
signature=signature,
735-
exclude_dims=exclude_dims)
810+
exclude_dims=exclude_dims,
811+
dask=dask,
812+
output_dtypes=output_dtypes,
813+
output_sizes=output_sizes)
736814

737815
if any(isinstance(a, GroupBy) for a in args):
738816
# kwargs has already been added into func
@@ -744,7 +822,7 @@ def stack(objects, dim, new_coord):
744822
dataset_join=dataset_join,
745823
dataset_fill_value=dataset_fill_value,
746824
keep_attrs=keep_attrs,
747-
dask_array=dask_array)
825+
dask=dask)
748826
return apply_groupby_ufunc(this_apply, *args)
749827
elif any(is_dict_like(a) for a in args):
750828
return apply_dataset_ufunc(variables_ufunc, *args,
@@ -763,7 +841,7 @@ def stack(objects, dim, new_coord):
763841
elif any(isinstance(a, Variable) for a in args):
764842
return variables_ufunc(*args)
765843
else:
766-
return array_ufunc(*args)
844+
return apply_array_ufunc(func, *args, dask=dask)
767845

768846

769847
def where(cond, x, y):
@@ -805,4 +883,4 @@ def where(cond, x, y):
805883
cond, x, y,
806884
join='exact',
807885
dataset_join='exact',
808-
dask_array='allowed')
886+
dask='allowed')

xarray/core/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def fillna(data, other, join="left", dataset_join="left"):
148148

149149
return apply_ufunc(duck_array_ops.fillna, data, other,
150150
join=join,
151-
dask_array="allowed",
151+
dask="allowed",
152152
dataset_join=dataset_join,
153153
dataset_fill_value=np.nan,
154154
keep_attrs=True)
@@ -176,7 +176,7 @@ def where_method(self, cond, other=dtypes.NA):
176176
self, cond, other,
177177
join=join,
178178
dataset_join=join,
179-
dask_array='allowed',
179+
dask='allowed',
180180
keep_attrs=True)
181181

182182

0 commit comments

Comments
 (0)