Skip to content

[stdlib] Refactor and optimize StringSlice.replace() #3860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mojo/stdlib/stdlib/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1263,14 +1263,15 @@ struct String(
"""Return the number of non-overlapping occurrences of substring
`substr` in the string.

If sub is empty, returns the number of empty strings between characters
which is the length of the string plus one.

Args:
substr: The substring to count.
substr: The substring to count.

Returns:
The number of occurrences of `substr`.
The number of occurrences of `substr`.

Notes:
If sub is empty, returns the number of empty strings between
characters which is the length of the string plus one.
"""
return self.as_string_slice().count(substr)

Expand Down
96 changes: 53 additions & 43 deletions mojo/stdlib/stdlib/collections/string/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1099,66 +1099,75 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
"""
return rebind[Self.Immutable](self)

fn replace(self, old: StringSlice, new: StringSlice) -> String:
fn replace[
*, buffer_size: Int = 4096
](self, old: StringSlice, new: StringSlice) -> String:
"""Return a copy of the string with all occurrences of substring `old`
if replaced by `new`.

Parameters:
buffer_size: The max size of the stack buffer.

Args:
old: The substring to replace.
new: The substring to replace with.

Returns:
The string where all occurrences of `old` are replaced with `new`.
"""
if not old:
return self._interleave(new)

var occurrences = self.count(old)
if occurrences == -1:
return String(self)

var self_start = self.unsafe_ptr()
var self_ptr = self.unsafe_ptr()
var new_ptr = new.unsafe_ptr()

var self_len = self.byte_length()
var s_len = self.byte_length()
var old_len = old.byte_length()
var new_len = new.byte_length()

var res = String(capacity=self_len + (new_len - old_len) * occurrences)

for _ in range(occurrences):
var curr_offset = Int(self_ptr) - Int(self_start)

var idx = self.find(old, curr_offset)

debug_assert(idx >= 0, "expected to find occurrence during find")
if old_len == 0:
# NOTE: this is bigger than necessary if there are multi-byte
# sequences but is faster as it doesn't require counting cont. bytes
var capacity = s_len + new_len * self.byte_length()
var res = String(capacity=capacity)

# Copy preceding unchanged chars
for _ in range(curr_offset, idx):
res.append_byte(self_ptr[])
self_ptr += 1
@parameter
fn _replace_interleave[W: Writer](mut writer: W):
for s in self.codepoint_slices():
writer.write(new, s)

# Insert a copy of the new replacement string
for i in range(new_len):
res.append_byte(new_ptr[i])
if res._capacity_or_data.is_inline():
_replace_interleave(res)
else:
var buffer = _WriteBufferStack[buffer_size](res)
_replace_interleave(buffer)
buffer.flush()
return res^

self_ptr += old_len
# FIXME(#3792): this should use self.as_bytes().count(old) which will be
# faster because returning unicode offsets has overhead
var occurrences = self.count(old)
if occurrences == 0:
return String(self)

while self_ptr < self.unsafe_ptr() + self_len:
res.append_byte(self_ptr[])
self_ptr += 1
# NOTE: this is bigger than necessary if there are multi-byte
# sequences but is faster as it doesn't require counting cont. bytes
var res = String(capacity=s_len + (new_len - old_len) * occurrences)

return res^
@parameter
fn _replace_existing[W: Writer](mut writer: W):
var s_offset = 0
while s_offset < s_len:
# FIXME(#3548): this should use raw bytes self.as_bytes().find(...)
var idx = self.find(old, s_offset)
if idx == -1: # if not found copy remainder
writer.write(self[s_offset:])
break
# Copy preceding unchanged chars and insert new replacement string
writer.write(self[s_offset:idx], new)
s_offset = idx + old_len

fn _interleave(self, val: StringSlice) -> String:
var val_ptr = val.unsafe_ptr()
var self_ptr = self.unsafe_ptr()
var res = String(capacity=val.byte_length() * self.byte_length())
for i in range(self.byte_length()):
for j in range(val.byte_length()):
res.append_byte(val_ptr[j])
res.append_byte(self_ptr[i])
if res._capacity_or_data.is_inline():
_replace_existing(res)
else:
var buffer = _WriteBufferStack[buffer_size](res)
_replace_existing(buffer)
buffer.flush()
return res^

fn split(
Expand Down Expand Up @@ -2002,14 +2011,15 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
"""Return the number of non-overlapping occurrences of substring
`substr` in the string.

If sub is empty, returns the number of empty strings between characters
which is the length of the string plus one.

Args:
substr: The substring to count.

Returns:
The number of occurrences of `substr`.

Notes:
If sub is empty, returns the number of empty strings between
characters which is the length of the string plus one.
"""
if not substr:
return len(self) + 1
Expand Down
7 changes: 7 additions & 0 deletions mojo/stdlib/test/collections/string/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,13 @@ def test_replace():
var s3 = String("a complex test case with some spaces")
assert_equal("a complex test case with some spaces", s3.replace(" ", " "))

# Test Unicode codepoints
var s4 = String("Mꙩjꙩ")
assert_equal("Mojo", s4.replace("ꙩ", "o"))
assert_equal("🔥🔥🔥🔥", s4.replace("Mꙩjꙩ", "🔥🔥🔥🔥"))
assert_equal("🔥M🔥o🔥j🔥o", StaticString("Mojo").replace("", "🔥"))
assert_equal("🔥M🔥ꙩ🔥j🔥ꙩ", s4.replace("", "🔥"))


def test_rfind():
# Basic usage.
Expand Down