Skip to content

Commit a4e79ea

Browse files
authored
[mypyc] Add basic optimization for sorted (#18902)
Ref: mypyc/mypyc#1089
1 parent c3ed5e0 commit a4e79ea

File tree

7 files changed

+68
-0
lines changed

7 files changed

+68
-0
lines changed

mypyc/doc/native_operations.rst

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Functions
3636
* ``delattr(obj, name)``
3737
* ``slice(start, stop, step)``
3838
* ``globals()``
39+
* ``sorted(obj)``
3940

4041
Method decorators
4142
-----------------

mypyc/lib-rt/CPy.h

+1
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ int CPyList_Insert(PyObject *list, CPyTagged index, PyObject *value);
662662
PyObject *CPyList_Extend(PyObject *o1, PyObject *o2);
663663
int CPyList_Remove(PyObject *list, PyObject *obj);
664664
CPyTagged CPyList_Index(PyObject *list, PyObject *obj);
665+
PyObject *CPySequence_Sort(PyObject *seq);
665666
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size);
666667
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq);
667668
PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size);

mypyc/lib-rt/list_ops.c

+12
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,18 @@ CPyTagged CPyList_Index(PyObject *list, PyObject *obj) {
319319
return index << 1;
320320
}
321321

322+
PyObject *CPySequence_Sort(PyObject *seq) {
323+
PyObject *newlist = PySequence_List(seq);
324+
if (newlist == NULL)
325+
return NULL;
326+
int res = PyList_Sort(newlist);
327+
if (res < 0) {
328+
Py_DECREF(newlist);
329+
return NULL;
330+
}
331+
return newlist;
332+
}
333+
322334
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size) {
323335
Py_ssize_t size = CPyTagged_AsSsize_t(t_size);
324336
if (size == -1 && PyErr_Occurred()) {

mypyc/primitives/list_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727
# Get the 'builtins.list' type object.
2828
load_address_op(name="builtins.list", type=object_rprimitive, src="PyList_Type")
2929

30+
# sorted(obj)
31+
function_op(
32+
name="builtins.sorted",
33+
arg_types=[object_rprimitive],
34+
return_type=list_rprimitive,
35+
c_function_name="CPySequence_Sort",
36+
error_kind=ERR_MAGIC,
37+
)
38+
3039
# list(obj)
3140
to_list = function_op(
3241
name="builtins.list",

mypyc/test-data/fixtures/ir.py

+1
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -
384384
def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
385385
@overload
386386
def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ...
387+
def sorted(iterable: Iterable[_T]) -> list[_T]: ...
387388
def exit() -> None: ...
388389
def min(x: _T, y: _T) -> _T: ...
389390
def max(x: _T, y: _T) -> _T: ...

mypyc/test-data/irbuild-lists.test

+22
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,25 @@ L3:
561561
goto L1
562562
L4:
563563
return 1
564+
565+
[case testSorted]
566+
from typing import List, Any
567+
def list_sort(a: List[int]) -> None:
568+
a.sort()
569+
def sort_iterable(a: Any) -> None:
570+
sorted(a)
571+
[out]
572+
def list_sort(a):
573+
a :: list
574+
r0 :: i32
575+
r1 :: bit
576+
L0:
577+
r0 = PyList_Sort(a)
578+
r1 = r0 >= 0 :: signed
579+
return 1
580+
def sort_iterable(a):
581+
a :: object
582+
r0 :: list
583+
L0:
584+
r0 = CPySequence_Sort(a)
585+
return 1

mypyc/test-data/run-lists.test

+22
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,25 @@ def test_index_with_literal() -> None:
489489
assert d is d2
490490
d = a[-2].d
491491
assert d is d1
492+
493+
[case testSorted]
494+
from typing import List
495+
496+
def test_list_sort() -> None:
497+
l1 = [2, 1, 3]
498+
id_l1 = id(l1)
499+
l1.sort()
500+
assert l1 == [1, 2, 3]
501+
assert id_l1 == id(l1)
502+
503+
def test_sorted() -> None:
504+
res = [1, 2, 3]
505+
l1 = [2, 1, 3]
506+
id_l1 = id(l1)
507+
s_l1 = sorted(l1)
508+
assert s_l1 == res
509+
assert id_l1 != id(s_l1)
510+
assert l1 == [2, 1, 3]
511+
assert sorted((2, 1, 3)) == res
512+
assert sorted({2, 1, 3}) == res
513+
assert sorted({2: "", 1: "", 3: ""}) == res

0 commit comments

Comments
 (0)