Skip to content

Commit

Permalink
[mypyc] Optimize str.rsplit
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Feb 13, 2025
1 parent 1ec3f44 commit d12e5e2
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Methods
* ``s.join(x: Iterable)``
* ``s.replace(old: str, new: str)``
* ``s.replace(old: str, new: str, count: int)``
* ``s.rsplit()``
* ``s.rsplit(sep: str)``
* ``s.rsplit(sep: str, maxsplit: int)``
* ``s.split()``
* ``s.split(sep: str)``
* ``s.split(sep: str, maxsplit: int)``
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
Expand Down
9 changes: 9 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) {
return PyUnicode_Split(str, sep, temp_max_split);
}

PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) {
Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split);
if (temp_max_split == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
return NULL;
}
return PyUnicode_RSplit(str, sep, temp_max_split);
}

PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
PyObject *new_substr, CPyTagged max_replace) {
Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace);
Expand Down
9 changes: 9 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
# str.split(...)
str_split_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive]
str_split_functions = ["PyUnicode_Split", "PyUnicode_Split", "CPyStr_Split"]
str_rsplit_functions = ["PyUnicode_RSplit", "PyUnicode_RSplit", "CPyStr_RSplit"]
str_split_constants: list[list[tuple[int, RType]]] = [
[(0, pointer_rprimitive), (-1, c_int_rprimitive)],
[(-1, c_int_rprimitive)],
Expand All @@ -135,6 +136,14 @@
extra_int_constants=str_split_constants[i],
error_kind=ERR_MAGIC,
)
method_op(
name="rsplit",
arg_types=str_split_types[0 : i + 1],
return_type=list_rprimitive,
c_function_name=str_rsplit_functions[i],
extra_int_constants=str_split_constants[i],
error_kind=ERR_MAGIC,
)

# str.replace(old, new)
method_op(
Expand Down
3 changes: 2 additions & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __getitem__(self, i: int) -> str: pass
def __getitem__(self, i: slice) -> str: pass
def __contains__(self, item: str) -> bool: pass
def __iter__(self) -> Iterator[str]: ...
def split(self, sep: Optional[str] = None, max: Optional[int] = None) -> List[str]: pass
def split(self, sep: Optional[str] = None, maxsplit: int = ...) -> List[str]: pass
def rsplit(self, sep: Optional[str] = None, maxsplit: int = ...) -> List[str]: pass
def strip (self, item: str) -> str: pass
def join(self, x: Iterable[str]) -> str: pass
def format(self, *args: Any, **kwargs: Any) -> str: ...
Expand Down
17 changes: 17 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None)
return s.split(sep)
return s.split()

def do_rsplit(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]:
if sep is not None:
if max_split is not None:
return s.rsplit(sep, max_split)
else:
return s.rsplit(sep)
return s.rsplit()

ss = "abc abcd abcde abcdef"

def test_split() -> None:
Expand All @@ -66,6 +74,15 @@ def test_split() -> None:
assert do_split(ss, " ", 1) == ["abc", "abcd abcde abcdef"]
assert do_split(ss, " ", 2) == ["abc", "abcd", "abcde abcdef"]

def test_rsplit() -> None:
assert do_rsplit(ss) == ["abc", "abcd", "abcde", "abcdef"]
assert do_rsplit(ss, " ") == ["abc", "abcd", "abcde", "abcdef"]
assert do_rsplit(ss, "-") == ["abc abcd abcde abcdef"]
assert do_rsplit(ss, " ", -1) == ["abc", "abcd", "abcde", "abcdef"]
assert do_rsplit(ss, " ", 0) == ["abc abcd abcde abcdef"]
assert do_rsplit(ss, " ", 1) == ["abc abcd abcde", "abcdef"] # different to do_split
assert do_rsplit(ss, " ", 2) == ["abc abcd", "abcde", "abcdef"] # different to do_split

def getitem(s: str, index: int) -> str:
return s[index]

Expand Down

0 comments on commit d12e5e2

Please sign in to comment.