@@ -64,32 +64,13 @@ def all_core_dims(self):
64
64
return self ._all_core_dims
65
65
66
66
@property
67
- def n_inputs (self ):
67
+ def num_inputs (self ):
68
68
return len (self .input_core_dims )
69
69
70
70
@property
71
- def n_outputs (self ):
71
+ def num_outputs (self ):
72
72
return len (self .output_core_dims )
73
73
74
- @classmethod
75
- def default (cls , n_inputs ):
76
- return cls ([()] * n_inputs , [()])
77
-
78
- @classmethod
79
- def from_sequence (cls , nested ):
80
- if (not isinstance (nested , collections .Sequence ) or
81
- not len (nested ) == 2 or
82
- any (not isinstance (arg_list , collections .Sequence )
83
- for arg_list in nested ) or
84
- any (isinstance (arg , basestring ) or
85
- not isinstance (arg , collections .Sequence )
86
- for arg_list in nested for arg in arg_list )):
87
- raise TypeError ('functions signatures not provided as a string '
88
- 'must be a triply nested sequence providing the '
89
- 'list of core dimensions for each variable, for '
90
- 'both input and output.' )
91
- return cls (* nested )
92
-
93
74
def __eq__ (self , other ):
94
75
try :
95
76
return (self .input_core_dims == other .input_core_dims and
@@ -190,7 +171,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
190
171
data_vars = [getattr (a , 'variable' , a ) for a in args ]
191
172
result_var = func (* data_vars )
192
173
193
- if signature .n_outputs > 1 :
174
+ if signature .num_outputs > 1 :
194
175
out = tuple (DataArray (variable , coords , name = name , fastpath = True )
195
176
for variable , coords in zip (result_var , result_coords ))
196
177
else :
@@ -269,10 +250,10 @@ def _as_variables_or_variable(arg):
269
250
270
251
def _unpack_dict_tuples (
271
252
result_vars , # type: Mapping[Any, Tuple[Variable]]
272
- n_outputs , # type: int
253
+ num_outputs , # type: int
273
254
):
274
255
# type: (...) -> Tuple[Dict[Any, Variable]]
275
- out = tuple (OrderedDict () for _ in range (n_outputs ))
256
+ out = tuple (OrderedDict () for _ in range (num_outputs ))
276
257
for name , values in result_vars .items ():
277
258
for value , results_dict in zip (values , out ):
278
259
results_dict [name ] = value
@@ -298,8 +279,8 @@ def apply_dict_of_variables_ufunc(func, *args, **kwargs):
298
279
for name , variable_args in zip (names , grouped_by_name ):
299
280
result_vars [name ] = func (* variable_args )
300
281
301
- if signature .n_outputs > 1 :
302
- return _unpack_dict_tuples (result_vars , signature .n_outputs )
282
+ if signature .num_outputs > 1 :
283
+ return _unpack_dict_tuples (result_vars , signature .num_outputs )
303
284
else :
304
285
return result_vars
305
286
@@ -335,8 +316,8 @@ def apply_dataset_ufunc(func, *args, **kwargs):
335
316
336
317
if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and
337
318
fill_value is _DEFAULT_FILL_VALUE ):
338
- raise TypeError ('To apply an operation to datasets with different ' ,
339
- 'data variables, you must supply the ' ,
319
+ raise TypeError ('to apply an operation to datasets with different '
320
+ 'data variables with apply_ufunc , you must supply the '
340
321
'dataset_fill_value argument.' )
341
322
342
323
if kwargs :
@@ -353,7 +334,7 @@ def apply_dataset_ufunc(func, *args, **kwargs):
353
334
func , * args , signature = signature , join = dataset_join ,
354
335
fill_value = fill_value )
355
336
356
- if signature .n_outputs > 1 :
337
+ if signature .num_outputs > 1 :
357
338
out = tuple (_fast_dataset (* args )
358
339
for args in zip (result_vars , list_of_coords ))
359
340
else :
@@ -388,12 +369,12 @@ def apply_groupby_ufunc(func, *args):
388
369
from .variable import Variable
389
370
390
371
groupbys = [arg for arg in args if isinstance (arg , GroupBy )]
391
- if not groupbys :
392
- raise ValueError ('must have at least one groupby to iterate over' )
372
+ assert groupbys , 'must have at least one groupby to iterate over'
393
373
first_groupby = groupbys [0 ]
394
374
if any (not first_groupby ._group .equals (gb ._group ) for gb in groupbys [1 :]):
395
- raise ValueError ('can only perform operations over multiple groupbys '
396
- 'at once if they are all grouped the same way' )
375
+ raise ValueError ('apply_ufunc can only perform operations over '
376
+ 'multiple GroupBy objets at once if they are all '
377
+ 'grouped the same way' )
397
378
398
379
grouped_dim = first_groupby ._group .name
399
380
unique_values = first_groupby ._unique_coord .values
@@ -430,7 +411,7 @@ def unified_dim_sizes(variables, exclude_dims=frozenset()):
430
411
for var in variables :
431
412
if len (set (var .dims )) < len (var .dims ):
432
413
raise ValueError ('broadcasting cannot handle duplicate '
433
- 'dimensions: %r' % list (var .dims ))
414
+ 'dimensions on a variable : %r' % list (var .dims ))
434
415
for dim , size in zip (var .dims , var .shape ):
435
416
if dim not in exclude_dims :
436
417
if dim not in dim_sizes :
@@ -462,15 +443,17 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
462
443
set_old_dims = set (old_dims )
463
444
missing_core_dims = [d for d in core_dims if d not in set_old_dims ]
464
445
if missing_core_dims :
465
- raise ValueError ('operation requires dimensions missing on input '
466
- 'variable: %r' % missing_core_dims )
446
+ raise ValueError ('operand to apply_ufunc has required core dimensions '
447
+ '%r, but some of these are missing on the input '
448
+ 'variable: %r' % (list (core_dims ), missing_core_dims ))
467
449
468
450
set_new_dims = set (new_dims )
469
451
unexpected_dims = [d for d in old_dims if d not in set_new_dims ]
470
452
if unexpected_dims :
471
- raise ValueError ('operation encountered unexpected dimensions %r '
472
- 'on input variable: these are core dimensions on '
473
- 'other input or output variables' % unexpected_dims )
453
+ raise ValueError ('operand to apply_ufunc encountered unexpected '
454
+ 'dimensions %r on an input variable: these are core '
455
+ 'dimensions on other input or output variables'
456
+ % unexpected_dims )
474
457
475
458
# for consistency with numpy, keep broadcast dimensions to the left
476
459
old_broadcast_dims = tuple (d for d in broadcast_dims if d in set_old_dims )
@@ -500,6 +483,9 @@ def apply_variable_ufunc(func, *args, **kwargs):
500
483
501
484
signature = kwargs .pop ('signature' )
502
485
exclude_dims = kwargs .pop ('exclude_dims' , _DEFAULT_FROZEN_SET )
486
+ dask = kwargs .pop ('dask' , 'forbidden' )
487
+ output_dtypes = kwargs .pop ('output_dtypes' , None )
488
+ output_sizes = kwargs .pop ('output_sizes' , None )
503
489
if kwargs :
504
490
raise TypeError ('apply_variable_ufunc() got unexpected keyword '
505
491
'arguments: %s' % list (kwargs ))
@@ -515,9 +501,28 @@ def apply_variable_ufunc(func, *args, **kwargs):
515
501
else arg
516
502
for arg , core_dims in zip (args , signature .input_core_dims )]
517
503
504
+ if any (isinstance (array , dask_array_type ) for array in input_data ):
505
+ if dask == 'forbidden' :
506
+ raise ValueError ('apply_ufunc encountered a dask array on an '
507
+ 'argument, but handling for dask arrays has not '
508
+ 'been enabled. Either set the ``dask`` argument '
509
+ 'or load your data into memory first with '
510
+ '``.load()`` or ``.compute()``' )
511
+ elif dask == 'parallelized' :
512
+ input_dims = [broadcast_dims + input_dims
513
+ for input_dims in signature .input_core_dims ]
514
+ numpy_func = func
515
+ func = lambda * arrays : _apply_with_dask_atop (
516
+ numpy_func , arrays , input_dims , output_dims , signature ,
517
+ output_dtypes , output_sizes )
518
+ elif dask == 'allowed' :
519
+ pass
520
+ else :
521
+ raise ValueError ('unknown setting for dask array handling in '
522
+ 'apply_ufunc: {}' .format (dask ))
518
523
result_data = func (* input_data )
519
524
520
- if signature .n_outputs > 1 :
525
+ if signature .num_outputs > 1 :
521
526
output = []
522
527
for dims , data in zip (output_dims , result_data ):
523
528
output .append (Variable (dims , data ))
@@ -527,24 +532,83 @@ def apply_variable_ufunc(func, *args, **kwargs):
527
532
return Variable (dims , result_data )
528
533
529
534
535
+ def _apply_with_dask_atop (func , args , input_dims , output_dims , signature ,
536
+ output_dtypes , output_sizes = None ):
537
+ import dask .array as da
538
+
539
+ if signature .num_outputs > 1 :
540
+ raise NotImplementedError ('multiple outputs from apply_ufunc not yet '
541
+ "supported with dask='parallelized'" )
542
+
543
+ if output_dtypes is None :
544
+ raise ValueError ('output dtypes (output_dtypes) must be supplied to '
545
+ "apply_func when using dask='parallelized'" )
546
+ if not isinstance (output_dtypes , list ):
547
+ raise TypeError ('output_dtypes must be a list of objects coercible to '
548
+ 'numpy dtypes, got {}' .format (output_dtypes ))
549
+ if len (output_dtypes ) != signature .num_outputs :
550
+ raise ValueError ('apply_ufunc arguments output_dtypes and '
551
+ 'output_core_dims must have the same length: {} vs {}'
552
+ .format (len (output_dtypes ), signature .num_outputs ))
553
+ (dtype ,) = output_dtypes
554
+
555
+ if output_sizes is None :
556
+ output_sizes = {}
557
+
558
+ new_dims = signature .all_output_core_dims - signature .all_input_core_dims
559
+ if any (dim not in output_sizes for dim in new_dims ):
560
+ raise ValueError ("when using dask='parallelized' with apply_ufunc, "
561
+ 'output core dimensions not found on inputs must have '
562
+ 'explicitly set sizes with ``output_sizes``: {}'
563
+ .format (new_dims ))
564
+
565
+ for n , (data , core_dims ) in enumerate (
566
+ zip (args , signature .input_core_dims )):
567
+ if isinstance (data , dask_array_type ):
568
+ # core dimensions cannot span multiple chunks
569
+ for axis , dim in enumerate (core_dims , start = - len (core_dims )):
570
+ if len (data .chunks [axis ]) != 1 :
571
+ raise ValueError (
572
+ 'dimension {!r} on {}th function argument to '
573
+ "apply_ufunc with dask='parallelized' consists of "
574
+ 'multiple chunks, but is also a core dimension. To '
575
+ 'fix, rechunk into a single dask array chunk along '
576
+ 'this dimension, i.e., ``.rechunk({})``, but beware '
577
+ 'that this may significantly increase memory usage.'
578
+ .format (dim , n , {dim : - 1 }))
579
+
580
+ (out_ind ,) = output_dims
581
+ # skip leading dimensions that we did not insert with broadcast_compat_data
582
+ atop_args = [element
583
+ for (arg , dims ) in zip (args , input_dims )
584
+ for element in (arg , dims [- getattr (arg , 'ndim' , 0 ):])]
585
+ return da .atop (func , out_ind , * atop_args , dtype = dtype , concatenate = True ,
586
+ new_axes = output_sizes )
587
+
588
+
530
589
def apply_array_ufunc (func , * args , ** kwargs ):
531
- """apply_variable_ufunc (func, *args, dask_array ='forbidden')
590
+ """apply_array_ufunc (func, *args, dask ='forbidden')
532
591
"""
533
- dask_array = kwargs .pop ('dask_array ' , 'forbidden' )
592
+ dask = kwargs .pop ('dask ' , 'forbidden' )
534
593
if kwargs :
535
594
raise TypeError ('apply_array_ufunc() got unexpected keyword '
536
595
'arguments: %s' % list (kwargs ))
537
596
538
597
if any (isinstance (arg , dask_array_type ) for arg in args ):
539
- # TODO: add a mode dask_array='auto' when dask.array gets a function
540
- # for applying arbitrary gufuncs
541
- if dask_array == 'forbidden' :
542
- raise ValueError ('encountered dask array, but did not set '
543
- "dask_array='allowed'" )
544
- elif dask_array != 'allowed' :
545
- raise ValueError ('unknown setting for dask array handling: %r'
546
- % dask_array )
547
- # fall through
598
+ if dask == 'forbidden' :
599
+ raise ValueError ('apply_ufunc encountered a dask array on an '
600
+ 'argument, but handling for dask arrays has not '
601
+ 'been enabled. Either set the ``dask`` argument '
602
+ 'or load your data into memory first with '
603
+ '``.load()`` or ``.compute()``' )
604
+ elif dask == 'parallelized' :
605
+ raise ValueError ("cannot use dask='parallelized' for apply_ufunc "
606
+ 'unless at least one input is an xarray object' )
607
+ elif dask == 'allowed' :
608
+ pass
609
+ else :
610
+ raise ValueError ('unknown setting for dask array handling: {}'
611
+ .format (dask ))
548
612
return func (* args )
549
613
550
614
@@ -559,7 +623,9 @@ def apply_ufunc(func, *args, **kwargs):
559
623
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
560
624
keep_attrs : bool = False,
561
625
kwargs : Mapping = None,
562
- dask_array : str = 'forbidden')
626
+ dask : str = 'forbidden',
627
+ output_dtypes : Optional[Sequence] = None,
628
+ output_sizes : Optional[Mapping[Any, int]] = None)
563
629
564
630
Apply a vectorized function for unlabeled arrays on xarray objects.
565
631
@@ -581,8 +647,8 @@ def apply_ufunc(func, *args, **kwargs):
581
647
Mix of labeled and/or unlabeled arrays to which to apply the function.
582
648
input_core_dims : Sequence[Sequence], optional
583
649
List of the same length as ``args`` giving the list of core dimensions
584
- on each input argument that should be broadcast. By default, we assume
585
- there are no core dimensions on any input arguments.
650
+ on each input argument that should not be broadcast. By default, we
651
+ assume there are no core dimensions on any input arguments.
586
652
587
653
For example ,``input_core_dims=[[], ['time']]`` indicates that all
588
654
dimensions on the first argument and all dimensions other than 'time'
@@ -630,10 +696,20 @@ def apply_ufunc(func, *args, **kwargs):
630
696
Whether to copy attributes from the first argument to the output.
631
697
kwargs: dict, optional
632
698
Optional keyword arguments passed directly on to call ``func``.
633
- dask_array: 'forbidden' or 'allowed', optional
634
- Whether or not to allow applying the ufunc to objects containing lazy
635
- data in the form of dask arrays. By default, this is forbidden, to
636
- avoid implicitly converting lazy data.
699
+ dask: 'forbidden', 'allowed' or 'parallelized', optional
700
+ How to handle applying to objects containing lazy data in the form of
701
+ dask arrays:
702
+ - 'forbidden' (default): raise an error if a dask array is encountered.
703
+ - 'allowed': pass dask arrays directly on to ``func``.
704
+ - 'parallelized': automatically parallelize ``func`` if any of the
705
+ inputs are a dask array. If used, the ``output_dtypes`` argument must
706
+ also be provided. Multiple output arguments are not yet supported.
707
+ output_dtypes : list of dtypes, optional
708
+ Optional list of output dtypes. Only used if dask='parallelized'.
709
+ output_sizes : dict, optional
710
+ Optional mapping from dimension names to sizes for outputs. Only used if
711
+ dask='parallelized' and new dimensions (not found on inputs) appear on
712
+ outputs.
637
713
638
714
Returns
639
715
-------
@@ -710,7 +786,9 @@ def stack(objects, dim, new_coord):
710
786
exclude_dims = kwargs .pop ('exclude_dims' , frozenset ())
711
787
dataset_fill_value = kwargs .pop ('dataset_fill_value' , _DEFAULT_FILL_VALUE )
712
788
kwargs_ = kwargs .pop ('kwargs' , None )
713
- dask_array = kwargs .pop ('dask_array' , 'forbidden' )
789
+ dask = kwargs .pop ('dask' , 'forbidden' )
790
+ output_dtypes = kwargs .pop ('output_dtypes' , None )
791
+ output_sizes = kwargs .pop ('output_sizes' , None )
714
792
if kwargs :
715
793
raise TypeError ('apply_ufunc() got unexpected keyword arguments: %s'
716
794
% list (kwargs ))
@@ -727,12 +805,12 @@ def stack(objects, dim, new_coord):
727
805
if kwargs_ :
728
806
func = functools .partial (func , ** kwargs_ )
729
807
730
- array_ufunc = functools .partial (
731
- apply_array_ufunc , func , dask_array = dask_array )
732
-
733
- variables_ufunc = functools .partial (apply_variable_ufunc , array_ufunc ,
808
+ variables_ufunc = functools .partial (apply_variable_ufunc , func ,
734
809
signature = signature ,
735
- exclude_dims = exclude_dims )
810
+ exclude_dims = exclude_dims ,
811
+ dask = dask ,
812
+ output_dtypes = output_dtypes ,
813
+ output_sizes = output_sizes )
736
814
737
815
if any (isinstance (a , GroupBy ) for a in args ):
738
816
# kwargs has already been added into func
@@ -744,7 +822,7 @@ def stack(objects, dim, new_coord):
744
822
dataset_join = dataset_join ,
745
823
dataset_fill_value = dataset_fill_value ,
746
824
keep_attrs = keep_attrs ,
747
- dask_array = dask_array )
825
+ dask = dask )
748
826
return apply_groupby_ufunc (this_apply , * args )
749
827
elif any (is_dict_like (a ) for a in args ):
750
828
return apply_dataset_ufunc (variables_ufunc , * args ,
@@ -763,7 +841,7 @@ def stack(objects, dim, new_coord):
763
841
elif any (isinstance (a , Variable ) for a in args ):
764
842
return variables_ufunc (* args )
765
843
else :
766
- return array_ufunc ( * args )
844
+ return apply_array_ufunc ( func , * args , dask = dask )
767
845
768
846
769
847
def where (cond , x , y ):
@@ -805,4 +883,4 @@ def where(cond, x, y):
805
883
cond , x , y ,
806
884
join = 'exact' ,
807
885
dataset_join = 'exact' ,
808
- dask_array = 'allowed' )
886
+ dask = 'allowed' )
0 commit comments