Skip to content

Commit 5655490

Browse files
committed
Refactor and optimize String.replace()
Signed-off-by: martinvuyk <[email protected]>
1 parent 2ace785 commit 5655490

File tree

2 files changed

+46
-53
lines changed

2 files changed

+46
-53
lines changed

Diff for: stdlib/src/collections/string.mojo

+44-52
Original file line numberDiff line numberDiff line change
@@ -1670,14 +1670,15 @@ struct String(
16701670
"""Return the number of non-overlapping occurrences of substring
16711671
`substr` in the string.
16721672
1673-
If sub is empty, returns the number of empty strings between characters
1674-
which is the length of the string plus one.
1675-
16761673
Args:
1677-
substr: The substring to count.
1674+
substr: The substring to count.
16781675
16791676
Returns:
1680-
The number of occurrences of `substr`.
1677+
The number of occurrences of `substr`.
1678+
1679+
Notes:
1680+
If sub is empty, returns the number of empty strings between characters
1681+
which is the length of the string plus one.
16811682
"""
16821683
if not substr:
16831684
return len(self) + 1
@@ -1901,51 +1902,54 @@ struct String(
19011902
Returns:
19021903
The string where all occurrences of `old` are replaced with `new`.
19031904
"""
1904-
if not old:
1905-
return self._interleave(new)
1906-
1907-
var occurrences = self.count(old)
1908-
if occurrences == -1:
1909-
return self
1910-
1911-
var self_start = self.unsafe_ptr()
1912-
var self_ptr = self.unsafe_ptr()
1905+
var s_ptr = self.unsafe_ptr()
19131906
var new_ptr = new.unsafe_ptr()
19141907

1915-
var self_len = self.byte_length()
1908+
var s_len = self.byte_length()
19161909
var old_len = old.byte_length()
19171910
var new_len = new.byte_length()
19181911

1919-
var res = Self._buffer_type()
1920-
res.reserve(self_len + (old_len - new_len) * occurrences + 1)
1921-
1922-
for _ in range(occurrences):
1923-
var curr_offset = int(self_ptr) - int(self_start)
1924-
1925-
var idx = self.find(old, curr_offset)
1912+
if old_len == 0:
1913+
var capacity = s_len + new_len * self.byte_length() + 1
1914+
var res_ptr = UnsafePointer[Byte].alloc(capacity)
1915+
var offset = 0
1916+
for s in self:
1917+
memcpy(res_ptr + offset, new_ptr, new_len)
1918+
offset += new_len
1919+
memcpy(res_ptr + offset, s.unsafe_ptr(), s.byte_length())
1920+
offset += s.byte_length()
1921+
res_ptr[capacity - 1] = 0
1922+
return String(ptr=res_ptr, length=capacity)
1923+
1924+
# FIXME(#3792): this should use self.as_bytes().count(old) which will be
1925+
# faster because returning unicode offsets has overhead and will return
1926+
# less bytes than necessary and cause a segfault
1927+
var occurrences = self.count(old)
1928+
if occurrences == 0:
1929+
return self
19261930

1927-
debug_assert(idx >= 0, "expected to find occurrence during find")
1931+
var capacity = s_len + (new_len - old_len) * occurrences + 1
1932+
var res_ptr = UnsafePointer[Byte].alloc(capacity)
1933+
var s_offset = 0
1934+
var res_offset = 0
19281935

1936+
while s_offset < s_len:
1937+
# FIXME(#3548): this should use raw bytes self.as_bytes().find(...)
1938+
var idx = self.find(old, s_offset)
1939+
if idx == -1:
1940+
memcpy(res_ptr + res_offset, s_ptr + s_offset, s_len - s_offset)
1941+
break
19291942
# Copy preceding unchanged chars
1930-
for _ in range(curr_offset, idx):
1931-
res.append(self_ptr[])
1932-
self_ptr += 1
1933-
1943+
var length = idx - s_offset
1944+
memcpy(res_ptr + res_offset, s_ptr + s_offset, length)
1945+
res_offset += length
1946+
s_offset += length + old_len
19341947
# Insert a copy of the new replacement string
1935-
for i in range(new_len):
1936-
res.append(new_ptr[i])
1948+
memcpy(res_ptr + res_offset, new_ptr, new_len)
1949+
res_offset += new_len
19371950

1938-
self_ptr += old_len
1939-
1940-
while True:
1941-
var val = self_ptr[]
1942-
if val == 0:
1943-
break
1944-
res.append(self_ptr[])
1945-
self_ptr += 1
1946-
1947-
res.append(0)
1948-
return String(res^)
1951+
res_ptr[capacity - 1] = 0
1952+
return String(ptr=res_ptr, length=capacity)
19491953

19501954
fn strip(self, chars: StringSlice) -> StringSlice[__origin_of(self)]:
19511955
"""Return a copy of the string with leading and trailing characters
@@ -2030,18 +2034,6 @@ struct String(
20302034
"""
20312035
hasher._update_with_bytes(self.unsafe_ptr(), self.byte_length())
20322036

2033-
fn _interleave(self, val: String) -> String:
2034-
var res = Self._buffer_type()
2035-
var val_ptr = val.unsafe_ptr()
2036-
var self_ptr = self.unsafe_ptr()
2037-
res.reserve(val.byte_length() * self.byte_length() + 1)
2038-
for i in range(self.byte_length()):
2039-
for j in range(val.byte_length()):
2040-
res.append(val_ptr[j])
2041-
res.append(self_ptr[i])
2042-
res.append(0)
2043-
return String(res^)
2044-
20452037
fn lower(self) -> String:
20462038
"""Returns a copy of the string with all cased characters
20472039
converted to lowercase.

Diff for: stdlib/test/python/my_module.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def __init__(self, bar):
2525

2626
class AbstractPerson(ABC):
2727
@abstractmethod
28-
def method(self): ...
28+
def method(self):
29+
...
2930

3031

3132
def my_function(name):

0 commit comments

Comments
 (0)