@@ -500,6 +500,9 @@ def apply_variable_ufunc(func, *args, **kwargs):
500
500
501
501
signature = kwargs .pop ('signature' )
502
502
exclude_dims = kwargs .pop ('exclude_dims' , _DEFAULT_FROZEN_SET )
503
+ dask = kwargs .pop ('dask' , 'forbidden' )
504
+ output_dtypes = kwargs .pop ('output_dtypes' , None )
505
+ output_sizes = kwargs .pop ('output_sizes' , None )
503
506
if kwargs :
504
507
raise TypeError ('apply_variable_ufunc() got unexpected keyword '
505
508
'arguments: %s' % list (kwargs ))
@@ -515,6 +518,22 @@ def apply_variable_ufunc(func, *args, **kwargs):
515
518
else arg
516
519
for arg , core_dims in zip (args , signature .input_core_dims )]
517
520
521
+ if any (isinstance (array , dask_array_type ) for array in input_data ):
522
+ if dask == 'forbidden' :
523
+ raise ValueError ('encountered dask array, but did not set '
524
+ "dask='allowed'" )
525
+ elif dask == 'parallelized' :
526
+ input_dims = [broadcast_dims + input_dims
527
+ for input_dims in signature .input_core_dims ]
528
+ numpy_func = func
529
+ func = lambda * arrays : _apply_with_dask_atop (
530
+ numpy_func , arrays , input_dims , output_dims , signature ,
531
+ output_dtypes , output_sizes )
532
+ elif dask == 'allowed' :
533
+ pass
534
+ else :
535
+ raise ValueError ('unknown setting for dask array handling: {}'
536
+ .format (dask ))
518
537
result_data = func (* input_data )
519
538
520
539
if signature .n_outputs > 1 :
@@ -527,24 +546,65 @@ def apply_variable_ufunc(func, *args, **kwargs):
527
546
return Variable (dims , result_data )
528
547
529
548
549
+ def _apply_with_dask_atop (func , args , input_dims , output_dims , signature ,
550
+ output_dtypes , output_sizes = None ):
551
+ import dask .array as da
552
+
553
+ if signature .n_outputs > 1 :
554
+ raise NotImplementedError (
555
+ "multiple outputs not yet supported with dask='parallelized'" )
556
+
557
+ if output_dtypes is None :
558
+ raise ValueError (
559
+ "output dtypes (output_dtypes) required when using dask='parallelized'" )
560
+ if len (output_dtypes ) != signature .n_outputs :
561
+ raise ValueError ('wrong number of output dtypes' )
562
+ (dtype ,) = output_dtypes
563
+
564
+ if output_sizes is None :
565
+ output_sizes = {}
566
+
567
+ new_dims = signature .all_output_core_dims - signature .all_input_core_dims
568
+ if any (dim not in output_sizes for dim in new_dims ):
569
+ raise ValueError ('output core dimensions not found on inputs must have '
570
+ 'explicitly set sizes with ``output_sizes``: {}'
571
+ .format (new_dims ))
572
+
573
+ args2 = []
574
+ for data , core_dims in zip (args , signature .input_core_dims ):
575
+ if isinstance (data , dask_array_type ):
576
+ # core dimensions cannot span multiple chunks
577
+ chunks = {axis : (data .shape [axis ],)
578
+ for axis , dim in enumerate (core_dims , - len (core_dims ))}
579
+ data = data .rechunk (chunks )
580
+ args2 .append (data )
581
+
582
+ (out_ind ,) = output_dims
583
+ atop_args = [ai for a in zip (args2 , input_dims ) for ai in a ]
584
+ return da .atop (func , out_ind , * atop_args , dtype = dtype , concatenate = True ,
585
+ new_axes = output_sizes )
586
+
587
+
530
588
def apply_array_ufunc (func , * args , ** kwargs ):
531
- """apply_variable_ufunc (func, *args, dask_array ='forbidden')
589
+ """apply_array_ufunc (func, *args, dask ='forbidden')
532
590
"""
533
- dask_array = kwargs .pop ('dask_array ' , 'forbidden' )
591
+ dask = kwargs .pop ('dask ' , 'forbidden' )
534
592
if kwargs :
535
593
raise TypeError ('apply_array_ufunc() got unexpected keyword '
536
594
'arguments: %s' % list (kwargs ))
537
595
538
596
if any (isinstance (arg , dask_array_type ) for arg in args ):
539
- # TODO: add a mode dask_array='auto' when dask.array gets a function
540
- # for applying arbitrary gufuncs
541
- if dask_array == 'forbidden' :
597
+ if dask == 'forbidden' :
542
598
raise ValueError ('encountered dask array, but did not set '
543
- "dask_array='allowed'" )
544
- elif dask_array != 'allowed' :
545
- raise ValueError ('unknown setting for dask array handling: %r'
546
- % dask_array )
547
- # fall through
599
+ "dask='allowed'" )
600
+ elif dask == 'parallelized' :
601
+ raise ValueError ("cannot use dask='parallelized' unless at least "
602
+ 'one input is an xarray object' )
603
+ elif dask == 'allowed' :
604
+ pass
605
+ else :
606
+ raise ValueError ('unknown setting for dask array handling: {}'
607
+ .format (dask ))
548
608
return func (* args )
549
609
550
610
@@ -559,7 +619,9 @@ def apply_ufunc(func, *args, **kwargs):
559
619
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
560
620
keep_attrs : bool = False,
561
621
kwargs : Mapping = None,
562
- dask_array : str = 'forbidden')
622
+ dask_array : str = 'forbidden',
623
+ output_dtypes : Optional[Sequence] = None,
624
+ output_sizes : Optional[Mapping[Any, int]] = None)
563
625
564
626
Apply a vectorized function for unlabeled arrays on xarray objects.
565
627
@@ -630,10 +692,20 @@ def apply_ufunc(func, *args, **kwargs):
630
692
Whether to copy attributes from the first argument to the output.
631
693
kwargs: dict, optional
632
694
Optional keyword arguments passed directly on to call ``func``.
633
- dask_array: 'forbidden' or 'allowed', optional
634
- Whether or not to allow applying the ufunc to objects containing lazy
635
- data in the form of dask arrays. By default, this is forbidden, to
636
- avoid implicitly converting lazy data.
695
+ dask: 'forbidden', 'allowed' or 'parallelized', optional
696
+ How to handle applying to objects containing lazy data in the form of
697
+ dask arrays:
698
+ - 'forbidden' (default): raise an error if a dask array is encountered.
699
+ - 'allowed': pass dask arrays directly on to ``func``.
700
+ - 'parallelized': automatically parallelize ``func`` if any of the
701
+ inputs are a dask array. If used, the ``otypes`` argument must also be
702
+ provided. Multiple output arguments are not yet supported.
703
+ output_dtypes : list of dtypes, optional
704
+ Optional list of output dtypes. Only used if dask='parallelized'.
705
+ output_sizes : dict, optional
706
+ Optional mapping from dimension names to sizes for outputs. Only used if
707
+ dask='parallelized' and new dimensions (not found on inputs) appear on
708
+ outputs.
637
709
638
710
Returns
639
711
-------
@@ -710,7 +782,9 @@ def stack(objects, dim, new_coord):
710
782
exclude_dims = kwargs .pop ('exclude_dims' , frozenset ())
711
783
dataset_fill_value = kwargs .pop ('dataset_fill_value' , _DEFAULT_FILL_VALUE )
712
784
kwargs_ = kwargs .pop ('kwargs' , None )
713
- dask_array = kwargs .pop ('dask_array' , 'forbidden' )
785
+ dask = kwargs .pop ('dask' , 'forbidden' )
786
+ output_dtypes = kwargs .pop ('output_dtypes' , None )
787
+ output_sizes = kwargs .pop ('output_sizes' , None )
714
788
if kwargs :
715
789
raise TypeError ('apply_ufunc() got unexpected keyword arguments: %s'
716
790
% list (kwargs ))
@@ -727,12 +801,12 @@ def stack(objects, dim, new_coord):
727
801
if kwargs_ :
728
802
func = functools .partial (func , ** kwargs_ )
729
803
730
- array_ufunc = functools .partial (
731
- apply_array_ufunc , func , dask_array = dask_array )
732
-
733
- variables_ufunc = functools .partial (apply_variable_ufunc , array_ufunc ,
804
+ variables_ufunc = functools .partial (apply_variable_ufunc , func ,
734
805
signature = signature ,
735
- exclude_dims = exclude_dims )
806
+ exclude_dims = exclude_dims ,
807
+ dask = dask ,
808
+ output_dtypes = output_dtypes ,
809
+ output_sizes = output_sizes )
736
810
737
811
if any (isinstance (a , GroupBy ) for a in args ):
738
812
# kwargs has already been added into func
@@ -744,7 +818,7 @@ def stack(objects, dim, new_coord):
744
818
dataset_join = dataset_join ,
745
819
dataset_fill_value = dataset_fill_value ,
746
820
keep_attrs = keep_attrs ,
747
- dask_array = dask_array )
821
+ dask = dask )
748
822
return apply_groupby_ufunc (this_apply , * args )
749
823
elif any (is_dict_like (a ) for a in args ):
750
824
return apply_dataset_ufunc (variables_ufunc , * args ,
@@ -763,7 +837,7 @@ def stack(objects, dim, new_coord):
763
837
elif any (isinstance (a , Variable ) for a in args ):
764
838
return variables_ufunc (* args )
765
839
else :
766
- return array_ufunc ( * args )
840
+ return apply_array_ufunc ( func , * args , dask = dask )
767
841
768
842
769
843
def where (cond , x , y ):
@@ -805,4 +879,4 @@ def where(cond, x, y):
805
879
cond , x , y ,
806
880
join = 'exact' ,
807
881
dataset_join = 'exact' ,
808
- dask_array = 'allowed' )
882
+ dask = 'allowed' )
0 commit comments