@@ -64,32 +64,13 @@ def all_core_dims(self):
6464        return  self ._all_core_dims 
6565
6666    @property  
67-     def  n_inputs (self ):
67+     def  num_inputs (self ):
6868        return  len (self .input_core_dims )
6969
7070    @property  
71-     def  n_outputs (self ):
71+     def  num_outputs (self ):
7272        return  len (self .output_core_dims )
7373
74-     @classmethod  
75-     def  default (cls , n_inputs ):
76-         return  cls ([()] *  n_inputs , [()])
77- 
78-     @classmethod  
79-     def  from_sequence (cls , nested ):
80-         if  (not  isinstance (nested , collections .Sequence ) or 
81-                 not  len (nested ) ==  2  or 
82-                 any (not  isinstance (arg_list , collections .Sequence )
83-                     for  arg_list  in  nested ) or 
84-                 any (isinstance (arg , basestring ) or 
85-                     not  isinstance (arg , collections .Sequence )
86-                     for  arg_list  in  nested  for  arg  in  arg_list )):
87-             raise  TypeError ('functions signatures not provided as a string ' 
88-                             'must be a triply nested sequence providing the ' 
89-                             'list of core dimensions for each variable, for ' 
90-                             'both input and output.' )
91-         return  cls (* nested )
92- 
9374    def  __eq__ (self , other ):
9475        try :
9576            return  (self .input_core_dims  ==  other .input_core_dims  and 
@@ -190,7 +171,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
190171    data_vars  =  [getattr (a , 'variable' , a ) for  a  in  args ]
191172    result_var  =  func (* data_vars )
192173
193-     if  signature .n_outputs  >  1 :
174+     if  signature .num_outputs  >  1 :
194175        out  =  tuple (DataArray (variable , coords , name = name , fastpath = True )
195176                    for  variable , coords  in  zip (result_var , result_coords ))
196177    else :
@@ -269,10 +250,10 @@ def _as_variables_or_variable(arg):
269250
270251def  _unpack_dict_tuples (
271252        result_vars ,  # type: Mapping[Any, Tuple[Variable]] 
272-         n_outputs ,    # type: int 
253+         num_outputs ,    # type: int 
273254):
274255    # type: (...) -> Tuple[Dict[Any, Variable]] 
275-     out  =  tuple (OrderedDict () for  _  in  range (n_outputs ))
256+     out  =  tuple (OrderedDict () for  _  in  range (num_outputs ))
276257    for  name , values  in  result_vars .items ():
277258        for  value , results_dict  in  zip (values , out ):
278259            results_dict [name ] =  value 
@@ -298,8 +279,8 @@ def apply_dict_of_variables_ufunc(func, *args, **kwargs):
298279    for  name , variable_args  in  zip (names , grouped_by_name ):
299280        result_vars [name ] =  func (* variable_args )
300281
301-     if  signature .n_outputs  >  1 :
302-         return  _unpack_dict_tuples (result_vars , signature .n_outputs )
282+     if  signature .num_outputs  >  1 :
283+         return  _unpack_dict_tuples (result_vars , signature .num_outputs )
303284    else :
304285        return  result_vars 
305286
@@ -335,8 +316,8 @@ def apply_dataset_ufunc(func, *args, **kwargs):
335316
336317    if  (dataset_join  not  in   _JOINS_WITHOUT_FILL_VALUES  and 
337318            fill_value  is  _DEFAULT_FILL_VALUE ):
338-         raise  TypeError ('To  apply an operation to datasets with different ' , 
339-                         'data variables, you must supply the ' , 
319+         raise  TypeError ('to  apply an operation to datasets with different ' 
320+                         'data variables with apply_ufunc , you must supply the ' 
340321                        'dataset_fill_value argument.' )
341322
342323    if  kwargs :
@@ -353,7 +334,7 @@ def apply_dataset_ufunc(func, *args, **kwargs):
353334        func , * args , signature = signature , join = dataset_join ,
354335        fill_value = fill_value )
355336
356-     if  signature .n_outputs  >  1 :
337+     if  signature .num_outputs  >  1 :
357338        out  =  tuple (_fast_dataset (* args )
358339                    for  args  in  zip (result_vars , list_of_coords ))
359340    else :
@@ -388,12 +369,12 @@ def apply_groupby_ufunc(func, *args):
388369    from  .variable  import  Variable 
389370
390371    groupbys  =  [arg  for  arg  in  args  if  isinstance (arg , GroupBy )]
391-     if  not  groupbys :
392-         raise  ValueError ('must have at least one groupby to iterate over' )
372+     assert  groupbys , 'must have at least one groupby to iterate over' 
393373    first_groupby  =  groupbys [0 ]
394374    if  any (not  first_groupby ._group .equals (gb ._group ) for  gb  in  groupbys [1 :]):
395-         raise  ValueError ('can only perform operations over multiple groupbys ' 
396-                          'at once if they are all grouped the same way' )
375+         raise  ValueError ('apply_ufunc can only perform operations over ' 
376+                          'multiple GroupBy objets at once if they are all ' 
377+                          'grouped the same way' )
397378
398379    grouped_dim  =  first_groupby ._group .name 
399380    unique_values  =  first_groupby ._unique_coord .values 
@@ -430,7 +411,7 @@ def unified_dim_sizes(variables, exclude_dims=frozenset()):
430411    for  var  in  variables :
431412        if  len (set (var .dims )) <  len (var .dims ):
432413            raise  ValueError ('broadcasting cannot handle duplicate ' 
433-                              'dimensions: %r'  %  list (var .dims ))
414+                              'dimensions on a variable : %r'  %  list (var .dims ))
434415        for  dim , size  in  zip (var .dims , var .shape ):
435416            if  dim  not  in   exclude_dims :
436417                if  dim  not  in   dim_sizes :
@@ -462,15 +443,17 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
462443    set_old_dims  =  set (old_dims )
463444    missing_core_dims  =  [d  for  d  in  core_dims  if  d  not  in   set_old_dims ]
464445    if  missing_core_dims :
465-         raise  ValueError ('operation requires dimensions missing on input ' 
466-                          'variable: %r'  %  missing_core_dims )
446+         raise  ValueError ('operand to apply_ufunc has required core dimensions ' 
447+                          '%r, but some of these are missing on the input ' 
448+                          'variable:  %r'  %  (list (core_dims ), missing_core_dims ))
467449
468450    set_new_dims  =  set (new_dims )
469451    unexpected_dims  =  [d  for  d  in  old_dims  if  d  not  in   set_new_dims ]
470452    if  unexpected_dims :
471-         raise  ValueError ('operation encountered unexpected dimensions %r ' 
472-                          'on input variable: these are core dimensions on ' 
473-                          'other input or output variables'  %  unexpected_dims )
453+         raise  ValueError ('operand to apply_ufunc encountered unexpected ' 
454+                          'dimensions %r on an input variable: these are core ' 
455+                          'dimensions on other input or output variables' 
456+                          %  unexpected_dims )
474457
475458    # for consistency with numpy, keep broadcast dimensions to the left 
476459    old_broadcast_dims  =  tuple (d  for  d  in  broadcast_dims  if  d  in  set_old_dims )
@@ -500,6 +483,9 @@ def apply_variable_ufunc(func, *args, **kwargs):
500483
501484    signature  =  kwargs .pop ('signature' )
502485    exclude_dims  =  kwargs .pop ('exclude_dims' , _DEFAULT_FROZEN_SET )
486+     dask  =  kwargs .pop ('dask' , 'forbidden' )
487+     output_dtypes  =  kwargs .pop ('output_dtypes' , None )
488+     output_sizes  =  kwargs .pop ('output_sizes' , None )
503489    if  kwargs :
504490        raise  TypeError ('apply_variable_ufunc() got unexpected keyword ' 
505491                        'arguments: %s'  %  list (kwargs ))
@@ -515,9 +501,28 @@ def apply_variable_ufunc(func, *args, **kwargs):
515501                  else  arg 
516502                  for  arg , core_dims  in  zip (args , signature .input_core_dims )]
517503
504+     if  any (isinstance (array , dask_array_type ) for  array  in  input_data ):
505+         if  dask  ==  'forbidden' :
506+             raise  ValueError ('apply_ufunc encountered a dask array on an ' 
507+                              'argument, but handling for dask arrays has not ' 
508+                              'been enabled. Either set the ``dask`` argument ' 
509+                              'or load your data into memory first with ' 
510+                              '``.load()`` or ``.compute()``' )
511+         elif  dask  ==  'parallelized' :
512+             input_dims  =  [broadcast_dims  +  input_dims 
513+                           for  input_dims  in  signature .input_core_dims ]
514+             numpy_func  =  func 
515+             func  =  lambda  * arrays : _apply_with_dask_atop (
516+                 numpy_func , arrays , input_dims , output_dims , signature ,
517+                 output_dtypes , output_sizes )
518+         elif  dask  ==  'allowed' :
519+             pass 
520+         else :
521+             raise  ValueError ('unknown setting for dask array handling in ' 
522+                              'apply_ufunc: {}' .format (dask ))
518523    result_data  =  func (* input_data )
519524
520-     if  signature .n_outputs  >  1 :
525+     if  signature .num_outputs  >  1 :
521526        output  =  []
522527        for  dims , data  in  zip (output_dims , result_data ):
523528            output .append (Variable (dims , data ))
@@ -527,24 +532,83 @@ def apply_variable_ufunc(func, *args, **kwargs):
527532        return  Variable (dims , result_data )
528533
529534
535+ def  _apply_with_dask_atop (func , args , input_dims , output_dims , signature ,
536+                           output_dtypes , output_sizes = None ):
537+     import  dask .array  as  da 
538+ 
539+     if  signature .num_outputs  >  1 :
540+         raise  NotImplementedError ('multiple outputs from apply_ufunc not yet ' 
541+                                   "supported with dask='parallelized'" )
542+ 
543+     if  output_dtypes  is  None :
544+         raise  ValueError ('output dtypes (output_dtypes) must be supplied to ' 
545+                          "apply_func when using dask='parallelized'" )
546+     if  not  isinstance (output_dtypes , list ):
547+         raise  TypeError ('output_dtypes must be a list of objects coercible to ' 
548+                         'numpy dtypes, got {}' .format (output_dtypes ))
549+     if  len (output_dtypes ) !=  signature .num_outputs :
550+         raise  ValueError ('apply_ufunc arguments output_dtypes and ' 
551+                          'output_core_dims must have the same length: {} vs {}' 
552+                          .format (len (output_dtypes ), signature .num_outputs ))
553+     (dtype ,) =  output_dtypes 
554+ 
555+     if  output_sizes  is  None :
556+         output_sizes  =  {}
557+ 
558+     new_dims  =  signature .all_output_core_dims  -  signature .all_input_core_dims 
559+     if  any (dim  not  in   output_sizes  for  dim  in  new_dims ):
560+         raise  ValueError ("when using dask='parallelized' with apply_ufunc, " 
561+                          'output core dimensions not found on inputs must have ' 
562+                          'explicitly set sizes with ``output_sizes``: {}' 
563+                          .format (new_dims ))
564+ 
565+     for  n , (data , core_dims ) in  enumerate (
566+             zip (args , signature .input_core_dims )):
567+         if  isinstance (data , dask_array_type ):
568+             # core dimensions cannot span multiple chunks 
569+             for  axis , dim  in  enumerate (core_dims , start = - len (core_dims )):
570+                 if  len (data .chunks [axis ]) !=  1 :
571+                     raise  ValueError (
572+                         'dimension {!r} on {}th function argument to ' 
573+                         "apply_ufunc with dask='parallelized' consists of " 
574+                         'multiple chunks, but is also a core dimension. To ' 
575+                         'fix, rechunk into a single dask array chunk along ' 
576+                         'this dimension, i.e., ``.rechunk({})``, but beware ' 
577+                         'that this may significantly increase memory usage.' 
578+                         .format (dim , n , {dim : - 1 }))
579+ 
580+     (out_ind ,) =  output_dims 
581+     # skip leading dimensions that we did not insert with broadcast_compat_data 
582+     atop_args  =  [element 
583+                  for  (arg , dims ) in  zip (args , input_dims )
584+                  for  element  in  (arg , dims [- getattr (arg , 'ndim' , 0 ):])]
585+     return  da .atop (func , out_ind , * atop_args , dtype = dtype , concatenate = True ,
586+                    new_axes = output_sizes )
587+ 
588+ 
530589def  apply_array_ufunc (func , * args , ** kwargs ):
531-     """apply_variable_ufunc (func, *args, dask_array ='forbidden') 
590+     """apply_array_ufunc (func, *args, dask ='forbidden') 
532591    """ 
533-     dask_array  =  kwargs .pop ('dask_array ' , 'forbidden' )
592+     dask  =  kwargs .pop ('dask ' , 'forbidden' )
534593    if  kwargs :
535594        raise  TypeError ('apply_array_ufunc() got unexpected keyword ' 
536595                        'arguments: %s'  %  list (kwargs ))
537596
538597    if  any (isinstance (arg , dask_array_type ) for  arg  in  args ):
539-         # TODO: add a mode dask_array='auto' when dask.array gets a function 
540-         # for applying arbitrary gufuncs 
541-         if  dask_array  ==  'forbidden' :
542-             raise  ValueError ('encountered dask array, but did not set ' 
543-                              "dask_array='allowed'" )
544-         elif  dask_array  !=  'allowed' :
545-             raise  ValueError ('unknown setting for dask array handling: %r' 
546-                              %  dask_array )
547-         # fall through 
598+         if  dask  ==  'forbidden' :
599+             raise  ValueError ('apply_ufunc encountered a dask array on an ' 
600+                              'argument, but handling for dask arrays has not ' 
601+                              'been enabled. Either set the ``dask`` argument ' 
602+                              'or load your data into memory first with ' 
603+                              '``.load()`` or ``.compute()``' )
604+         elif  dask  ==  'parallelized' :
605+             raise  ValueError ("cannot use dask='parallelized' for apply_ufunc " 
606+                              'unless at least one input is an xarray object' )
607+         elif  dask  ==  'allowed' :
608+             pass 
609+         else :
610+             raise  ValueError ('unknown setting for dask array handling: {}' 
611+                              .format (dask ))
548612    return  func (* args )
549613
550614
@@ -559,7 +623,9 @@ def apply_ufunc(func, *args, **kwargs):
559623                   dataset_fill_value : Any = _DEFAULT_FILL_VALUE, 
560624                   keep_attrs : bool = False, 
561625                   kwargs : Mapping = None, 
562-                    dask_array : str = 'forbidden') 
626+                    dask : str = 'forbidden', 
627+                    output_dtypes : Optional[Sequence] = None, 
628+                    output_sizes : Optional[Mapping[Any, int]] = None) 
563629
564630    Apply a vectorized function for unlabeled arrays on xarray objects. 
565631
@@ -581,8 +647,8 @@ def apply_ufunc(func, *args, **kwargs):
581647        Mix of labeled and/or unlabeled arrays to which to apply the function. 
582648    input_core_dims : Sequence[Sequence], optional 
583649        List of the same length as ``args`` giving the list of core dimensions 
584-         on each input argument that should be broadcast. By default, we assume  
585-         there are no core dimensions on any input arguments. 
650+         on each input argument that should not  be broadcast. By default, we 
651+         assume  there are no core dimensions on any input arguments. 
586652
587653        For example ,``input_core_dims=[[], ['time']]`` indicates that all 
588654        dimensions on the first argument and all dimensions other than 'time' 
@@ -630,10 +696,20 @@ def apply_ufunc(func, *args, **kwargs):
630696        Whether to copy attributes from the first argument to the output. 
631697    kwargs: dict, optional 
632698        Optional keyword arguments passed directly on to call ``func``. 
633-     dask_array: 'forbidden' or 'allowed', optional 
634-         Whether or not to allow applying the ufunc to objects containing lazy 
635-         data in the form of dask arrays. By default, this is forbidden, to 
636-         avoid implicitly converting lazy data. 
699+     dask: 'forbidden', 'allowed' or 'parallelized', optional 
700+         How to handle applying to objects containing lazy data in the form of 
701+         dask arrays: 
702+         - 'forbidden' (default): raise an error if a dask array is encountered. 
703+         - 'allowed': pass dask arrays directly on to ``func``. 
704+         - 'parallelized': automatically parallelize ``func`` if any of the 
705+           inputs are a dask array. If used, the ``output_dtypes`` argument must 
706+           also be provided. Multiple output arguments are not yet supported. 
707+     output_dtypes : list of dtypes, optional 
708+         Optional list of output dtypes. Only used if dask='parallelized'. 
709+     output_sizes : dict, optional 
710+         Optional mapping from dimension names to sizes for outputs. Only used if 
711+         dask='parallelized' and new dimensions (not found on inputs) appear on 
712+         outputs. 
637713
638714    Returns 
639715    ------- 
@@ -710,7 +786,9 @@ def stack(objects, dim, new_coord):
710786    exclude_dims  =  kwargs .pop ('exclude_dims' , frozenset ())
711787    dataset_fill_value  =  kwargs .pop ('dataset_fill_value' , _DEFAULT_FILL_VALUE )
712788    kwargs_  =  kwargs .pop ('kwargs' , None )
713-     dask_array  =  kwargs .pop ('dask_array' , 'forbidden' )
789+     dask  =  kwargs .pop ('dask' , 'forbidden' )
790+     output_dtypes  =  kwargs .pop ('output_dtypes' , None )
791+     output_sizes  =  kwargs .pop ('output_sizes' , None )
714792    if  kwargs :
715793        raise  TypeError ('apply_ufunc() got unexpected keyword arguments: %s' 
716794                        %  list (kwargs ))
@@ -727,12 +805,12 @@ def stack(objects, dim, new_coord):
727805    if  kwargs_ :
728806        func  =  functools .partial (func , ** kwargs_ )
729807
730-     array_ufunc  =  functools .partial (
731-         apply_array_ufunc , func , dask_array = dask_array )
732- 
733-     variables_ufunc  =  functools .partial (apply_variable_ufunc , array_ufunc ,
808+     variables_ufunc  =  functools .partial (apply_variable_ufunc , func ,
734809                                        signature = signature ,
735-                                         exclude_dims = exclude_dims )
810+                                         exclude_dims = exclude_dims ,
811+                                         dask = dask ,
812+                                         output_dtypes = output_dtypes ,
813+                                         output_sizes = output_sizes )
736814
737815    if  any (isinstance (a , GroupBy ) for  a  in  args ):
738816        # kwargs has already been added into func 
@@ -744,7 +822,7 @@ def stack(objects, dim, new_coord):
744822                                       dataset_join = dataset_join ,
745823                                       dataset_fill_value = dataset_fill_value ,
746824                                       keep_attrs = keep_attrs ,
747-                                        dask_array = dask_array )
825+                                        dask = dask )
748826        return  apply_groupby_ufunc (this_apply , * args )
749827    elif  any (is_dict_like (a ) for  a  in  args ):
750828        return  apply_dataset_ufunc (variables_ufunc , * args ,
@@ -763,7 +841,7 @@ def stack(objects, dim, new_coord):
763841    elif  any (isinstance (a , Variable ) for  a  in  args ):
764842        return  variables_ufunc (* args )
765843    else :
766-         return  array_ufunc ( * args )
844+         return  apply_array_ufunc ( func ,  * args ,  dask = dask )
767845
768846
769847def  where (cond , x , y ):
@@ -805,4 +883,4 @@ def where(cond, x, y):
805883                       cond , x , y ,
806884                       join = 'exact' ,
807885                       dataset_join = 'exact' ,
808-                        dask_array = 'allowed' )
886+                        dask = 'allowed' )
0 commit comments