Skip to content

Commit f9464fd

Browse files
authored
ENH: three argument version of where (#1496)
* ENH: three argument version of where Fixes GH576 * Docstring fixup * Use join=exact for three argument where * Add where function * Don't require pytest 3 for raises_regex
1 parent 4ebc8ab commit f9464fd

17 files changed

+325
-139
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Top-level functions
1818
broadcast
1919
concat
2020
merge
21+
where
2122
set_options
2223
full_like
2324
zeros_like

doc/computation.rst

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ numpy) over all array values:
2525
2626
.. ipython:: python
2727
28-
arr = xr.DataArray(np.random.randn(2, 3),
28+
arr = xr.DataArray(np.random.RandomState(0).randn(2, 3),
2929
[('x', ['a', 'b']), ('y', [10, 20, 30])])
3030
arr - 3
3131
abs(arr)
@@ -39,6 +39,12 @@ __ http://docs.scipy.org/doc/numpy/reference/ufuncs.html
3939
4040
np.sin(arr)
4141
42+
Use :py:func:`~xarray.where` to conditionally switch between values:
43+
44+
.. ipython:: python
45+
46+
xr.where(arr > 0, 'positive', 'negative')
47+
4248
Data arrays also implement many :py:class:`numpy.ndarray` methods:
4349

4450
.. ipython:: python

doc/whats-new.rst

+24-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@ Enhancements
3232
- Speed-up (x 100) of :py:func:`~xarray.conventions.decode_cf_datetime`.
3333
By `Christian Chwala <https://github.com/cchwala>`_.
3434

35+
- New function :py:func:`~xarray.where` for conditionally switching between
36+
values in xarray objects, like :py:func:`numpy.where`:
37+
38+
.. ipython::
39+
:verbatim:
40+
41+
In [1]: import xarray as xr
42+
43+
In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=('x', 'y'))
44+
45+
In [3]: xr.where(arr % 2, 'even', 'odd')
46+
Out[3]:
47+
<xarray.DataArray (x: 2, y: 3)>
48+
array([['even', 'odd', 'even'],
49+
['odd', 'even', 'odd']],
50+
dtype='<U4')
51+
Dimensions without coordinates: x, y
52+
53+
Equivalently, the :py:meth:`~xarray.Dataset.where` method also now supports
54+
the ``other`` argument, for filling with a value other than ``NaN``
55+
(:issue:`576`).
56+
By `Stephan Hoyer <https://github.com/shoyer>`_.
57+
3558
Bug fixes
3659
~~~~~~~~~
3760

@@ -49,7 +72,7 @@ Bug fixes
4972

5073
- Fix :py:func:`xarray.testing.assert_allclose` to actually use ``atol`` and
5174
``rtol`` arguments when called on ``DataArray`` objects.
52-
By `Stephan Hoyer <http://github.com/shoyer>`_.
75+
By `Stephan Hoyer <https://github.com/shoyer>`_.
5376

5477
.. _whats-new.0.9.6:
5578

xarray/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .core.alignment import align, broadcast, broadcast_arrays
77
from .core.common import full_like, zeros_like, ones_like
88
from .core.combine import concat, auto_combine
9+
from .core.computation import where
910
from .core.extensions import (register_dataarray_accessor,
1011
register_dataset_accessor)
1112
from .core.variable import as_variable, Variable, IndexVariable, Coordinate

xarray/core/accessors.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
from __future__ import division
33
from __future__ import print_function
44

5-
from .common import is_datetime_like
5+
from .dtypes import is_datetime_like
66
from .pycompat import dask_array_type
77

8-
from functools import partial
9-
108
import numpy as np
119
import pandas as pd
1210

@@ -147,4 +145,4 @@ def f(self, dtype=dtype):
147145

148146
time = _tslib_field_accessor(
149147
"time", "Timestamps corresponding to datetimes", object
150-
)
148+
)

xarray/core/alignment.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import numpy as np
99

10-
from . import duck_array_ops, utils
11-
from .common import _maybe_promote
10+
from . import duck_array_ops
11+
from . import dtypes
12+
from . import utils
1213
from .indexing import get_indexer
1314
from .pycompat import iteritems, OrderedDict, suppress
1415
from .utils import is_full_slice, is_dict_like
@@ -368,7 +369,7 @@ def var_indexers(var, indexers):
368369
if any_not_full_slices(assign_to):
369370
# there are missing values to in-fill
370371
data = var[assign_from].data
371-
dtype, fill_value = _maybe_promote(var.dtype)
372+
dtype, fill_value = dtypes.maybe_promote(var.dtype)
372373

373374
if isinstance(data, np.ndarray):
374375
shape = tuple(new_sizes.get(dim, size)

xarray/core/common.py

+49-85
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import pandas as pd
66

77
from .pycompat import basestring, suppress, dask_array_type, OrderedDict
8+
from . import dtypes
89
from . import formatting
10+
from . import ops
911
from .utils import SortedKeysDict, not_implemented, Frozen
1012

1113

@@ -557,64 +559,74 @@ def resample(self, freq, dim, how='mean', skipna=None, closed=None,
557559
result = result.rename({RESAMPLE_DIM: dim.name})
558560
return result
559561

560-
def where(self, cond, other=None, drop=False):
561-
"""Return an object of the same shape with all entries where cond is
562-
True and all other entries masked.
562+
def where(self, cond, other=dtypes.NA, drop=False):
563+
"""Filter elements from this object according to a condition.
563564
564565
This operation follows the normal broadcasting and alignment rules that
565566
xarray uses for binary arithmetic.
566567
567568
Parameters
568569
----------
569-
cond : boolean DataArray or Dataset
570-
other : unimplemented, optional
571-
Unimplemented placeholder for compatibility with future
572-
numpy / pandas versions
570+
cond : DataArray or Dataset with boolean dtype
571+
Locations at which to preserve this object's values.
572+
other : scalar, DataArray or Dataset, optional
573+
Value to use for locations in this object where ``cond`` is False.
574+
By default, these locations filled with NA.
573575
drop : boolean, optional
574-
Coordinate labels that only correspond to NA values should be
575-
dropped
576+
If True, coordinate labels that only correspond to False values of
577+
the condition are dropped from the result. Mutually exclusive with
578+
``other``.
576579
577580
Returns
578581
-------
579-
same type as caller or if drop=True same type as caller with dimensions
580-
reduced for dim element where mask is True
582+
Same type as caller.
581583
582584
Examples
583585
--------
584586
585587
>>> import numpy as np
586588
>>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y'))
587-
>>> a.where((a > 6) & (a < 18))
589+
>>> a.where(a.x + a.y < 4)
588590
<xarray.DataArray (x: 5, y: 5)>
589-
array([[ nan, nan, nan, nan, nan],
590-
[ nan, nan, 7., 8., 9.],
591-
[ 10., 11., 12., 13., 14.],
592-
[ 15., 16., 17., nan, nan],
591+
array([[ 0., 1., 2., 3., nan],
592+
[ 5., 6., 7., nan, nan],
593+
[ 10., 11., nan, nan, nan],
594+
[ 15., nan, nan, nan, nan],
593595
[ nan, nan, nan, nan, nan]])
594-
Coordinates:
595-
* y (y) int64 0 1 2 3 4
596-
* x (x) int64 0 1 2 3 4
597-
>>> a.where((a > 6) & (a < 18), drop=True)
596+
Dimensions without coordinates: x, y
597+
>>> a.where(a.x + a.y < 5, -1)
598598
<xarray.DataArray (x: 5, y: 5)>
599-
array([[ nan, nan, 7., 8., 9.],
600-
[ 10., 11., 12., 13., 14.],
601-
[ 15., 16., 17., nan, nan],
602-
Coordinates:
603-
* x (x) int64 1 2 3
604-
* y (y) int64 0 1 2 3 4
599+
array([[ 0, 1, 2, 3, 4],
600+
[ 5, 6, 7, 8, -1],
601+
[10, 11, 12, -1, -1],
602+
[15, 16, -1, -1, -1],
603+
[20, -1, -1, -1, -1]])
604+
Dimensions without coordinates: x, y
605+
>>> a.where(a.x + a.y < 4, drop=True)
606+
<xarray.DataArray (x: 4, y: 4)>
607+
array([[ 0., 1., 2., 3.],
608+
[ 5., 6., 7., nan],
609+
[ 10., 11., nan, nan],
610+
[ 15., nan, nan, nan]])
611+
Dimensions without coordinates: x, y
612+
613+
See also
614+
--------
615+
numpy.where : corresponding numpy function
616+
where : equivalent function
605617
"""
606-
if other is not None:
607-
raise NotImplementedError("The optional argument 'other' has not "
608-
"yet been implemented")
618+
from .alignment import align
619+
from .dataarray import DataArray
620+
from .dataset import Dataset
609621

610622
if drop:
611-
from .dataarray import DataArray
612-
from .dataset import Dataset
613-
from .alignment import align
623+
if other is not dtypes.NA:
624+
raise ValueError('cannot set `other` if drop=True')
614625

615626
if not isinstance(cond, (Dataset, DataArray)):
616-
raise TypeError("Cond argument is %r but must be a %r or %r" %
627+
raise TypeError("cond argument is %r but must be a %r or %r" %
617628
(cond, Dataset, DataArray))
629+
618630
# align so we can use integer indexing
619631
self, cond = align(self, cond)
620632

@@ -627,16 +639,11 @@ def where(self, cond, other=None, drop=False):
627639
# clip the data corresponding to coordinate dims that are not used
628640
nonzeros = zip(clipcond.dims, np.nonzero(clipcond.values))
629641
indexers = {k: np.unique(v) for k, v in nonzeros}
630-
outobj = self.isel(**indexers)
631-
outcond = cond.isel(**indexers)
632-
else:
633-
outobj = self
634-
outcond = cond
635642

636-
# preserve attributes
637-
out = outobj._where(outcond)
638-
out._copy_attrs_from(self)
639-
return out
643+
self = self.isel(**indexers)
644+
cond = cond.isel(**indexers)
645+
646+
return ops.where_method(self, cond, other)
640647

641648
def close(self):
642649
"""Close any files linked to this object
@@ -658,42 +665,6 @@ def __exit__(self, exc_type, exc_value, traceback):
658665
__or__ = __div__ = __eq__ = __ne__ = not_implemented
659666

660667

661-
def _maybe_promote(dtype):
662-
"""Simpler equivalent of pandas.core.common._maybe_promote"""
663-
# N.B. these casting rules should match pandas
664-
if np.issubdtype(dtype, float):
665-
fill_value = np.nan
666-
elif np.issubdtype(dtype, int):
667-
# convert to floating point so NaN is valid
668-
dtype = float
669-
fill_value = np.nan
670-
elif np.issubdtype(dtype, complex):
671-
fill_value = np.nan + np.nan * 1j
672-
elif np.issubdtype(dtype, np.datetime64):
673-
fill_value = np.datetime64('NaT')
674-
elif np.issubdtype(dtype, np.timedelta64):
675-
fill_value = np.timedelta64('NaT')
676-
else:
677-
dtype = object
678-
fill_value = np.nan
679-
return np.dtype(dtype), fill_value
680-
681-
682-
def _possibly_convert_objects(values):
683-
"""Convert arrays of datetime.datetime and datetime.timedelta objects into
684-
datetime64 and timedelta64, according to the pandas convention.
685-
"""
686-
return np.asarray(pd.Series(values.ravel())).reshape(values.shape)
687-
688-
689-
def _get_fill_value(dtype):
690-
"""Return a fill value that appropriately promotes types when used with
691-
np.concatenate
692-
"""
693-
_, fill_value = _maybe_promote(dtype)
694-
return fill_value
695-
696-
697668
def full_like(other, fill_value, dtype=None):
698669
"""Return a new object with the same shape and type as a given object.
699670
@@ -761,10 +732,3 @@ def ones_like(other, dtype=None):
761732
"""Shorthand for full_like(other, 1, dtype)
762733
"""
763734
return full_like(other, 1, dtype)
764-
765-
766-
def is_datetime_like(dtype):
767-
"""Check if a dtype is a subclass of the numpy datetime types
768-
"""
769-
return (np.issubdtype(dtype, np.datetime64) or
770-
np.issubdtype(dtype, np.timedelta64))

0 commit comments

Comments
 (0)