Skip to content

Enhance TokenTextSplitter with flexible punctuation detection, bug fixes, and validation #2526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
Expand All @@ -33,9 +34,15 @@
* @author Raphael Yu
* @author Christian Tzolov
* @author Ricken Bazolo
* @author Nan Chiu
*/
public class TokenTextSplitter extends TextSplitter {

public static final Function<String, Integer> DEFAULT_FIND_LAST_PUNCTUATION = chunkText ->
Math.max(chunkText.lastIndexOf('.'),
Math.max(chunkText.lastIndexOf('?'),
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));

private static final int DEFAULT_CHUNK_SIZE = 800;

private static final int MIN_CHUNK_SIZE_CHARS = 350;
Expand Down Expand Up @@ -64,21 +71,31 @@ public class TokenTextSplitter extends TextSplitter {

private final boolean keepSeparator;

// Finds the last period or punctuation mark in the given chunk
private final Function<String, Integer> findLastPunctuation;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS,
KEEP_SEPARATOR, DEFAULT_FIND_LAST_PUNCTUATION);
}

public TokenTextSplitter(boolean keepSeparator) {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS,
keepSeparator, DEFAULT_FIND_LAST_PUNCTUATION);
}

public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
boolean keepSeparator, Function<String, Integer> findLastPunctuation) {
Assert.isTrue(chunkSize > 0, "ChunkSize must be greater than 0");
Assert.isTrue(minChunkSizeChars >= 0, "MinChunkSizeChars must be a positive value");
Assert.isTrue(maxNumChunks > 0, "MaxNumChunks must be greater than 0");
Assert.notNull(findLastPunctuation, "FindLastPunctuation must not be null");
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
this.keepSeparator = keepSeparator;
this.findLastPunctuation = findLastPunctuation;
}

public static Builder builder() {
Expand All @@ -97,8 +114,7 @@ protected List<String> doSplit(String text, int chunkSize) {

List<Integer> tokens = getEncodedTokens(text);
List<String> chunks = new ArrayList<>();
int num_chunks = 0;
while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
while (!tokens.isEmpty() && chunks.size() < this.maxNumChunks) {
List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));
String chunkText = decodeTokens(chunk);

Expand All @@ -109,24 +125,25 @@ protected List<String> doSplit(String text, int chunkSize) {
}

// Find the last period or punctuation mark in the chunk
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
int lastPunctuation = findLastPunctuation.apply(chunkText);

if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
if (lastPunctuation > this.minChunkSizeChars && lastPunctuation < chunkText.length() - 1) {
// Truncate the chunk text at the punctuation mark
chunkText = chunkText.substring(0, lastPunctuation + 1);
}

String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim()
: chunkText.replace(System.lineSeparator(), " ").trim();
if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
// Reserve capacity for remaining tokens
if (chunks.size() == this.maxNumChunks - 1) {
break;
}
chunks.add(chunkTextToAppend);
}

// Remove the tokens corresponding to the chunk text from the remaining tokens
tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size());

num_chunks++;
}

// Handle the remaining tokens
Expand Down Expand Up @@ -154,15 +171,17 @@ private String decodeTokens(List<Integer> tokens) {

public static final class Builder {

private int chunkSize;
private int chunkSize = DEFAULT_CHUNK_SIZE;

private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS;

private int minChunkSizeChars;
private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED;

private int minChunkLengthToEmbed;
private int maxNumChunks = MAX_NUM_CHUNKS;

private int maxNumChunks;
private boolean keepSeparator = KEEP_SEPARATOR;

private boolean keepSeparator;
private Function<String, Integer> findLastPunctuation = DEFAULT_FIND_LAST_PUNCTUATION;

private Builder() {
}
Expand Down Expand Up @@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
return this;
}

public Builder withFindLastPunctuation(Function<String, Integer> findLastPunctuation) {
this.findLastPunctuation = findLastPunctuation;
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
this.maxNumChunks, this.keepSeparator, this.findLastPunctuation);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@

/**
* @author Ricken Bazolo
* @author Nan Chiu
*/
public class TokenTextSplitterTest {

Expand Down Expand Up @@ -88,6 +89,7 @@ public void testTokenTextSplitterBuilderWithAllFields() {
.withMinChunkLengthToEmbed(3)
.withMaxNumChunks(50)
.withKeepSeparator(true)
.withFindLastPunctuation(TokenTextSplitter.DEFAULT_FIND_LAST_PUNCTUATION)
.build();

var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));
Expand All @@ -112,4 +114,80 @@ public void testTokenTextSplitterBuilderWithAllFields() {
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

@Test
public void testTokenTextSplitterBuilderWithCustomFindLastPunctuationFunction() {

var contentFormatter1 = DefaultContentFormatter.defaultConfig();
var contentFormatter2 = DefaultContentFormatter.defaultConfig();

assertThat(contentFormatter1).isNotSameAs(contentFormatter2);

var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.",
Map.of("key1", "value1", "key2", "value2"));
doc1.setContentFormatter(contentFormatter1);

var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly "
+ "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.",
Map.of("key2", "value22", "key3", "value3"));
doc2.setContentFormatter(contentFormatter2);

var tokenTextSplitter = TokenTextSplitter.builder()
.withMinChunkSizeChars(5)
.withFindLastPunctuation(text -> text.lastIndexOf(','))
.build();

var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));

assertThat(chunks.size()).isEqualTo(4);

// Doc 1
assertThat(chunks.get(0).getText()).isEqualTo("In the end,");
assertThat(chunks.get(1).getText()).isEqualTo("writing arises when man realizes that memory is not enough.");

// Doc 2
assertThat(chunks.get(2).getText()).isEqualTo("The most oppressive thing about the labyrinth is that you are constantly being forced to choose. It isn’t the lack of an exit,");
assertThat(chunks.get(3).getText()).isEqualTo("but the abundance of exits that is so disorienting.");

// Verify that the same, merged metadata is copied to all chunks.
assertThat(chunks.get(0).getMetadata()).isEqualTo(chunks.get(1).getMetadata());
assertThat(chunks.get(2).getMetadata()).isEqualTo(chunks.get(3).getMetadata());
assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3");
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

@Test
public void testDoSplitRespectsMaxNumChunksAndAvoidsEarlyRemainingProcessing() {
var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.");
var tokenTextSplitter1 = TokenTextSplitter.builder()
.withChunkSize(5)
.withMinChunkSizeChars(1)
.withMaxNumChunks(2)
.build();
List<Document> chunks1 = tokenTextSplitter1.apply(List.of(doc1));

// equal to MaxNumChunks
assertThat(chunks1.size()).isEqualTo(2);


var doc2 = new Document("d d d d d d d d. " +
"The most oppressive thing about the labyrinth is that you are constantly being forced to choose. " +
"d d d d d d d d. " +
"It isn’t the lack of an exit, but the abundance of exits that is so disorienting.");
var tokenTextSplitter2 = TokenTextSplitter.builder()
.withChunkSize(20)
.withMinChunkSizeChars(1)
.withMinChunkLengthToEmbed(20)
.withMaxNumChunks(2)
.build();
List<Document> chunks2 = tokenTextSplitter2.apply(List.of(doc2));

assertThat(chunks2.size()).isEqualTo(2);

// Doc 2
assertThat(chunks2.get(0).getText()).isEqualTo(
"The most oppressive thing about the labyrinth is that you are constantly being forced to choose.");
assertThat(chunks2.get(1).getText()).isEqualTo(
"It isn’t the lack of an exit, but the abundance of exits that is so disorienting.");
}

}