Skip to content

Commit fd28ff1

Browse files
committed
dask='parallelized' for apply_ufunc'
1 parent f9464fd commit fd28ff1

File tree

3 files changed

+204
-40
lines changed

3 files changed

+204
-40
lines changed

xarray/core/computation.py

+98-24
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ def apply_variable_ufunc(func, *args, **kwargs):
500500

501501
signature = kwargs.pop('signature')
502502
exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET)
503+
dask = kwargs.pop('dask', 'forbidden')
504+
output_dtypes = kwargs.pop('output_dtypes', None)
505+
output_sizes = kwargs.pop('output_sizes', None)
503506
if kwargs:
504507
raise TypeError('apply_variable_ufunc() got unexpected keyword '
505508
'arguments: %s' % list(kwargs))
@@ -515,6 +518,22 @@ def apply_variable_ufunc(func, *args, **kwargs):
515518
else arg
516519
for arg, core_dims in zip(args, signature.input_core_dims)]
517520

521+
if any(isinstance(array, dask_array_type) for array in input_data):
522+
if dask == 'forbidden':
523+
raise ValueError('encountered dask array, but did not set '
524+
"dask='allowed'")
525+
elif dask == 'parallelized':
526+
input_dims = [broadcast_dims + input_dims
527+
for input_dims in signature.input_core_dims]
528+
numpy_func = func
529+
func = lambda *arrays: _apply_with_dask_atop(
530+
numpy_func, arrays, input_dims, output_dims, signature,
531+
output_dtypes, output_sizes)
532+
elif dask == 'allowed':
533+
pass
534+
else:
535+
raise ValueError('unknown setting for dask array handling: {}'
536+
.format(dask))
518537
result_data = func(*input_data)
519538

520539
if signature.n_outputs > 1:
@@ -527,24 +546,65 @@ def apply_variable_ufunc(func, *args, **kwargs):
527546
return Variable(dims, result_data)
528547

529548

549+
def _apply_with_dask_atop(func, args, input_dims, output_dims, signature,
550+
output_dtypes, output_sizes=None):
551+
import dask.array as da
552+
553+
if signature.n_outputs > 1:
554+
raise NotImplementedError(
555+
"multiple outputs not yet supported with dask='parallelized'")
556+
557+
if output_dtypes is None:
558+
raise ValueError(
559+
"output dtypes (output_dtypes) required when using dask='parallelized'")
560+
if len(output_dtypes) != signature.n_outputs:
561+
raise ValueError('wrong number of output dtypes')
562+
(dtype,) = output_dtypes
563+
564+
if output_sizes is None:
565+
output_sizes = {}
566+
567+
new_dims = signature.all_output_core_dims - signature.all_input_core_dims
568+
if any(dim not in output_sizes for dim in new_dims):
569+
raise ValueError('output core dimensions not found on inputs must have '
570+
'explicitly set sizes with ``output_sizes``: {}'
571+
.format(new_dims))
572+
573+
args2 = []
574+
for data, core_dims in zip(args, signature.input_core_dims):
575+
if isinstance(data, dask_array_type):
576+
# core dimensions cannot span multiple chunks
577+
chunks = {axis: (data.shape[axis],)
578+
for axis, dim in enumerate(core_dims, -len(core_dims))}
579+
data = data.rechunk(chunks)
580+
args2.append(data)
581+
582+
(out_ind,) = output_dims
583+
atop_args = [ai for a in zip(args2, input_dims) for ai in a]
584+
return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
585+
new_axes=output_sizes)
586+
587+
530588
def apply_array_ufunc(func, *args, **kwargs):
531-
"""apply_variable_ufunc(func, *args, dask_array='forbidden')
589+
"""apply_array_ufunc(func, *args, dask='forbidden')
532590
"""
533-
dask_array = kwargs.pop('dask_array', 'forbidden')
591+
dask = kwargs.pop('dask', 'forbidden')
534592
if kwargs:
535593
raise TypeError('apply_array_ufunc() got unexpected keyword '
536594
'arguments: %s' % list(kwargs))
537595

538596
if any(isinstance(arg, dask_array_type) for arg in args):
539-
# TODO: add a mode dask_array='auto' when dask.array gets a function
540-
# for applying arbitrary gufuncs
541-
if dask_array == 'forbidden':
597+
if dask == 'forbidden':
542598
raise ValueError('encountered dask array, but did not set '
543-
"dask_array='allowed'")
544-
elif dask_array != 'allowed':
545-
raise ValueError('unknown setting for dask array handling: %r'
546-
% dask_array)
547-
# fall through
599+
"dask='allowed'")
600+
elif dask == 'parallelized':
601+
raise ValueError("cannot use dask='parallelized' unless at least "
602+
'one input is an xarray object')
603+
elif dask == 'allowed':
604+
pass
605+
else:
606+
raise ValueError('unknown setting for dask array handling: {}'
607+
.format(dask))
548608
return func(*args)
549609

550610

@@ -559,7 +619,9 @@ def apply_ufunc(func, *args, **kwargs):
559619
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
560620
keep_attrs : bool = False,
561621
kwargs : Mapping = None,
562-
dask_array : str = 'forbidden')
622+
dask_array : str = 'forbidden',
623+
output_dtypes : Optional[Sequence] = None,
624+
output_sizes : Optional[Mapping[Any, int]] = None)
563625
564626
Apply a vectorized function for unlabeled arrays on xarray objects.
565627
@@ -630,10 +692,20 @@ def apply_ufunc(func, *args, **kwargs):
630692
Whether to copy attributes from the first argument to the output.
631693
kwargs: dict, optional
632694
Optional keyword arguments passed directly on to call ``func``.
633-
dask_array: 'forbidden' or 'allowed', optional
634-
Whether or not to allow applying the ufunc to objects containing lazy
635-
data in the form of dask arrays. By default, this is forbidden, to
636-
avoid implicitly converting lazy data.
695+
dask: 'forbidden', 'allowed' or 'parallelized', optional
696+
How to handle applying to objects containing lazy data in the form of
697+
dask arrays:
698+
- 'forbidden' (default): raise an error if a dask array is encountered.
699+
- 'allowed': pass dask arrays directly on to ``func``.
700+
- 'parallelized': automatically parallelize ``func`` if any of the
701+
inputs are a dask array. If used, the ``otypes`` argument must also be
702+
provided. Multiple output arguments are not yet supported.
703+
output_dtypes : list of dtypes, optional
704+
Optional list of output dtypes. Only used if dask='parallelized'.
705+
output_sizes : dict, optional
706+
Optional mapping from dimension names to sizes for outputs. Only used if
707+
dask='parallelized' and new dimensions (not found on inputs) appear on
708+
outputs.
637709
638710
Returns
639711
-------
@@ -710,7 +782,9 @@ def stack(objects, dim, new_coord):
710782
exclude_dims = kwargs.pop('exclude_dims', frozenset())
711783
dataset_fill_value = kwargs.pop('dataset_fill_value', _DEFAULT_FILL_VALUE)
712784
kwargs_ = kwargs.pop('kwargs', None)
713-
dask_array = kwargs.pop('dask_array', 'forbidden')
785+
dask = kwargs.pop('dask', 'forbidden')
786+
output_dtypes = kwargs.pop('output_dtypes', None)
787+
output_sizes = kwargs.pop('output_sizes', None)
714788
if kwargs:
715789
raise TypeError('apply_ufunc() got unexpected keyword arguments: %s'
716790
% list(kwargs))
@@ -727,12 +801,12 @@ def stack(objects, dim, new_coord):
727801
if kwargs_:
728802
func = functools.partial(func, **kwargs_)
729803

730-
array_ufunc = functools.partial(
731-
apply_array_ufunc, func, dask_array=dask_array)
732-
733-
variables_ufunc = functools.partial(apply_variable_ufunc, array_ufunc,
804+
variables_ufunc = functools.partial(apply_variable_ufunc, func,
734805
signature=signature,
735-
exclude_dims=exclude_dims)
806+
exclude_dims=exclude_dims,
807+
dask=dask,
808+
output_dtypes=output_dtypes,
809+
output_sizes=output_sizes)
736810

737811
if any(isinstance(a, GroupBy) for a in args):
738812
# kwargs has already been added into func
@@ -744,7 +818,7 @@ def stack(objects, dim, new_coord):
744818
dataset_join=dataset_join,
745819
dataset_fill_value=dataset_fill_value,
746820
keep_attrs=keep_attrs,
747-
dask_array=dask_array)
821+
dask=dask)
748822
return apply_groupby_ufunc(this_apply, *args)
749823
elif any(is_dict_like(a) for a in args):
750824
return apply_dataset_ufunc(variables_ufunc, *args,
@@ -763,7 +837,7 @@ def stack(objects, dim, new_coord):
763837
elif any(isinstance(a, Variable) for a in args):
764838
return variables_ufunc(*args)
765839
else:
766-
return array_ufunc(*args)
840+
return apply_array_ufunc(func, *args, dask=dask)
767841

768842

769843
def where(cond, x, y):
@@ -805,4 +879,4 @@ def where(cond, x, y):
805879
cond, x, y,
806880
join='exact',
807881
dataset_join='exact',
808-
dask_array='allowed')
882+
dask='allowed')

xarray/tests/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def raises_regex(error, pattern):
188188
with pytest.raises(error) as excinfo:
189189
yield
190190
message = str(excinfo.value)
191-
if not re.match(pattern, message):
192-
raise AssertionError('exception %r did not match pattern %s'
191+
if not re.search(pattern, message):
192+
raise AssertionError('exception %r did not match pattern %r'
193193
% (excinfo.value, pattern))
194194

195195

xarray/tests/test_computation.py

+104-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
join_dict_keys, ordered_set_intersection, ordered_set_union,
1414
unified_dim_sizes, apply_ufunc)
1515

16-
from . import requires_dask
16+
from . import requires_dask, raises_regex
1717

1818

1919
def assert_identical(a, b):
@@ -344,7 +344,7 @@ def func(*x):
344344
output_core_dims=[[dim]],
345345
exclude_dims={dim})
346346
if isinstance(result, (xr.Dataset, xr.DataArray)):
347-
# note: this will fail dim is not a coordinate on any input
347+
# note: this will fail if dim is not a coordinate on any input
348348
new_coord = np.concatenate([obj.coords[dim] for obj in objects])
349349
result.coords[dim] = new_coord
350350
return result
@@ -530,25 +530,17 @@ def add(a, b, join, dataset_join):
530530
assert_identical(actual, expected)
531531

532532

533-
class _NoCacheVariable(xr.Variable):
534-
"""Subclass of Variable for testing that does not cache values."""
535-
# TODO: remove this class when we change the default behavior for caching
536-
# dask.array objects.
537-
def _data_cached(self):
538-
return np.asarray(self._data)
539-
540-
541533
@requires_dask
542534
def test_apply_dask():
543535
import dask.array as da
544536

545537
array = da.ones((2,), chunks=2)
546-
variable = _NoCacheVariable('x', array)
538+
variable = xr.Variable('x', array)
547539
coords = xr.DataArray(variable).coords.variables
548540
data_array = xr.DataArray(variable, coords, fastpath=True)
549541
dataset = xr.Dataset({'y': variable})
550542

551-
# encountered dask array, but did not set dask_array='allowed'
543+
# encountered dask array, but did not set dask='allowed'
552544
with pytest.raises(ValueError):
553545
apply_ufunc(identity, array)
554546
with pytest.raises(ValueError):
@@ -560,10 +552,10 @@ def test_apply_dask():
560552

561553
# unknown setting for dask array handling
562554
with pytest.raises(ValueError):
563-
apply_ufunc(identity, array, dask_array='auto')
555+
apply_ufunc(identity, array, dask='unknown')
564556

565557
def dask_safe_identity(x):
566-
return apply_ufunc(identity, x, dask_array='allowed')
558+
return apply_ufunc(identity, x, dask='allowed')
567559

568560
assert array is dask_safe_identity(array)
569561

@@ -580,6 +572,104 @@ def dask_safe_identity(x):
580572
assert_identical(dataset, actual)
581573

582574

575+
@requires_dask
576+
def test_apply_dask_parallelized():
577+
import dask.array as da
578+
579+
array = da.ones((2, 2), chunks=(1, 1))
580+
data_array = xr.DataArray(array, dims=('x', 'y'))
581+
582+
actual = apply_ufunc(identity, data_array, dask='parallelized',
583+
output_dtypes=[float])
584+
assert isinstance(actual.data, da.Array)
585+
assert actual.data.chunks == array.chunks
586+
assert_identical(data_array, actual)
587+
588+
# check rechunking of core dimensions
589+
actual = apply_ufunc(identity, data_array, dask='parallelized',
590+
output_dtypes=[float],
591+
input_core_dims=[('y',)],
592+
output_core_dims=[('y',)])
593+
assert isinstance(actual.data, da.Array)
594+
assert actual.data.chunks == ((1, 1), (2,))
595+
assert_identical(data_array, actual)
596+
597+
598+
@requires_dask
599+
def test_apply_dask_parallelized_errors():
600+
import dask.array as da
601+
602+
array = da.ones((2, 2), chunks=(1, 1))
603+
data_array = xr.DataArray(array, dims=('x', 'y'))
604+
605+
with pytest.raises(NotImplementedError):
606+
apply_ufunc(identity, data_array, output_core_dims=[['z'], ['z']],
607+
dask='parallelized')
608+
with raises_regex(ValueError, 'dtypes'):
609+
apply_ufunc(identity, data_array, dask='parallelized')
610+
with raises_regex(ValueError, 'wrong number'):
611+
apply_ufunc(identity, data_array, dask='parallelized',
612+
output_dtypes=[float, float])
613+
with raises_regex(ValueError, 'output_sizes'):
614+
apply_ufunc(identity, data_array, output_core_dims=[['z']],
615+
output_dtypes=[float], dask='parallelized')
616+
with raises_regex(ValueError, 'at least one input is an xarray object'):
617+
apply_ufunc(identity, array, dask='parallelized')
618+
619+
620+
@requires_dask
621+
def test_apply_dask_multiple_inputs():
622+
import dask.array as da
623+
624+
def covariance(x, y):
625+
return ((x - x.mean(axis=-1, keepdims=True))
626+
* (y - y.mean(axis=-1, keepdims=True))).mean(axis=-1)
627+
628+
rs = np.random.RandomState(42)
629+
array1 = da.from_array(rs.randn(4, 4), chunks=(2, 2))
630+
array2 = da.from_array(rs.randn(4, 4), chunks=(2, 2))
631+
data_array_1 = xr.DataArray(array1, dims=('x', 'y'))
632+
data_array_2 = xr.DataArray(array2, dims=('x', 'y'))
633+
634+
expected = apply_ufunc(
635+
covariance, data_array_1.compute(), data_array_2.compute(),
636+
input_core_dims=[['y'], ['y']])
637+
allowed = apply_ufunc(
638+
covariance, data_array_1, data_array_2, input_core_dims=[['y'], ['y']],
639+
dask='allowed')
640+
assert isinstance(allowed.data, da.Array)
641+
xr.testing.assert_allclose(expected, allowed.compute())
642+
643+
parallelized = apply_ufunc(
644+
covariance, data_array_1, data_array_2, input_core_dims=[['y'], ['y']],
645+
dask='parallelized', output_dtypes=[float])
646+
assert isinstance(parallelized.data, da.Array)
647+
xr.testing.assert_allclose(expected, parallelized.compute())
648+
649+
650+
@requires_dask
651+
def test_apply_dask_new_output_dimension():
652+
import dask.array as da
653+
654+
array = da.ones((2, 2), chunks=(1, 1))
655+
data_array = xr.DataArray(array, dims=('x', 'y'))
656+
657+
def stack_negative(obj):
658+
def func(x):
659+
return xr.core.npcompat.stack([x, -x], axis=-1)
660+
return apply_ufunc(func, obj, output_core_dims=[['sign']],
661+
dask='parallelized', output_dtypes=[obj.dtype],
662+
output_sizes={'sign': 2})
663+
664+
expected = stack_negative(data_array.compute())
665+
666+
actual = stack_negative(data_array)
667+
assert actual.dims == ('x', 'y', 'sign')
668+
assert actual.shape == (2, 2, 2)
669+
assert isinstance(actual.data, da.Array)
670+
assert_identical(expected, actual)
671+
672+
583673
def test_where():
584674
cond = xr.DataArray([True, False], dims='x')
585675
actual = xr.where(cond, 1, 0)

0 commit comments

Comments
 (0)