From 78334ab61d2b3dcbd6337c325c69d75f10dc2a14 Mon Sep 17 00:00:00 2001 From: qnnn <65326092+qnnn@users.noreply.github.com> Date: Thu, 20 Mar 2025 17:40:33 +0800 Subject: [PATCH] Enhance TokenTextSplitter with flexible punctuation detection, bug fixes, and validation Signed-off-by: qnnn <65326092+qnnn@users.noreply.github.com> --- .../splitter/TokenTextSplitter.java | 58 ++++++++++---- .../splitter/TokenTextSplitterTest.java | 80 ++++++++++++++++++- 2 files changed, 120 insertions(+), 18 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index eddaef112bb..a37b598a456 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; @@ -33,9 +34,15 @@ * @author Raphael Yu * @author Christian Tzolov * @author Ricken Bazolo + * @author Nan Chiu */ public class TokenTextSplitter extends TextSplitter { + public static final Function DEFAULT_FIND_LAST_PUNCTUATION = chunkText -> + Math.max(chunkText.lastIndexOf('.'), + Math.max(chunkText.lastIndexOf('?'), + Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); + private static final int DEFAULT_CHUNK_SIZE = 800; private static final int MIN_CHUNK_SIZE_CHARS = 350; @@ -64,21 +71,31 @@ public class TokenTextSplitter extends TextSplitter { private final boolean keepSeparator; + // Finds the last period or punctuation mark in the given chunk + private final Function findLastPunctuation; + public TokenTextSplitter() { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, + KEEP_SEPARATOR, DEFAULT_FIND_LAST_PUNCTUATION); } public TokenTextSplitter(boolean keepSeparator) { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, + keepSeparator, DEFAULT_FIND_LAST_PUNCTUATION); } public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, - boolean keepSeparator) { + boolean keepSeparator, Function findLastPunctuation) { + Assert.isTrue(chunkSize > 0, "ChunkSize must be greater than 0"); + Assert.isTrue(minChunkSizeChars >= 0, "MinChunkSizeChars must be a positive value"); + Assert.isTrue(maxNumChunks > 0, "MaxNumChunks must be greater than 0"); + Assert.notNull(findLastPunctuation, "FindLastPunctuation must not be null"); this.chunkSize = chunkSize; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; this.keepSeparator = keepSeparator; + this.findLastPunctuation = findLastPunctuation; } public static Builder builder() { @@ -97,8 +114,7 @@ protected List doSplit(String text, int chunkSize) { List tokens = getEncodedTokens(text); List chunks = new ArrayList<>(); - int num_chunks = 0; - while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) { + while (!tokens.isEmpty() && chunks.size() < this.maxNumChunks) { List chunk = tokens.subList(0, Math.min(chunkSize, tokens.size())); String chunkText = decodeTokens(chunk); @@ -109,10 +125,9 @@ protected List doSplit(String text, int chunkSize) { } // Find the last period or punctuation mark in the chunk - int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'), - Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); + int lastPunctuation = findLastPunctuation.apply(chunkText); - if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) { + if (lastPunctuation > this.minChunkSizeChars && lastPunctuation < chunkText.length() - 1) { // Truncate the chunk text at the punctuation mark chunkText = chunkText.substring(0, lastPunctuation + 1); } @@ -120,13 +135,15 @@ protected List doSplit(String text, int chunkSize) { String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim() : chunkText.replace(System.lineSeparator(), " ").trim(); if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) { + // Reserve capacity for remaining tokens + if (chunks.size() == this.maxNumChunks - 1) { + break; + } chunks.add(chunkTextToAppend); } // Remove the tokens corresponding to the chunk text from the remaining tokens tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size()); - - num_chunks++; } // Handle the remaining tokens @@ -154,15 +171,17 @@ private String decodeTokens(List tokens) { public static final class Builder { - private int chunkSize; + private int chunkSize = DEFAULT_CHUNK_SIZE; + + private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS; - private int minChunkSizeChars; + private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED; - private int minChunkLengthToEmbed; + private int maxNumChunks = MAX_NUM_CHUNKS; - private int maxNumChunks; + private boolean keepSeparator = KEEP_SEPARATOR; - private boolean keepSeparator; + private Function findLastPunctuation = DEFAULT_FIND_LAST_PUNCTUATION; private Builder() { } @@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) { return this; } + public Builder withFindLastPunctuation(Function findLastPunctuation) { + this.findLastPunctuation = findLastPunctuation; + return this; + } + public TokenTextSplitter build() { return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, - this.maxNumChunks, this.keepSeparator); + this.maxNumChunks, this.keepSeparator, this.findLastPunctuation); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index e803c8a4e40..699cc15259b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ /** * @author Ricken Bazolo + * @author Nan Chiu */ public class TokenTextSplitterTest { @@ -88,6 +89,7 @@ public void testTokenTextSplitterBuilderWithAllFields() { .withMinChunkLengthToEmbed(3) .withMaxNumChunks(50) .withKeepSeparator(true) + .withFindLastPunctuation(TokenTextSplitter.DEFAULT_FIND_LAST_PUNCTUATION) .build(); var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); @@ -112,4 +114,80 @@ public void testTokenTextSplitterBuilderWithAllFields() { assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); } + @Test + public void testTokenTextSplitterBuilderWithCustomFindLastPunctuationFunction() { + + var contentFormatter1 = DefaultContentFormatter.defaultConfig(); + var contentFormatter2 = DefaultContentFormatter.defaultConfig(); + + assertThat(contentFormatter1).isNotSameAs(contentFormatter2); + + var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.", + Map.of("key1", "value1", "key2", "value2")); + doc1.setContentFormatter(contentFormatter1); + + var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly " + + "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.", + Map.of("key2", "value22", "key3", "value3")); + doc2.setContentFormatter(contentFormatter2); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withMinChunkSizeChars(5) + .withFindLastPunctuation(text -> text.lastIndexOf(',')) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); + + assertThat(chunks.size()).isEqualTo(4); + + // Doc 1 + assertThat(chunks.get(0).getText()).isEqualTo("In the end,"); + assertThat(chunks.get(1).getText()).isEqualTo("writing arises when man realizes that memory is not enough."); + + // Doc 2 + assertThat(chunks.get(2).getText()).isEqualTo("The most oppressive thing about the labyrinth is that you are constantly being forced to choose. It isn’t the lack of an exit,"); + assertThat(chunks.get(3).getText()).isEqualTo("but the abundance of exits that is so disorienting."); + + // Verify that the same, merged metadata is copied to all chunks. + assertThat(chunks.get(0).getMetadata()).isEqualTo(chunks.get(1).getMetadata()); + assertThat(chunks.get(2).getMetadata()).isEqualTo(chunks.get(3).getMetadata()); + assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3"); + assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); + } + + @Test + public void testDoSplitRespectsMaxNumChunksAndAvoidsEarlyRemainingProcessing() { + var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough."); + var tokenTextSplitter1 = TokenTextSplitter.builder() + .withChunkSize(5) + .withMinChunkSizeChars(1) + .withMaxNumChunks(2) + .build(); + List chunks1 = tokenTextSplitter1.apply(List.of(doc1)); + + // equal to MaxNumChunks + assertThat(chunks1.size()).isEqualTo(2); + + + var doc2 = new Document("d d d d d d d d. " + + "The most oppressive thing about the labyrinth is that you are constantly being forced to choose. " + + "d d d d d d d d. " + + "It isn’t the lack of an exit, but the abundance of exits that is so disorienting."); + var tokenTextSplitter2 = TokenTextSplitter.builder() + .withChunkSize(20) + .withMinChunkSizeChars(1) + .withMinChunkLengthToEmbed(20) + .withMaxNumChunks(2) + .build(); + List chunks2 = tokenTextSplitter2.apply(List.of(doc2)); + + assertThat(chunks2.size()).isEqualTo(2); + + // Doc 2 + assertThat(chunks2.get(0).getText()).isEqualTo( + "The most oppressive thing about the labyrinth is that you are constantly being forced to choose."); + assertThat(chunks2.get(1).getText()).isEqualTo( + "It isn’t the lack of an exit, but the abundance of exits that is so disorienting."); + } + }