diff --git a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java index f85bf7a..b3d7c9c 100644 --- a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java +++ b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; @Value public class TextConversation { @@ -46,4 +47,8 @@ public List getFirstMessages(@NonNull Integer conversationSize) { .limit(conversationSize) .toList(); } + + public Optional getLastMessage() { + return messages.isEmpty() ? Optional.empty() : Optional.of(messages.getLast()); + } } diff --git a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java index a563f04..049c549 100644 --- a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java +++ b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java @@ -30,9 +30,9 @@ public CompletableFuture createPrediction( ) { return CompletableFuture .supplyAsync(doCreatePrediction(model, createPrediction), executorService) - .thenApplyAsync(checkAndEmitPredictionCreationTask, executorService) - .thenApplyAsync(pollPredictionStatusTask, executorService) - .whenCompleteAsync(emitPredictionResponseTask, executorService); + .thenApply(checkAndEmitPredictionCreationTask) + .thenApply(pollPredictionStatusTask) + .whenComplete(emitPredictionResponseTask); } public CompletableFuture createPrediction( @@ -41,9 +41,9 @@ public CompletableFuture createPrediction( ) { return CompletableFuture .supplyAsync(doCreatePrediction(model, createPrediction), executorService) - .thenApplyAsync(checkAndEmitPredictionCreationTask, executorService) - .thenApplyAsync(pollPredictionStatusTask, executorService) - .whenCompleteAsync(emitPredictionResponseTask, executorService); + .thenApply(checkAndEmitPredictionCreationTask) + .thenApply(pollPredictionStatusTask) + .whenComplete(emitPredictionResponseTask); } Supplier> doCreatePrediction( diff --git a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/Conversation.java b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/Conversation.java index 1409ba2..4085333 100644 --- a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/Conversation.java +++ b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/Conversation.java @@ -3,10 +3,8 @@ import io.graversen.replicate.common.ReplicateModel; import io.graversen.replicate.common.TextConversation; import io.graversen.replicate.common.TextMessage; -import lombok.AccessLevel; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; -import lombok.Value; +import io.graversen.replicate.common.TextPredictionRoles; +import lombok.*; import java.util.UUID; @@ -45,7 +43,20 @@ public Conversation update(@NonNull TextConversation conversation) { ); } + @ToString.Include + public ConversationStates getState() { + if (conversation.getMessages().isEmpty()) { + return ConversationStates.IDLE; + } + + final var isLastMessageFromUser = conversation.getLastMessage() + .map(message -> message.getRole().equals(TextPredictionRoles.USER.asString())) + .orElse(false); + + return isLastMessageFromUser ? ConversationStates.WAITING_FOR_ASSISTANT : ConversationStates.WAITING_FOR_USER; + } + private static String createId() { - return String.format("c_%s", UUID.randomUUID().toString()); + return String.format("c_%s", UUID.randomUUID()); } } diff --git a/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/ConversationStates.java b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/ConversationStates.java new file mode 100644 index 0000000..c1a0549 --- /dev/null +++ b/spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/ConversationStates.java @@ -0,0 +1,7 @@ +package io.graversen.replicate.service; + +public enum ConversationStates { + IDLE, + WAITING_FOR_ASSISTANT, + WAITING_FOR_USER +}