Skip to content

ENH: Sorting of ExtensionArrays #19957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 22, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,27 @@ def isna(self):
"""
raise AbstractMethodError(self)

def argsort(self, axis=-1, kind='quicksort', order=None):
"""Returns the indices that would sort this array.

Parameters
----------
axis : int or None, optional
Axis along which to sort. ExtensionArrays are 1-dimensional,
so this is only included for compatibility with NumPy.
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
Sorting algorithm.
order : str or list of str, optional
Included for NumPy compatibility.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this compatibility needed because in the code we use np.argsort(values) which passes those keywords to the method?
(it is a bit unfortunate ..)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary, and I can remove it.

I see now that Categorical.argsort has a different signature. I suppose we should match that.


Returns
-------
index_array : ndarray
Array of indices that sort ``self``.

"""
return np.array(self).argsort(kind=kind)

# ------------------------------------------------------------------------
# Indexing methods
# ------------------------------------------------------------------------
Expand Down
40 changes: 40 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,43 @@ def test_count(self, data_missing):
def test_apply_simple_series(self, data):
result = pd.Series(data).apply(id)
assert isinstance(result, pd.Series)

def test_argsort(self, data_for_sorting):
result = pd.Series(data_for_sorting).argsort()
expected = pd.Series(np.array([2, 0, 1]))
self.assert_series_equal(result, expected)

def test_argsort_missing(self, data_missing_for_sorting):
result = pd.Series(data_missing_for_sorting).argsort()
expected = pd.Series(np.array([1, -1, 0]))
self.assert_series_equal(result, expected)

@pytest.mark.parametrize('ascending', [True, False])
def test_sort_values(self, data_for_sorting, ascending):
ser = pd.Series(data_for_sorting)
result = ser.sort_values(ascending=ascending)
expected = ser.iloc[[2, 0, 1]]
if not ascending:
expected = expected[::-1]

self.assert_series_equal(result, expected)

@pytest.mark.parametrize('ascending', [True, False])
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
ser = pd.Series(data_missing_for_sorting)
result = ser.sort_values(ascending=ascending)
if ascending:
expected = ser.iloc[[2, 0, 1]]
else:
expected = ser.iloc[[0, 2, 1]]
self.assert_series_equal(result, expected)

@pytest.mark.parametrize('ascending', [True, False])
def test_sort_values_frame(self, data_for_sorting, ascending):
df = pd.DataFrame({"A": [1, 2, 1],
"B": data_for_sorting})
result = df.sort_values(['A', 'B'])
expected = pd.DataFrame({"A": [1, 1, 2],
'B': data_for_sorting.take([2, 0, 1])},
index=[2, 0, 1])
self.assert_frame_equal(result, expected)
12 changes: 12 additions & 0 deletions pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def data_missing():
return Categorical([np.nan, 'A'])


@pytest.fixture
def data_for_sorting():
return Categorical(['A', 'B', 'C'], categories=['C', 'A', 'B'],
ordered=True)


@pytest.fixture
def data_missing_for_sorting():
return Categorical(['A', None, 'B'], categories=['B', 'A'],
ordered=True)


@pytest.fixture
def na_value():
return np.nan
Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/extension/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ def all_data(request, data, data_missing):
return data_missing


@pytest.fixture
def data_for_sorting():
"""Length-3 array with a known sort order.

This should be three items [B, C, A] with
A < B < C
"""
raise NotImplementedError


@pytest.fixture
def data_missing_for_sorting():
"""Length-3 array with a known sort order.

This should be three items [B, NA, A] with
A < B and NA missing.
"""
raise NotImplementedError


@pytest.fixture
def na_cmp():
"""Binary operator for comparing NA values.
Expand Down
43 changes: 35 additions & 8 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ def data_missing():
return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])


@pytest.fixture
def data_for_sorting():
return DecimalArray([decimal.Decimal('1'),
decimal.Decimal('2'),
decimal.Decimal('0')])


@pytest.fixture
def data_missing_for_sorting():
return DecimalArray([decimal.Decimal('1'),
decimal.Decimal('NaN'),
decimal.Decimal('0')])


@pytest.fixture
def na_cmp():
return lambda x, y: x.is_nan() and y.is_nan()
Expand All @@ -35,19 +49,32 @@ def na_value():
return decimal.Decimal("NaN")


class TestDtype(base.BaseDtypeTests):
class BaseDecimal(object):
@staticmethod
def assert_series_equal(left, right, *args, **kwargs):

left_na = left.isna()
right_na = right.isna()

tm.assert_series_equal(left_na, right_na)
return tm.assert_series_equal(left[~left_na],
right[~right_na],
*args, **kwargs)


class TestDtype(BaseDecimal, base.BaseDtypeTests):
pass


class TestInterface(base.BaseInterfaceTests):
class TestInterface(BaseDecimal, base.BaseInterfaceTests):
pass


class TestConstructors(base.BaseConstructorsTests):
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
pass


class TestReshaping(base.BaseReshapingTests):
class TestReshaping(BaseDecimal, base.BaseReshapingTests):

def test_align(self, data, na_value):
# Have to override since assert_series_equal doesn't
Expand Down Expand Up @@ -88,15 +115,15 @@ def test_align_frame(self, data, na_value):
assert e2.loc[0, 'A'].is_nan()


class TestGetitem(base.BaseGetitemTests):
class TestGetitem(BaseDecimal, base.BaseGetitemTests):
pass


class TestMissing(base.BaseMissingTests):
class TestMissing(BaseDecimal, base.BaseMissingTests):
pass


class TestMethods(base.BaseMethodsTests):
class TestMethods(BaseDecimal, base.BaseMethodsTests):
@pytest.mark.parametrize('dropna', [True, False])
@pytest.mark.xfail(reason="value_counts not implemented yet.")
def test_value_counts(self, all_data, dropna):
Expand All @@ -112,7 +139,7 @@ def test_value_counts(self, all_data, dropna):
tm.assert_series_equal(result, expected)


class TestCasting(base.BaseCastingTests):
class TestCasting(BaseDecimal, base.BaseCastingTests):
pass


Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def data_missing():
return JSONArray([{}, {'a': 10}])


@pytest.fixture
def data_for_sorting():
return JSONArray([{'b': 1}, {'c': 4}, {'a': 2, 'c': 3}])


@pytest.fixture
def data_missing_for_sorting():
return JSONArray([{'b': 1}, {}, {'c': 4}])


@pytest.fixture
def na_value():
return {}
Expand Down Expand Up @@ -68,6 +78,26 @@ class TestMethods(base.BaseMethodsTests):
def test_value_counts(self, all_data, dropna):
pass

@pytest.mark.skip(reason="Dictionaries are not orderable.")
def test_argsort(self):
pass

@pytest.mark.skip(reason="Dictionaries are not orderable.")
def test_argsort_missing(self):
pass

@pytest.mark.skip(reason="Dictionaries are not orderable.")
def test_sort_values(self):
pass

@pytest.mark.skip(reason="Dictionaries are not orderable.")
def test_sort_values_missing(self):
pass

@pytest.mark.skip(reason="Dictionaries are not orderable.")
def test_sort_values_frame(self):
pass


class TestCasting(base.BaseCastingTests):
pass