Skip to content

Commit 49c3fa4

Browse files
authored
[mypyc] Optimize str.rsplit (#18673)
Use `PyUnicode_RSplit` to optimize `str.rsplit` calls. Although not present in the documentation, it's has actually part of the stable API since Python 3.2. https://github.com/python/cpython/blob/v3.13.2/Doc/data/stable_abi.dat#L799 https://github.com/python/cpython/blob/main/Include/unicodeobject.h#L841-L858
1 parent f404b16 commit 49c3fa4

File tree

6 files changed

+42
-2
lines changed

6 files changed

+42
-2
lines changed

mypyc/doc/str_operations.rst

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Methods
3636
* ``s.removesuffix(suffix: str)``
3737
* ``s.replace(old: str, new: str)``
3838
* ``s.replace(old: str, new: str, count: int)``
39+
* ``s.rsplit()``
40+
* ``s.rsplit(sep: str)``
41+
* ``s.rsplit(sep: str, maxsplit: int)``
3942
* ``s.split()``
4043
* ``s.split(sep: str)``
4144
* ``s.split(sep: str, maxsplit: int)``

mypyc/lib-rt/CPy.h

+1
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
721721
PyObject *CPyStr_Build(Py_ssize_t len, ...);
722722
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
723723
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
724+
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
724725
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
725726
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
726727
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);

mypyc/lib-rt/str_ops.c

+9
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) {
142142
return PyUnicode_Split(str, sep, temp_max_split);
143143
}
144144

145+
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) {
146+
Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split);
147+
if (temp_max_split == -1 && PyErr_Occurred()) {
148+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
149+
return NULL;
150+
}
151+
return PyUnicode_RSplit(str, sep, temp_max_split);
152+
}
153+
145154
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
146155
PyObject *new_substr, CPyTagged max_replace) {
147156
Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace);

mypyc/primitives/str_ops.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@
136136
error_kind=ERR_NEVER,
137137
)
138138

139-
# str.split(...)
139+
# str.split(...) and str.rsplit(...)
140140
str_split_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive]
141141
str_split_functions = ["PyUnicode_Split", "PyUnicode_Split", "CPyStr_Split"]
142+
str_rsplit_functions = ["PyUnicode_RSplit", "PyUnicode_RSplit", "CPyStr_RSplit"]
142143
str_split_constants: list[list[tuple[int, RType]]] = [
143144
[(0, pointer_rprimitive), (-1, c_int_rprimitive)],
144145
[(-1, c_int_rprimitive)],
@@ -153,6 +154,14 @@
153154
extra_int_constants=str_split_constants[i],
154155
error_kind=ERR_MAGIC,
155156
)
157+
method_op(
158+
name="rsplit",
159+
arg_types=str_split_types[0 : i + 1],
160+
return_type=list_rprimitive,
161+
c_function_name=str_rsplit_functions[i],
162+
extra_int_constants=str_split_constants[i],
163+
error_kind=ERR_MAGIC,
164+
)
156165

157166
# str.replace(old, new)
158167
method_op(

mypyc/test-data/fixtures/ir.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def __getitem__(self, i: int) -> str: pass
102102
def __getitem__(self, i: slice) -> str: pass
103103
def __contains__(self, item: str) -> bool: pass
104104
def __iter__(self) -> Iterator[str]: ...
105-
def split(self, sep: Optional[str] = None, max: Optional[int] = None) -> List[str]: pass
105+
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
106+
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
106107
def strip (self, item: str) -> str: pass
107108
def join(self, x: Iterable[str]) -> str: pass
108109
def format(self, *args: Any, **kwargs: Any) -> str: ...

mypyc/test-data/run-strings.test

+17
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None)
6161
return s.split(sep)
6262
return s.split()
6363

64+
def do_rsplit(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
65+
if sep is not None:
66+
if max_split is not None:
67+
return s.rsplit(sep, max_split)
68+
else:
69+
return s.rsplit(sep)
70+
return s.rsplit()
71+
6472
ss = "abc abcd abcde abcdef"
6573

6674
def test_split() -> None:
@@ -72,6 +80,15 @@ def test_split() -> None:
7280
assert do_split(ss, " ", 1) == ["abc", "abcd abcde abcdef"]
7381
assert do_split(ss, " ", 2) == ["abc", "abcd", "abcde abcdef"]
7482

83+
def test_rsplit() -> None:
84+
assert do_rsplit(ss) == ["abc", "abcd", "abcde", "abcdef"]
85+
assert do_rsplit(ss, " ") == ["abc", "abcd", "abcde", "abcdef"]
86+
assert do_rsplit(ss, "-") == ["abc abcd abcde abcdef"]
87+
assert do_rsplit(ss, " ", -1) == ["abc", "abcd", "abcde", "abcdef"]
88+
assert do_rsplit(ss, " ", 0) == ["abc abcd abcde abcdef"]
89+
assert do_rsplit(ss, " ", 1) == ["abc abcd abcde", "abcdef"] # different to do_split
90+
assert do_rsplit(ss, " ", 2) == ["abc abcd", "abcde", "abcdef"] # different to do_split
91+
7592
def getitem(s: str, index: int) -> str:
7693
return s[index]
7794

0 commit comments

Comments
 (0)