Skip to content

Commit 19e3fd4

Browse files
authored
1 parent 49e014a commit 19e3fd4

File tree

6 files changed

+90
-0
lines changed

6 files changed

+90
-0
lines changed

mypyc/doc/str_operations.rst

+6
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,18 @@ Methods
3333
* ``s.encode(encoding: str, errors: str)``
3434
* ``s1.endswith(s2: str)``
3535
* ``s1.endswith(t: tuple[str, ...])``
36+
* ``s1.find(s2: str)``
37+
* ``s1.find(s2: str, start: int)``
38+
* ``s1.find(s2: str, start: int, end: int)``
3639
* ``s.join(x: Iterable)``
3740
* ``s.partition(sep: str)``
3841
* ``s.removeprefix(prefix: str)``
3942
* ``s.removesuffix(suffix: str)``
4043
* ``s.replace(old: str, new: str)``
4144
* ``s.replace(old: str, new: str, count: int)``
45+
* ``s1.rfind(s2: str)``
46+
* ``s1.rfind(s2: str, start: int)``
47+
* ``s1.rfind(s2: str, start: int, end: int)``
4248
* ``s.rpartition(sep: str)``
4349
* ``s.rsplit()``
4450
* ``s.rsplit(sep: str)``

mypyc/lib-rt/CPy.h

+2
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
720720

721721
PyObject *CPyStr_Build(Py_ssize_t len, ...);
722722
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
723+
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
724+
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
723725
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
724726
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
725727
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);

mypyc/lib-rt/str_ops.c

+23
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,29 @@ PyObject *CPyStr_Build(Py_ssize_t len, ...) {
133133
return res;
134134
}
135135

136+
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction) {
137+
CPyTagged end = PyUnicode_GET_LENGTH(str) << 1;
138+
return CPyStr_FindWithEnd(str, substr, start, end, direction);
139+
}
140+
141+
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction) {
142+
Py_ssize_t temp_start = CPyTagged_AsSsize_t(start);
143+
if (temp_start == -1 && PyErr_Occurred()) {
144+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
145+
return CPY_INT_TAG;
146+
}
147+
Py_ssize_t temp_end = CPyTagged_AsSsize_t(end);
148+
if (temp_end == -1 && PyErr_Occurred()) {
149+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
150+
return CPY_INT_TAG;
151+
}
152+
Py_ssize_t index = PyUnicode_Find(str, substr, temp_start, temp_end, direction);
153+
if (unlikely(index == -2)) {
154+
return CPY_INT_TAG;
155+
}
156+
return index << 1;
157+
}
158+
136159
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split) {
137160
Py_ssize_t temp_max_split = CPyTagged_AsSsize_t(max_split);
138161
if (temp_max_split == -1 && PyErr_Occurred()) {

mypyc/primitives/str_ops.py

+23
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@
9595
ordering=[1, 0],
9696
)
9797

98+
# str.find(...) and str.rfind(...)
99+
str_find_types: list[RType] = [str_rprimitive, str_rprimitive, int_rprimitive, int_rprimitive]
100+
str_find_functions = ["CPyStr_Find", "CPyStr_Find", "CPyStr_FindWithEnd"]
101+
str_find_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []]
102+
str_rfind_constants: list[list[tuple[int, RType]]] = [[(0, c_int_rprimitive)], [], []]
103+
for i in range(len(str_find_types) - 1):
104+
method_op(
105+
name="find",
106+
arg_types=str_find_types[0 : i + 2],
107+
return_type=int_rprimitive,
108+
c_function_name=str_find_functions[i],
109+
extra_int_constants=str_find_constants[i] + [(1, c_int_rprimitive)],
110+
error_kind=ERR_MAGIC,
111+
)
112+
method_op(
113+
name="rfind",
114+
arg_types=str_find_types[0 : i + 2],
115+
return_type=int_rprimitive,
116+
c_function_name=str_find_functions[i],
117+
extra_int_constants=str_rfind_constants[i] + [(-1, c_int_rprimitive)],
118+
error_kind=ERR_MAGIC,
119+
)
120+
98121
# str.join(obj)
99122
method_op(
100123
name="join",

mypyc/test-data/fixtures/ir.py

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __getitem__(self, i: int) -> str: pass
102102
def __getitem__(self, i: slice) -> str: pass
103103
def __contains__(self, item: str) -> bool: pass
104104
def __iter__(self) -> Iterator[str]: ...
105+
def find(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ...
106+
def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None, /) -> int: ...
105107
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
106108
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
107109
def splitlines(self, keepends: bool = False) -> List[str]: ...

mypyc/test-data/run-strings.test

+34
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,20 @@ def contains(s: str, o: str) -> bool:
146146
def getitem(s: str, index: int) -> str:
147147
return s[index]
148148

149+
def find(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int:
150+
if start is not None:
151+
if end is not None:
152+
return s.find(substr, start, end)
153+
return s.find(substr, start)
154+
return s.find(substr)
155+
156+
def rfind(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int:
157+
if start is not None:
158+
if end is not None:
159+
return s.rfind(substr, start, end)
160+
return s.rfind(substr, start)
161+
return s.rfind(substr)
162+
149163
s = "abc"
150164

151165
def test_contains() -> None:
@@ -170,6 +184,26 @@ def test_getitem() -> None:
170184
with assertRaises(IndexError, "string index out of range"):
171185
getitem(s, -4)
172186

187+
def test_find() -> None:
188+
s = "abcab"
189+
assert find(s, "Hello") == -1
190+
assert find(s, "abc") == 0
191+
assert find(s, "b") == 1
192+
assert find(s, "b", 1) == 1
193+
assert find(s, "b", 1, 2) == 1
194+
assert find(s, "b", 3) == 4
195+
assert find(s, "b", 3, 5) == 4
196+
assert find(s, "b", 3, 4) == -1
197+
198+
assert rfind(s, "Hello") == -1
199+
assert rfind(s, "abc") == 0
200+
assert rfind(s, "b") == 4
201+
assert rfind(s, "b", 1) == 4
202+
assert rfind(s, "b", 1, 2) == 1
203+
assert rfind(s, "b", 3) == 4
204+
assert rfind(s, "b", 3, 5) == 4
205+
assert rfind(s, "b", 3, 4) == -1
206+
173207
def str_to_int(s: str, base: Optional[int] = None) -> int:
174208
if base:
175209
return int(s, base)

0 commit comments

Comments
 (0)