Skip to content

Commit 2f9ed12

Browse files
committed
✨ Stateful conversation processing
1 parent 3f5b6a3 commit 2f9ed12

File tree

6 files changed

+214
-0
lines changed

6 files changed

+214
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package io.graversen.replicate.facade;
2+
3+
import io.graversen.replicate.common.TextConversation;
4+
import io.graversen.replicate.common.TextMessage;
5+
import io.graversen.replicate.llama3.Llama3Tokenizer;
6+
import io.graversen.replicate.service.Conversation;
7+
import io.graversen.replicate.service.ConversationService;
8+
import io.graversen.replicate.service.CreateConversation;
9+
import io.graversen.replicate.service.CreateTextPrediction;
10+
import lombok.NonNull;
11+
import lombok.RequiredArgsConstructor;
12+
import org.springframework.stereotype.Component;
13+
14+
import java.util.concurrent.CompletableFuture;
15+
import java.util.concurrent.ConcurrentHashMap;
16+
import java.util.concurrent.ConcurrentMap;
17+
import java.util.function.Function;
18+
19+
@Component
20+
@RequiredArgsConstructor
21+
public class ConversationFacade {
22+
private final ConcurrentMap<String, ConversationOptions> conversationOptions = new ConcurrentHashMap<>();
23+
24+
private final @NonNull ReplicateFacade replicateFacade;
25+
private final @NonNull ConversationService conversationService;
26+
27+
public Conversation create(@NonNull CreateConversation createConversation, @NonNull ConversationOptions options) {
28+
final var conversation = conversationService.create(createConversation);
29+
conversationOptions.put(conversation.getId(), options);
30+
return conversation;
31+
}
32+
33+
public CompletableFuture<Conversation> chat(@NonNull String id, @NonNull TextMessage message) {
34+
var conversation = conversationService.appendMessage(id, message);
35+
36+
if (conversation.isEmpty()) {
37+
return CompletableFuture.failedFuture(new IllegalArgumentException("Could not find Conversation: " + id));
38+
}
39+
40+
final var textConversation = conversation.get().getConversation();
41+
final var model = conversation.get().getModel();
42+
final var options = conversationOptions.get(conversation.get().getId());
43+
44+
final var createTextPrediction = new CreateTextPrediction(
45+
textConversation,
46+
options.getTemperature(),
47+
null,
48+
null,
49+
null,
50+
null,
51+
null
52+
);
53+
54+
final var pendingPrediction = replicateFacade.createPrediction(model, createTextPrediction);
55+
return pendingPrediction
56+
.thenApply(parseTextConversation())
57+
.thenApply(updateConversation(conversation.get().getId()));
58+
}
59+
60+
Function<TextConversation, Conversation> updateConversation(@NonNull String id) {
61+
return conversation -> conversationService.update(id, conversation)
62+
.orElseThrow(() -> new IllegalArgumentException("Could not find Conversation: " + id));
63+
}
64+
65+
Function<PredictionResponseAndModel, TextConversation> parseTextConversation() {
66+
return predictionResponseAndModel -> {
67+
// For now, assume Llama3 is used and apply its tokenization parsing
68+
// In the future, rely on a strategy pattern implementation selector for multi-model-family support
69+
70+
final var textInput = predictionResponseAndModel.getPredictionResponse().getInputKey("prompt")
71+
.orElseThrow(() -> new IllegalStateException("No text input found in prediction"));
72+
73+
var conversation = Llama3Tokenizer.parseTextCompletion(textInput);
74+
75+
final var textOutput = predictionResponseAndModel.getPredictionResponse().getTextOutput()
76+
.orElseThrow(() -> new IllegalStateException("No text output found in prediction"));
77+
78+
return conversation.append(TextMessage.assistant(textOutput));
79+
};
80+
}
81+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package io.graversen.replicate.facade;
2+
3+
import lombok.Getter;
4+
import lombok.NonNull;
5+
import lombok.RequiredArgsConstructor;
6+
7+
@Getter
8+
@RequiredArgsConstructor
9+
public class ConversationOptions {
10+
private final @NonNull Double temperature;
11+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.common.ReplicateModel;
4+
import io.graversen.replicate.common.TextConversation;
5+
import io.graversen.replicate.common.TextMessage;
6+
import lombok.AccessLevel;
7+
import lombok.NonNull;
8+
import lombok.RequiredArgsConstructor;
9+
import lombok.Value;
10+
11+
import java.util.UUID;
12+
13+
@Value
14+
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
15+
public class Conversation {
16+
@NonNull String id;
17+
@NonNull String userId;
18+
@NonNull ReplicateModel model;
19+
@NonNull TextConversation conversation;
20+
21+
public static Conversation createDefault(@NonNull String systemMessage, @NonNull ReplicateModel model) {
22+
return new Conversation(
23+
createId(),
24+
"default",
25+
model,
26+
TextConversation.of(systemMessage)
27+
);
28+
}
29+
30+
public Conversation appendMessage(@NonNull TextMessage message) {
31+
return new Conversation(
32+
getId(),
33+
getUserId(),
34+
getModel(),
35+
getConversation().append(message)
36+
);
37+
}
38+
39+
public Conversation update(@NonNull TextConversation conversation) {
40+
return new Conversation(
41+
getId(),
42+
getUserId(),
43+
getModel(),
44+
conversation
45+
);
46+
}
47+
48+
private static String createId() {
49+
return String.format("c_%s", UUID.randomUUID().toString());
50+
}
51+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.common.TextConversation;
4+
import io.graversen.replicate.common.TextMessage;
5+
import lombok.NonNull;
6+
import lombok.RequiredArgsConstructor;
7+
import org.springframework.stereotype.Service;
8+
9+
import java.util.List;
10+
import java.util.Optional;
11+
import java.util.concurrent.ConcurrentHashMap;
12+
import java.util.concurrent.ConcurrentMap;
13+
14+
@Service
15+
@RequiredArgsConstructor
16+
public class ConversationService {
17+
private final ConcurrentMap<String, Conversation> conversations = new ConcurrentHashMap<>();
18+
19+
public Conversation create(@NonNull CreateConversation createConversation) {
20+
final var conversation = Conversation.createDefault(createConversation.getSystemMessage(), createConversation.getReplicateModel());
21+
conversations.put(conversation.getId(), conversation);
22+
return conversation;
23+
}
24+
25+
public Optional<Conversation> appendMessage(@NonNull String id, @NonNull TextMessage message) {
26+
final var conversationOrNull = conversations.computeIfPresent(id, (key, conversation) -> conversation.appendMessage(message));
27+
return Optional.ofNullable(conversationOrNull);
28+
}
29+
30+
public Optional<Conversation> update(@NonNull String id, @NonNull TextConversation conversation) {
31+
final var conversationOrNull = conversations.computeIfPresent(id, (key, value) -> value.update(conversation));
32+
return Optional.ofNullable(conversationOrNull);
33+
}
34+
35+
public Optional<Conversation> getById(@NonNull String id) {
36+
return Optional.ofNullable(conversations.get(id));
37+
}
38+
39+
public List<Conversation> getByUser(@NonNull String userId) {
40+
return conversations.values().stream()
41+
.filter(conversation -> conversation.getUserId().equals(userId))
42+
.toList();
43+
}
44+
45+
public List<Conversation> getAll() {
46+
return List.copyOf(conversations.values());
47+
}
48+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.common.ReplicateModel;
4+
import lombok.Getter;
5+
import lombok.NonNull;
6+
import lombok.RequiredArgsConstructor;
7+
8+
@Getter
9+
@RequiredArgsConstructor
10+
public class CreateConversation {
11+
private final @NonNull String systemMessage;
12+
private final @NonNull ReplicateModel replicateModel;
13+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/PredictionResponse.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.time.Duration;
99
import java.time.OffsetDateTime;
1010
import java.util.ArrayList;
11+
import java.util.LinkedHashMap;
1112
import java.util.List;
1213
import java.util.Optional;
1314
import java.util.function.Function;
@@ -73,6 +74,15 @@ public Optional<String> getTextOutput() {
7374
}
7475
}
7576

77+
public Optional<String> getInputKey(@NonNull String key) {
78+
try {
79+
final var input = (LinkedHashMap<String, Object>) getInput();
80+
return Optional.of(((String) input.get(key)));
81+
} catch (Exception e) {
82+
return Optional.empty();
83+
}
84+
}
85+
7686
private Function<List<String>, String> composeTextResponse() {
7787
return strings -> strings.stream().filter(string -> !string.isBlank()).collect(Collectors.joining());
7888
}

0 commit comments

Comments
 (0)