Skip to content

Commit 3f5b6a3

Browse files
committed
⚡ Generate TextConversation from tokenized conversation
1 parent 7a49567 commit 3f5b6a3

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,38 @@
33
import lombok.NonNull;
44
import lombok.Value;
55

6+
import java.util.ArrayList;
67
import java.util.List;
78

89
@Value
910
public class TextConversation {
1011
@NonNull String systemMessage;
1112
@NonNull List<TextMessage> messages;
1213

14+
public static TextConversation of(@NonNull String systemMessage) {
15+
return new TextConversation(
16+
systemMessage,
17+
List.of()
18+
);
19+
}
20+
1321
public static TextConversation of(@NonNull String systemMessage, @NonNull String userMessage) {
1422
return new TextConversation(
1523
systemMessage,
1624
List.of(TextMessage.user(userMessage))
1725
);
1826
}
1927

28+
public TextConversation append(@NonNull TextMessage message) {
29+
final var mutableMessages = new ArrayList<>(getMessages());
30+
mutableMessages.add(message);
31+
32+
return new TextConversation(
33+
getSystemMessage(),
34+
List.copyOf(mutableMessages)
35+
);
36+
}
37+
2038
public List<TextMessage> getLastMessages(@NonNull Integer conversationSize) {
2139
return messages.stream()
2240
.skip(Math.max(0, messages.size() - conversationSize))

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/llama3/Llama3Tokenizer.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import lombok.NonNull;
77
import lombok.experimental.UtilityClass;
88

9+
import java.util.ArrayList;
910
import java.util.LinkedList;
1011
import java.util.Objects;
1112
import java.util.function.Consumer;
13+
import java.util.regex.Pattern;
1214

1315
@UtilityClass
1416
public class Llama3Tokenizer {
@@ -60,6 +62,41 @@ public static Llama3TextCompletion generateTextCompletion(@NonNull TextConversat
6062
return new Llama3TextCompletion(textCompletion);
6163
}
6264

65+
public static TextConversation parseTextCompletion(@NonNull String tokenizedConversation) {
66+
String systemMessage = null;
67+
final var messages = new ArrayList<TextMessage>();
68+
69+
final String[] tokens = tokenizedConversation.split(Pattern.quote(END_OF_TEXT_ID));
70+
71+
for (String token : tokens) {
72+
token = token.trim();
73+
74+
if (token.isEmpty()) {
75+
continue;
76+
}
77+
78+
if (token.startsWith(BEGIN_OF_TEXT)) {
79+
String systemToken = token.substring(BEGIN_OF_TEXT.length()).trim();
80+
if (systemToken.startsWith(START_HEADER_ID + ROLE_SYSTEM)) {
81+
systemMessage = systemToken.substring(systemToken.indexOf(END_HEADER_ID) + END_HEADER_ID.length()).trim();
82+
}
83+
} else if (token.startsWith(START_HEADER_ID)) {
84+
String role = token.substring(START_HEADER_ID.length(), token.indexOf(END_HEADER_ID)).trim();
85+
String messageContent = token.substring(token.indexOf(END_HEADER_ID) + END_HEADER_ID.length()).trim();
86+
87+
if (!role.isEmpty() && !messageContent.isEmpty()) {
88+
messages.add(new TextMessage(role, messageContent));
89+
}
90+
}
91+
}
92+
93+
if (systemMessage == null) {
94+
throw new IllegalArgumentException("System message not found in the tokenized conversation.");
95+
}
96+
97+
return new TextConversation(systemMessage, messages);
98+
}
99+
63100
public static Integer approximateConversationContextSize(@NonNull TextConversation conversation, @Nullable Integer tokenSize) {
64101
final var conversationTextCompletion = generateTextCompletion(conversation);
65102
return getTokens(conversationTextCompletion.getText(), tokenSize);

spring-boot-starter-replicate/src/test/java/io/graversen/replicate/llama3/Llama3TokenizerTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,38 @@ class Llama3TokenizerTest {
1919
new TextMessage("assistant", "Why don't scientists trust atoms? Because they make up everything!")
2020
);
2121

22+
@Test
23+
void parseTextCompletion_exampleConversation() {
24+
final var tokenizedConversation =
25+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" +
26+
"You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n" +
27+
"Hello. How are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" +
28+
"Hello! I'm doing great, thanks for asking! I'm here to help you with anything you need, so please feel free to ask me any questions or share what's on your mind. How about you? How's your day going so far?<|eot_id|><|start_header_id|>user<|end_header_id|>\n" +
29+
"It is going good thanks<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" +
30+
"That's wonderful to hear! I'm glad to know that your day is going well. If you don't mind me asking, what's been the highlight of your day so far? Is there anything exciting or interesting that's happened? I'm all ears and happy to listen!<|eot_id|><|start_header_id|>user<|end_header_id|>\n" +
31+
"Nah fam<|eot_id|>";
32+
33+
final var conversation = Llama3Tokenizer.parseTextCompletion(tokenizedConversation);
34+
35+
assertEquals("You are a helpful assistant.", conversation.getSystemMessage());
36+
assertEquals(5, conversation.getMessages().size());
37+
38+
assertEquals("user", conversation.getMessages().get(0).getRole());
39+
assertEquals("Hello. How are you?", conversation.getMessages().get(0).getText());
40+
41+
assertEquals("assistant", conversation.getMessages().get(1).getRole());
42+
assertEquals("Hello! I'm doing great, thanks for asking! I'm here to help you with anything you need, so please feel free to ask me any questions or share what's on your mind. How about you? How's your day going so far?", conversation.getMessages().get(1).getText());
43+
44+
assertEquals("user", conversation.getMessages().get(2).getRole());
45+
assertEquals("It is going good thanks", conversation.getMessages().get(2).getText());
46+
47+
assertEquals("assistant", conversation.getMessages().get(3).getRole());
48+
assertEquals("That's wonderful to hear! I'm glad to know that your day is going well. If you don't mind me asking, what's been the highlight of your day so far? Is there anything exciting or interesting that's happened? I'm all ears and happy to listen!", conversation.getMessages().get(3).getText());
49+
50+
assertEquals("user", conversation.getMessages().get(4).getRole());
51+
assertEquals("Nah fam", conversation.getMessages().get(4).getText());
52+
}
53+
2254
@Test
2355
public void fitToContextWindow_defaultWindow() {
2456
final var conversation = new TextConversation(systemMessage, messages);

0 commit comments

Comments
 (0)