From f870d57fc15a1005805db31ab5cd24f085506718 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Thu, 8 Feb 2024 14:49:28 +0100 Subject: [PATCH] Stop assuming input and output of invocations implement MessageLite (#214) --- .../restate/sdk/common/syscalls/Result.java | 24 +++++++++ .../restate/sdk/common/syscalls/Syscalls.java | 7 +-- .../java/dev/restate/sdk/core/Entries.java | 18 +++---- .../sdk/core/ExecutorSwitchingSyscalls.java | 9 ++-- .../restate/sdk/core/GrpcUnaryRpcHandler.java | 51 +++++++++++-------- .../dev/restate/sdk/core/RestateEndpoint.java | 5 +- .../dev/restate/sdk/core/SyscallsImpl.java | 13 ++--- .../restate/sdk/core/SyscallsInternal.java | 7 ++- 8 files changed, 75 insertions(+), 59 deletions(-) diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java index 9b6353d1a..a98fcb916 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java @@ -9,6 +9,7 @@ package dev.restate.sdk.common.syscalls; import dev.restate.sdk.common.TerminalException; +import java.util.function.Function; import javax.annotation.Nullable; /** @@ -47,6 +48,29 @@ private Result() {} @Nullable public abstract TerminalException getFailure(); + // --- Helper methods + + /** + * Map this result success value. If the mapper throws an exception, this exception will be + * converted to {@link TerminalException} and return a new failed {@link Result}. + */ + public Result mapSuccess(Function mapper) { + if (this.isSuccess()) { + try { + return Result.success(mapper.apply(this.getValue())); + } catch (TerminalException e) { + return Result.failure(e); + } catch (Exception e) { + return Result.failure( + new TerminalException(TerminalException.Code.UNKNOWN, e.getMessage())); + } + } + //noinspection unchecked + return (Result) this; + } + + // --- Factory methods + @SuppressWarnings("unchecked") public static Result empty() { return (Result) Empty.INSTANCE; diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java index 8e20f7e86..f90841a5e 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java @@ -9,7 +9,6 @@ package dev.restate.sdk.common.syscalls; import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; import dev.restate.sdk.common.InvocationId; import dev.restate.sdk.common.TerminalException; import io.grpc.Context; @@ -18,7 +17,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.Function; import javax.annotation.Nullable; /** @@ -53,10 +51,9 @@ static Syscalls current() { // Note: These are not supposed to be exposed to RestateContext, but they should be used through // gRPC APIs. - void pollInput( - Function mapper, SyscallCallback> callback); + void pollInput(SyscallCallback> callback); - void writeOutput(T value, SyscallCallback callback); + void writeOutput(ByteString value, SyscallCallback callback); void writeOutput(TerminalException exception, SyscallCallback callback); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java index e05d54661..6ea38ff8f 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java @@ -49,14 +49,12 @@ void updateUserStateStorageWithCompletion( E expected, CompletionMessage actual, UserStateStore userStateStore) {} } - static final class PollInputEntry - extends CompletableJournalEntry { + static final class PollInputEntry + extends CompletableJournalEntry { - private final Function> valueParser; + static final PollInputEntry INSTANCE = new PollInputEntry(); - PollInputEntry(Function> valueParser) { - this.valueParser = valueParser; - } + private PollInputEntry() {} @Override public void trace(PollInputStreamEntryMessage expected, Span span) { @@ -69,9 +67,9 @@ public boolean hasResult(PollInputStreamEntryMessage actual) { } @Override - public Result parseEntryResult(PollInputStreamEntryMessage actual) { + public Result parseEntryResult(PollInputStreamEntryMessage actual) { if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.VALUE) { - return valueParser.apply(actual.getValue()); + return Result.success(actual.getValue()); } else if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.FAILURE) { return Result.failure(Util.toRestateException(actual.getFailure())); } else { @@ -80,9 +78,9 @@ public Result parseEntryResult(PollInputStreamEntryMessage actual) { } @Override - public Result parseCompletionResult(CompletionMessage actual) { + public Result parseCompletionResult(CompletionMessage actual) { if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - return valueParser.apply(actual.getValue()); + return Result.success(actual.getValue()); } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { return Result.failure(Util.toRestateException(actual.getFailure())); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java index 1ddaabd2e..3e339b722 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java @@ -9,7 +9,6 @@ package dev.restate.sdk.core; import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; import dev.restate.sdk.common.InvocationId; import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.common.syscalls.Deferred; @@ -20,7 +19,6 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.Executor; -import java.util.function.Function; class ExecutorSwitchingSyscalls implements SyscallsInternal { @@ -33,13 +31,12 @@ class ExecutorSwitchingSyscalls implements SyscallsInternal { } @Override - public void pollInput( - Function mapper, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.pollInput(mapper, callback)); + public void pollInput(SyscallCallback> callback) { + syscallsExecutor.execute(() -> syscalls.pollInput(callback)); } @Override - public void writeOutput(T value, SyscallCallback callback) { + public void writeOutput(ByteString value, SyscallCallback callback) { syscallsExecutor.execute(() -> syscalls.writeOutput(value, callback)); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/GrpcUnaryRpcHandler.java b/sdk-core/src/main/java/dev/restate/sdk/core/GrpcUnaryRpcHandler.java index ff3c7853a..b0fa5572d 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/GrpcUnaryRpcHandler.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/GrpcUnaryRpcHandler.java @@ -8,36 +8,37 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import com.google.protobuf.MessageLite; +import com.google.protobuf.ByteString; import dev.restate.sdk.common.InvocationId; import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.common.syscalls.SyscallCallback; import dev.restate.sdk.common.syscalls.Syscalls; import io.grpc.*; +import java.io.IOException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import javax.annotation.Nullable; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -class GrpcUnaryRpcHandler implements RpcHandler { +class GrpcUnaryRpcHandler implements RpcHandler { private static final Logger LOG = LogManager.getLogger(GrpcUnaryRpcHandler.class); private final SyscallsInternal syscalls; - private final RestateServerCallListener restateListener; + private final RestateServerCallListener restateListener; private final CompletableFuture serverCallReady; - private final MethodDescriptor methodDescriptor; + private final MethodDescriptor methodDescriptor; GrpcUnaryRpcHandler( - ServerMethodDefinition method, + ServerMethodDefinition method, SyscallsInternal syscalls, @Nullable Executor userCodeExecutor) { this.syscalls = syscalls; this.methodDescriptor = method.getMethodDescriptor(); this.serverCallReady = new CompletableFuture<>(); - RestateServerCall serverCall = - new RestateServerCall(this.methodDescriptor, this.syscalls, this.serverCallReady); + RestateServerCall serverCall = + new RestateServerCall<>(method.getMethodDescriptor(), this.syscalls, this.serverCallReady); // This gRPC context will be propagated to the user thread. // Note: from now on we cannot modify this context anymore! @@ -47,13 +48,13 @@ class GrpcUnaryRpcHandler implements RpcHandler { .withValue(Syscalls.SYSCALLS_KEY, this.syscalls); // Create the listener - RestateServerCallListener listener = + RestateServerCallListener listener = new GrpcServerCallListenerAdaptor<>( context, serverCall, new Metadata(), method.getServerCallHandler()); // Wrap in the executor switcher, if needed if (userCodeExecutor != null) { - listener = new ExecutorSwitchingServerCallListener(listener, userCodeExecutor); + listener = new ExecutorSwitchingServerCallListener<>(listener, userCodeExecutor); } this.restateListener = listener; @@ -69,7 +70,7 @@ public void start() { SyscallCallback.of( pollInputReadyResult -> { if (pollInputReadyResult.isSuccess()) { - final MessageLite message = pollInputReadyResult.getValue(); + final Req message = pollInputReadyResult.getValue(); LOG.trace("Read input message:\n{}", message); // In theory, we never need this, as once we reach this point of the code the server @@ -198,20 +199,20 @@ private void closeWithException(Throwable e) { } } - private static class ExecutorSwitchingServerCallListener - implements RestateServerCallListener { + private static class ExecutorSwitchingServerCallListener + implements RestateServerCallListener { - private final RestateServerCallListener listener; + private final RestateServerCallListener listener; private final Executor userExecutor; private ExecutorSwitchingServerCallListener( - RestateServerCallListener listener, Executor userExecutor) { + RestateServerCallListener listener, Executor userExecutor) { this.listener = listener; this.userExecutor = userExecutor; } @Override - public void invoke(MessageLite message) { + public void invoke(Req message) { userExecutor.execute(() -> listener.invoke(message)); } @@ -254,9 +255,9 @@ public void ready() { *
  • Trampolining back to state machine executor is provided by the syscalls wrapper. * */ - static class RestateServerCall extends ServerCall { + static class RestateServerCall extends ServerCall { - private final MethodDescriptor methodDescriptor; + private final MethodDescriptor methodDescriptor; private final SyscallsInternal syscalls; // This variable don't need to be volatile because it's accessed only by #request() @@ -264,7 +265,7 @@ static class RestateServerCall extends ServerCall { private final CompletableFuture serverCallReady; RestateServerCall( - MethodDescriptor methodDescriptor, + MethodDescriptor methodDescriptor, SyscallsInternal syscalls, CompletableFuture serverCallReady) { this.methodDescriptor = methodDescriptor; @@ -308,9 +309,17 @@ public void sendHeaders(Metadata headers) { } @Override - public void sendMessage(MessageLite message) { + public void sendMessage(Res message) { + ByteString output; + try { + output = ByteString.readFrom(methodDescriptor.streamResponse(message)); + } catch (IOException e) { + syscalls.fail(e); + return; + } + syscalls.writeOutput( - message, + output, SyscallCallback.ofVoid( () -> LOG.trace("Wrote output message:\n{}", message), syscalls::fail)); } @@ -346,7 +355,7 @@ public boolean isCancelled() { } @Override - public MethodDescriptor getMethodDescriptor() { + public MethodDescriptor getMethodDescriptor() { return methodDescriptor; } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java b/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java index 83893b6be..289419413 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java @@ -8,7 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import com.google.protobuf.MessageLite; import dev.restate.generated.service.discovery.Discovery; import dev.restate.sdk.common.ServiceAdapter; import dev.restate.sdk.common.ServicesBundle; @@ -64,9 +63,7 @@ public InvocationHandler resolve( throw ProtocolException.methodNotFound(serviceName, methodName); } String fullyQualifiedServiceMethod = serviceName + "/" + methodName; - ServerMethodDefinition method = - (ServerMethodDefinition) - svc.getMethod(fullyQualifiedServiceMethod); + ServerMethodDefinition method = svc.getMethod(fullyQualifiedServiceMethod); if (method == null) { throw ProtocolException.methodNotFound(serviceName, methodName); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java index cd1aec0d0..bb18ad658 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java @@ -9,7 +9,6 @@ package dev.restate.sdk.core; import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; import com.google.rpc.Code; import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; @@ -48,27 +47,23 @@ public InvocationId invocationId() { } @Override - public void pollInput( - Function mapper, SyscallCallback> callback) { + public void pollInput(SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("pollInput"); this.stateMachine.processCompletableJournalEntry( - PollInputStreamEntryMessage.getDefaultInstance(), - new PollInputEntry<>(protoDeserializer(mapper)), - callback); + PollInputStreamEntryMessage.getDefaultInstance(), PollInputEntry.INSTANCE, callback); }, callback); } @Override - public void writeOutput(T value, SyscallCallback callback) { + public void writeOutput(ByteString value, SyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("writeOutput success"); this.writeOutput( - Protocol.OutputStreamEntryMessage.newBuilder().setValue(value.toByteString()).build(), - callback); + Protocol.OutputStreamEntryMessage.newBuilder().setValue(value).build(), callback); }, callback); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java index 05d004cee..f25115097 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java @@ -9,7 +9,6 @@ package dev.restate.sdk.core; import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; import dev.restate.sdk.common.syscalls.Deferred; import dev.restate.sdk.common.syscalls.Result; import dev.restate.sdk.common.syscalls.SyscallCallback; @@ -35,16 +34,16 @@ default Deferred createAllDeferred(List> children) { // -- Helper for pollInput - default void pollInputAndResolve( + default void pollInputAndResolve( Function mapper, SyscallCallback> callback) { this.pollInput( - mapper, SyscallCallback.of( deferredValue -> this.resolveDeferred( deferredValue, SyscallCallback.ofVoid( - () -> callback.onSuccess(deferredValue.toResult()), callback::onCancel)), + () -> callback.onSuccess(deferredValue.toResult().mapSuccess(mapper)), + callback::onCancel)), callback::onCancel)); }