From b2d57d95f56e7b021a1d9d68dbdbb48e9da9f881 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Thu, 20 Feb 2025 12:24:52 +0100 Subject: [PATCH] Service protocol v4 (#447) --- buf.lock | 2 - buf.yaml | 8 - build.gradle.kts | 4 +- .../library-publishing-conventions.gradle.kts | 61 +- .../kotlin/test-jar-conventions.gradle.kts | 9 - client-kotlin/build.gradle.kts | 11 + .../dev/restate/client/kotlin/ingress.kt | 99 ++ client/build.gradle.kts | 20 + .../main/java/dev/restate/client/Client.java | 354 +++++++ .../restate/client/ClientRequestOptions.java | 100 ++ .../dev/restate/client/ClientResponse.java | 26 + .../dev/restate}/client/IngressException.java | 42 +- .../java/dev/restate/client/SendResponse.java | 18 + .../dev/restate/client/base/BaseClient.java | 535 +++++++++++ .../dev/restate/client/jdk/JdkClient.java | 129 +++ .../build.gradle.kts | 9 +- .../main/java/dev/restate}/common/Output.java | 2 +- .../main/java/dev/restate/common/Request.java | 243 +++++ .../java/dev/restate/common/SendRequest.java | 81 ++ .../main/java/dev/restate/common/Slice.java | 84 ++ .../main/java/dev/restate}/common/Target.java | 2 +- .../common/function/ThrowingBiConsumer.java | 2 +- .../common/function/ThrowingBiFunction.java | 2 +- .../common/function/ThrowingConsumer.java | 2 +- .../common/function/ThrowingFunction.java | 6 +- .../common/function/ThrowingRunnable.java | 2 +- .../common/function/ThrowingSupplier.java | 2 +- .../main/java/dev/restate/serde}/Serde.java | 127 +-- .../java/dev/restate/serde/SerdeFactory.java | 41 + .../main/java/dev/restate/serde/TypeRef.java | 36 + .../main/java/dev/restate/serde/TypeTag.java | 32 + examples/README.md | 2 +- examples/build.gradle.kts | 5 +- .../java/my/restate/sdk/examples/Counter.java | 17 +- .../restate/sdk/examples/LambdaHandler.java | 4 +- .../my/restate/sdk/examples/LoanWorkflow.java | 30 +- .../my/restate/sdk/examples/CounterKt.kt | 20 +- examples/src/main/resources/log4j2.properties | 4 +- gradle/libs.versions.toml | 6 +- gradle/wrapper/gradle-wrapper.properties | 2 +- .../dev/restate/sdk/gen/model/Service.java | 51 +- .../template/HandlebarsTemplateEngine.java | 46 +- sdk-api-gen/build.gradle.kts | 14 - .../dev/restate/sdk/gen/ElementConverter.java | 49 +- .../sdk/gen/MetaRestateAnnotation.java | 2 +- .../dev/restate/sdk/gen/ServiceProcessor.java | 38 +- .../src/main/resources/templates/Client.hbs | 171 ++-- .../{Definitions.hbs => Metadata.hbs} | 5 +- .../src/main/resources/templates/Requests.hbs | 17 + .../templates/ServiceDefinitionFactory.hbs | 30 +- sdk-api-kotlin-gen/build.gradle.kts | 14 - .../sdk/kotlin/gen/KElementConverter.kt | 30 +- .../sdk/kotlin/gen/MetaRestateAnnotation.kt | 2 +- .../sdk/kotlin/gen/ServiceProcessor.kt | 30 +- .../src/main/resources/templates/Client.hbs | 105 +- .../main/resources/templates/Definitions.hbs | 14 - .../src/main/resources/templates/Metadata.hbs | 15 + .../src/main/resources/templates/Requests.hbs | 17 + .../templates/ServiceDefinitionFactory.hbs | 29 +- sdk-api-kotlin/build.gradle.kts | 16 +- .../dev/restate/sdk/kotlin/Awaitables.kt | 191 ---- .../dev/restate/sdk/kotlin/ContextImpl.kt | 400 ++++---- .../dev/restate/sdk/kotlin/HandlerRunner.kt | 82 +- .../kotlin/dev/restate/sdk/kotlin/KtSerdes.kt | 174 ---- .../kotlin/dev/restate/sdk/kotlin/Util.kt | 42 +- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 235 +++-- .../dev/restate/sdk/kotlin/awaitables.kt | 230 +++++ .../restate/sdk/kotlin/endpoint/endpoint.kt | 10 +- .../kotlin/dev/restate/sdk/kotlin/ingress.kt | 92 -- .../KotlinSerializationSerdeFactory.kt | 212 ++++ .../restate/sdk/kotlin/serialization/api.kt | 28 + sdk-api/build.gradle.kts | 14 +- .../java/dev/restate/sdk/AnyAwaitable.java | 33 - .../main/java/dev/restate/sdk/Awaitable.java | 244 ++--- .../main/java/dev/restate/sdk/Awakeable.java | 40 +- .../java/dev/restate/sdk/AwakeableHandle.java | 17 +- .../java/dev/restate/sdk/CallAwaitable.java | 52 + .../main/java/dev/restate/sdk/Context.java | 371 +++++-- .../java/dev/restate/sdk/ContextImpl.java | 329 ++++--- .../java/dev/restate/sdk/DurablePromise.java | 4 +- .../java/dev/restate/sdk/HandlerRunner.java | 111 ++- .../dev/restate/sdk/InvocationHandle.java | 29 + .../main/java/dev/restate/sdk/JsonSerdes.java | 159 --- .../java/dev/restate/sdk/ObjectContext.java | 3 +- .../java/dev/restate/sdk/PreviewContext.java | 48 +- .../java/dev/restate/sdk/RestateRandom.java | 19 +- .../src/main/java/dev/restate/sdk/Select.java | 85 ++ .../dev/restate/sdk/SharedObjectContext.java | 4 +- .../restate/sdk/SharedWorkflowContext.java | 2 +- .../src/main/java/dev/restate/sdk/Util.java | 93 +- .../java/dev/restate/sdk/TestSerdesTest.java | 80 -- sdk-common/build.gradle.kts | 2 + .../sdk/annotation/CustomSerdeFactory.java | 21 + .../dev/restate/sdk/annotation/Service.java | 4 +- .../restate/sdk/annotation/VirtualObject.java | 4 +- .../dev/restate/sdk/annotation/Workflow.java | 6 +- .../sdk/client/CallRequestOptions.java | 84 -- .../java/dev/restate/sdk/client/Client.java | 316 ------ .../dev/restate/sdk/client/DefaultClient.java | 474 --------- .../restate/sdk/client/RequestOptions.java | 64 -- .../dev/restate/sdk/client/SendResponse.java | 57 -- .../java/dev/restate/sdk/common/Request.java | 89 -- .../dev/restate/sdk/common/RichSerde.java | 62 -- .../restate/sdk/common/syscalls/Deferred.java | 29 - .../common/syscalls/HandlerDefinition.java | 56 -- .../common/syscalls/HandlerSpecification.java | 134 --- .../restate/sdk/common/syscalls/Result.java | 175 ---- .../sdk/common/syscalls/SyscallCallback.java | 74 -- .../restate/sdk/common/syscalls/Syscalls.java | 110 --- .../dev/restate/sdk/endpoint/Endpoint.java | 166 ++++ .../restate/sdk/endpoint/HeadersAccessor.java | 38 + .../RequestIdentityVerifier.java | 19 +- .../sdk/endpoint/definition/AsyncResult.java | 39 + .../endpoint/definition/HandlerContext.java | 109 +++ .../definition/HandlerDefinition.java | 130 +++ .../definition}/HandlerRunner.java | 23 +- .../definition}/HandlerType.java | 2 +- .../definition}/ServiceDefinition.java | 31 +- .../ServiceDefinitionFactories.java | 87 ++ .../definition}/ServiceDefinitionFactory.java | 8 +- .../definition}/ServiceType.java | 2 +- .../AbortedExecutionException.java | 2 +- .../{common => types}/DurablePromiseKey.java | 24 +- .../dev/restate/sdk/types/HandlerRequest.java | 26 + .../sdk/{common => types}/InvocationId.java | 2 +- .../sdk/{common => types}/RetryPolicy.java | 2 +- .../sdk/{common => types}/StateKey.java | 18 +- .../{common => types}/TerminalException.java | 2 +- .../restate/sdk/types/TimeoutException.java | 10 +- sdk-core/build.gradle.kts | 65 +- .../dev/restate/sdk/core/AckStateMachine.java | 56 -- .../dev/restate/sdk/core/AsyncResults.java | 353 +++++++ .../BaseSuspendableCallbackStateMachine.java | 62 -- .../dev/restate/sdk/core/CallbackHandle.java | 46 - .../dev/restate/sdk/core/DeferredResults.java | 255 ----- ...ceProtocol.java => DiscoveryProtocol.java} | 51 +- .../restate/sdk/core/EndpointManifest.java | 99 +- .../sdk/core/EndpointRequestHandler.java | 217 +++++ .../java/dev/restate/sdk/core/Entries.java | 728 -------------- .../sdk/core/ExceptionCatchingSubscriber.java | 50 - .../dev/restate/sdk/core/ExceptionUtils.java | 63 ++ .../ExecutorSwitchingHandlerContextImpl.java | 191 ++++ .../sdk/core/ExecutorSwitchingSyscalls.java | 198 ---- .../restate/sdk/core/HandlerContextImpl.java | 519 ++++++++++ .../sdk/core/HandlerContextInternal.java | 65 ++ .../sdk/core/IncomingEntriesStateMachine.java | 54 -- .../restate/sdk/core/InputPublisherState.java | 33 - .../dev/restate/sdk/core/InvocationFlow.java | 23 - .../sdk/core/InvocationStateMachine.java | 904 ------------------ .../dev/restate/sdk/core/MessageEncoder.java | 66 -- .../dev/restate/sdk/core/MessageHeader.java | 101 -- .../dev/restate/sdk/core/MessageType.java | 206 ---- .../restate/sdk/core/ProtocolException.java | 105 +- .../sdk/core/ReadyResultStateMachine.java | 96 -- ...ointHandler.java => RequestProcessor.java} | 9 +- .../sdk/core/RequestProcessorImpl.java | 160 ++++ .../sdk/core/ResolvedEndpointHandlerImpl.java | 156 --- .../sdk/core/RestateContextDataProvider.java | 21 +- .../dev/restate/sdk/core/RestateEndpoint.java | 304 ------ .../core/StaticResponseRequestProcessor.java | 67 ++ .../dev/restate/sdk/core/SyscallsImpl.java | 454 --------- .../restate/sdk/core/SyscallsInternal.java | 43 - .../java/dev/restate/sdk/core/Tracing.java | 28 - .../main/java/dev/restate/sdk/core/Util.java | 175 ---- .../core/statemachine/AsyncResultsState.java | 131 +++ .../sdk/core/statemachine/ClosedState.java | 31 + .../core/statemachine/CommandAccessor.java | 145 +++ .../EagerState.java} | 56 +- .../{ => statemachine}/InvocationIdImpl.java | 4 +- .../{ => statemachine}/InvocationInput.java | 2 +- .../{ => statemachine}/InvocationState.java | 2 +- .../sdk/core/statemachine/Journal.java | 71 ++ .../{ => statemachine}/MessageDecoder.java | 92 +- .../sdk/core/statemachine/MessageEncoder.java | 61 ++ .../sdk/core/statemachine/MessageHeader.java | 53 + .../sdk/core/statemachine/MessageType.java | 361 +++++++ .../sdk/core/statemachine/NotificationId.java | 18 + .../core/statemachine/NotificationValue.java | 28 + .../core/statemachine/ProcessingState.java | 380 ++++++++ .../sdk/core/statemachine/ReplayingState.java | 284 ++++++ .../sdk/core/statemachine/RunState.java | 48 + .../core/statemachine/ServiceProtocol.java | 64 ++ .../sdk/core/statemachine/StartInfo.java | 18 +- .../restate/sdk/core/statemachine/State.java | 159 +++ .../sdk/core/statemachine/StateContext.java | 94 ++ .../sdk/core/statemachine/StateHolder.java | 38 + .../sdk/core/statemachine/StateMachine.java | 151 +++ .../core/statemachine/StateMachineImpl.java | 653 +++++++++++++ .../restate/sdk/core/statemachine/Util.java | 158 +++ .../WaitingReplayEntriesState.java | 68 ++ .../core/statemachine/WaitingStartState.java | 68 ++ .../main/sdk-proto/dev/restate/sdk/java.proto | 24 - .../dev/restate/service/discovery.proto | 3 +- .../dev/restate/service/protocol.proto | 510 +++++++--- .../service-invocation-protocol.md | 6 +- .../dev/restate/sdk/core/AssertUtils.java | 61 +- .../sdk/core/AsyncResultTestSuite.java | 323 +++++++ .../sdk/core/AwakeableIdTestSuite.java | 31 +- .../dev/restate/sdk/core/CallTestSuite.java | 80 ++ .../core/ComponentDiscoveryHandlerTest.java | 19 +- .../restate/sdk/core/DeferredTestSuite.java | 411 -------- .../restate/sdk/core/EagerStateTestSuite.java | 98 +- .../sdk/core/InvocationIdTestSuite.java | 12 +- .../restate/sdk/core/MessageDecoderTest.java | 71 -- .../restate/sdk/core/MessageHeaderTest.java | 27 - .../dev/restate/sdk/core/MockBidiStream.java | 110 +++ .../restate/sdk/core/MockMultiThreaded.java | 91 -- .../restate/sdk/core/MockRequestResponse.java | 89 ++ .../restate/sdk/core/MockSingleThread.java | 82 -- .../sdk/core/OnlyInputAndOutputTestSuite.java | 6 +- .../restate/sdk/core/PromiseTestSuite.java | 117 ++- .../java/dev/restate/sdk/core/ProtoUtils.java | 332 ------- .../dev/restate/sdk/core/RandomTestSuite.java | 16 +- .../restate/sdk/core/SideEffectTestSuite.java | 181 ++-- .../dev/restate/sdk/core/SleepTestSuite.java | 75 +- .../core/StateMachineFailuresTestSuite.java | 46 +- .../dev/restate/sdk/core/StateTestSuite.java | 113 +-- .../dev/restate/sdk/core/TestDefinitions.java | 73 +- .../java/dev/restate/sdk/core/TestRunner.java | 29 +- .../java/dev/restate/sdk/core/TestSerdes.java | 15 +- .../sdk/core/UserFailuresTestSuite.java | 62 +- .../sdk/core/javaapi/AsyncResultTest.java | 115 ++- .../sdk/core/javaapi}/AwakeableIdTest.java | 11 +- .../restate/sdk/core/javaapi/CallTest.java | 48 + .../core/javaapi/CodegenDiscoveryTest.java | 34 +- .../sdk/core/javaapi}/CodegenTest.java | 155 +-- .../sdk/core/javaapi}/EagerStateTest.java | 49 +- .../javaapi}/GreeterWithExplicitName.java | 3 +- .../javaapi}/GreeterWithoutExplicitName.java | 3 +- .../sdk/core/javaapi}/InvocationIdTest.java | 9 +- .../sdk/core/javaapi/JavaAPITests.java | 59 +- .../sdk/core/javaapi/MySerdeFactory.java | 38 + .../sdk/core/javaapi}/NameInferenceTest.java | 8 +- .../core/javaapi}/OnlyInputAndOutputTest.java | 9 +- .../sdk/core/javaapi}/PromiseTest.java | 33 +- .../restate/sdk/core/javaapi}/RandomTest.java | 21 +- .../sdk/core/javaapi}/SideEffectTest.java | 48 +- .../restate/sdk/core/javaapi}/SleepTest.java | 10 +- .../javaapi}/StateMachineFailuresTest.java | 14 +- .../restate/sdk/core/javaapi}/StateTest.java | 23 +- .../sdk/core/javaapi}/UserFailuresTest.java | 10 +- .../sdk/core}/lambda/LambdaHandlerTest.java | 38 +- .../testservices/JavaCounterService.java | 6 +- .../testservices/MyServicesHandler.java | 8 +- .../core/statemachine/MessageDecoderTest.java | 59 ++ .../sdk/core/statemachine/ProtoUtils.java | 502 ++++++++++ .../sdk/core/kotlinapi/AsyncResultTest.kt | 63 +- .../sdk/core/kotlinapi}/AwakeableIdTest.kt | 5 +- .../restate/sdk/core/kotlinapi/CallTest.kt | 39 + .../core/kotlinapi/CodegenDiscoveryTest.kt | 25 +- .../sdk/core/kotlinapi}/CodegenTest.kt | 175 ++-- .../sdk/core/kotlinapi}/EagerStateTest.kt | 6 +- .../sdk/core/kotlinapi}/InvocationIdTest.kt | 4 +- .../sdk/core/kotlinapi/KotlinAPITests.kt | 66 +- .../kotlinapi}/MyMetaServiceAnnotation.kt | 2 +- .../core/kotlinapi}/OnlyInputAndOutputTest.kt | 4 +- .../sdk/core/kotlinapi}/PromiseTest.kt | 19 +- .../restate/sdk/core/kotlinapi}/RandomTest.kt | 10 +- .../sdk/core/kotlinapi}/SideEffectTest.kt | 58 +- .../restate/sdk/core/kotlinapi}/SleepTest.kt | 5 +- .../kotlinapi}/StateMachineFailuresTest.kt | 17 +- .../restate/sdk/core/kotlinapi}/StateTest.kt | 28 +- .../sdk/core/kotlinapi}/UserFailuresTest.kt | 7 +- .../sdk/core/vertx/RestateHttpServerTest.kt | 156 +++ .../vertx/RestateHttpServerTestExecutor.kt | 26 +- .../sdk/core/vertx/RestateHttpServerTests.kt | 19 +- .../core/vertx/ThreadTrampoliningTestSuite.kt | 84 +- sdk-http-vertx/build.gradle.kts | 22 - .../vertx/HttpEndpointRequestHandler.java | 105 ++ .../http/vertx/HttpRequestFlowAdapter.java | 12 +- .../http/vertx/HttpResponseFlowAdapter.java | 14 +- .../http/vertx/RequestHttpServerHandler.java | 190 ---- .../vertx/RestateHttpEndpointBuilder.java | 183 ---- .../sdk/http/vertx/RestateHttpServer.java | 153 +++ .../vertx/testservices/BlockingGreeter.java | 37 - .../sdk/http/vertx/RestateHttpEndpointTest.kt | 230 ----- .../vertx/testservices/GreeterKtComponent.kt | 43 - .../test/resources/junit-platform.properties | 3 - sdk-java-http/build.gradle.kts | 14 + sdk-java-lambda/build.gradle.kts | 14 + sdk-kotlin-http/build.gradle.kts | 13 + sdk-kotlin-lambda/build.gradle.kts | 13 + sdk-lambda/build.gradle.kts | 14 - .../sdk/lambda/BaseRestateLambdaHandler.java | 13 +- .../lambda/LambdaEndpointRequestHandler.java | 117 +++ .../sdk/lambda/LambdaFlowAdapters.java | 25 +- .../sdk/lambda/RestateLambdaEndpoint.java | 199 ---- .../lambda/RestateLambdaEndpointBuilder.java | 83 -- .../testservices/KotlinCounterService.kt | 38 - .../src/test/resources/log4j2.properties | 8 - .../RestateRequestIdentityVerifier.java | 7 +- .../serde/jackson/JacksonSerdeFactory.java | 93 ++ .../sdk/serde/jackson/JacksonSerdes.java | 77 +- .../sdk/serde/jackson/JacksonSerdesTest.java | 2 +- .../sdk/serde/protobuf/ProtobufSerdes.java | 55 -- .../build.gradle.kts | 1 + .../kotlin/RestateHttpEndpointBeanTest.kt | 4 +- .../kotlin/SdkTestingIntegrationTest.kt | 2 +- sdk-spring-boot-starter/build.gradle.kts | 7 +- .../java/RestateHttpEndpointBeanTest.java | 4 +- .../java/SdkTestingIntegrationTest.java | 2 +- sdk-spring-boot/build.gradle.kts | 8 +- .../RestateClientAutoConfiguration.java | 6 +- .../sdk/springboot/RestateComponent.java | 2 +- .../springboot/RestateEndpointProperties.java | 8 +- .../springboot/RestateHttpEndpointBean.java | 22 +- ....java => RestateHttpServerProperties.java} | 4 +- .../RestateClientAutoConfigurationTest.java | 2 +- sdk-testing/build.gradle.kts | 1 + .../sdk/testing/ManualRestateRunner.java | 213 ----- .../restate/sdk/testing/RestateClient.java | 2 +- .../restate/sdk/testing/RestateExtension.java | 59 +- .../restate/sdk/testing/RestateRunner.java | 306 ++++-- .../sdk/testing/RestateRunnerBuilder.java | 108 --- .../java/dev/restate/sdk/testing/Counter.java | 5 +- .../sdk/testing/CounterOldExtensionTest.java | 36 - .../dev/restate/sdk/testing/CounterTest.java | 2 +- settings.gradle.kts | 21 +- test-services/build.gradle.kts | 13 +- .../sdk/testservices/AwakeableHolderImpl.kt | 10 +- .../testservices/BlockAndWaitWorkflowImpl.kt | 16 +- .../sdk/testservices/CancelTestImpl.kt | 11 +- .../restate/sdk/testservices/CounterImpl.kt | 10 +- .../restate/sdk/testservices/FailingImpl.kt | 2 +- .../restate/sdk/testservices/KillTestImpl.kt | 2 +- .../sdk/testservices/ListObjectImpl.kt | 7 +- .../dev/restate/sdk/testservices/Main.kt | 69 +- .../restate/sdk/testservices/MapObjectImpl.kt | 9 +- .../sdk/testservices/NonDeterministicImpl.kt | 9 +- .../dev/restate/sdk/testservices/ProxyImpl.kt | 43 +- .../sdk/testservices/TestUtilsServiceImpl.kt | 7 +- .../restate/sdk/testservices/interpreter.kt | 52 +- 332 files changed, 13824 insertions(+), 12353 deletions(-) delete mode 100644 buf.lock delete mode 100644 buf.yaml delete mode 100644 buildSrc/src/main/kotlin/test-jar-conventions.gradle.kts create mode 100644 client-kotlin/build.gradle.kts create mode 100644 client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt create mode 100644 client/build.gradle.kts create mode 100644 client/src/main/java/dev/restate/client/Client.java create mode 100644 client/src/main/java/dev/restate/client/ClientRequestOptions.java create mode 100644 client/src/main/java/dev/restate/client/ClientResponse.java rename {sdk-common/src/main/java/dev/restate/sdk => client/src/main/java/dev/restate}/client/IngressException.java (56%) create mode 100644 client/src/main/java/dev/restate/client/SendResponse.java create mode 100644 client/src/main/java/dev/restate/client/base/BaseClient.java create mode 100644 client/src/main/java/dev/restate/client/jdk/JdkClient.java rename {sdk-serde-protobuf => common}/build.gradle.kts (52%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/Output.java (98%) create mode 100644 common/src/main/java/dev/restate/common/Request.java create mode 100644 common/src/main/java/dev/restate/common/SendRequest.java create mode 100644 common/src/main/java/dev/restate/common/Slice.java rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/Target.java (98%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingBiConsumer.java (96%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingBiFunction.java (92%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingConsumer.java (92%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingFunction.java (91%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingRunnable.java (92%) rename {sdk-common/src/main/java/dev/restate/sdk => common/src/main/java/dev/restate}/common/function/ThrowingSupplier.java (93%) rename {sdk-common/src/main/java/dev/restate/sdk/common => common/src/main/java/dev/restate/serde}/Serde.java (52%) create mode 100644 common/src/main/java/dev/restate/serde/SerdeFactory.java create mode 100644 common/src/main/java/dev/restate/serde/TypeRef.java create mode 100644 common/src/main/java/dev/restate/serde/TypeTag.java rename sdk-api-gen/src/main/resources/templates/{Definitions.hbs => Metadata.hbs} (50%) create mode 100644 sdk-api-gen/src/main/resources/templates/Requests.hbs delete mode 100644 sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs create mode 100644 sdk-api-kotlin-gen/src/main/resources/templates/Metadata.hbs create mode 100644 sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs delete mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt delete mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt rename sdk-common/src/main/java/dev/restate/sdk/common/syscalls/EnterSideEffectSyscallCallback.java => sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/endpoint/endpoint.kt (57%) delete mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/KotlinSerializationSerdeFactory.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/api.kt delete mode 100644 sdk-api/src/main/java/dev/restate/sdk/AnyAwaitable.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java delete mode 100644 sdk-api/src/main/java/dev/restate/sdk/JsonSerdes.java create mode 100644 sdk-api/src/main/java/dev/restate/sdk/Select.java delete mode 100644 sdk-api/src/test/java/dev/restate/sdk/TestSerdesTest.java create mode 100644 sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/client/CallRequestOptions.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/client/Client.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/client/RequestOptions.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/client/SendResponse.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/Request.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Deferred.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/SyscallCallback.java delete mode 100644 sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/Endpoint.java create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/HeadersAccessor.java rename sdk-common/src/main/java/dev/restate/sdk/{auth => endpoint}/RequestIdentityVerifier.java (66%) create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/AsyncResult.java create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java rename sdk-common/src/main/java/dev/restate/sdk/{common/syscalls => endpoint/definition}/HandlerRunner.java (54%) rename sdk-common/src/main/java/dev/restate/sdk/{common => endpoint/definition}/HandlerType.java (89%) rename sdk-common/src/main/java/dev/restate/sdk/{common/syscalls => endpoint/definition}/ServiceDefinition.java (67%) create mode 100644 sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java rename sdk-common/src/main/java/dev/restate/sdk/{common/syscalls => endpoint/definition}/ServiceDefinitionFactory.java (61%) rename sdk-common/src/main/java/dev/restate/sdk/{common => endpoint/definition}/ServiceType.java (89%) rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/AbortedExecutionException.java (96%) rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/DurablePromiseKey.java (60%) create mode 100644 sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/InvocationId.java (95%) rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/RetryPolicy.java (99%) rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/StateKey.java (67%) rename sdk-common/src/main/java/dev/restate/sdk/{common => types}/TerminalException.java (97%) rename sdk-core/src/main/java/dev/restate/sdk/core/SuspendableCallback.java => sdk-common/src/main/java/dev/restate/sdk/types/TimeoutException.java (67%) delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/BaseSuspendableCallbackStateMachine.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/CallbackHandle.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/DeferredResults.java rename sdk-core/src/main/java/dev/restate/sdk/core/{ServiceProtocol.java => DiscoveryProtocol.java} (71%) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/Entries.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ExceptionCatchingSubscriber.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/IncomingEntriesStateMachine.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/InputPublisherState.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/InvocationFlow.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/MessageEncoder.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ReadyResultStateMachine.java rename sdk-core/src/main/java/dev/restate/sdk/core/{ResolvedEndpointHandler.java => RequestProcessor.java} (63%) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/StaticResponseRequestProcessor.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/Tracing.java delete mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/Util.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java rename sdk-core/src/main/java/dev/restate/sdk/core/{UserStateStore.java => statemachine/EagerState.java} (50%) rename sdk-core/src/main/java/dev/restate/sdk/core/{ => statemachine}/InvocationIdImpl.java (95%) rename sdk-core/src/main/java/dev/restate/sdk/core/{ => statemachine}/InvocationInput.java (95%) rename sdk-core/src/main/java/dev/restate/sdk/core/{ => statemachine}/InvocationState.java (90%) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java rename sdk-core/src/main/java/dev/restate/sdk/core/{ => statemachine}/MessageDecoder.java (61%) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java rename sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java => sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java (52%) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java delete mode 100644 sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MessageDecoderTest.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MockMultiThreaded.java create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/MockSingleThread.java delete mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java rename sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java => sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java (51%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/AwakeableIdTest.java (70%) create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java rename sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java => sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java (61%) rename {sdk-api-gen/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/CodegenTest.java (67%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/EagerStateTest.java (64%) rename {sdk-api-gen/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/GreeterWithExplicitName.java (88%) rename {sdk-api-gen/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/GreeterWithoutExplicitName.java (88%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/InvocationIdTest.java (77%) rename sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java => sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java (59%) create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java rename {sdk-api-gen/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/NameInferenceTest.java (70%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/OnlyInputAndOutputTest.java (76%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/PromiseTest.java (76%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/RandomTest.java (62%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/SideEffectTest.java (68%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/SleepTest.java (85%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/StateMachineFailuresTest.java (82%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/StateTest.java (73%) rename {sdk-api/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core/javaapi}/UserFailuresTest.java (90%) rename {sdk-lambda/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core}/lambda/LambdaHandlerTest.java (81%) rename {sdk-lambda/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core}/lambda/testservices/JavaCounterService.java (88%) rename {sdk-lambda/src/test/java/dev/restate/sdk => sdk-core/src/test/java/dev/restate/sdk/core}/lambda/testservices/MyServicesHandler.java (65%) create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java create mode 100644 sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java rename sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt (62%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/AwakeableIdTest.kt (81%) create mode 100644 sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt rename sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt (70%) rename {sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/CodegenTest.kt (67%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/EagerStateTest.kt (93%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/InvocationIdTest.kt (84%) rename sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/KotlinCoroutinesTests.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt (55%) rename {sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/MyMetaServiceAnnotation.kt (92%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/OnlyInputAndOutputTest.kt (84%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/PromiseTest.kt (73%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/RandomTest.kt (66%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/SideEffectTest.kt (62%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/SleepTest.kt (86%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/StateMachineFailuresTest.kt (75%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/StateTest.kt (71%) rename {sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin => sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi}/UserFailuresTest.kt (90%) create mode 100644 sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt rename sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTestExecutor.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt (79%) rename sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt (61%) rename sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt => sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt (54%) create mode 100644 sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpEndpointRequestHandler.java delete mode 100644 sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java delete mode 100644 sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java create mode 100644 sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpServer.java delete mode 100644 sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeter.java delete mode 100644 sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt delete mode 100644 sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtComponent.kt delete mode 100644 sdk-http-vertx/src/test/resources/junit-platform.properties create mode 100644 sdk-java-http/build.gradle.kts create mode 100644 sdk-java-lambda/build.gradle.kts create mode 100644 sdk-kotlin-http/build.gradle.kts create mode 100644 sdk-kotlin-lambda/build.gradle.kts create mode 100644 sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaEndpointRequestHandler.java delete mode 100644 sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpoint.java delete mode 100644 sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpointBuilder.java delete mode 100644 sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt delete mode 100644 sdk-lambda/src/test/resources/log4j2.properties create mode 100644 sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdeFactory.java delete mode 100644 sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java rename sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/{RestateEndpointHttpServerProperties.java => RestateHttpServerProperties.java} (85%) delete mode 100644 sdk-testing/src/main/java/dev/restate/sdk/testing/ManualRestateRunner.java delete mode 100644 sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunnerBuilder.java delete mode 100644 sdk-testing/src/test/java/dev/restate/sdk/testing/CounterOldExtensionTest.java diff --git a/buf.lock b/buf.lock deleted file mode 100644 index 4f98143f5..000000000 --- a/buf.lock +++ /dev/null @@ -1,2 +0,0 @@ -# Generated by buf. DO NOT EDIT. -version: v2 diff --git a/buf.yaml b/buf.yaml deleted file mode 100644 index ab3bd5be4..000000000 --- a/buf.yaml +++ /dev/null @@ -1,8 +0,0 @@ -version: v2 -name: buf.build/restatedev/service-protocol -lint: - use: - - DEFAULT -breaking: - use: - - FILE diff --git a/build.gradle.kts b/build.gradle.kts index 2bd9a6ec1..36d6338b7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,3 +1,5 @@ +import com.github.jk1.license.render.ReportRenderer + plugins { alias(libs.plugins.dependency.license.report) alias(libs.plugins.nexus.publish) @@ -45,7 +47,7 @@ allprojects { tasks.named("check") { dependsOn("checkLicense") } licenseReport { - renderers = arrayOf(com.github.jk1.license.render.CsvReportRenderer()) + renderers = arrayOf(com.github.jk1.license.render.CsvReportRenderer()) excludeBoms = true diff --git a/buildSrc/src/main/kotlin/library-publishing-conventions.gradle.kts b/buildSrc/src/main/kotlin/library-publishing-conventions.gradle.kts index e57010476..a4f34688d 100644 --- a/buildSrc/src/main/kotlin/library-publishing-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/library-publishing-conventions.gradle.kts @@ -7,11 +7,68 @@ project.afterEvaluate { publishing { publications { create("maven") { + afterEvaluate { + val shadowJar = tasks.findByName("shadowJar") + if (shadowJar == null) { + from(components["java"]) + } + else { + apply(plugin = "com.gradleup.shadow") + + from(components["shadow"]) + artifact(tasks["sourcesJar"]!!) + artifact(tasks["javadocJar"]!!) + + afterEvaluate { + // Fix for avoiding inclusion of runtime dependencies marked as 'shadow' in MANIFEST Class-Path. + // https://github.com/johnrengelman/shadow/issues/324 + pom.withXml { + val rootNode = asElement() + val doc = rootNode.ownerDocument + + val dependenciesNode = + if (rootNode.getElementsByTagName("dependencies").length != 0) { + rootNode.getElementsByTagName("dependencies").item(0) + } else { + rootNode.appendChild( + doc.createElement("dependencies") + ) + } + + project.configurations["shade"].allDependencies.forEach { dep -> + dependenciesNode.appendChild( + doc.createElement("dependency").apply { + appendChild( + doc.createElement("groupId").apply { + textContent = dep.group + } + ) + appendChild( + doc.createElement("artifactId").apply { + textContent = dep.name + } + ) + appendChild( + doc.createElement("version").apply { + textContent = dep.version + } + ) + appendChild( + doc.createElement("scope").apply { + textContent = "runtime" + } + ) + } + ) + } + } + } + } + } + groupId = "dev.restate" artifactId = project.name - from(components["java"]) - pom { name = "Restate SDK :: ${project.name}" description = project.description!! diff --git a/buildSrc/src/main/kotlin/test-jar-conventions.gradle.kts b/buildSrc/src/main/kotlin/test-jar-conventions.gradle.kts deleted file mode 100644 index 7f00d2728..000000000 --- a/buildSrc/src/main/kotlin/test-jar-conventions.gradle.kts +++ /dev/null @@ -1,9 +0,0 @@ -configurations { register("testArchive") } - -tasks.register("testJar") { - archiveClassifier.set("tests") - - from(project.the()["test"].output) -} - -artifacts { add("testArchive", tasks["testJar"]) } \ No newline at end of file diff --git a/client-kotlin/build.gradle.kts b/client-kotlin/build.gradle.kts new file mode 100644 index 000000000..fb330dc5b --- /dev/null +++ b/client-kotlin/build.gradle.kts @@ -0,0 +1,11 @@ +plugins { + `kotlin-conventions` + `library-publishing-conventions` +} + +description = "Restate Client to interact with services from within other Kotlin applications" + +dependencies { + api(project(":client")) + implementation(libs.kotlinx.coroutines.core) +} diff --git a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt new file mode 100644 index 000000000..f682e3416 --- /dev/null +++ b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt @@ -0,0 +1,99 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client.kotlin + +import dev.restate.client.Client +import dev.restate.client.ClientRequestOptions +import dev.restate.client.ClientResponse +import dev.restate.client.SendResponse +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.serde.Serde +import kotlinx.coroutines.future.await + +// Extension methods for the Client + +fun clientRequestOptions(init: ClientRequestOptions.Builder.() -> Unit): ClientRequestOptions { + val builder = ClientRequestOptions.builder() + builder.init() + return builder.build() +} + +suspend fun Client.callSuspend(request: Request): ClientResponse { + return this.callAsync(request).await() +} + +suspend fun Client.callSuspend( + requestBuilder: Request.Builder +): ClientResponse { + return this.callAsync(requestBuilder).await() +} + +suspend fun Client.sendSuspend( + request: Request +): ClientResponse> { + return this.sendAsync(request).await() +} + +suspend fun Client.sendSuspend( + request: Request.Builder +): ClientResponse> { + return this.sendSuspend(request.build()) +} + +suspend fun Client.AwakeableHandle.resolveSuspend( + serde: Serde, + payload: T, + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse { + return this.resolveAsync(serde, payload, options).await() +} + +suspend fun Client.AwakeableHandle.rejectSuspend( + reason: String, + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse { + return this.rejectAsync(reason, options).await() +} + +suspend fun Client.InvocationHandle.attachSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse { + return this.attachAsync(options).await() +} + +suspend fun Client.InvocationHandle.getOutputSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse> { + return this.getOutputAsync(options).await() +} + +suspend fun Client.IdempotentInvocationHandle.attachSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse { + return this.attachAsync(options).await() +} + +suspend fun Client.IdempotentInvocationHandle.getOutputSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse> { + return this.getOutputAsync(options).await() +} + +suspend fun Client.WorkflowHandle.attachSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse { + return this.attachAsync(options).await() +} + +suspend fun Client.WorkflowHandle.getOutputSuspend( + options: ClientRequestOptions = ClientRequestOptions.DEFAULT +): ClientResponse> { + return this.getOutputAsync(options).await() +} diff --git a/client/build.gradle.kts b/client/build.gradle.kts new file mode 100644 index 000000000..d0bf73c89 --- /dev/null +++ b/client/build.gradle.kts @@ -0,0 +1,20 @@ +plugins { + `java-library` + `java-conventions` + `kotlin-conventions` + `library-publishing-conventions` +} + +description = "Restate Client to interact with services from within other Java applications" + +dependencies { + compileOnly(libs.jspecify) + + api(project(":common")) + + implementation(libs.jackson.core) + implementation(libs.log4j.api) + + testImplementation(libs.junit.jupiter) + testImplementation(libs.assertj) +} diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java new file mode 100644 index 000000000..12cb1f7ec --- /dev/null +++ b/client/src/main/java/dev/restate/client/Client.java @@ -0,0 +1,354 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +import dev.restate.common.Output; +import dev.restate.common.Request; +import dev.restate.common.Target; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.TypeTag; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import org.jspecify.annotations.NonNull; + +public interface Client { + + CompletableFuture> callAsync(Request request); + + default CompletableFuture> callAsync( + Request.Builder request) { + return callAsync(request.build()); + } + + default ClientResponse call(Request request) throws IngressException { + try { + return callAsync(request).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse call(Request.Builder request) + throws IngressException { + return call(request.build()); + } + + CompletableFuture>> sendAsync( + Request request); + + default ClientResponse> send(Request request) + throws IngressException { + try { + return sendAsync(request).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default CompletableFuture>> sendAsync( + Request.Builder request) { + return sendAsync(request.build()); + } + + default ClientResponse> send(Request.Builder request) + throws IngressException { + return send(request.build()); + } + + /** + * Create a new {@link AwakeableHandle} for the provided identifier. You can use it to {@link + * AwakeableHandle#resolve(TypeTag, Object)} or {@link AwakeableHandle#reject(String)} an + * Awakeable from the ingress. + */ + AwakeableHandle awakeableHandle(String id); + + /** + * This class represents a handle to an Awakeable. It can be used to complete awakeables from the + * ingress + */ + interface AwakeableHandle { + /** Same as {@link #resolve(TypeTag, Object)} but async with options. */ + CompletableFuture> resolveAsync( + TypeTag serde, @NonNull T payload, ClientRequestOptions options); + + /** Same as {@link #resolve(TypeTag, Object)} but async. */ + default CompletableFuture> resolveAsync( + TypeTag serde, @NonNull T payload) { + return resolveAsync(serde, payload, ClientRequestOptions.DEFAULT); + } + + /** Same as {@link #resolve(TypeTag, Object)} with options. */ + default ClientResponse resolve( + TypeTag serde, @NonNull T payload, ClientRequestOptions options) { + try { + return resolveAsync(serde, payload, options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + /** + * Complete with success the Awakeable. + * + * @param serde used to serialize the Awakeable result payload. + * @param payload the result payload. MUST NOT be null. + */ + default ClientResponse resolve(TypeTag serde, @NonNull T payload) { + return this.resolve(serde, payload, ClientRequestOptions.DEFAULT); + } + + /** Same as {@link #reject(String)} but async with options. */ + CompletableFuture> rejectAsync( + String reason, ClientRequestOptions options); + + /** Same as {@link #reject(String)} but async. */ + default CompletableFuture> rejectAsync(String reason) { + return rejectAsync(reason, ClientRequestOptions.DEFAULT); + } + + /** Same as {@link #reject(String)} with options. */ + default ClientResponse reject(String reason, ClientRequestOptions options) { + try { + return rejectAsync(reason, options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + /** + * Complete with failure the Awakeable. + * + * @param reason the rejection reason. MUST NOT be null. + */ + default ClientResponse reject(String reason) { + return this.reject(reason, ClientRequestOptions.DEFAULT); + } + } + + InvocationHandle invocationHandle(String invocationId, TypeTag resSerde); + + interface InvocationHandle { + + String invocationId(); + + CompletableFuture> attachAsync(ClientRequestOptions options); + + default CompletableFuture> attachAsync() { + return attachAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse attach(ClientRequestOptions options) throws IngressException { + try { + return attachAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse attach() throws IngressException { + return attach(ClientRequestOptions.DEFAULT); + } + + CompletableFuture>> getOutputAsync(ClientRequestOptions options); + + default CompletableFuture>> getOutputAsync() { + return getOutputAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse> getOutput(ClientRequestOptions options) + throws IngressException { + try { + return getOutputAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse> getOutput() throws IngressException { + return getOutput(ClientRequestOptions.DEFAULT); + } + } + + IdempotentInvocationHandle idempotentInvocationHandle( + Target target, String idempotencyKey, TypeTag resSerde); + + interface IdempotentInvocationHandle { + + CompletableFuture> attachAsync(ClientRequestOptions options); + + default CompletableFuture> attachAsync() { + return attachAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse attach(ClientRequestOptions options) throws IngressException { + try { + return attachAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse attach() throws IngressException { + return attach(ClientRequestOptions.DEFAULT); + } + + CompletableFuture>> getOutputAsync(ClientRequestOptions options); + + default CompletableFuture>> getOutputAsync() { + return getOutputAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse> getOutput(ClientRequestOptions options) + throws IngressException { + try { + return getOutputAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse> getOutput() throws IngressException { + return getOutput(ClientRequestOptions.DEFAULT); + } + } + + WorkflowHandle workflowHandle( + String workflowName, String workflowId, TypeTag resSerde); + + interface WorkflowHandle { + CompletableFuture> attachAsync(ClientRequestOptions options); + + default CompletableFuture> attachAsync() { + return attachAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse attach(ClientRequestOptions options) throws IngressException { + try { + return attachAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse attach() throws IngressException { + return attach(ClientRequestOptions.DEFAULT); + } + + CompletableFuture>> getOutputAsync(ClientRequestOptions options); + + default CompletableFuture>> getOutputAsync() { + return getOutputAsync(ClientRequestOptions.DEFAULT); + } + + default ClientResponse> getOutput(ClientRequestOptions options) + throws IngressException { + try { + return getOutputAsync(options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default ClientResponse> getOutput() throws IngressException { + return getOutput(ClientRequestOptions.DEFAULT); + } + } + + /** + * Create a default JDK client. + * + * @param baseUri uri to connect to. + */ + static Client connect(String baseUri) { + return connect(baseUri, SerdeFactory.NOOP, ClientRequestOptions.DEFAULT); + } + + /** + * Create a default JDK client. + * + * @param baseUri uri to connect to + * @param options default options to use in all the requests. + */ + static Client connect(String baseUri, ClientRequestOptions options) { + return connect(baseUri, SerdeFactory.NOOP, options); + } + + /** + * Create a default JDK client. + * + * @param baseUri uri to connect to + * @param serdeFactory Serde factory to use. You must provide this when the provided {@link + * TypeTag} are not {@link Serde} instances. If you're just wrapping this client in a + * code-generated client, you don't need to provide this parameter. + */ + static Client connect(String baseUri, SerdeFactory serdeFactory) { + return connect(baseUri, serdeFactory, ClientRequestOptions.DEFAULT); + } + + /** + * Create a default JDK client. + * + * @param baseUri uri to connect to + * @param serdeFactory Serde factory to use. You must provide this when the provided {@link + * TypeTag} are not {@link Serde} instances. If you're just wrapping this client in a + * code-generated client, you don't need to provide this parameter. + * @param options default options to use in all the requests. + */ + static Client connect(String baseUri, SerdeFactory serdeFactory, ClientRequestOptions options) { + // We load through reflections to avoid CNF exceptions in JVMs + // where JDK's HttpClient is not available (see Android!) + try { + Class.forName("java.net.http.HttpClient"); + } catch (ClassNotFoundException e) { + throw new IllegalStateException( + "Cannot load the JdkClient, because the java.net.http.HttpClient is not available on this JVM. Please use another client", + e); + } + + try { + return (Client) + Class.forName("dev.restate.client.jdk.JdkClient") + .getMethod("of", String.class, SerdeFactory.class, ClientRequestOptions.class) + .invoke(null, baseUri, serdeFactory, options); + } catch (Exception e) { + throw new IllegalStateException("Cannot instantiate the client", e); + } + } +} diff --git a/client/src/main/java/dev/restate/client/ClientRequestOptions.java b/client/src/main/java/dev/restate/client/ClientRequestOptions.java new file mode 100644 index 000000000..d5861ee59 --- /dev/null +++ b/client/src/main/java/dev/restate/client/ClientRequestOptions.java @@ -0,0 +1,100 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +import java.util.*; +import org.jspecify.annotations.Nullable; + +public final class ClientRequestOptions { + + public static final ClientRequestOptions DEFAULT = new ClientRequestOptions(null); + + @Nullable private final Map headers; + + private ClientRequestOptions(@Nullable Map headers) { + this.headers = headers; + } + + public Map headers() { + if (headers == null) { + return Collections.emptyMap(); + } + return headers; + } + + public static Builder builder() { + return new Builder(); + } + + /** + * @param headers Headers to attach in the request. + */ + public static Builder withHeaders(Map headers) { + return builder().headers(headers); + } + + public static final class Builder { + @Nullable private Map headers; + + private Builder() {} + + /** + * @param key header key + * @param value header value + * @return this instance, so the builder can be used fluently. + */ + public Builder header(String key, String value) { + if (this.headers == null) { + this.headers = new HashMap<>(); + } + this.headers.put(key, value); + return this; + } + + /** + * @param newHeaders headers to send together with the request. + * @return this instance, so the builder can be used fluently. + */ + public Builder headers(Map newHeaders) { + if (this.headers == null) { + this.headers = new HashMap<>(); + } + this.headers.putAll(newHeaders); + return this; + } + + public @Nullable Map getHeaders() { + return headers; + } + + public Builder setHeaders(Map newHeaders) { + return headers(newHeaders); + } + + public ClientRequestOptions build() { + return new ClientRequestOptions(this.headers); + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ClientRequestOptions that)) return false; + return Objects.equals(headers, that.headers); + } + + @Override + public int hashCode() { + return Objects.hashCode(headers); + } + + @Override + public String toString() { + return "ClientRequestOptions{" + "headers=" + headers + '}'; + } +} diff --git a/client/src/main/java/dev/restate/client/ClientResponse.java b/client/src/main/java/dev/restate/client/ClientResponse.java new file mode 100644 index 000000000..a33099e63 --- /dev/null +++ b/client/src/main/java/dev/restate/client/ClientResponse.java @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +import java.util.Map; +import java.util.Set; +import org.jspecify.annotations.Nullable; + +public record ClientResponse(int statusCode, Headers headers, R response) { + public interface Headers { + @Nullable String get(String key); + + Set keys(); + + /** + * @return headers to lowercase keys map + */ + Map toLowercaseMap(); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java b/client/src/main/java/dev/restate/client/IngressException.java similarity index 56% rename from sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java rename to client/src/main/java/dev/restate/client/IngressException.java index 830aa5ee0..e489723b6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java +++ b/client/src/main/java/dev/restate/client/IngressException.java @@ -6,10 +6,8 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; +package dev.restate.client; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import org.jspecify.annotations.Nullable; @@ -20,48 +18,12 @@ public class IngressException extends RuntimeException { private final int statusCode; private final byte[] responseBody; - public IngressException( - String message, String requestMethod, String requestURI, Throwable cause) { - this(message, requestMethod, requestURI, -1, null, cause); - } - - IngressException(String message, HttpRequest request, Throwable cause) { - this(message, request.method(), request.uri().toString(), -1, null, cause); - } - - IngressException(String message, HttpRequest request) { - this(message, request, null); - } - - public IngressException( - String message, - String requestMethod, - String requestURI, - int statusCode, - byte[] responseBody) { - this(message, requestMethod, requestURI, statusCode, responseBody, null); - } - - IngressException(String message, HttpResponse response, Throwable cause) { - this( - message, - response.request().method(), - response.request().uri().toString(), - response.statusCode(), - response.body(), - cause); - } - - IngressException(String message, HttpResponse response) { - this(message, response, null); - } - public IngressException( String message, String requestMethod, String requestURI, int statusCode, - byte[] responseBody, + @Nullable byte[] responseBody, Throwable cause) { super(message, cause); this.statusCode = statusCode; diff --git a/client/src/main/java/dev/restate/client/SendResponse.java b/client/src/main/java/dev/restate/client/SendResponse.java new file mode 100644 index 000000000..c9c5bae07 --- /dev/null +++ b/client/src/main/java/dev/restate/client/SendResponse.java @@ -0,0 +1,18 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client; + +public record SendResponse(SendStatus status, Client.InvocationHandle invocationHandle) { + public enum SendStatus { + /** The request was sent for the first time. */ + ACCEPTED, + /** The request was already sent beforehand. */ + PREVIOUSLY_ACCEPTED + } +} diff --git a/client/src/main/java/dev/restate/client/base/BaseClient.java b/client/src/main/java/dev/restate/client/base/BaseClient.java new file mode 100644 index 000000000..6f268d2bc --- /dev/null +++ b/client/src/main/java/dev/restate/client/base/BaseClient.java @@ -0,0 +1,535 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client.base; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import dev.restate.client.*; +import dev.restate.common.*; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.TypeTag; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import org.jetbrains.annotations.NotNull; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +/** + * Base client. This can be used to build {@link Client} implementations on top with the HTTP client + * of your choice. + */ +public abstract class BaseClient implements Client { + + private static final JsonFactory JSON_FACTORY = new JsonFactory(); + + private final URI baseUri; + private final SerdeFactory serdeFactory; + private final ClientRequestOptions baseOptions; + + protected BaseClient(URI baseUri, SerdeFactory serdeFactory, ClientRequestOptions baseOptions) { + this.baseUri = baseUri; + this.serdeFactory = serdeFactory; + this.baseOptions = baseOptions; + } + + @Override + public CompletableFuture> callAsync(Request request) { + Serde reqSerde = this.serdeFactory.create(request.requestTypeTag()); + Serde resSerde = this.serdeFactory.create(request.responseTypeTag()); + + URI requestUri = toRequestURI(request.target(), false, null); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), request.headers().entrySet().stream()); + if (reqSerde.contentType() != null) { + headersStream = + Stream.concat( + headersStream, Stream.of(Map.entry("content-type", reqSerde.contentType()))); + } + if (request.idempotencyKey() != null) { + headersStream = + Stream.concat( + headersStream, Stream.of(Map.entry("idempotency-key", request.idempotencyKey()))); + } + Slice requestBody = reqSerde.serialize(request.request()); + + return doPostRequest( + requestUri, headersStream, requestBody, callResponseMapper("POST", requestUri, resSerde)); + } + + @Override + public CompletableFuture>> sendAsync( + Request request) { + Serde reqSerde = this.serdeFactory.create(request.requestTypeTag()); + + URI requestUri = + toRequestURI( + request.target(), + true, + (request instanceof SendRequest sendRequest) ? sendRequest.delay() : null); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), request.headers().entrySet().stream()); + if (reqSerde.contentType() != null) { + headersStream = + Stream.concat( + headersStream, Stream.of(Map.entry("content-type", reqSerde.contentType()))); + } + if (request.idempotencyKey() != null) { + headersStream = + Stream.concat( + headersStream, Stream.of(Map.entry("idempotency-key", request.idempotencyKey()))); + } + Slice requestBody = reqSerde.serialize(request.request()); + + return doPostRequest( + requestUri, + headersStream, + requestBody, + (statusCode, responseHeaders, responseBody) -> { + if (statusCode >= 300) { + handleNonSuccessResponse( + "POST", requestUri.toString(), statusCode, responseHeaders, responseBody); + } + + if (responseBody == null) { + throw new IngressException( + "Expecting a response body, but got none", + "POST", + requestUri.toString(), + statusCode, + null, + null); + } + + Map fields; + try { + fields = + findStringFieldsInJsonObject( + new ByteArrayInputStream(responseBody.toByteArray()), "invocationId", "status"); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", + "POST", + requestUri.toString(), + statusCode, + responseBody.toByteArray(), + e); + } + + String statusField = fields.get("status"); + SendResponse.SendStatus status; + if ("Accepted".equalsIgnoreCase(statusField)) { + status = SendResponse.SendStatus.ACCEPTED; + } else if ("PreviouslyAccepted".equalsIgnoreCase(statusField)) { + status = SendResponse.SendStatus.PREVIOUSLY_ACCEPTED; + } else { + throw new IngressException( + "Cannot deserialize the response status, got " + statusField, + "POST", + requestUri.toString(), + statusCode, + responseBody.toByteArray(), + null); + } + + return new ClientResponse<>( + statusCode, + responseHeaders, + new SendResponse<>( + status, invocationHandle(fields.get("invocationId"), request.responseTypeTag()))); + }); + } + + @Override + public AwakeableHandle awakeableHandle(String id) { + return new AwakeableHandle() { + @Override + public CompletableFuture> resolveAsync( + TypeTag serde, @NonNull T payload, ClientRequestOptions options) { + Serde reqSerde = serdeFactory.create(serde); + Slice requestBody = reqSerde.serialize(payload); + + URI requestUri = baseUri.resolve("/restate/awakeables/" + id + "/resolve"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + if (reqSerde.contentType() != null) { + headersStream = + Stream.concat( + headersStream, Stream.of(Map.entry("content-type", reqSerde.contentType()))); + } + + return doPostRequest( + requestUri, + headersStream, + requestBody, + handleVoidResponse("POST", requestUri.toString())); + } + + @Override + public CompletableFuture> rejectAsync( + String reason, ClientRequestOptions options) { + URI requestUri = baseUri.resolve("/restate/awakeables/" + id + "/reject"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), + Stream.concat( + options.headers().entrySet().stream(), + Stream.of(Map.entry("content-type", "text/plain")))); + + return doPostRequest( + requestUri, + headersStream, + Slice.wrap(reason), + handleVoidResponse("POST", requestUri.toString())); + } + }; + } + + @Override + public InvocationHandle invocationHandle( + String invocationId, TypeTag resTypeTag) { + Serde resSerde = serdeFactory.create(resTypeTag); + + return new InvocationHandle<>() { + @Override + public String invocationId() { + return invocationId; + } + + @Override + public CompletableFuture> attachAsync(ClientRequestOptions options) { + URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/attach"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, callResponseMapper("GET", requestUri, resSerde)); + } + + @Override + public CompletableFuture>> getOutputAsync( + ClientRequestOptions options) { + URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/output"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, getOutputResponseMapper("GET", requestUri, resSerde)); + } + }; + } + + @Override + public IdempotentInvocationHandle idempotentInvocationHandle( + Target target, String idempotencyKey, TypeTag resTypeTag) { + return new IdempotentInvocationHandle<>() { + @Override + public CompletableFuture> attachAsync(ClientRequestOptions options) { + Serde resSerde = serdeFactory.create(resTypeTag); + + URI requestUri = + baseUri.resolve( + "/restate/invocation" + + targetToURI(target) + + "/" + + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) + + "/attach"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, callResponseMapper("GET", requestUri, resSerde)); + } + + @Override + public CompletableFuture>> getOutputAsync( + ClientRequestOptions options) { + Serde resSerde = serdeFactory.create(resTypeTag); + + URI requestUri = + baseUri.resolve( + "/restate/invocation" + + targetToURI(target) + + "/" + + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) + + "/output"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, getOutputResponseMapper("GET", requestUri, resSerde)); + } + }; + } + + @Override + public WorkflowHandle workflowHandle( + String workflowName, String workflowId, TypeTag resTypeTag) { + return new WorkflowHandle<>() { + @Override + public CompletableFuture> attachAsync(ClientRequestOptions options) { + Serde resSerde = serdeFactory.create(resTypeTag); + + URI requestUri = + baseUri.resolve( + "/restate/workflow/" + + workflowName + + "/" + + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) + + "/attach"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, callResponseMapper("GET", requestUri, resSerde)); + } + + @Override + public CompletableFuture>> getOutputAsync( + ClientRequestOptions options) { + Serde resSerde = serdeFactory.create(resTypeTag); + + URI requestUri = + baseUri.resolve( + "/restate/workflow/" + + workflowName + + "/" + + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) + + "/output"); + Stream> headersStream = + Stream.concat( + baseOptions.headers().entrySet().stream(), options.headers().entrySet().stream()); + + return doGetRequest( + requestUri, headersStream, getOutputResponseMapper("GET", requestUri, resSerde)); + } + }; + } + + @FunctionalInterface + protected interface ResponseMapper { + ClientResponse mapResponse( + int statusCode, ClientResponse.Headers responseHeaders, @Nullable Slice responseBody); + } + + protected abstract CompletableFuture> doPostRequest( + URI target, + Stream> headers, + Slice payload, + ResponseMapper responseMapper); + + protected abstract CompletableFuture> doGetRequest( + URI target, Stream> headers, ResponseMapper responseMapper); + + private @NotNull ResponseMapper callResponseMapper( + String requestMethod, URI requestUri, Serde resSerde) { + return (statusCode, responseHeaders, responseBody) -> { + if (statusCode >= 300) { + handleNonSuccessResponse( + requestMethod, requestUri.toString(), statusCode, responseHeaders, responseBody); + } + + if (responseBody == null) { + throw new IngressException( + "Expecting a response body, but got none", + requestMethod, + requestUri.toString(), + statusCode, + null, + null); + } + try { + return new ClientResponse<>( + statusCode, responseHeaders, resSerde.deserialize(responseBody)); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", + requestMethod, + requestUri.toString(), + statusCode, + responseBody.toByteArray(), + e); + } + }; + } + + private @NotNull ResponseMapper> getOutputResponseMapper( + String requestMethod, URI requestUri, Serde resSerde) { + return (statusCode, responseHeaders, responseBody) -> { + if (statusCode == 470) { + return new ClientResponse<>(statusCode, responseHeaders, Output.notReady()); + } + + if (statusCode >= 300) { + handleNonSuccessResponse( + "GET", requestUri.toString(), statusCode, responseHeaders, responseBody); + } + + if (responseBody == null) { + throw new IngressException( + "Expecting a response body, but got none", + requestMethod, + requestUri.toString(), + statusCode, + null, + null); + } + try { + return new ClientResponse<>( + statusCode, responseHeaders, Output.ready(resSerde.deserialize(responseBody))); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", + requestMethod, + requestUri.toString(), + statusCode, + responseBody.toByteArray(), + e); + } + }; + } + + /** Contains prefix / but not postfix / */ + private String targetToURI(Target target) { + StringBuilder builder = new StringBuilder(); + builder.append("/").append(target.getService()); + if (target.getKey() != null) { + builder.append("/").append(URLEncoder.encode(target.getKey(), StandardCharsets.UTF_8)); + } + builder.append("/").append(target.getHandler()); + return builder.toString(); + } + + private URI toRequestURI(Target target, boolean isSend, Duration delay) { + StringBuilder builder = new StringBuilder(targetToURI(target)); + if (isSend) { + builder.append("/send"); + } + if (delay != null && !delay.isZero() && !delay.isNegative()) { + builder.append("?delay=").append(delay); + } + + return this.baseUri.resolve(builder.toString()); + } + + private ResponseMapper handleVoidResponse(String requestMethod, String requestURI) { + return (statusCode, responseHeaders, responseBody) -> { + if (statusCode >= 300) { + handleNonSuccessResponse( + requestMethod, requestURI, statusCode, responseHeaders, responseBody); + } + + return new ClientResponse<>(statusCode, responseHeaders, null); + }; + } + + private void handleNonSuccessResponse( + String requestMethod, + String requestURI, + int statusCode, + ClientResponse.Headers headers, + @Nullable Slice responseBody) { + String ct = headers.get("content-type"); + if (ct != null && ct.contains("application/json") && responseBody != null) { + String errorMessage; + // Let's try to parse the message field + try { + errorMessage = + findStringFieldInJsonObject( + new ByteArrayInputStream(responseBody.toByteArray()), "message"); + } catch (Exception e) { + throw new IngressException( + "Can't decode error response from ingress", + requestMethod, + requestURI, + statusCode, + responseBody.toByteArray(), + e); + } + throw new IngressException( + errorMessage, requestMethod, requestURI, statusCode, responseBody.toByteArray(), null); + } + + // Fallback error + throw new IngressException( + "Received non success status code", + requestMethod, + requestURI, + statusCode, + (responseBody != null) ? responseBody.toByteArray() : null, + null); + } + + private static String findStringFieldInJsonObject(InputStream body, String fieldName) + throws IOException { + try (JsonParser parser = JSON_FACTORY.createParser(body)) { + if (parser.nextToken() != JsonToken.START_OBJECT) { + throw new IllegalStateException( + "Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken()); + } + for (String actualFieldName = parser.nextFieldName(); + actualFieldName != null; + actualFieldName = parser.nextFieldName()) { + if (actualFieldName.equalsIgnoreCase(fieldName)) { + return parser.nextTextValue(); + } else { + parser.nextValue(); + } + } + throw new IllegalStateException( + "Expecting field name \"" + fieldName + "\", got " + parser.getCurrentToken()); + } + } + + private static Map findStringFieldsInJsonObject( + InputStream body, String... fields) throws IOException { + Map resultMap = new HashMap<>(); + Set fieldSet = new HashSet<>(Set.of(fields)); + + try (JsonParser parser = JSON_FACTORY.createParser(body)) { + if (parser.nextToken() != JsonToken.START_OBJECT) { + throw new IllegalStateException( + "Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken()); + } + for (String actualFieldName = parser.nextFieldName(); + actualFieldName != null; + actualFieldName = parser.nextFieldName()) { + if (fieldSet.remove(actualFieldName)) { + resultMap.put(actualFieldName, parser.nextTextValue()); + } else { + parser.nextValue(); + } + } + } + + if (!fieldSet.isEmpty()) { + throw new IllegalStateException( + "Expecting fields \"" + Arrays.toString(fields) + "\", cannot find fields " + fieldSet); + } + + return resultMap; + } +} diff --git a/client/src/main/java/dev/restate/client/jdk/JdkClient.java b/client/src/main/java/dev/restate/client/jdk/JdkClient.java new file mode 100644 index 000000000..667f2c0c8 --- /dev/null +++ b/client/src/main/java/dev/restate/client/jdk/JdkClient.java @@ -0,0 +1,129 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.client.jdk; + +import dev.restate.client.*; +import dev.restate.client.base.BaseClient; +import dev.restate.common.*; +import dev.restate.serde.SerdeFactory; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.jspecify.annotations.Nullable; + +public class JdkClient extends BaseClient { + + private final HttpClient httpClient; + + private JdkClient( + URI baseUri, + SerdeFactory serdeFactory, + ClientRequestOptions baseOptions, + HttpClient httpClient) { + super(baseUri, serdeFactory, baseOptions); + this.httpClient = httpClient; + } + + @Override + protected CompletableFuture> doPostRequest( + URI target, + Stream> headers, + Slice payload, + ResponseMapper responseMapper) { + var reqBuilder = HttpRequest.newBuilder().uri(target); + headers.forEach(h -> reqBuilder.header(h.getKey(), h.getValue())); + reqBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(payload.toByteArray())); + + return this.httpClient + .sendAsync(reqBuilder.build(), HttpResponse.BodyHandlers.ofByteArray()) + .handle( + (res, t) -> { + if (t != null) { + throw new IngressException( + "Error when executing the request: " + t.getMessage(), + "POST", + target.toString(), + -1, + null, + t); + } + + return responseMapper.mapResponse( + res.statusCode(), toHeaders(res.headers()), Slice.wrap(res.body())); + }); + } + + @Override + protected CompletableFuture> doGetRequest( + URI target, Stream> headers, ResponseMapper responseMapper) { + var reqBuilder = HttpRequest.newBuilder().uri(target); + headers.forEach(h -> reqBuilder.header(h.getKey(), h.getValue())); + reqBuilder.GET(); + + return this.httpClient + .sendAsync(reqBuilder.build(), HttpResponse.BodyHandlers.ofByteArray()) + .handle( + (res, t) -> { + if (t != null) { + throw new IngressException( + "Error when executing the request: " + t.getMessage(), + "POST", + target.toString(), + -1, + null, + t); + } + + return responseMapper.mapResponse( + res.statusCode(), toHeaders(res.headers()), Slice.wrap(res.body())); + }); + } + + private ClientResponse.Headers toHeaders(HttpHeaders httpHeaders) { + return new ClientResponse.Headers() { + @Override + public @Nullable String get(String key) { + return httpHeaders.firstValue(key).orElse(null); + } + + @Override + public Set keys() { + return httpHeaders.map().keySet(); + } + + @Override + public Map toLowercaseMap() { + return httpHeaders.map().entrySet().stream() + .map(e -> Map.entry(e.getKey().toLowerCase(), e.getValue().get(0))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + }; + } + + /** Create a new JDK Client */ + public static JdkClient of( + HttpClient httpClient, + String baseUri, + SerdeFactory serdeFactory, + ClientRequestOptions options) { + return new JdkClient(URI.create(baseUri), serdeFactory, options, httpClient); + } + + /** Create a new JDK Client */ + public static JdkClient of( + String baseUri, SerdeFactory serdeFactory, ClientRequestOptions options) { + return new JdkClient(URI.create(baseUri), serdeFactory, options, HttpClient.newHttpClient()); + } +} diff --git a/sdk-serde-protobuf/build.gradle.kts b/common/build.gradle.kts similarity index 52% rename from sdk-serde-protobuf/build.gradle.kts rename to common/build.gradle.kts index b27120285..6ec4b424e 100644 --- a/sdk-serde-protobuf/build.gradle.kts +++ b/common/build.gradle.kts @@ -1,16 +1,15 @@ plugins { + `java-library` `java-conventions` `kotlin-conventions` - `java-library` `library-publishing-conventions` } -description = "Restate SDK Protobuf integration" +description = "Common types used by different Restate Java modules" dependencies { compileOnly(libs.jspecify) - api(libs.protobuf.java) - - implementation(project(":sdk-common")) + testImplementation(libs.junit.jupiter) + testImplementation(libs.assertj) } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Output.java b/common/src/main/java/dev/restate/common/Output.java similarity index 98% rename from sdk-common/src/main/java/dev/restate/sdk/common/Output.java rename to common/src/main/java/dev/restate/common/Output.java index d862c48fc..224c6ca25 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Output.java +++ b/common/src/main/java/dev/restate/common/Output.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.common; import java.util.Objects; import java.util.Optional; diff --git a/common/src/main/java/dev/restate/common/Request.java b/common/src/main/java/dev/restate/common/Request.java new file mode 100644 index 000000000..3e15df886 --- /dev/null +++ b/common/src/main/java/dev/restate/common/Request.java @@ -0,0 +1,243 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common; + +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import org.jspecify.annotations.Nullable; + +public sealed class Request permits SendRequest { + + private final Target target; + private final TypeTag reqTypeTag; + private final TypeTag resTypeTag; + private final Req request; + @Nullable private final String idempotencyKey; + @Nullable private final LinkedHashMap headers; + + Request( + Target target, + TypeTag reqTypeTag, + TypeTag resTypeTag, + Req request, + @Nullable String idempotencyKey, + @Nullable LinkedHashMap headers) { + this.target = target; + this.reqTypeTag = reqTypeTag; + this.resTypeTag = resTypeTag; + this.request = request; + this.idempotencyKey = idempotencyKey; + this.headers = headers; + } + + public Target target() { + return target; + } + + public TypeTag requestTypeTag() { + return reqTypeTag; + } + + public TypeTag responseTypeTag() { + return resTypeTag; + } + + public Req request() { + return request; + } + + public @Nullable String idempotencyKey() { + return idempotencyKey; + } + + public Map headers() { + if (headers == null) { + return Map.of(); + } + return headers; + } + + public static Builder of( + Target target, TypeTag reqTypeTag, TypeTag resTypeTag, Req request) { + return new Builder<>(target, reqTypeTag, resTypeTag, request); + } + + public static Builder withNoResponseBody( + Target target, TypeTag reqTypeTag, Req request) { + return new Builder<>(target, reqTypeTag, Serde.VOID, request); + } + + public static Builder withNoRequestBody(Target target, TypeTag resTypeTag) { + return new Builder<>(target, Serde.VOID, resTypeTag, null); + } + + public static Builder ofRaw(Target target, byte[] request) { + return new Builder<>(target, TypeTag.of(Serde.RAW), TypeTag.of(Serde.RAW), request); + } + + public static final class Builder { + private final Target target; + private final TypeTag reqTypeTag; + private final TypeTag resTypeTag; + private final Req request; + @Nullable private String idempotencyKey; + @Nullable private LinkedHashMap headers; + + public Builder( + Target target, + TypeTag reqTypeTag, + TypeTag resTypeTag, + Req request, + @Nullable String idempotencyKey, + @Nullable LinkedHashMap headers) { + this.target = target; + this.reqTypeTag = reqTypeTag; + this.resTypeTag = resTypeTag; + this.request = request; + this.idempotencyKey = idempotencyKey; + this.headers = headers; + } + + private Builder(Target target, TypeTag reqTypeTag, TypeTag resTypeTag, Req request) { + this.target = target; + this.reqTypeTag = reqTypeTag; + this.resTypeTag = resTypeTag; + this.request = request; + } + + /** + * @param idempotencyKey Idempotency key to attach in the request. + * @return this instance, so the builder can be used fluently. + */ + public Builder idempotencyKey(String idempotencyKey) { + this.idempotencyKey = idempotencyKey; + return this; + } + + /** + * @param key header key + * @param value header value + * @return this instance, so the builder can be used fluently. + */ + public Builder header(String key, String value) { + if (this.headers == null) { + this.headers = new LinkedHashMap<>(); + } + this.headers.put(key, value); + return this; + } + + /** + * @param newHeaders headers to send together with the request. + * @return this instance, so the builder can be used fluently. + */ + public Builder headers(Map newHeaders) { + if (this.headers == null) { + this.headers = new LinkedHashMap<>(); + } + this.headers.putAll(newHeaders); + return this; + } + + public @Nullable String getIdempotencyKey() { + return idempotencyKey; + } + + public Builder setIdempotencyKey(@Nullable String idempotencyKey) { + return idempotencyKey(idempotencyKey); + } + + public @Nullable Map getHeaders() { + return headers; + } + + public Builder setHeaders(@Nullable Map headers) { + return headers(headers); + } + + public SendRequest asSend() { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null); + } + + public SendRequest asSendDelayed(Duration delay) { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay); + } + + public Request build() { + return new Request<>( + this.target, + this.reqTypeTag, + this.resTypeTag, + this.request, + this.idempotencyKey, + this.headers); + } + } + + public Builder toBuilder() { + return new Builder<>( + this.target, + this.reqTypeTag, + this.resTypeTag, + this.request, + this.idempotencyKey, + this.headers); + } + + public SendRequest asSend() { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null); + } + + public SendRequest asSendDelayed(Duration delay) { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Request that)) return false; + return Objects.equals(target, that.target) + && Objects.equals(reqTypeTag, that.reqTypeTag) + && Objects.equals(resTypeTag, that.resTypeTag) + && Objects.equals(request, that.request) + && Objects.equals(idempotencyKey, that.idempotencyKey) + && Objects.equals(headers, that.headers); + } + + @Override + public int hashCode() { + return Objects.hash(target, reqTypeTag, resTypeTag, request, idempotencyKey, headers); + } + + @Override + public String toString() { + return "CallRequest{" + + "target=" + + target + + ", reqSerdeInfo=" + + reqTypeTag + + ", resSerdeInfo=" + + resTypeTag + + ", request=" + + request + + ", idempotencyKey='" + + idempotencyKey + + '\'' + + ", headers=" + + headers + + '}'; + } +} diff --git a/common/src/main/java/dev/restate/common/SendRequest.java b/common/src/main/java/dev/restate/common/SendRequest.java new file mode 100644 index 000000000..61f908bfa --- /dev/null +++ b/common/src/main/java/dev/restate/common/SendRequest.java @@ -0,0 +1,81 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common; + +import dev.restate.serde.TypeTag; +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Objects; +import org.jspecify.annotations.Nullable; + +public final class SendRequest extends Request { + + @Nullable private final Duration delay; + + SendRequest( + Target target, + TypeTag reqTypeTag, + TypeTag resTypeTag, + Req request, + @Nullable String idempotencyKey, + @Nullable LinkedHashMap headers, + @Nullable Duration delay) { + super(target, reqTypeTag, resTypeTag, request, idempotencyKey, headers); + this.delay = delay; + } + + public @Nullable Duration delay() { + return delay; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof SendRequest that)) return false; + return Objects.equals(target(), that.target()) + && Objects.equals(requestTypeTag(), that.requestTypeTag()) + && Objects.equals(responseTypeTag(), that.responseTypeTag()) + && Objects.equals(request(), that.request()) + && Objects.equals(idempotencyKey(), that.idempotencyKey()) + && Objects.equals(headers(), that.headers()) + && Objects.equals(delay, that.delay); + } + + @Override + public int hashCode() { + return Objects.hash( + target(), + requestTypeTag(), + responseTypeTag(), + request(), + idempotencyKey(), + headers(), + delay); + } + + @Override + public String toString() { + return "CallRequest{" + + "target=" + + target() + + ", reqSerdeInfo=" + + requestTypeTag() + + ", resSerdeInfo=" + + responseTypeTag() + + ", request=" + + request() + + ", idempotencyKey='" + + idempotencyKey() + + '\'' + + ", headers=" + + headers() + + ", delay=" + + delay + + '}'; + } +} diff --git a/common/src/main/java/dev/restate/common/Slice.java b/common/src/main/java/dev/restate/common/Slice.java new file mode 100644 index 000000000..769da72f5 --- /dev/null +++ b/common/src/main/java/dev/restate/common/Slice.java @@ -0,0 +1,84 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +public interface Slice { + + int readableBytes(); + + void copyTo(ByteBuffer target); + + void copyTo(byte[] target); + + void copyTo(byte[] target, int targetOffset); + + byte byteAt(int position); + + ByteBuffer asReadOnlyByteBuffer(); + + byte[] toByteArray(); + + static Slice wrap(ByteBuffer byteBuffer) { + return new Slice() { + @Override + public ByteBuffer asReadOnlyByteBuffer() { + return byteBuffer.slice(); + } + + @Override + public int readableBytes() { + return byteBuffer.remaining(); + } + + @Override + public void copyTo(byte[] target) { + copyTo(target, 0); + } + + @Override + public void copyTo(byte[] target, int targetOffset) { + byteBuffer.slice().get(target, targetOffset, target.length); + } + + @Override + public byte byteAt(int position) { + return byteBuffer.slice().get(position); + } + + @Override + public void copyTo(ByteBuffer buffer) { + buffer.put(byteBuffer.slice()); + } + + @Override + public byte[] toByteArray() { + if (byteBuffer.hasArray()) { + return byteBuffer.array(); + } + + byte[] dst = new byte[byteBuffer.remaining()]; + byteBuffer.slice().get(dst); + return dst; + } + }; + } + + static Slice wrap(byte[] bytes) { + return wrap(ByteBuffer.wrap(bytes)); + } + + static Slice wrap(String str) { + return wrap(str.getBytes(StandardCharsets.UTF_8)); + } + + Slice EMPTY = Slice.wrap(new byte[0]); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Target.java b/common/src/main/java/dev/restate/common/Target.java similarity index 98% rename from sdk-common/src/main/java/dev/restate/sdk/common/Target.java rename to common/src/main/java/dev/restate/common/Target.java index 989c4d853..cdbd70e3a 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Target.java +++ b/common/src/main/java/dev/restate/common/Target.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.common; import java.util.Objects; import org.jspecify.annotations.Nullable; diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiConsumer.java b/common/src/main/java/dev/restate/common/function/ThrowingBiConsumer.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiConsumer.java rename to common/src/main/java/dev/restate/common/function/ThrowingBiConsumer.java index 30bade486..3ecb9f85c 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiConsumer.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingBiConsumer.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; import java.util.function.BiConsumer; diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiFunction.java b/common/src/main/java/dev/restate/common/function/ThrowingBiFunction.java similarity index 92% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiFunction.java rename to common/src/main/java/dev/restate/common/function/ThrowingBiFunction.java index c00e6740e..b5d04ac06 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingBiFunction.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingBiFunction.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; /** Like {@link java.util.function.BiFunction} but can throw checked exceptions. */ @FunctionalInterface diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingConsumer.java b/common/src/main/java/dev/restate/common/function/ThrowingConsumer.java similarity index 92% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingConsumer.java rename to common/src/main/java/dev/restate/common/function/ThrowingConsumer.java index 361ebc682..1fa009b6d 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingConsumer.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingConsumer.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; import java.util.function.Consumer; diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingFunction.java b/common/src/main/java/dev/restate/common/function/ThrowingFunction.java similarity index 91% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingFunction.java rename to common/src/main/java/dev/restate/common/function/ThrowingFunction.java index 032adeac1..af341490f 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingFunction.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingFunction.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; import java.util.function.Function; @@ -19,6 +19,10 @@ static Function wrap(ThrowingFunction fn) { return fn.asFunction(); } + static ThrowingFunction identity() { + return t -> t; + } + default Function asFunction() { return t -> { try { diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingRunnable.java b/common/src/main/java/dev/restate/common/function/ThrowingRunnable.java similarity index 92% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingRunnable.java rename to common/src/main/java/dev/restate/common/function/ThrowingRunnable.java index 8f95b42c2..6ec3c7791 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingRunnable.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingRunnable.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; /** Like {@link Runnable} but can throw checked exceptions. */ @FunctionalInterface diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingSupplier.java b/common/src/main/java/dev/restate/common/function/ThrowingSupplier.java similarity index 93% rename from sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingSupplier.java rename to common/src/main/java/dev/restate/common/function/ThrowingSupplier.java index 775bc5d05..997c016c2 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/function/ThrowingSupplier.java +++ b/common/src/main/java/dev/restate/common/function/ThrowingSupplier.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.function; +package dev.restate.common.function; /** Like {@link java.util.function.Supplier} but can throw checked exceptions. */ @FunctionalInterface diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java b/common/src/main/java/dev/restate/serde/Serde.java similarity index 52% rename from sdk-common/src/main/java/dev/restate/sdk/common/Serde.java rename to common/src/main/java/dev/restate/serde/Serde.java index 042893fb1..cfac4d040 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java +++ b/common/src/main/java/dev/restate/serde/Serde.java @@ -6,10 +6,10 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.serde; -import dev.restate.sdk.common.function.ThrowingFunction; -import java.nio.ByteBuffer; +import dev.restate.common.Slice; +import dev.restate.common.function.ThrowingFunction; import java.util.Objects; import org.jspecify.annotations.*; @@ -20,28 +20,16 @@ * in {@code sdk-api-kotlin}, {@code JacksonSerdes} in {@code sdk-serde-jackson}, {@code * ProtobufSerdes} in {@code sdk-serde-protobuf}. * + *

Implementations MUST be thread safe. + * *

You can create a custom one using {@link #using(String, ThrowingFunction, ThrowingFunction)}. */ @NullMarked -public interface Serde { - - byte[] serialize(T value); +public interface Serde extends TypeTag { - default ByteBuffer serializeToByteBuffer(T value) { - // This is safe because we don't mutate the generated byte[] afterward. - return ByteBuffer.wrap(serialize(value)); - } + Slice serialize(T value); - T deserialize(byte[] value); - - default T deserialize(ByteBuffer byteBuffer) { - if (byteBuffer.hasArray()) { - return deserialize(byteBuffer.array()); - } - byte[] bytes = new byte[byteBuffer.remaining()]; - byteBuffer.get(bytes); - return deserialize(bytes); - } + T deserialize(Slice value); // --- Metadata about the serialized/deserialized content @@ -54,6 +42,23 @@ default T deserialize(ByteBuffer byteBuffer) { return "application/octet-stream"; } + /** + * @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to + * files or HTTP. The schema is currently used by Restate to introspect the service contract + * and generate an OpenAPI definition. + */ + default @Nullable Schema jsonSchema() { + return null; + } + + sealed interface Schema {} + + /** Schema to be serialized using internal Jackson mapper. */ + record JsonSchema(Object schema) implements Schema {} + + /** Schema already serialized to String. The string should be a valid json schema. */ + record StringifiedJsonSchema(String schema) implements Schema {} + /** * Like {@link #using(String, ThrowingFunction, ThrowingFunction)}, using content-type {@code * application/octet-stream}. @@ -62,13 +67,13 @@ default T deserialize(ByteBuffer byteBuffer) { ThrowingFunction serializer, ThrowingFunction deserializer) { return new Serde<>() { @Override - public byte[] serialize(T value) { - return serializer.asFunction().apply(Objects.requireNonNull(value)); + public Slice serialize(T value) { + return Slice.wrap(serializer.asFunction().apply(Objects.requireNonNull(value))); } @Override - public T deserialize(byte[] value) { - return deserializer.asFunction().apply(value); + public T deserialize(Slice value) { + return deserializer.asFunction().apply(value.toByteArray()); } }; } @@ -83,13 +88,13 @@ public T deserialize(byte[] value) { ThrowingFunction deserializer) { return new Serde<>() { @Override - public byte[] serialize(T value) { - return serializer.asFunction().apply(Objects.requireNonNull(value)); + public Slice serialize(T value) { + return Slice.wrap(serializer.asFunction().apply(Objects.requireNonNull(value))); } @Override - public T deserialize(byte[] value) { - return deserializer.asFunction().apply(value); + public T deserialize(Slice value) { + return deserializer.asFunction().apply(value.toByteArray()); } @Override @@ -102,22 +107,12 @@ public String contentType() { static Serde withContentType(String contentType, Serde inner) { return new Serde<>() { @Override - public byte[] serialize(T value) { + public Slice serialize(T value) { return inner.serialize(value); } @Override - public ByteBuffer serializeToByteBuffer(T value) { - return inner.serializeToByteBuffer(value); - } - - @Override - public T deserialize(ByteBuffer byteBuffer) { - return inner.deserialize(byteBuffer); - } - - @Override - public T deserialize(byte[] value) { + public T deserialize(Slice value) { return inner.deserialize(value); } @@ -132,22 +127,12 @@ public String contentType() { Serde<@Nullable Void> VOID = new Serde<>() { @Override - public byte[] serialize(Void value) { - return new byte[0]; - } - - @Override - public ByteBuffer serializeToByteBuffer(Void value) { - return ByteBuffer.allocate(0); - } - - @Override - public Void deserialize(byte[] value) { - return null; + public Slice serialize(Void value) { + return Slice.EMPTY; } @Override - public Void deserialize(ByteBuffer byteBuffer) { + public Void deserialize(Slice value) { return null; } @@ -161,45 +146,27 @@ public String contentType() { Serde RAW = new Serde<>() { @Override - public byte[] serialize(byte[] value) { - return Objects.requireNonNull(value); + public Slice serialize(byte[] value) { + return Slice.wrap(Objects.requireNonNull(value)); } @Override - public byte[] deserialize(byte[] value) { - return value; + public byte[] deserialize(Slice value) { + return value.toByteArray(); } }; - /** Pass through {@link Serde} for {@link ByteBuffer}. */ - Serde BYTE_BUFFER = + /** Passthrough serializer/deserializer */ + Serde SLICE = new Serde<>() { @Override - public byte[] serialize(ByteBuffer byteBuffer) { - if (byteBuffer == null) { - return new byte[] {}; - } - if (byteBuffer.hasArray()) { - return byteBuffer.array(); - } - byte[] bytes = new byte[byteBuffer.remaining()]; - byteBuffer.get(bytes); - return bytes; - } - - @Override - public ByteBuffer serializeToByteBuffer(ByteBuffer value) { + public Slice serialize(Slice value) { return value; } @Override - public ByteBuffer deserialize(byte[] value) { - return ByteBuffer.wrap(value); - } - - @Override - public ByteBuffer deserialize(ByteBuffer byteBuffer) { - return byteBuffer; + public Slice deserialize(Slice value) { + return value; } }; } diff --git a/common/src/main/java/dev/restate/serde/SerdeFactory.java b/common/src/main/java/dev/restate/serde/SerdeFactory.java new file mode 100644 index 000000000..43c7269e7 --- /dev/null +++ b/common/src/main/java/dev/restate/serde/SerdeFactory.java @@ -0,0 +1,41 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde; + +public interface SerdeFactory { + + Serde create(TypeRef typeRef); + + Serde create(Class clazz); + + default Serde create(TypeTag typeTag) { + if (typeTag instanceof TypeTag.Class tClass) { + return this.create(tClass.type()); + } else if (typeTag instanceof TypeRef tTypeRef) { + return this.create(tTypeRef); + } else { + return ((Serde) typeTag); + } + } + + SerdeFactory NOOP = + new SerdeFactory() { + @Override + public Serde create(TypeRef typeRef) { + throw new UnsupportedOperationException( + "No SerdeFactory class was configured. Please configure one."); + } + + @Override + public Serde create(Class clazz) { + throw new UnsupportedOperationException( + "No SerdeFactory class was configured. Please configure one."); + } + }; +} diff --git a/common/src/main/java/dev/restate/serde/TypeRef.java b/common/src/main/java/dev/restate/serde/TypeRef.java new file mode 100644 index 000000000..d616c46d9 --- /dev/null +++ b/common/src/main/java/dev/restate/serde/TypeRef.java @@ -0,0 +1,36 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * This generic abstract class is used for obtaining full generics type information by sub-classing. + * Similar to Jackson's TypeReference. + * + * @param + */ +public abstract class TypeRef implements TypeTag { + private final Type type; + + protected TypeRef() { + Type superClass = this.getClass().getGenericSuperclass(); + if (superClass instanceof java.lang.Class) { + throw new IllegalArgumentException( + "Internal error: TypeRef constructed without actual type information"); + } else { + this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0]; + } + } + + public Type getType() { + return this.type; + } +} diff --git a/common/src/main/java/dev/restate/serde/TypeTag.java b/common/src/main/java/dev/restate/serde/TypeTag.java new file mode 100644 index 000000000..1a79b0038 --- /dev/null +++ b/common/src/main/java/dev/restate/serde/TypeTag.java @@ -0,0 +1,32 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.serde; + +/** + * Type tag is used to carry types runtime information for serialization/deserialization. Subclasses + * include {@link Serde} and {@link TypeRef}. + * + * @param + */ +public interface TypeTag { + + record Class(java.lang.Class type) implements TypeTag {} + + static TypeTag of(java.lang.Class type) { + return new Class<>(type); + } + + static TypeTag of(dev.restate.serde.TypeRef type) { + return type; + } + + static TypeTag of(dev.restate.serde.Serde serde) { + return serde; + } +} diff --git a/examples/README.md b/examples/README.md index d93cc9330..9885c2f8f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -54,4 +54,4 @@ curl http://localhost:8080/Counter/my-counter/add --json "1" curl http://localhost:8080/Counter/my-counter/get ``` -The command assumes that the Restate runtime is reachable under `localhost:8080`. +The commandAccessor assumes that the Restate runtime is reachable under `localhost:8080`. diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 52b42a339..f430faf0f 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -1,5 +1,4 @@ import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar -import com.github.jengelman.gradle.plugins.shadow.transformers.ServiceFileTransformer plugins { `java-conventions` @@ -13,6 +12,8 @@ dependencies { ksp(project(":sdk-api-kotlin-gen")) annotationProcessor(project(":sdk-api-gen")) + implementation(project(":client")) + implementation(project(":client-kotlin")) implementation(project(":sdk-api")) implementation(project(":sdk-lambda")) implementation(project(":sdk-http-vertx")) @@ -38,6 +39,6 @@ application { tasks.withType { this.enabled = false } -tasks.withType { transform(ServiceFileTransformer::class.java) } +tasks.withType { mergeServiceFiles() } tasks.withType { options.compilerArgs.add("-parameters") } diff --git a/examples/src/main/java/my/restate/sdk/examples/Counter.java b/examples/src/main/java/my/restate/sdk/examples/Counter.java index a612777c2..518529fcb 100644 --- a/examples/src/main/java/my/restate/sdk/examples/Counter.java +++ b/examples/src/main/java/my/restate/sdk/examples/Counter.java @@ -8,14 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; -import dev.restate.sdk.JsonSerdes; -import dev.restate.sdk.ObjectContext; -import dev.restate.sdk.SharedObjectContext; +import dev.restate.sdk.*; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Shared; import dev.restate.sdk.annotation.VirtualObject; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.http.vertx.RestateHttpServer; +import dev.restate.sdk.types.StateKey; +import java.time.Duration; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -25,7 +25,7 @@ public class Counter { private static final Logger LOG = LogManager.getLogger(Counter.class); - private static final StateKey TOTAL = StateKey.of("total", JsonSerdes.LONG); + private static final StateKey TOTAL = StateKey.of("total", Long.class); /** Reset the counter. */ @Handler @@ -38,6 +38,7 @@ public void reset(ObjectContext ctx) { public void add(ObjectContext ctx, long request) { long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request; + ctx.sleep(Duration.ofSeconds(120)); ctx.set(TOTAL, newValue); } @@ -61,7 +62,9 @@ public CounterUpdateResult getAndAdd(ObjectContext ctx, long request) { } public static void main(String[] args) { - RestateHttpEndpointBuilder.builder().bind(new Counter()).buildAndListen(); + Endpoint endpoint = Endpoint.builder().bind(new Counter()).build(); + + RestateHttpServer.listen(endpoint); } public record CounterUpdateResult(long newValue, long oldValue) {} diff --git a/examples/src/main/java/my/restate/sdk/examples/LambdaHandler.java b/examples/src/main/java/my/restate/sdk/examples/LambdaHandler.java index b11303ff2..7938e00ca 100644 --- a/examples/src/main/java/my/restate/sdk/examples/LambdaHandler.java +++ b/examples/src/main/java/my/restate/sdk/examples/LambdaHandler.java @@ -8,15 +8,15 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package my.restate.sdk.examples; +import dev.restate.sdk.endpoint.Endpoint; import dev.restate.sdk.lambda.BaseRestateLambdaHandler; -import dev.restate.sdk.lambda.RestateLambdaEndpointBuilder; import java.util.Objects; import java.util.regex.Pattern; public class LambdaHandler extends BaseRestateLambdaHandler { @Override - public void register(RestateLambdaEndpointBuilder builder) { + public void register(Endpoint.Builder builder) { for (String serviceClass : Objects.requireNonNullElse( System.getenv("LAMBDA_FACTORY_SERVICE_CLASS"), Counter.class.getCanonicalName()) diff --git a/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java b/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java index c5cdab7b9..8b41dcbdc 100644 --- a/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java +++ b/examples/src/main/java/my/restate/sdk/examples/LoanWorkflow.java @@ -9,22 +9,20 @@ package my.restate.sdk.examples; import dev.restate.sdk.Context; -import dev.restate.sdk.JsonSerdes; import dev.restate.sdk.SharedWorkflowContext; import dev.restate.sdk.WorkflowContext; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Service; import dev.restate.sdk.annotation.Shared; import dev.restate.sdk.annotation.Workflow; -import dev.restate.sdk.common.DurablePromiseKey; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder; -import dev.restate.sdk.serde.jackson.JacksonSerdes; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.http.vertx.RestateHttpServer; +import dev.restate.sdk.types.DurablePromiseKey; +import dev.restate.sdk.types.StateKey; +import dev.restate.sdk.types.TerminalException; import java.math.BigDecimal; import java.time.Duration; import java.time.Instant; -import java.util.concurrent.TimeoutException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -48,14 +46,13 @@ public record LoanRequest( private static final Logger LOG = LogManager.getLogger(LoanWorkflow.class); - private static final StateKey STATUS = - StateKey.of("status", JacksonSerdes.of(Status.class)); + private static final StateKey STATUS = StateKey.of("status", Status.class); private static final StateKey LOAN_REQUEST = - StateKey.of("loanRequest", JacksonSerdes.of(LoanRequest.class)); + StateKey.of("loanRequest", LoanRequest.class); private static final DurablePromiseKey HUMAN_APPROVAL = - DurablePromiseKey.of("humanApproval", JsonSerdes.BOOLEAN); + DurablePromiseKey.of("humanApproval", Boolean.class); private static final StateKey TRANSFER_EXECUTION_TIME = - StateKey.of("transferExecutionTime", JsonSerdes.STRING); + StateKey.of("transferExecutionTime", String.class); // --- The main workflow method @@ -90,7 +87,7 @@ public String run(WorkflowContext ctx, LoanRequest loanRequest) { .transfer( new TransferRequest(loanRequest.customerBankAccount(), loanRequest.amount())) .await(Duration.ofDays(7)); - } catch (TerminalException | TimeoutException e) { + } catch (TerminalException e) { LOG.warn("Transaction failed", e); ctx.set(STATUS, Status.TRANSFER_FAILED); return "Failed"; @@ -124,10 +121,9 @@ public Status getStatus(SharedWorkflowContext ctx) { } public static void main(String[] args) { - RestateHttpEndpointBuilder.builder() - .bind(new LoanWorkflow()) - .bind(new MockBank()) - .buildAndListen(); + Endpoint endpoint = Endpoint.builder().bind(new LoanWorkflow()).bind(new MockBank()).build(); + + RestateHttpServer.listen(endpoint); // Register the service in the meantime! LOG.info("Now it's time to register this deployment"); diff --git a/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt b/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt index 1526d5a63..6372b05e3 100644 --- a/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt +++ b/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt @@ -11,14 +11,9 @@ package my.restate.sdk.examples import dev.restate.sdk.annotation.Handler import dev.restate.sdk.annotation.Shared import dev.restate.sdk.annotation.VirtualObject -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder -import dev.restate.sdk.kotlin.HandlerRunner -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.SharedObjectContext -import io.vertx.core.Vertx -import io.vertx.core.VertxOptions -import kotlinx.coroutines.Dispatchers +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.endpoint.* import kotlinx.serialization.Serializable import org.apache.logging.log4j.LogManager import org.apache.logging.log4j.Logger @@ -29,7 +24,7 @@ import org.apache.logging.log4j.Logger class CounterKt { companion object { - private val TOTAL = KtStateKey.json("total") + private val TOTAL = stateKey("total") private val LOG: Logger = LogManager.getLogger(CounterKt::class.java) } @@ -62,9 +57,6 @@ class CounterKt { } fun main() { - RestateHttpEndpointBuilder.builder(Vertx.vertx(VertxOptions().setEventLoopPoolSize(8))) - .bind( - CounterKtServiceDefinitionFactory().create(CounterKt()), - HandlerRunner.Options(Dispatchers.Unconfined)) - .buildAndListen() + val endpoint = endpoint { bind(CounterKt()) } + RestateHttpServer.listen(endpoint) } diff --git a/examples/src/main/resources/log4j2.properties b/examples/src/main/resources/log4j2.properties index 363186141..871f44bc5 100644 --- a/examples/src/main/resources/log4j2.properties +++ b/examples/src/main/resources/log4j2.properties @@ -5,7 +5,7 @@ status = warn appender.console.type = Console appender.console.name = consoleLogger appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p %notEmpty{[%X{restateInvocationTarget}]}%notEmpty{[%X{restateInvocationId}]} %c - %m%n +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p %notEmpty{[%X{restateInvocationTarget}]}%notEmpty{[%X{restateInvocationId}]} %t %c - %m%n # Filter out logging during replay appender.console.filter.replay.type = ContextMapFilter @@ -22,5 +22,5 @@ logger.app.additivity = false logger.app.appenderRef.console.ref = consoleLogger # Root logger -rootLogger.level = error +rootLogger.level = warn rootLogger.appenderRef.stdout.ref = consoleLogger \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d4f203184..ab00b6862 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -10,7 +10,7 @@ kotlinx-coroutines = "1.9.0" junit = "5.10.2" spring-boot = "3.4.0" log4j = "2.24.2" -restate = "1.3.0-SNAPSHOT" +restate = "2.0.0-SNAPSHOT" [libraries] aws-lambda-core = "com.amazonaws:aws-lambda-java-core:1.2.3" @@ -47,7 +47,7 @@ kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serializa jspecify = "org.jspecify:jspecify:1.0.0" junit-jupiter = { module = "org.junit.jupiter:junit-jupiter", version.ref = "junit" } junit-api = { module = "org.junit.jupiter:junit-jupiter-api", version.ref = "junit" } -slf4j-nop = "org.slf4j:slf4j-nop:1.7.32" +slf4j-nop = "org.slf4j:slf4j-nop:2.0.16" spring-boot-starter = { module = "org.springframework.boot:spring-boot-starter", version.ref = "spring-boot" } spring-boot-starter-json = { module = "org.springframework.boot:spring-boot-starter-json", version.ref = "spring-boot" } spring-boot-starter-test = { module = "org.springframework.boot:spring-boot-starter-test", version.ref = "spring-boot" } @@ -55,7 +55,7 @@ testcontainers = "org.testcontainers:testcontainers:1.20.4" [plugins] dependency-license-report = "com.github.jk1.dependency-license-report:2.0" -shadow = "com.github.johnrengelman.shadow:8.1.1" +shadow = "com.gradleup.shadow:9.0.0-beta8" jib = "com.google.cloud.tools.jib:3.4.4" ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } protobuf = "com.google.protobuf:0.9.4" diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 1af9e0930..e18bc253b 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.12.1-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Service.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Service.java index c839c3994..722495bcc 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Service.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Service.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.gen.model; -import dev.restate.sdk.common.ServiceType; +import dev.restate.sdk.endpoint.definition.ServiceType; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -24,6 +24,9 @@ public class Service { private final ServiceType serviceType; private final List handlers; private final @Nullable String documentation; + private final boolean contextClientEnabled; + private final boolean ingressClientEnabled; + private final String serdeFactoryDecl; public Service( CharSequence targetPkg, @@ -31,14 +34,19 @@ public Service( String serviceName, ServiceType serviceType, List handlers, - @Nullable String documentation) { + @Nullable String documentation, + boolean contextClientEnabled, + boolean ingressClientEnabled, + String serdeFactoryDecl) { this.targetPkg = targetPkg; this.targetFqcn = targetFqcn; this.serviceName = serviceName; - this.serviceType = serviceType; this.handlers = handlers; this.documentation = documentation; + this.contextClientEnabled = contextClientEnabled; + this.ingressClientEnabled = ingressClientEnabled; + this.serdeFactoryDecl = serdeFactoryDecl; } public CharSequence getTargetPkg() { @@ -76,6 +84,18 @@ public List getMethods() { return documentation; } + public boolean isContextClientEnabled() { + return contextClientEnabled; + } + + public boolean isIngressClientEnabled() { + return ingressClientEnabled; + } + + public String getSerdeFactoryDecl() { + return serdeFactoryDecl; + } + public static Builder builder() { return new Builder(); } @@ -87,6 +107,9 @@ public static class Builder { private ServiceType serviceType; private final List handlers = new ArrayList<>(); private String documentation; + private boolean contextClientEnabled = true; + private boolean ingressClientEnabled = true; + private String serdeFactoryDecl; public Builder withTargetPkg(CharSequence targetPkg) { this.targetPkg = targetPkg; @@ -123,6 +146,21 @@ public Builder withDocumentation(String documentation) { return this; } + public Builder withContextClientEnabled(boolean contextClientEnabled) { + this.contextClientEnabled = contextClientEnabled; + return this; + } + + public Builder withIngressClientEnabled(boolean ingressClientEnabled) { + this.ingressClientEnabled = ingressClientEnabled; + return this; + } + + public Builder withSerdeFactoryDecl(String serdeFactoryDecl) { + this.serdeFactoryDecl = serdeFactoryDecl; + return this; + } + public CharSequence getTargetPkg() { return targetPkg; } @@ -164,13 +202,18 @@ public Service validateAndBuild() { throw new IllegalArgumentException("Cannot have two handlers with the same name"); } + Objects.requireNonNull(serdeFactoryDecl, "Serde factory should not be null"); + return new Service( Objects.requireNonNull(targetPkg), Objects.requireNonNull(targetFqcn), Objects.requireNonNull(serviceName), Objects.requireNonNull(serviceType), handlers, - documentation); + documentation, + contextClientEnabled, + ingressClientEnabled, + serdeFactoryDecl); } } } diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java index 73b0ed37c..3673e3f6f 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java @@ -15,8 +15,8 @@ import com.github.jknack.handlebars.helper.StringHelpers; import com.github.jknack.handlebars.internal.text.StringEscapeUtils; import com.github.jknack.handlebars.io.TemplateLoader; -import dev.restate.sdk.common.ServiceType; -import dev.restate.sdk.common.function.ThrowingFunction; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.endpoint.definition.ServiceType; import dev.restate.sdk.gen.model.Handler; import dev.restate.sdk.gen.model.HandlerType; import dev.restate.sdk.gen.model.Service; @@ -49,15 +49,16 @@ public HandlebarsTemplateEngine( return switch (h.serviceType) { case SERVICE -> String.format( - "Target.service(%s.SERVICE_NAME, \"%s\")", h.definitionsClass, h.name); + "dev.restate.common.Target.service(%s.SERVICE_NAME, \"%s\")", + h.metadataClass, h.name); case VIRTUAL_OBJECT -> String.format( - "Target.virtualObject(%s.SERVICE_NAME, %s, \"%s\")", - h.definitionsClass, options.param(0), h.name); + "dev.restate.common.Target.virtualObject(%s.SERVICE_NAME, %s, \"%s\")", + h.metadataClass, options.param(0), h.name); case WORKFLOW -> String.format( - "Target.workflow(%s.SERVICE_NAME, %s, \"%s\")", - h.definitionsClass, options.param(0), h.name); + "dev.restate.common.Target.workflow(%s.SERVICE_NAME, %s, \"%s\")", + h.metadataClass, options.param(0), h.name); }; }); handlebars.registerHelpers(StringEscapeUtils.class); @@ -107,6 +108,14 @@ static class ServiceTemplateModel { public final String serviceName; public final String documentation; + public final String metadataClass; + public final String requestsClass; + + public final boolean contextClientEnabled; + public final boolean ingressClientEnabled; + public final String serdeFactoryDecl; + public final String serdeFactoryRef; + public final String serviceType; public final boolean isWorkflow; public final boolean isObject; @@ -122,6 +131,14 @@ private ServiceTemplateModel( this.generatedClassSimpleName = this.generatedClassSimpleNamePrefix + baseTemplateName; this.serviceName = inner.getFullyQualifiedServiceName(); + this.metadataClass = this.generatedClassSimpleNamePrefix + "Metadata"; + this.requestsClass = this.generatedClassSimpleNamePrefix + "Requests"; + + this.contextClientEnabled = inner.isContextClientEnabled(); + this.ingressClientEnabled = inner.isIngressClientEnabled(); + this.serdeFactoryDecl = inner.getSerdeFactoryDecl(); + this.serdeFactoryRef = metadataClass + ".SERDE_FACTORY"; + this.documentation = inner.getDocumentation(); this.serviceType = inner.getServiceType().toString(); @@ -135,10 +152,7 @@ private ServiceTemplateModel( .map( h -> new HandlerTemplateModel( - h, - inner.getServiceType(), - this.generatedClassSimpleNamePrefix + "Definitions", - handlerNamesToPrefix)) + h, inner.getServiceType(), metadataClass, handlerNamesToPrefix)) .collect(Collectors.toList()); } } @@ -153,7 +167,7 @@ static class HandlerTemplateModel { public final boolean isExclusive; private final ServiceType serviceType; - private final String definitionsClass; + private final String metadataClass; public final String documentation; public final boolean inputEmpty; @@ -174,7 +188,7 @@ static class HandlerTemplateModel { private HandlerTemplateModel( Handler inner, ServiceType serviceType, - String definitionsClass, + String metadataClass, Set handlerNamesToPrefix) { this.name = inner.name().toString(); this.methodName = (handlerNamesToPrefix.contains(this.name) ? "_" : "") + this.name; @@ -185,7 +199,7 @@ private HandlerTemplateModel( this.isStateless = inner.handlerType() == HandlerType.STATELESS; this.serviceType = serviceType; - this.definitionsClass = definitionsClass; + this.metadataClass = metadataClass; this.documentation = inner.documentation(); this.inputEmpty = inner.inputType().isEmpty(); @@ -194,14 +208,14 @@ private HandlerTemplateModel( this.boxedInputFqcn = inner.inputType().boxed(); this.inputSerdeFieldName = this.name.toUpperCase() + "_INPUT"; this.inputAcceptContentType = inner.inputAccept(); - this.inputSerdeRef = definitionsClass + ".Serde." + this.inputSerdeFieldName; + this.inputSerdeRef = metadataClass + ".Serde." + this.inputSerdeFieldName; this.outputEmpty = inner.outputType().isEmpty(); this.outputFqcn = inner.outputType().name(); this.outputSerdeDecl = inner.outputType().serdeDecl(); this.boxedOutputFqcn = inner.outputType().boxed(); this.outputSerdeFieldName = this.name.toUpperCase() + "_OUTPUT"; - this.outputSerdeRef = definitionsClass + ".Serde." + this.outputSerdeFieldName; + this.outputSerdeRef = metadataClass + ".Serde." + this.outputSerdeFieldName; } } } diff --git a/sdk-api-gen/build.gradle.kts b/sdk-api-gen/build.gradle.kts index c4d33b9be..d2d1d492a 100644 --- a/sdk-api-gen/build.gradle.kts +++ b/sdk-api-gen/build.gradle.kts @@ -1,6 +1,5 @@ plugins { `java-conventions` - `test-jar-conventions` application `library-publishing-conventions` } @@ -12,17 +11,4 @@ dependencies { implementation(project(":sdk-api-gen-common")) implementation(project(":sdk-api")) - - testAnnotationProcessor(project(":sdk-api-gen")) - testImplementation(project(":sdk-core")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - testImplementation(libs.protobuf.java) - testImplementation(libs.log4j.core) - testImplementation(libs.jackson.databind) - testImplementation(project(":sdk-serde-jackson")) - testImplementation(libs.mutiny) - - // Import test suites from sdk-core - testImplementation(project(":sdk-core", "testArchive")) } diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java index 3c8518c1d..33d93eb68 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java @@ -14,7 +14,7 @@ import dev.restate.sdk.SharedWorkflowContext; import dev.restate.sdk.WorkflowContext; import dev.restate.sdk.annotation.*; -import dev.restate.sdk.common.ServiceType; +import dev.restate.sdk.endpoint.definition.ServiceType; import dev.restate.sdk.gen.model.*; import dev.restate.sdk.gen.model.Handler; import dev.restate.sdk.gen.model.Service; @@ -24,6 +24,7 @@ import java.util.stream.Collectors; import javax.annotation.processing.Messager; import javax.lang.model.element.*; +import javax.lang.model.type.MirroredTypeException; import javax.lang.model.type.TypeKind; import javax.lang.model.type.TypeMirror; import javax.lang.model.util.Elements; @@ -34,8 +35,8 @@ class ElementConverter { private static final PayloadType EMPTY_PAYLOAD = - new PayloadType(true, "", "Void", "dev.restate.sdk.common.Serde.VOID"); - private static final String RAW_SERDE = "dev.restate.sdk.common.Serde.RAW"; + new PayloadType(true, "", "Void", "dev.restate.serde.Serde.VOID"); + private static final String RAW_SERDE = "dev.restate.serde.Serde.RAW"; private final Messager messager; private final Elements elements; @@ -95,6 +96,12 @@ Service fromTypeElement(MetaRestateAnnotation metaAnnotation, TypeElement elemen Diagnostic.Kind.WARNING, "The service " + serviceName + " has no handlers", element); } + String serdeFactoryDecl = "new dev.restate.sdk.serde.jackson.JacksonSerdeFactory()"; + CustomSerdeFactory customSerdeFactory = element.getAnnotation(CustomSerdeFactory.class); + if (customSerdeFactory != null) { + serdeFactoryDecl = "new " + getCustomSerdeClassCanonicalName(customSerdeFactory) + "()"; + } + try { return new Service.Builder() .withTargetPkg(targetPkg) @@ -103,6 +110,7 @@ Service fromTypeElement(MetaRestateAnnotation metaAnnotation, TypeElement elemen .withDocumentation(sanitizeJavadoc(elements.getDocComment(element))) .withServiceType(metaAnnotation.getServiceType()) .withHandlers(handlers) + .withSerdeFactoryDecl(serdeFactoryDecl) .validateAndBuild(); } catch (Exception e) { messager.printMessage( @@ -311,7 +319,7 @@ private PayloadType payloadFromTypeMirrorAndAnnotations( element); } - String serdeDecl = rawAnnotation != null ? RAW_SERDE : jsonSerdeDecl(ty); + String serdeDecl = rawAnnotation != null ? RAW_SERDE : serdeDecl(ty); if (rawAnnotation != null && !rawAnnotation .contentType() @@ -329,29 +337,17 @@ private PayloadType payloadFromTypeMirrorAndAnnotations( } private static String contentTypeDecoratedSerdeDecl(String serdeDecl, String contentType) { - return "dev.restate.sdk.common.Serde.withContentType(\"" - + contentType - + "\", " - + serdeDecl - + ")"; + return "dev.restate.serde.Serde.withContentType(\"" + contentType + "\", " + serdeDecl + ")"; } - private static String jsonSerdeDecl(TypeMirror ty) { + private static String serdeDecl(TypeMirror ty) { return switch (ty.getKind()) { - case BOOLEAN -> "dev.restate.sdk.JsonSerdes.BOOLEAN"; - case BYTE -> "dev.restate.sdk.JsonSerdes.BYTE"; - case SHORT -> "dev.restate.sdk.JsonSerdes.SHORT"; - case INT -> "dev.restate.sdk.JsonSerdes.INT"; - case LONG -> "dev.restate.sdk.JsonSerdes.LONG"; - case CHAR -> "dev.restate.sdk.JsonSerdes.CHAR"; - case FLOAT -> "dev.restate.sdk.JsonSerdes.FLOAT"; - case DOUBLE -> "dev.restate.sdk.JsonSerdes.DOUBLE"; - case VOID -> "dev.restate.sdk.common.Serde.VOID"; + case VOID -> "dev.restate.serde.Serde.VOID"; default -> // Default to Jackson type reference serde - "dev.restate.sdk.serde.jackson.JacksonSerdes.of(new com.fasterxml.jackson.core.type.TypeReference<" - + ty - + ">() {})"; + "SERDE_FACTORY.create(dev.restate.serde.TypeTag.of(new dev.restate.serde.TypeRef<" + + boxedType(ty) + + ">() {}))"; }; } @@ -370,6 +366,15 @@ private static String boxedType(TypeMirror ty) { }; } + private String getCustomSerdeClassCanonicalName(CustomSerdeFactory customSerdeFactoryAnnotation) { + try { + Class clazz = customSerdeFactoryAnnotation.value(); + return clazz.getCanonicalName(); + } catch (MirroredTypeException e) { + return e.getTypeMirror().toString(); + } + } + private static String sanitizeJavadoc(String documentation) { // TODO this needs probably a bit more work, but eventually people will use markdown for // javadocs anyway! diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/MetaRestateAnnotation.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/MetaRestateAnnotation.java index b3a24720a..886b9100c 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/MetaRestateAnnotation.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/MetaRestateAnnotation.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.gen; -import dev.restate.sdk.common.ServiceType; +import dev.restate.sdk.endpoint.definition.ServiceType; import java.util.Map; import javax.lang.model.element.*; import org.jspecify.annotations.Nullable; diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java index bfad019b7..af0dd0562 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java @@ -8,9 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.gen; -import dev.restate.sdk.common.ServiceType; -import dev.restate.sdk.common.function.ThrowingFunction; -import dev.restate.sdk.common.syscalls.ServiceDefinitionFactory; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; +import dev.restate.sdk.endpoint.definition.ServiceType; import dev.restate.sdk.gen.model.Service; import dev.restate.sdk.gen.template.HandlebarsTemplateEngine; import java.io.*; @@ -31,9 +31,10 @@ @SupportedSourceVersion(SourceVersion.RELEASE_17) public class ServiceProcessor extends AbstractProcessor { - private HandlebarsTemplateEngine definitionsCodegen; + private HandlebarsTemplateEngine metadataCodegen; private HandlebarsTemplateEngine serviceDefinitionFactoryCodegen; private HandlebarsTemplateEngine clientCodegen; + private HandlebarsTemplateEngine requestsCodegen; private static final Set RESERVED_METHOD_NAMES = Set.of("send", "submit", "workflowHandle"); @@ -44,17 +45,17 @@ public synchronized void init(ProcessingEnvironment processingEnv) { FilerTemplateLoader filerTemplateLoader = new FilerTemplateLoader(processingEnv.getFiler()); - this.definitionsCodegen = + this.metadataCodegen = new HandlebarsTemplateEngine( - "Definitions", + "Metadata", filerTemplateLoader, Map.of( ServiceType.WORKFLOW, - "templates/Definitions.hbs", + "templates/Metadata.hbs", ServiceType.SERVICE, - "templates/Definitions.hbs", + "templates/Metadata.hbs", ServiceType.VIRTUAL_OBJECT, - "templates/Definitions.hbs"), + "templates/Metadata.hbs"), RESERVED_METHOD_NAMES); this.serviceDefinitionFactoryCodegen = new HandlebarsTemplateEngine( @@ -80,6 +81,18 @@ public synchronized void init(ProcessingEnvironment processingEnv) { ServiceType.VIRTUAL_OBJECT, "templates/Client.hbs"), RESERVED_METHOD_NAMES); + this.requestsCodegen = + new HandlebarsTemplateEngine( + "Requests", + filerTemplateLoader, + Map.of( + ServiceType.WORKFLOW, + "templates/Requests.hbs", + ServiceType.SERVICE, + "templates/Requests.hbs", + ServiceType.VIRTUAL_OBJECT, + "templates/Requests.hbs"), + RESERVED_METHOD_NAMES); } @Override @@ -115,9 +128,12 @@ public boolean process(Set annotations, RoundEnvironment try { ThrowingFunction fileCreator = name -> filer.createSourceFile(name, e.getKey()).openWriter(); - this.definitionsCodegen.generate(fileCreator, e.getValue()); + this.metadataCodegen.generate(fileCreator, e.getValue()); this.serviceDefinitionFactoryCodegen.generate(fileCreator, e.getValue()); - this.clientCodegen.generate(fileCreator, e.getValue()); + this.requestsCodegen.generate(fileCreator, e.getValue()); + if (e.getValue().isContextClientEnabled() || e.getValue().isIngressClientEnabled()) { + this.clientCodegen.generate(fileCreator, e.getValue()); + } } catch (Throwable ex) { throw new RuntimeException(ex); } diff --git a/sdk-api-gen/src/main/resources/templates/Client.hbs b/sdk-api-gen/src/main/resources/templates/Client.hbs index c57063ff5..90f75b38d 100644 --- a/sdk-api-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-gen/src/main/resources/templates/Client.hbs @@ -1,27 +1,37 @@ {{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} import dev.restate.sdk.Awaitable; +import dev.restate.sdk.CallAwaitable; import dev.restate.sdk.Context; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.Target; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; +import dev.restate.common.Target; import java.util.Optional; import java.time.Duration; public class {{generatedClassSimpleName}} { + {{#contextClientEnabled}} public static ContextClient fromContext(Context ctx{{#isKeyed}}, String key{{/isKeyed}}) { return new ContextClient(ctx{{#isKeyed}}, key{{/isKeyed}}); } + {{/contextClientEnabled}} - public static IngressClient fromClient(dev.restate.sdk.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { + {{#ingressClientEnabled}} + public static IngressClient fromClient(dev.restate.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { return new IngressClient(client{{#isKeyed}}, key{{/isKeyed}}); } public static IngressClient connect(String baseUri{{#isKeyed}}, String key{{/isKeyed}}) { - return new IngressClient(dev.restate.sdk.client.Client.connect(baseUri){{#isKeyed}}, key{{/isKeyed}}); + return new IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY){{#isKeyed}}, key{{/isKeyed}}); } + public static IngressClient connect(String baseUri, dev.restate.client.ClientRequestOptions requestOptions{{#isKeyed}}, String key{{/isKeyed}}) { + return new IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY, requestOptions){{#isKeyed}}, key{{/isKeyed}}); + } + {{/ingressClientEnabled}} + + {{#contextClientEnabled}} public static class ContextClient { private final Context ctx; @@ -33,118 +43,78 @@ public class {{generatedClassSimpleName}} { } {{#handlers}} - public Awaitable<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + public CallAwaitable<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return this.ctx.call( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{outputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}); - }{{/handlers}} - - public Send send() { - return new Send(null); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ); } + {{/handlers}} - public Send send(Duration delay) { - return new Send(delay); + public Send send() { + return new Send(); } public class Send { - private final Duration delay; - - Send(Duration delay) { - this.delay = delay; - } - {{#handlers}} - public void {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - ContextClient.this.ctx.send( - {{{targetExpr this "ContextClient.this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - delay); - }{{/handlers}} + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return ContextClient.this.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ); + } + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}Duration delay) { + return ContextClient.this.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(delay) + ); + } + {{/handlers}} } } + {{/contextClientEnabled}} + {{#ingressClientEnabled}} public static class IngressClient { - private final dev.restate.sdk.client.Client client; + private final dev.restate.client.Client client; {{#isKeyed}}private final String key;{{/isKeyed}} - public IngressClient(dev.restate.sdk.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { + public IngressClient(dev.restate.client.Client client{{#isKeyed}}, String key{{/isKeyed}}) { this.client = client; {{#isKeyed}}this.key = key;{{/isKeyed}} } {{#handlers}}{{#if isWorkflow}} - public dev.restate.sdk.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> workflowHandle() { + public dev.restate.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> workflowHandle() { return IngressClient.this.client.workflowHandle( - {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, + {{metadataClass}}.SERVICE_NAME, this.key, {{outputSerdeRef}}); } - public dev.restate.sdk.client.SendResponse submit({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return this.submit( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.RequestOptions.DEFAULT); - } - - public dev.restate.sdk.client.SendResponse submit({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { + public dev.restate.client.SendResponse submit({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return IngressClient.this.client.send( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - null, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).response(); } - public java.util.concurrent.CompletableFuture submitAsync({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return this.submitAsync( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.RequestOptions.DEFAULT); - } - - public java.util.concurrent.CompletableFuture submitAsync({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { + public java.util.concurrent.CompletableFuture submitAsync({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return IngressClient.this.client.sendAsync( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - null, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).thenApply(dev.restate.client.ClientResponse::response); } {{else}} public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - {{^outputEmpty}}return {{/outputEmpty}}this.{{methodName}}( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.CallRequestOptions.DEFAULT); - } - - public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.CallRequestOptions requestOptions) { {{^outputEmpty}}return {{/outputEmpty}}this.client.call( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{outputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ).response(); } public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return this.{{methodName}}Async( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.CallRequestOptions.DEFAULT); - } - - public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.CallRequestOptions requestOptions) { return this.client.callAsync( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{outputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - requestOptions); - }{{/if}}{{/handlers}} + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + ).thenApply(dev.restate.client.ClientResponse::response); + } + {{/if}}{{/handlers}} public Send send() { return new Send(null); @@ -163,35 +133,28 @@ public class {{generatedClassSimpleName}} { } {{#handlers}}{{^isWorkflow}} - public dev.restate.sdk.client.SendResponse {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return this.{{methodName}}( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.CallRequestOptions.DEFAULT); - } - - public dev.restate.sdk.client.SendResponse {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.CallRequestOptions requestOptions) { + public dev.restate.client.SendResponse {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + if (this.delay == null) { + return IngressClient.this.client.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).response(); + } return IngressClient.this.client.send( - {{{targetExpr this "IngressClient.this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - this.delay, - requestOptions); - } - - public java.util.concurrent.CompletableFuture {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return this.{{methodName}}Async( - {{^inputEmpty}}req, {{/inputEmpty}} - dev.restate.sdk.client.CallRequestOptions.DEFAULT); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(this.delay) + ).response(); } - public java.util.concurrent.CompletableFuture {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.CallRequestOptions requestOptions) { + public java.util.concurrent.CompletableFuture {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + if (this.delay == null) { + return IngressClient.this.client.sendAsync( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).thenApply(dev.restate.client.ClientResponse::response); + } return IngressClient.this.client.sendAsync( - {{{targetExpr this "IngressClient.this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}, - this.delay, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(this.delay) + ).thenApply(dev.restate.client.ClientResponse::response); }{{/isWorkflow}}{{/handlers}} } } + {{/ingressClientEnabled}} } \ No newline at end of file diff --git a/sdk-api-gen/src/main/resources/templates/Definitions.hbs b/sdk-api-gen/src/main/resources/templates/Metadata.hbs similarity index 50% rename from sdk-api-gen/src/main/resources/templates/Definitions.hbs rename to sdk-api-gen/src/main/resources/templates/Metadata.hbs index bb3932b4d..a6c59449c 100644 --- a/sdk-api-gen/src/main/resources/templates/Definitions.hbs +++ b/sdk-api-gen/src/main/resources/templates/Metadata.hbs @@ -3,13 +3,14 @@ public final class {{generatedClassSimpleName}} { public static final String SERVICE_NAME = "{{serviceName}}"; + public static final dev.restate.serde.SerdeFactory SERDE_FACTORY = {{serdeFactoryDecl}}; private {{generatedClassSimpleName}}() {} public final static class Serde { {{#handlers}} - public static final dev.restate.sdk.common.Serde<{{{boxedInputFqcn}}}> {{inputSerdeFieldName}} = {{{inputSerdeDecl}}}; - public static final dev.restate.sdk.common.Serde<{{{boxedOutputFqcn}}}> {{outputSerdeFieldName}} = {{{outputSerdeDecl}}}; + public static final dev.restate.serde.Serde<{{{boxedInputFqcn}}}> {{inputSerdeFieldName}} = {{{inputSerdeDecl}}}; + public static final dev.restate.serde.Serde<{{{boxedOutputFqcn}}}> {{outputSerdeFieldName}} = {{{outputSerdeDecl}}}; {{/handlers}} private Serde() {} diff --git a/sdk-api-gen/src/main/resources/templates/Requests.hbs b/sdk-api-gen/src/main/resources/templates/Requests.hbs new file mode 100644 index 000000000..21063900f --- /dev/null +++ b/sdk-api-gen/src/main/resources/templates/Requests.hbs @@ -0,0 +1,17 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +public final class {{generatedClassSimpleName}} { + + private {{generatedClassSimpleName}}() {} + + {{#handlers}} + public static dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> {{methodName}}({{#if ../isKeyed}}String key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return dev.restate.common.Request.of( + {{{targetExpr this "key"}}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + {{#if inputEmpty}}null{{else}}req{{/if}}); + } + + {{/handlers}} +} \ No newline at end of file diff --git a/sdk-api-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs b/sdk-api-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs index 5c1a92024..710d54f7e 100644 --- a/sdk-api-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs +++ b/sdk-api-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs @@ -1,23 +1,29 @@ {{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} -public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.syscalls.ServiceDefinitionFactory<{{originalClassFqcn}}, dev.restate.sdk.HandlerRunner.Options> { +public class {{generatedClassSimpleName}} implements dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory<{{originalClassFqcn}}> { @java.lang.Override - public dev.restate.sdk.common.syscalls.ServiceDefinition create({{originalClassFqcn}} bindableService) { - return dev.restate.sdk.common.syscalls.ServiceDefinition.of( - {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, - {{#if isObject}}dev.restate.sdk.common.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.common.ServiceType.WORKFLOW{{else}}dev.restate.sdk.common.ServiceType.SERVICE{{/if}}, + public dev.restate.sdk.endpoint.definition.ServiceDefinition create({{originalClassFqcn}} bindableService, dev.restate.sdk.endpoint.definition.HandlerRunner.Options overrideHandlerOptions) { + dev.restate.sdk.HandlerRunner.Options handlerRunnerOptions = dev.restate.sdk.HandlerRunner.Options.DEFAULT; + if (overrideHandlerOptions != null) { + if (overrideHandlerOptions instanceof dev.restate.sdk.HandlerRunner.Options) { + handlerRunnerOptions = (dev.restate.sdk.HandlerRunner.Options)overrideHandlerOptions; + } else { + throw new IllegalArgumentException("The provided options class MUST be instance of dev.restate.sdk.HandlerRunner.Options, but was " + overrideHandlerOptions.getClass()); + } + } + return dev.restate.sdk.endpoint.definition.ServiceDefinition.of( + {{metadataClass}}.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.endpoint.definition.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.ServiceType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.ServiceType.SERVICE{{/if}}, java.util.List.of( {{#handlers}} - dev.restate.sdk.common.syscalls.HandlerDefinition.of( - dev.restate.sdk.common.syscalls.HandlerSpecification.of( + dev.restate.sdk.endpoint.definition.HandlerDefinition.of( "{{name}}", - {{#if isExclusive}}dev.restate.sdk.common.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.common.HandlerType.WORKFLOW{{else}}dev.restate.sdk.common.HandlerType.SHARED{{/if}}, + {{#if isExclusive}}dev.restate.sdk.endpoint.definition.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.HandlerType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.HandlerType.SHARED{{/if}}, {{inputSerdeRef}}, - {{outputSerdeRef}} - ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}{{#if documentation}}.withDocumentation("{{escapeJava documentation}}"){{/if}}, - dev.restate.sdk.HandlerRunner.of(bindableService::{{name}}) - ){{#unless @last}},{{/unless}} + {{outputSerdeRef}}, + dev.restate.sdk.HandlerRunner.of(bindableService::{{name}}, {{serdeFactoryRef}}, handlerRunnerOptions) + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}{{#if documentation}}.withDocumentation("{{escapeJava documentation}}"){{/if}}{{#unless @last}},{{/unless}} {{/handlers}} ) ){{#if documentation}}.withDocumentation("{{escapeJava documentation}}"){{/if}}; diff --git a/sdk-api-kotlin-gen/build.gradle.kts b/sdk-api-kotlin-gen/build.gradle.kts index 52a42c27c..df744d4e7 100644 --- a/sdk-api-kotlin-gen/build.gradle.kts +++ b/sdk-api-kotlin-gen/build.gradle.kts @@ -1,6 +1,5 @@ plugins { `kotlin-conventions` - `test-jar-conventions` `library-publishing-conventions` alias(libs.plugins.ksp) } @@ -14,17 +13,4 @@ dependencies { implementation(project(":sdk-api-gen-common")) implementation(project(":sdk-api-kotlin")) - - kspTest(project(":sdk-api-kotlin-gen")) - testImplementation(project(":sdk-core")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - testImplementation(libs.protobuf.java) - testImplementation(libs.log4j.core) - testImplementation(libs.kotlinx.coroutines.core) - testImplementation(libs.kotlinx.serialization.core) - testImplementation(libs.mutiny) - - // Import test suites from sdk-core - testImplementation(project(":sdk-core", "testArchive")) } diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt index aaa315833..72c10f470 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt @@ -17,9 +17,10 @@ import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.symbol.* import com.google.devtools.ksp.visitor.KSDefaultVisitor import dev.restate.sdk.annotation.Accept +import dev.restate.sdk.annotation.CustomSerdeFactory import dev.restate.sdk.annotation.Json import dev.restate.sdk.annotation.Raw -import dev.restate.sdk.common.ServiceType +import dev.restate.sdk.endpoint.definition.ServiceType import dev.restate.sdk.gen.model.Handler import dev.restate.sdk.gen.model.HandlerType import dev.restate.sdk.gen.model.PayloadType @@ -37,8 +38,12 @@ class KElementConverter( companion object { private val SUPPORTED_CLASS_KIND: Set = setOf(ClassKind.CLASS, ClassKind.INTERFACE) private val EMPTY_PAYLOAD: PayloadType = - PayloadType(true, "", "Unit", "dev.restate.sdk.kotlin.KtSerdes.UNIT") - private const val RAW_SERDE: String = "dev.restate.sdk.common.Serde.RAW" + PayloadType( + true, + "", + "Unit", + "dev.restate.sdk.kotlin.serialization.KotlinSerializationSerdeFactory.UNIT") + private const val RAW_SERDE: String = "dev.restate.serde.Serde.RAW" } override fun defaultHandler(node: KSNode, data: Service.Builder) {} @@ -94,6 +99,14 @@ class KElementConverter( "The class declaration $targetFqcn has no methods annotated as handlers", classDeclaration) } + + var serdeFactoryDecl = "dev.restate.sdk.kotlin.serialization.KotlinSerializationSerdeFactory()" + val customSerdeFactory: CustomSerdeFactory? = + classDeclaration.getAnnotationsByType(CustomSerdeFactory::class).firstOrNull() + if (customSerdeFactory != null) { + serdeFactoryDecl = "new " + customSerdeFactory.value + "()" + } + data.withSerdeFactoryDecl(serdeFactoryDecl) } @OptIn(KspExperimental::class) @@ -224,11 +237,7 @@ class KElementConverter( } private fun contentTypeDecoratedSerdeDecl(serdeDecl: String, contentType: String): String { - return ("dev.restate.sdk.common.Serde.withContentType(\"" + - contentType + - "\", " + - serdeDecl + - ")") + return ("dev.restate.serde.Serde.withContentType(\"" + contentType + "\", " + serdeDecl + ")") } private fun defaultHandlerType(serviceType: ServiceType): HandlerType { @@ -291,8 +300,9 @@ class KElementConverter( private fun jsonSerdeDecl(ty: KSType, qualifiedTypeName: String): String { return when (ty) { - builtIns.unitType -> "dev.restate.sdk.kotlin.KtSerdes.UNIT" - else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty, qualifiedTypeName)}>()" + builtIns.unitType -> EMPTY_PAYLOAD.serdeDecl + else -> + "SERDE_FACTORY.create(dev.restate.sdk.kotlin.serialization.typeTag<${boxedType(ty, qualifiedTypeName)}>())" } } diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt index d8555378f..aa1351be5 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/MetaRestateAnnotation.kt @@ -10,7 +10,7 @@ package dev.restate.sdk.kotlin.gen import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSName -import dev.restate.sdk.common.ServiceType +import dev.restate.sdk.endpoint.definition.ServiceType internal data class MetaRestateAnnotation( val annotationName: KSName, diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt index 637583b1d..6f553049c 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt @@ -18,8 +18,8 @@ import com.google.devtools.ksp.symbol.ClassKind import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSClassDeclaration import com.google.devtools.ksp.symbol.Origin -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.syscalls.ServiceDefinitionFactory +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory +import dev.restate.sdk.endpoint.definition.ServiceType import dev.restate.sdk.gen.model.Service import dev.restate.sdk.gen.template.HandlebarsTemplateEngine import java.io.BufferedWriter @@ -52,14 +52,23 @@ class ServiceProcessor(private val logger: KSPLogger, private val codeGenerator: ServiceType.WORKFLOW to "templates/Client", ServiceType.VIRTUAL_OBJECT to "templates/Client"), RESERVED_METHOD_NAMES) - private val definitionsCodegen: HandlebarsTemplateEngine = + private val metadataCodegen: HandlebarsTemplateEngine = HandlebarsTemplateEngine( - "Definitions", + "Metadata", ClassPathTemplateLoader(), mapOf( - ServiceType.SERVICE to "templates/Definitions", - ServiceType.WORKFLOW to "templates/Definitions", - ServiceType.VIRTUAL_OBJECT to "templates/Definitions"), + ServiceType.SERVICE to "templates/Metadata", + ServiceType.WORKFLOW to "templates/Metadata", + ServiceType.VIRTUAL_OBJECT to "templates/Metadata"), + RESERVED_METHOD_NAMES) + private val requestsCodegen: HandlebarsTemplateEngine = + HandlebarsTemplateEngine( + "Requests", + ClassPathTemplateLoader(), + mapOf( + ServiceType.SERVICE to "templates/Requests", + ServiceType.WORKFLOW to "templates/Requests", + ServiceType.VIRTUAL_OBJECT to "templates/Requests"), RESERVED_METHOD_NAMES) @OptIn(KspExperimental::class) @@ -103,8 +112,11 @@ class ServiceProcessor(private val logger: KSPLogger, private val codeGenerator: .writer(Charset.defaultCharset()) } this.bindableServiceFactoryCodegen.generate(fileCreator, service.second) - this.clientCodegen.generate(fileCreator, service.second) - this.definitionsCodegen.generate(fileCreator, service.second) + this.metadataCodegen.generate(fileCreator, service.second) + this.requestsCodegen.generate(fileCreator, service.second) + if (service.second.isContextClientEnabled || service.second.isIngressClientEnabled) { + this.clientCodegen.generate(fileCreator, service.second) + } } catch (ex: Throwable) { throw RuntimeException(ex) } diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs index bfc889771..753051b96 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs @@ -1,97 +1,106 @@ {{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} -import dev.restate.sdk.kotlin.Awaitable +{{#contextClientEnabled}} +import dev.restate.sdk.kotlin.CallAwaitable +import dev.restate.sdk.kotlin.InvocationHandle import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.Target +import dev.restate.sdk.kotlin.asSendDelayed +{{/contextClientEnabled}} +import dev.restate.serde.Serde +import dev.restate.common.Target import kotlin.time.Duration -import dev.restate.sdk.kotlin.callSuspend -import dev.restate.sdk.kotlin.sendSuspend +{{#ingressClientEnabled}} +import dev.restate.client.kotlin.* +{{/ingressClientEnabled}} object {{generatedClassSimpleName}} { + {{#contextClientEnabled}} fun fromContext(ctx: Context{{#isKeyed}}, key: String{{/isKeyed}}): ContextClient { return ContextClient(ctx{{#isKeyed}}, key{{/isKeyed}}) } + {{/contextClientEnabled}} - fun fromClient(client: dev.restate.sdk.client.Client{{#isKeyed}}, key: String{{/isKeyed}}): IngressClient { + {{#ingressClientEnabled}} + fun fromClient(client: dev.restate.client.Client{{#isKeyed}}, key: String{{/isKeyed}}): IngressClient { return IngressClient(client{{#isKeyed}}, key{{/isKeyed}}); } - fun connect(baseUri: String{{#isKeyed}}, key: String{{/isKeyed}}): IngressClient { - return IngressClient(dev.restate.sdk.client.Client.connect(baseUri){{#isKeyed}}, key{{/isKeyed}}); + fun connect(baseUri: String, {{#isKeyed}}key: String, {{/isKeyed}}requestOptions: dev.restate.client.ClientRequestOptions = dev.restate.client.ClientRequestOptions.DEFAULT): IngressClient { + return IngressClient(dev.restate.client.Client.connect(baseUri, {{metadataClass}}.SERDE_FACTORY, requestOptions){{#isKeyed}}, key{{/isKeyed}}); } + {{/ingressClientEnabled}} + {{#contextClientEnabled}} class ContextClient(private val ctx: Context{{#isKeyed}}, private val key: String{{/isKeyed}}){ {{#handlers}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}): Awaitable<{{{boxedOutputFqcn}}}> { - return this.ctx.callAsync( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{outputSerdeRef}}, - {{#if inputEmpty}}Unit{{else}}req{{/if}}) + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): CallAwaitable<{{{boxedOutputFqcn}}}> { + return this.ctx.call( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ) }{{/handlers}} - fun send(delay: Duration = Duration.ZERO): Send { - return Send(delay) + fun send(): Send { + return Send() } - inner class Send(private val delay: Duration) { + inner class Send internal constructor() { {{#handlers}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}) { - this@ContextClient.ctx.send( - {{{targetExpr this "this@ContextClient.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}Unit{{else}}req{{/if}}, - delay); + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): InvocationHandle<{{{boxedOutputFqcn}}}> { + if (delay != null) { + return this@ContextClient.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init).asSendDelayed(delay) + ); + } + return this@ContextClient.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ); }{{/handlers}} } } + {{/contextClientEnabled}} - class IngressClient(private val client: dev.restate.sdk.client.Client{{#isKeyed}}, private val key: String{{/isKeyed}}) { + {{#ingressClientEnabled}} + class IngressClient(private val client: dev.restate.client.Client{{#isKeyed}}, private val key: String{{/isKeyed}}) { {{#handlers}}{{#if isWorkflow}} - fun workflowHandle(): dev.restate.sdk.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> { + fun workflowHandle(): dev.restate.client.Client.WorkflowHandle<{{{boxedOutputFqcn}}}> { return this@IngressClient.client.workflowHandle( - {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, + {{metadataClass}}.SERVICE_NAME, this.key, {{outputSerdeRef}}); } - suspend fun submit({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): dev.restate.sdk.client.SendResponse { + suspend fun submit({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { return this@IngressClient.client.sendSuspend( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}Unit{{else}}req{{/if}}, - kotlin.time.Duration.ZERO, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ).response(); } {{else}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.CallRequestOptions = dev.restate.sdk.client.CallRequestOptions.DEFAULT): {{{boxedOutputFqcn}}} { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): {{{boxedOutputFqcn}}} { return this@IngressClient.client.callSuspend( - {{{targetExpr this "this.key"}}}, - {{inputSerdeRef}}, - {{outputSerdeRef}}, - {{#if inputEmpty}}Unit{{else}}req{{/if}}, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ).response(); } {{/if}}{{/handlers}} - fun send(delay: Duration = Duration.ZERO): Send { - return Send(delay) + fun send(): Send { + return Send() } - inner class Send(private val delay: Duration) { + inner class Send() { {{#handlers}}{{^isWorkflow}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.CallRequestOptions = dev.restate.sdk.client.CallRequestOptions.DEFAULT): dev.restate.sdk.client.SendResponse { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { + if (delay != null) { + return this@IngressClient.client.sendSuspend( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init).asSendDelayed(delay) + ).response(); + } return this@IngressClient.client.sendSuspend( - {{{targetExpr this "this@IngressClient.key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}Unit{{else}}req{{/if}}, - delay, - requestOptions); + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + ).response(); }{{/isWorkflow}}{{/handlers}} } } + {{/ingressClientEnabled}} } \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs deleted file mode 100644 index 1da7a7fe6..000000000 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs +++ /dev/null @@ -1,14 +0,0 @@ -{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} - -object {{generatedClassSimpleName}} { - - const val SERVICE_NAME: String = "{{serviceName}}" - - object Serde { - {{#handlers}} - val {{inputSerdeFieldName}}: dev.restate.sdk.common.Serde<{{{boxedInputFqcn}}}> = {{{inputSerdeDecl}}} - val {{outputSerdeFieldName}}: dev.restate.sdk.common.Serde<{{{boxedOutputFqcn}}}> = {{{outputSerdeDecl}}} - {{/handlers}} - } - -} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Metadata.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Metadata.hbs new file mode 100644 index 000000000..46d006dcf --- /dev/null +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Metadata.hbs @@ -0,0 +1,15 @@ +{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} + +object {{generatedClassSimpleName}} { + + const val SERVICE_NAME: String = "{{serviceName}}" + val SERDE_FACTORY: dev.restate.serde.SerdeFactory = {{serdeFactoryDecl}} + + object Serde { + {{#handlers}} + val {{inputSerdeFieldName}}: dev.restate.serde.Serde<{{{boxedInputFqcn}}}> = {{{inputSerdeDecl}}} + val {{outputSerdeFieldName}}: dev.restate.serde.Serde<{{{boxedOutputFqcn}}}> = {{{outputSerdeDecl}}} + {{/handlers}} + } + +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs new file mode 100644 index 000000000..be07e022e --- /dev/null +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs @@ -0,0 +1,17 @@ +{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} + +object {{generatedClassSimpleName}} { + + {{#handlers}} + fun {{methodName}}({{#if ../isKeyed}}key: String, {{/if}}{{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.common.Request<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> { + val builder = dev.restate.common.Request.of<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>( + {{{targetExpr this "key"}}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, + {{#if inputEmpty}}null{{else}}req{{/if}}); + builder.init() + return builder.build() + } + + {{/handlers}} +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs index 1b6c39cb4..f0640fb62 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/ServiceDefinitionFactory.hbs @@ -1,22 +1,27 @@ {{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} -class {{generatedClassSimpleName}}: dev.restate.sdk.common.syscalls.ServiceDefinitionFactory<{{originalClassFqcn}}, dev.restate.sdk.kotlin.HandlerRunner.Options> { +class {{generatedClassSimpleName}}: dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory<{{originalClassFqcn}}> { - override fun create(bindableService: {{originalClassFqcn}}): dev.restate.sdk.common.syscalls.ServiceDefinition { - return dev.restate.sdk.common.syscalls.ServiceDefinition.of( - {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, - {{#if isObject}}dev.restate.sdk.common.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.common.ServiceType.WORKFLOW{{else}}dev.restate.sdk.common.ServiceType.SERVICE{{/if}}, + override fun create(bindableService: {{originalClassFqcn}}, overrideHandlerOptions: dev.restate.sdk.endpoint.definition.HandlerRunner.Options?): dev.restate.sdk.endpoint.definition.ServiceDefinition { + val handlerRunnerOptions = if (overrideHandlerOptions != null) { + check(overrideHandlerOptions is dev.restate.sdk.kotlin.HandlerRunner.Options) + overrideHandlerOptions as dev.restate.sdk.kotlin.HandlerRunner.Options + } else { + dev.restate.sdk.kotlin.HandlerRunner.Options.DEFAULT + } + + return dev.restate.sdk.endpoint.definition.ServiceDefinition.of( + {{metadataClass}}.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.endpoint.definition.ServiceType.VIRTUAL_OBJECT{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.ServiceType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.ServiceType.SERVICE{{/if}}, listOf( {{#handlers}} - dev.restate.sdk.common.syscalls.HandlerDefinition.of( - dev.restate.sdk.common.syscalls.HandlerSpecification.of( + dev.restate.sdk.endpoint.definition.HandlerDefinition.of( "{{name}}", - {{#if isExclusive}}dev.restate.sdk.common.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.common.HandlerType.WORKFLOW{{else}}dev.restate.sdk.common.HandlerType.SHARED{{/if}}, + {{#if isExclusive}}dev.restate.sdk.endpoint.definition.HandlerType.EXCLUSIVE{{else if isWorkflow}}dev.restate.sdk.endpoint.definition.HandlerType.WORKFLOW{{else}}dev.restate.sdk.endpoint.definition.HandlerType.SHARED{{/if}}, {{inputSerdeRef}}, - {{outputSerdeRef}} - ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}, - dev.restate.sdk.kotlin.HandlerRunner.of(bindableService::{{name}}) - ){{#unless @last}},{{/unless}} + {{outputSerdeRef}}, + dev.restate.sdk.kotlin.HandlerRunner.of({{serdeFactoryRef}}, handlerRunnerOptions, bindableService::{{name}}) + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}{{#unless @last}},{{/unless}} {{/handlers}} ) ) diff --git a/sdk-api-kotlin/build.gradle.kts b/sdk-api-kotlin/build.gradle.kts index 8786bb969..340f5ac14 100644 --- a/sdk-api-kotlin/build.gradle.kts +++ b/sdk-api-kotlin/build.gradle.kts @@ -1,27 +1,17 @@ plugins { `kotlin-conventions` - `test-jar-conventions` `library-publishing-conventions` } description = "Restate SDK Kotlin APIs" dependencies { - api(project(":sdk-common")) - implementation(libs.kotlinx.coroutines.core) implementation(libs.kotlinx.serialization.core) - implementation(libs.kotlinx.serialization.json) + api(libs.kotlinx.serialization.json) + + api(project(":sdk-common")) implementation(libs.log4j.api) implementation(libs.opentelemetry.kotlin) - - testImplementation(project(":sdk-core")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - testImplementation(libs.log4j.core) - testImplementation(libs.protobuf.java) - testImplementation(libs.mutiny) - - testImplementation(project(":sdk-core", "testArchive")) } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt deleted file mode 100644 index b2ce2164f..000000000 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin - -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.syscalls.Deferred -import dev.restate.sdk.common.syscalls.Result -import dev.restate.sdk.common.syscalls.Syscalls -import java.nio.ByteBuffer -import kotlinx.coroutines.CancellableContinuation -import kotlinx.coroutines.suspendCancellableCoroutine - -internal abstract class BaseAwaitableImpl -internal constructor(internal val syscalls: Syscalls) : Awaitable { - abstract fun deferred(): Deferred<*> - - abstract suspend fun awaitResult(): Result - - override val onAwait: SelectClause - get() = SelectClauseImpl(this) - - override suspend fun await(): T { - val res = awaitResult() - if (!res.isSuccess) { - throw res.failure!! - } - @Suppress("UNCHECKED_CAST") return res.value as T - } -} - -internal class SingleAwaitableImpl( - syscalls: Syscalls, - private val deferred: Deferred -) : BaseAwaitableImpl(syscalls) { - private var result: Result? = null - - override fun deferred(): Deferred<*> { - return deferred - } - - override suspend fun awaitResult(): Result { - if (!deferred().isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred(), completingUnitContinuation(cont)) - } - } - if (this.result == null) { - this.result = deferred.toResult() - } - return this.result!! - } -} - -internal abstract class BaseSingleMappedAwaitableImpl( - private val inner: BaseAwaitableImpl -) : BaseAwaitableImpl(inner.syscalls) { - private var mappedResult: Result? = null - - override fun deferred(): Deferred<*> { - return inner.deferred() - } - - abstract suspend fun map(res: Result): Result - - override suspend fun awaitResult(): Result { - if (mappedResult == null) { - this.mappedResult = map(inner.awaitResult()) - } - return mappedResult!! - } -} - -internal open class SingleSerdeAwaitableImpl -internal constructor( - syscalls: Syscalls, - deferred: Deferred, - private val serde: Serde, -) : - BaseSingleMappedAwaitableImpl( - SingleAwaitableImpl(syscalls, deferred), - ) { - @Suppress("UNCHECKED_CAST") - override suspend fun map(res: Result): Result { - return if (res.isSuccess) { - // This propagates exceptions as non-terminal - Result.success(serde.deserializeWrappingException(syscalls, res.value!!)) - } else { - res as Result - } - } -} - -internal class UnitAwakeableImpl(syscalls: Syscalls, deferred: Deferred) : - BaseSingleMappedAwaitableImpl(SingleAwaitableImpl(syscalls, deferred)) { - @Suppress("UNCHECKED_CAST") - override suspend fun map(res: Result): Result { - return if (res.isSuccess) { - Result.success(Unit) - } else { - res as Result - } - } -} - -internal class AnyAwaitableImpl -internal constructor(syscalls: Syscalls, private val awaitables: List>) : - BaseSingleMappedAwaitableImpl( - SingleAwaitableImpl( - syscalls, - syscalls.createAnyDeferred( - awaitables.map { (it as BaseAwaitableImpl<*>).deferred() }))), - AnyAwaitable { - - override suspend fun awaitIndex(): Int { - if (!deferred().isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred(), completingUnitContinuation(cont)) - } - } - - return deferred().toResult()!!.value as Int - } - - @Suppress("UNCHECKED_CAST") - override suspend fun map(res: Result): Result { - return if (res.isSuccess) - ((awaitables[res.value!!] as BaseAwaitableImpl<*>).awaitResult() as Result) - else (res as Result) - } -} - -internal fun wrapAllAwaitable(awaitables: List>): Awaitable { - val syscalls = (awaitables.get(0) as BaseAwaitableImpl<*>).syscalls - return UnitAwakeableImpl( - syscalls, - syscalls.createAllDeferred(awaitables.map { (it as BaseAwaitableImpl<*>).deferred() }), - ) -} - -internal fun wrapAnyAwaitable(awaitables: List>): AnyAwaitable { - val syscalls = (awaitables.get(0) as BaseAwaitableImpl<*>).syscalls - return AnyAwaitableImpl(syscalls, awaitables) -} - -internal class AwakeableImpl -internal constructor( - syscalls: Syscalls, - deferred: Deferred, - serde: Serde, - override val id: String -) : SingleSerdeAwaitableImpl(syscalls, deferred, serde), Awakeable {} - -internal class AwakeableHandleImpl(val syscalls: Syscalls, val id: String) : AwakeableHandle { - override suspend fun resolve(serde: Serde, payload: T) { - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveAwakeable( - id, serde.serializeWrappingException(syscalls, payload), completingUnitContinuation(cont)) - } - } - - override suspend fun reject(reason: String) { - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.rejectAwakeable(id, reason, completingUnitContinuation(cont)) - } - } -} - -internal class SelectClauseImpl(override val awaitable: Awaitable) : SelectClause - -@PublishedApi -internal class SelectImplementation : SelectBuilder { - - private val clauses: MutableList, suspend (Any?) -> R>> = mutableListOf() - - @Suppress("UNCHECKED_CAST") - override fun SelectClause.invoke(block: suspend (T) -> R) { - clauses.add(this.awaitable as Awaitable<*> to block as suspend (Any?) -> R) - } - - suspend fun doSelect(): R { - val index = wrapAnyAwaitable(clauses.map { it.first }).awaitIndex() - val resolved = clauses[index] - return resolved.first.await().let { resolved.second(it) } - } -} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index 84d519f1c..5620e1d7d 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -8,223 +8,164 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.sdk.common.* -import dev.restate.sdk.common.Target -import dev.restate.sdk.common.syscalls.Deferred -import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback -import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback -import dev.restate.sdk.common.syscalls.Syscalls -import java.nio.ByteBuffer -import kotlin.coroutines.resume +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.common.SendRequest +import dev.restate.common.Slice +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.types.DurablePromiseKey +import dev.restate.sdk.types.HandlerRequest +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeTag +import java.util.concurrent.CompletableFuture +import kotlin.jvm.optionals.getOrNull import kotlin.time.Duration import kotlin.time.toJavaDuration -import kotlinx.coroutines.CancellableContinuation -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.suspendCancellableCoroutine - -internal class ContextImpl internal constructor(internal val syscalls: Syscalls) : WorkflowContext { +import kotlinx.coroutines.* +import kotlinx.coroutines.future.await + +internal class ContextImpl +internal constructor( + internal val handlerContext: HandlerContext, + internal val contextSerdeFactory: SerdeFactory +) : WorkflowContext { override fun key(): String { - return this.syscalls.objectKey() + return this.handlerContext.objectKey() } - override fun request(): Request { - return this.syscalls.request() + override fun request(): HandlerRequest { + return this.handlerContext.request() } - override suspend fun get(key: StateKey): T? { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.get(key.name(), completingContinuation(cont)) - } - - if (!deferred.isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred, completingUnitContinuation(cont)) - } - } - - val readyResult = deferred.toResult()!! - if (!readyResult.isSuccess) { - throw readyResult.failure!! - } - if (readyResult.isEmpty) { - return null - } - return key.serde().deserializeWrappingException(syscalls, readyResult.value!!)!! - } - - override suspend fun stateKeys(): Collection { - val deferred: Deferred> = - suspendCancellableCoroutine { cont: CancellableContinuation>> -> - syscalls.getKeys(completingContinuation(cont)) - } - - if (!deferred.isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred, completingUnitContinuation(cont)) - } - } + override suspend fun get(key: StateKey): T? = + resolveSerde(key.serdeInfo()) + .let { serde -> + SingleAwaitableImpl(handlerContext.get(key.name()).await()).simpleMap { + it.getOrNull()?.let { serde.deserialize(it) } + } + } + .await() - val readyResult = deferred.toResult()!! - if (!readyResult.isSuccess) { - throw readyResult.failure!! - } - return readyResult.value!! - } + override suspend fun stateKeys(): Collection = + SingleAwaitableImpl(handlerContext.getKeys().await()).await() override suspend fun set(key: StateKey, value: T) { - val serializedValue = key.serde().serializeWrappingException(syscalls, value) - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.set(key.name(), serializedValue, completingUnitContinuation(cont)) - } + handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await() } override suspend fun clear(key: StateKey<*>) { - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.clear(key.name(), completingUnitContinuation(cont)) - } + handlerContext.clear(key.name()).await() } override suspend fun clearAll() { - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.clearAll(completingUnitContinuation(cont)) - } - } + handlerContext.clearAll().await() + } + + override suspend fun timer(duration: Duration, name: String?): Awaitable = + SingleAwaitableImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {} + + override suspend fun call( + request: Request + ): CallAwaitable = + resolveSerde(request.responseTypeTag()).let { responseSerde -> + val callHandle = + handlerContext + .call( + request.target(), + resolveAndSerialize(request.requestTypeTag(), request.request()), + request.idempotencyKey(), + request.headers().entries) + .await() + + val callAsyncResult = + callHandle.callAsyncResult.map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + } + + return@let CallAwaitableImpl(callAsyncResult, callHandle.invocationIdAsyncResult) + } - override suspend fun timer(duration: Duration): Awaitable { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.sleep(duration.toJavaDuration(), completingContinuation(cont)) + override suspend fun send( + request: Request + ): InvocationHandle = + resolveSerde(request.responseTypeTag()).let { responseSerde -> + val invocationIdAsyncResult = + handlerContext + .send( + request.target(), + resolveAndSerialize(request.requestTypeTag(), request.request()), + request.idempotencyKey(), + request.headers().entries, + (request as? SendRequest)?.delay()) + .await() + + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() } + } - return UnitAwakeableImpl(syscalls, deferred) - } - - override suspend fun callAsync( - target: Target, - inputSerde: Serde, - outputSerde: Serde, - parameter: T - ): Awaitable { - val input = inputSerde.serializeWrappingException(syscalls, parameter) - - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.call(target, input, completingContinuation(cont)) + override fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag + ): InvocationHandle = + resolveSerde(responseTypeTag).let { responseSerde -> + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationId } + } - return SingleSerdeAwaitableImpl(syscalls, deferred, outputSerde) - } - - override suspend fun send( - target: Target, - inputSerde: Serde, - parameter: T, - delay: Duration - ) { - val input = inputSerde.serializeWrappingException(syscalls, parameter) - - return suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.send(target, input, delay.toJavaDuration(), completingUnitContinuation(cont)) - } - } - - override suspend fun runBlock( - serde: Serde, + override suspend fun runAsync( + typeTag: TypeTag, name: String, retryPolicy: RetryPolicy?, block: suspend () -> T - ): T { - val exitResult = - suspendCancellableCoroutine { cont: CancellableContinuation> - -> - syscalls.enterSideEffectBlock( - name, - object : EnterSideEffectSyscallCallback { - override fun onSuccess(t: ByteBuffer?) { - val deferred: CompletableDeferred = CompletableDeferred() - deferred.complete(t!!) - cont.resume(deferred) - } - - override fun onFailure(t: TerminalException) { - val deferred: CompletableDeferred = CompletableDeferred() - deferred.completeExceptionally(t) - cont.resume(deferred) - } - - override fun onCancel(t: Throwable?) { - cont.cancel(t) - } - - override fun onNotExecuted() { - cont.resume(CompletableDeferred()) - } - }) - } - - if (exitResult.isCompleted) { - return serde.deserializeWrappingException(syscalls, exitResult.await())!! - } - - var actionReturnValue: T? = null - var actionFailure: Throwable? = null - try { - actionReturnValue = block() - } catch (t: Throwable) { - actionFailure = t - } - - val exitCallback = - object : ExitSideEffectSyscallCallback { - override fun onSuccess(t: ByteBuffer?) { - exitResult.complete(t!!) - } - - override fun onFailure(t: TerminalException) { - exitResult.completeExceptionally(t) - } - - override fun onCancel(t: Throwable?) { - exitResult.cancel(CancellationException(message = null, cause = t)) - } + ): Awaitable { + var serde: Serde = resolveSerde(typeTag) + var coroutineCtx = currentCoroutineContext() + val javaRetryPolicy = + retryPolicy?.let { + dev.restate.sdk.types.RetryPolicy.exponential( + it.initialDelay.toJavaDuration(), it.exponentiationFactor) + .setMaxAttempts(it.maxAttempts) + .setMaxDelay(it.maxDelay?.toJavaDuration()) + .setMaxDuration(it.maxDuration?.toJavaDuration()) } - if (actionFailure != null) { - val javaRetryPolicy = - retryPolicy?.let { - dev.restate.sdk.common.RetryPolicy.exponential( - it.initialDelay.toJavaDuration(), it.exponentiationFactor) - .setMaxAttempts(it.maxAttempts) - .setMaxDelay(it.maxDelay?.toJavaDuration()) - .setMaxDuration(it.maxDuration?.toJavaDuration()) - } - syscalls.exitSideEffectBlockWithException(actionFailure, javaRetryPolicy, exitCallback) - } else { - syscalls.exitSideEffectBlock( - serde.serializeWrappingException(syscalls, actionReturnValue), exitCallback) - } - - return serde.deserializeWrappingException(syscalls, exitResult.await()) + val scope = CoroutineScope(coroutineCtx + CoroutineName("restate-run-$name")) + + val asyncResult = + handlerContext + .submitRun(name) { completer -> + scope.launch { + val result: Slice? + try { + result = serde.serialize(block()) + } catch (e: Throwable) { + completer.proposeFailure(e, javaRetryPolicy) + return@launch + } + completer.proposeSuccess(result) + } + } + .await() + return SingleAwaitableImpl(asyncResult).map { serde.deserialize(it) } } - override suspend fun awakeable(serde: Serde): Awakeable { - val (aid, deferredResult) = - suspendCancellableCoroutine { - cont: CancellableContinuation>> -> - syscalls.awakeable(completingContinuation(cont)) - } - - return AwakeableImpl(syscalls, deferredResult, serde, aid) + override suspend fun awakeable(typeTag: TypeTag): Awakeable { + val serde: Serde = resolveSerde(typeTag) + val awk = handlerContext.awakeable().await() + return AwakeableImpl(awk.asyncResult, serde, awk.id) } override fun awakeableHandle(id: String): AwakeableHandle { - return AwakeableHandleImpl(syscalls, id) + return AwakeableHandleImpl(this, id) } override fun random(): RestateRandom { - return RestateRandom(syscalls.request().invocationId().toRandomSeed(), syscalls) + return RestateRandom(handlerContext.request().invocationId().toRandomSeed()) } override fun promise(key: DurablePromiseKey): DurablePromise { @@ -237,76 +178,55 @@ internal class ContextImpl internal constructor(internal val syscalls: Syscalls) inner class DurablePromiseImpl(private val key: DurablePromiseKey) : DurablePromise { - override suspend fun awaitable(): Awaitable { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.promise(key.name(), completingContinuation(cont)) - } - - return SingleSerdeAwaitableImpl(syscalls, deferred, key.serde()) - } - - override suspend fun peek(): Output { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.peekPromise(key.name(), completingContinuation(cont)) - } + val serde: Serde = resolveSerde(key.serdeInfo()) - if (!deferred.isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred, completingUnitContinuation(cont)) + override suspend fun awaitable(): Awaitable = + SingleAwaitableImpl(handlerContext.promise(key.name()).await()).simpleMap { + serde.deserialize(it) } - } - val readyResult = deferred.toResult()!! - if (!readyResult.isSuccess) { - throw readyResult.failure!! - } - if (readyResult.isEmpty) { - return Output.notReady() - } - return Output.ready(key.serde().deserializeWrappingException(syscalls, readyResult.value!!)) - } + override suspend fun peek(): Output = + SingleAwaitableImpl(handlerContext.peekPromise(key.name()).await()) + .simpleMap { it.map { serde.deserialize(it) } } + .await() } inner class DurablePromiseHandleImpl(private val key: DurablePromiseKey) : DurablePromiseHandle { - override suspend fun resolve(payload: T) { - val input = key.serde().serializeWrappingException(syscalls, payload) - - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.resolvePromise(key.name(), input, completingContinuation(cont)) - } - - if (!deferred.isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred, completingUnitContinuation(cont)) - } - } + val serde: Serde = resolveSerde(key.serdeInfo()) - val readyResult = deferred.toResult()!! - if (!readyResult.isSuccess) { - throw readyResult.failure!! - } + override suspend fun resolve(payload: T) { + SingleAwaitableImpl( + handlerContext + .resolvePromise( + key.name(), serde.serializeWrappingException(handlerContext, payload)) + .await()) + .await() } override suspend fun reject(reason: String) { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> - syscalls.rejectPromise(key.name(), reason, completingContinuation(cont)) - } + SingleAwaitableImpl( + handlerContext.rejectPromise(key.name(), TerminalException(reason)).await()) + .await() + } + } - if (!deferred.isCompleted) { - suspendCancellableCoroutine { cont: CancellableContinuation -> - syscalls.resolveDeferred(deferred, completingUnitContinuation(cont)) - } - } + internal fun resolveAndSerialize(typeTag: TypeTag, value: T): Slice { + return try { + val serde = contextSerdeFactory.create(typeTag) + serde.serialize(value) + } catch (e: Exception) { + handlerContext.fail(e) + throw CancellationException("Failed serialization", e) + } + } - val readyResult = deferred.toResult()!! - if (!readyResult.isSuccess) { - throw readyResult.failure!! - } + private fun resolveSerde(typeTag: TypeTag): Serde { + return try { + contextSerdeFactory.create(typeTag)!! + } catch (e: Exception) { + handlerContext.fail(e) + throw CancellationException("Cannot resolve serde", e) } } } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt index bb43ec57f..c17658fe1 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt @@ -8,12 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.SyscallCallback -import dev.restate.sdk.common.syscalls.Syscalls +import dev.restate.common.Slice +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.types.TerminalException +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory import io.opentelemetry.extension.kotlin.asContextElement -import java.nio.ByteBuffer +import java.util.concurrent.CompletableFuture import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -21,52 +22,64 @@ import kotlinx.coroutines.asContextElement import kotlinx.coroutines.launch import org.apache.logging.log4j.LogManager -/** Adapter class for [dev.restate.sdk.common.syscalls.HandlerRunner] to use the Kotlin API. */ +/** Adapter class for [dev.restate.sdk.endpoint.definition.HandlerRunner] to use the Kotlin API. */ class HandlerRunner internal constructor( private val runner: suspend (CTX, REQ) -> RES, -) : dev.restate.sdk.common.syscalls.HandlerRunner { + private val contextSerdeFactory: SerdeFactory, + private val options: Options +) : dev.restate.sdk.endpoint.definition.HandlerRunner { companion object { private val LOG = LogManager.getLogger(HandlerRunner::class.java) fun of( - runner: suspend (CTX, REQ) -> RES + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX, REQ) -> RES, ): HandlerRunner { - return HandlerRunner(runner) + return HandlerRunner(runner, contextSerdeFactory, options) } - fun of(runner: suspend (CTX) -> RES): HandlerRunner { - return HandlerRunner { ctx: CTX, _: Unit -> runner(ctx) } + fun of( + contextSerdeFactory: SerdeFactory, + options: Options = Options.DEFAULT, + runner: suspend (CTX) -> RES, + ): HandlerRunner { + return HandlerRunner({ ctx: CTX, _: Unit -> runner(ctx) }, contextSerdeFactory, options) } } override fun run( - handlerSpecification: HandlerSpecification, - syscalls: Syscalls, - options: Options?, - callback: SyscallCallback - ) { - val ctx: Context = ContextImpl(syscalls) + handlerContext: HandlerContext, + requestSerde: Serde, + responseSerde: Serde, + ): CompletableFuture { + val ctx: Context = ContextImpl(handlerContext, contextSerdeFactory) val scope = CoroutineScope( - (options?.coroutineContext ?: Options.DEFAULT.coroutineContext) + - dev.restate.sdk.common.syscalls.HandlerRunner.SYSCALLS_THREAD_LOCAL - .asContextElement(syscalls) + - syscalls.request().otelContext()!!.asContextElement()) + options.coroutineContext + + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL + .asContextElement(handlerContext) + + handlerContext.request().otelContext()!!.asContextElement()) + + val completableFuture = CompletableFuture() + scope.launch { - val serializedResult: ByteBuffer + val serializedResult: Slice try { // Parse input val req: REQ try { - req = handlerSpecification.requestSerde.deserialize(syscalls.request().bodyBuffer()) + req = requestSerde.deserialize(handlerContext.request().body) } catch (e: Throwable) { - LOG.warn("Error when deserializing input", e) - throw TerminalException( - TerminalException.BAD_REQUEST_CODE, "Cannot deserialize input: " + e.message) + LOG.warn("Error deserializing request", e) + completableFuture.completeExceptionally( + throw TerminalException( + TerminalException.BAD_REQUEST_CODE, "Cannot deserialize request: " + e.message)) + return@launch } // Execute user code @@ -74,23 +87,26 @@ internal constructor( // Serialize output try { - serializedResult = handlerSpecification.responseSerde.serializeToByteBuffer(res) + serializedResult = responseSerde.serialize(res) } catch (e: Throwable) { - LOG.warn("Error when serializing input", e) - throw TerminalException( - TerminalException.INTERNAL_SERVER_ERROR_CODE, "Cannot serialize output: $e") + LOG.warn("Error when serializing response", e) + completableFuture.completeExceptionally(e) + return@launch } } catch (e: Throwable) { - callback.onCancel(e) + completableFuture.completeExceptionally(e) return@launch } // Complete callback - callback.onSuccess(serializedResult) + completableFuture.complete(serializedResult) } + + return completableFuture } - class Options(val coroutineContext: CoroutineContext) { + data class Options(val coroutineContext: CoroutineContext) : + dev.restate.sdk.endpoint.definition.HandlerRunner.Options { companion object { val DEFAULT: Options = Options(Dispatchers.Default) } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt deleted file mode 100644 index ea94b20dd..000000000 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin - -import dev.restate.sdk.common.DurablePromiseKey -import dev.restate.sdk.common.RichSerde -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.StateKey -import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets -import kotlin.reflect.typeOf -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.Serializable -import kotlinx.serialization.builtins.ListSerializer -import kotlinx.serialization.builtins.serializer -import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encodeToString -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonArray -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonNull -import kotlinx.serialization.json.JsonTransformingSerializer -import kotlinx.serialization.serializer - -object KtStateKey { - - /** Creates a json [StateKey]. */ - inline fun json(name: String): StateKey { - return StateKey.of(name, KtSerdes.json()) - } -} - -object KtDurablePromiseKey { - - /** Creates a json [StateKey]. */ - inline fun json(name: String): DurablePromiseKey { - return DurablePromiseKey.of(name, KtSerdes.json()) - } -} - -object KtSerdes { - - /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ - inline fun json(): Serde { - @Suppress("UNCHECKED_CAST") - return when (typeOf()) { - typeOf() -> UNIT as Serde - else -> json(serializer()) - } - } - - val UNIT: Serde = - object : Serde { - override fun serialize(value: Unit?): ByteArray { - return ByteArray(0) - } - - override fun serializeToByteBuffer(value: Unit?): ByteBuffer { - return ByteBuffer.allocate(0) - } - - override fun deserialize(value: ByteArray) { - return - } - - override fun deserialize(byteBuffer: ByteBuffer) { - return - } - - override fun contentType(): String? { - return null - } - } - - /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ - inline fun json(serializer: KSerializer): Serde { - return object : RichSerde { - override fun serialize(value: T?): ByteArray { - if (value == null) { - return Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray() - } - - return Json.encodeToString(serializer, value).encodeToByteArray() - } - - override fun deserialize(value: ByteArray): T { - return Json.decodeFromString(serializer, String(value, StandardCharsets.UTF_8)) - } - - override fun contentType(): String { - return "application/json" - } - - override fun jsonSchema(): String { - val schema: JsonSchema = serializer.descriptor.jsonSchema() - return Json.encodeToString(schema) - } - } - } - - @Serializable - @PublishedApi - internal data class JsonSchema( - @Serializable(with = StringListSerializer::class) val type: List? = null, - val format: String? = null, - ) { - companion object { - val INT = JsonSchema(type = listOf("number"), format = "int32") - - val LONG = JsonSchema(type = listOf("number"), format = "int64") - - val DOUBLE = JsonSchema(type = listOf("number"), format = "double") - - val FLOAT = JsonSchema(type = listOf("number"), format = "float") - - val STRING = JsonSchema(type = listOf("string")) - - val BOOLEAN = JsonSchema(type = listOf("boolean")) - - val OBJECT = JsonSchema(type = listOf("object")) - - val LIST = JsonSchema(type = listOf("array")) - - val ANY = JsonSchema() - } - } - - object StringListSerializer : - JsonTransformingSerializer>(ListSerializer(String.Companion.serializer())) { - override fun transformSerialize(element: JsonElement): JsonElement { - require(element is JsonArray) - return element.singleOrNull() ?: element - } - } - - /** - * Super simplistic json schema generation. We should replace this with an appropriate library. - */ - @OptIn(ExperimentalSerializationApi::class) - @PublishedApi - internal fun SerialDescriptor.jsonSchema(): JsonSchema { - var schema = - when (this.kind) { - PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN - PrimitiveKind.BYTE -> JsonSchema.INT - PrimitiveKind.CHAR -> JsonSchema.STRING - PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE - PrimitiveKind.FLOAT -> JsonSchema.FLOAT - PrimitiveKind.INT -> JsonSchema.INT - PrimitiveKind.LONG -> JsonSchema.LONG - PrimitiveKind.SHORT -> JsonSchema.INT - PrimitiveKind.STRING -> JsonSchema.STRING - StructureKind.LIST -> JsonSchema.LIST - StructureKind.MAP -> JsonSchema.OBJECT - else -> JsonSchema.ANY - } - - // Add nullability constraint - if (this.isNullable && schema.type != null) { - schema = schema.copy(type = schema.type.plus("null")) - } - - return schema - } -} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt index e9a6e360d..eee82127a 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt @@ -8,47 +8,19 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.syscalls.SyscallCallback -import dev.restate.sdk.common.syscalls.Syscalls -import java.nio.ByteBuffer -import kotlin.coroutines.resume -import kotlinx.coroutines.CancellableContinuation +import dev.restate.common.Slice +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.serde.Serde import kotlinx.coroutines.CancellationException -internal fun completingContinuation(cont: CancellableContinuation): SyscallCallback { - return SyscallCallback.of(cont::resume) { - cont.cancel(CancellationException("Restate internal error", it)) - } -} - -internal fun completingUnitContinuation( - cont: CancellableContinuation -): SyscallCallback { - return SyscallCallback.of( - { cont.resume(Unit) }, { cont.cancel(CancellationException("Restate internal error", it)) }) -} - internal fun Serde.serializeWrappingException( - syscalls: Syscalls, + handlerContext: HandlerContext, value: T? -): ByteBuffer { +): Slice { return try { - this.serializeToByteBuffer(value) + this.serialize(value) } catch (e: Exception) { - syscalls.fail(e) + handlerContext.fail(e) throw CancellationException("Failed serialization", e) } } - -internal fun Serde.deserializeWrappingException( - syscalls: Syscalls, - ByteBuffer: ByteBuffer -): T { - return try { - this.deserialize(ByteBuffer) - } catch (e: Exception) { - syscalls.fail(e) - throw CancellationException("Failed deserialization", e) - } -} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 2c8364aaf..b3224170b 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -8,16 +8,19 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.sdk.common.DurablePromiseKey -import dev.restate.sdk.common.Output -import dev.restate.sdk.common.Request -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.Target -import dev.restate.sdk.common.syscalls.Syscalls +import dev.restate.common.Output +import dev.restate.common.Request +import dev.restate.common.SendRequest +import dev.restate.sdk.kotlin.serialization.typeTag +import dev.restate.sdk.types.DurablePromiseKey +import dev.restate.sdk.types.HandlerRequest +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException +import dev.restate.serde.TypeTag import java.util.* import kotlin.random.Random import kotlin.time.Duration +import kotlin.time.toJavaDuration /** * This interface exposes the Restate functionalities to Restate services. It can be used to @@ -32,7 +35,7 @@ import kotlin.time.Duration */ sealed interface Context { - fun request(): Request + fun request(): HandlerRequest /** * Causes the current execution of the function invocation to sleep for the given duration. @@ -48,27 +51,21 @@ sealed interface Context { * [Awaitable.await]. * * @param duration for which to sleep. + * @param name name to be used for the timer */ - suspend fun timer(duration: Duration): Awaitable + suspend fun timer(duration: Duration, name: String? = null): Awaitable /** - * Invoke another Restate service method and wait for the response. Same as - * `call(methodDescriptor, parameter).await()`. + * Invoke another Restate service method. * * @param target the address of the callee * @param inputSerde Input serde * @param outputSerde Output serde * @param parameter the invocation request parameter. - * @return the invocation response. + * @param callOptions request options. + * @return a [CallAwaitable] that wraps the result. */ - suspend fun call( - target: Target, - inputSerde: Serde, - outputSerde: Serde, - parameter: T - ): R { - return callAsync(target, inputSerde, outputSerde, parameter).await() - } + suspend fun call(request: Request): CallAwaitable /** * Invoke another Restate service method. @@ -77,14 +74,25 @@ sealed interface Context { * @param inputSerde Input serde * @param outputSerde Output serde * @param parameter the invocation request parameter. - * @return an [Awaitable] that wraps the Restate service method result. + * @param callOptions request options. + * @return a [CallAwaitable] that wraps the result. */ - suspend fun callAsync( - target: Target, - inputSerde: Serde, - outputSerde: Serde, - parameter: T - ): Awaitable + suspend fun call( + requestBuilder: Request.Builder + ): CallAwaitable { + return call(requestBuilder.build()) + } + + /** + * Invoke another Restate service without waiting for the response. + * + * @param target the address of the callee + * @param inputSerde Input serde + * @param parameter the invocation request parameter. + * @param sendOptions request options. + * @return a [SendHandle] to interact with the sent request. + */ + suspend fun send(request: Request): InvocationHandle /** * Invoke another Restate service without waiting for the response. @@ -92,14 +100,26 @@ sealed interface Context { * @param target the address of the callee * @param inputSerde Input serde * @param parameter the invocation request parameter. - * @param delay time to wait before executing the call + * @param sendOptions request options. + * @return a [SendHandle] to interact with the sent request. + */ + suspend fun send( + sendRequestBuilder: Request.Builder + ): InvocationHandle { + return send(sendRequestBuilder.build()) + } + + /** + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. */ - suspend fun send( - target: Target, - inputSerde: Serde, - parameter: T, - delay: Duration = Duration.ZERO - ) + fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag + ): InvocationHandle /** * Execute a non-deterministic closure, recording the result value in the journal. The result @@ -142,18 +162,27 @@ sealed interface Context { * * To propagate failures to the run call-site, make sure to wrap them in [TerminalException]. * - * @param serde the type tag of the return value, used to serialize/deserialize it. + * @param typeTag the type tag of the return value, used to serialize/deserialize it. * @param name the name of the side effect. * @param block closure to execute. * @param T type of the return value. * @return value of the runBlock operation. */ suspend fun runBlock( - serde: Serde, + typeTag: TypeTag, + name: String = "", + retryPolicy: RetryPolicy? = null, + block: suspend () -> T + ): T { + return runAsync(typeTag, name, retryPolicy, block).await() + } + + suspend fun runAsync( + typeTag: TypeTag, name: String = "", retryPolicy: RetryPolicy? = null, block: suspend () -> T - ): T + ): Awaitable /** * Create an [Awakeable], addressable through [Awakeable.id]. @@ -166,7 +195,7 @@ sealed interface Context { * @return the [Awakeable] to await on. * @see Awakeable */ - suspend fun awakeable(serde: Serde): Awakeable + suspend fun awakeable(typeTag: TypeTag): Awakeable /** * Create a new [AwakeableHandle] for the provided identifier. You can use it to @@ -178,7 +207,7 @@ sealed interface Context { /** * Create a [RestateRandom] instance inherently predictable, seeded on the - * [dev.restate.sdk.common.InvocationId], which is not secret. + * [dev.restate.sdk.types.InvocationId], which is not secret. * * This instance is useful to generate identifiers, idempotency keys, and for uniform sampling * from a set of options. If a cryptographically secure value is needed, please generate that @@ -192,10 +221,23 @@ sealed interface Context { } /** - * Execute a non-deterministic closure, recording the result value in the journal using - * [KtSerdes.json]. The result value will be re-played in case of re-invocation (e.g. because of - * failure recovery or suspension point) without re-executing the closure. Use this feature if you - * want to perform non-deterministic operations. + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + */ +inline fun Context.invocationHandle( + invocationId: String +): InvocationHandle { + return this.invocationHandle(invocationId, typeTag()) +} + +/** + * Execute a non-deterministic closure, recording the result value in the journal. The result value + * will be re-played in case of re-invocation (e.g. because of failure recovery or suspension point) + * without re-executing the closure. Use this feature if you want to perform non-deterministic + * operations. * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. To control and limit the amount of retries, pass a [RetryPolicy] @@ -238,11 +280,19 @@ suspend inline fun Context.runBlock( retryPolicy: RetryPolicy? = null, noinline block: suspend () -> T ): T { - return this.runBlock(KtSerdes.json(), name, retryPolicy, block) + return this.runBlock(typeTag(), name, retryPolicy, block) +} + +suspend inline fun Context.runAsync( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T +): Awaitable { + return this.runAsync(typeTag(), name, retryPolicy, block) } /** - * Create an [Awakeable] using [KtSerdes.json] deserializer, addressable through [Awakeable.id]. + * Create an [Awakeable], addressable through [Awakeable.id]. * * You can use this feature to implement external asynchronous systems interactions, for example you * can send a Kafka record including the [Awakeable.id], and then let another service consume from @@ -252,7 +302,7 @@ suspend inline fun Context.runBlock( * @see Awakeable */ suspend inline fun Context.awakeable(): Awakeable { - return this.awakeable(KtSerdes.json()) + return this.awakeable(typeTag()) } /** @@ -265,7 +315,7 @@ sealed interface SharedObjectContext : Context { fun key(): String /** - * Gets the state stored under key, deserializing the raw value using the [StateKey.serde]. + * Gets the state stored under key, deserializing the raw value using the [StateKey.serdeInfo]. * * @param key identifying the state to get and its type. * @return the value containing the stored state deserialized. @@ -281,6 +331,14 @@ sealed interface SharedObjectContext : Context { suspend fun stateKeys(): Collection } +inline fun stateKey(name: String): StateKey { + return StateKey.of(name, typeTag()) +} + +suspend inline fun SharedObjectContext.get(key: String): T? { + return this.get(StateKey.of(key, typeTag())) +} + /** * This interface can be used only within exclusive handlers of virtual objects. It extends * [Context] adding access to the virtual object instance key-value state storage. @@ -288,7 +346,7 @@ sealed interface SharedObjectContext : Context { sealed interface ObjectContext : SharedObjectContext { /** - * Sets the given value under the given key, serializing the value using the [StateKey.serde]. + * Sets the given value under the given key, serializing the value using the [StateKey.serdeInfo]. * * @param key identifying the value to store and its type. * @param value to store under the given key. @@ -306,6 +364,10 @@ sealed interface ObjectContext : SharedObjectContext { suspend fun clearAll() } +suspend inline fun ObjectContext.set(key: String, value: T) { + this.set(StateKey.of(key, typeTag()), value) +} + /** * This interface can be used only within shared handlers of workflow. It extends [Context] adding * access to the workflow instance key-value state storage and to the [DurablePromise] API. @@ -348,11 +410,10 @@ sealed interface SharedWorkflowContext : SharedObjectContext { */ sealed interface WorkflowContext : SharedWorkflowContext, ObjectContext -class RestateRandom(seed: Long, private val syscalls: Syscalls) : Random() { +class RestateRandom(seed: Long) : Random() { private val r = Random(seed) override fun nextBits(bitCount: Int): Int { - check(!syscalls.isInsideSideEffect) { "You can't use RestateRandom inside ctx.runBlock!" } return r.nextBits(bitCount) } @@ -374,9 +435,22 @@ class RestateRandom(seed: Long, private val syscalls: Syscalls) : Random() { sealed interface Awaitable { suspend fun await(): T + suspend fun await(duration: Duration): T + + suspend fun withTimeout(duration: Duration): Awaitable + /** Clause for [select] operator. */ val onAwait: SelectClause + suspend fun map(transform: suspend (value: T) -> R): Awaitable + + suspend fun map( + transformSuccess: suspend (value: T) -> R, + transformFailure: suspend (exception: TerminalException) -> R + ): Awaitable + + suspend fun mapFailure(transform: suspend (exception: TerminalException) -> T): Awaitable + companion object { fun all( first: Awaitable<*>, @@ -386,7 +460,11 @@ sealed interface Awaitable { return wrapAllAwaitable(listOf(first) + listOf(second) + others.asList()) } - fun any(first: Awaitable<*>, second: Awaitable<*>, vararg others: Awaitable<*>): AnyAwaitable { + fun any( + first: Awaitable<*>, + second: Awaitable<*>, + vararg others: Awaitable<*> + ): Awaitable { return wrapAnyAwaitable(listOf(first) + listOf(second) + others.asList()) } } @@ -421,11 +499,6 @@ suspend fun awaitAll(vararg awaitables: Awaitable): List { return awaitables.map { it.await() }.toList() } -sealed interface AnyAwaitable : Awaitable { - /** Same as [Awaitable.await], but returns the index of the first completed element. */ - suspend fun awaitIndex(): Int -} - /** * Like [kotlinx.coroutines.selects.select], but for [Awaitable] * @@ -436,13 +509,13 @@ sealed interface AnyAwaitable : Awaitable { * val result = select { * callAwaitable.onAwait { it.message } * timeout.onAwait { throw TimeoutException() } - * } + * }.await() * ``` */ -suspend inline fun select(crossinline builder: SelectBuilder.() -> Unit): R { +suspend inline fun select(crossinline builder: SelectBuilder.() -> Unit): Awaitable { val selectImpl = SelectImplementation() builder.invoke(selectImpl) - return selectImpl.doSelect() + return selectImpl.build() } sealed interface SelectBuilder { @@ -454,6 +527,26 @@ sealed interface SelectClause { val awaitable: Awaitable } +/** The [Awaitable] returned by a [Context.call]. */ +sealed interface CallAwaitable : Awaitable { + suspend fun invocationId(): String +} + +/** An invocation handle, that can be used to interact with a running invocation. */ +sealed interface InvocationHandle { + /** @return the invocation id of this invocation */ + suspend fun invocationId(): String + + /** Cancel this invocation. */ + suspend fun cancel() + + /** Attach to this invocation. This will wait for the invocation to complete */ + suspend fun attach(): Awaitable + + /** @return the output of this invocation, if present. */ + suspend fun output(): Output +} + /** * An [Awakeable] is a special type of [Awaitable] which can be arbitrarily completed by another * service, by addressing it with its [id]. @@ -475,11 +568,11 @@ sealed interface AwakeableHandle { /** * Complete with success the [Awakeable]. * - * @param serde used to serialize the [Awakeable] result payload. + * @param typeTag used to serialize the [Awakeable] result payload. * @param payload the result payload. * @see Awakeable */ - suspend fun resolve(serde: Serde, payload: T) + suspend fun resolve(typeTag: TypeTag, payload: T) /** * Complete with failure the [Awakeable]. @@ -491,13 +584,13 @@ sealed interface AwakeableHandle { } /** - * Complete with success the [Awakeable] using [KtSerdes.json] serializer. + * Complete with success the [Awakeable]. * * @param payload the result payload. * @see Awakeable */ suspend inline fun AwakeableHandle.resolve(payload: T) { - return this.resolve(KtSerdes.json(), payload) + return this.resolve(typeTag(), payload) } /** @@ -543,3 +636,19 @@ sealed interface DurablePromiseHandle { */ suspend fun reject(reason: String) } + +inline fun durablePromiseKey(name: String): DurablePromiseKey { + return DurablePromiseKey.of(name, typeTag()) +} + +fun Request.Builder.asSendDelayed( + duration: Duration +): SendRequest { + return this.asSendDelayed(duration.toJavaDuration()) +} + +fun Request.asSendDelayed( + duration: Duration +): SendRequest { + return this.asSendDelayed(duration.toJavaDuration()) +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt new file mode 100644 index 000000000..10f170f1e --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt @@ -0,0 +1,230 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin + +import dev.restate.common.Output +import dev.restate.common.Slice +import dev.restate.sdk.endpoint.definition.AsyncResult +import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.types.TerminalException +import dev.restate.sdk.types.TimeoutException +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ExecutionException +import kotlin.time.Duration +import kotlin.time.toJavaDuration +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.future.await +import kotlinx.coroutines.launch + +internal abstract class BaseAwaitableImpl : Awaitable { + abstract fun asyncResult(): AsyncResult + + override val onAwait: SelectClause + get() = SelectClauseImpl(this) + + override suspend fun await(): T { + return asyncResult().poll().await() + } + + override suspend fun await(duration: Duration): T { + return withTimeout(duration).await() + } + + override suspend fun withTimeout(duration: Duration): Awaitable { + return (Awaitable.any( + this, + SingleAwaitableImpl(asyncResult().ctx().timer(duration.toJavaDuration(), null).await())) + as BaseAwaitableImpl<*>) + .simpleMap { + if (it == 1) { + throw TimeoutException("Timed out waiting for awaitable after $duration") + } + + try { + @Suppress("UNCHECKED_CAST") return@simpleMap this.asyncResult().poll().getNow(null) as T + } catch (e: ExecutionException) { + throw e.cause ?: e // unwrap original cause from ExecutionException + } + } + } + + fun simpleMap(transform: (T) -> R): Awaitable { + return SingleAwaitableImpl( + this.asyncResult().map { CompletableFuture.completedFuture(transform(it)) }) + } + + override suspend fun map(transform: suspend (T) -> R): Awaitable { + var ctx = currentCoroutineContext() + return SingleAwaitableImpl( + this.asyncResult().map { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transform(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + }) + } + + override suspend fun map( + transformSuccess: suspend (T) -> R, + transformFailure: suspend (TerminalException) -> R + ): Awaitable { + var ctx = currentCoroutineContext() + return SingleAwaitableImpl( + this.asyncResult() + .map( + { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transformSuccess(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + }, + { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val r: R + try { + r = transformFailure(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(r) + } + completableFuture + })) + } + + override suspend fun mapFailure(transform: suspend (TerminalException) -> T): Awaitable { + var ctx = currentCoroutineContext() + return SingleAwaitableImpl( + this.asyncResult().mapFailure { t -> + val completableFuture = CompletableFuture() + CoroutineScope(ctx).launch { + val newT: T + try { + newT = transform(t) + } catch (throwable: Throwable) { + completableFuture.completeExceptionally(throwable) + return@launch + } + completableFuture.complete(newT) + } + completableFuture + }) + } +} + +internal open class SingleAwaitableImpl(private val asyncResult: AsyncResult) : + BaseAwaitableImpl() { + override fun asyncResult(): AsyncResult { + return asyncResult + } +} + +internal fun wrapAllAwaitable(awaitables: List>): Awaitable { + val ctx = (awaitables.get(0) as BaseAwaitableImpl<*>).asyncResult().ctx() + return SingleAwaitableImpl( + ctx.createAllAsyncResult(awaitables.map { (it as BaseAwaitableImpl<*>).asyncResult() })) + .simpleMap {} +} + +internal fun wrapAnyAwaitable(awaitables: List>): BaseAwaitableImpl { + val ctx = (awaitables.get(0) as BaseAwaitableImpl<*>).asyncResult().ctx() + return SingleAwaitableImpl( + ctx.createAnyAsyncResult(awaitables.map { (it as BaseAwaitableImpl<*>).asyncResult() })) +} + +internal class CallAwaitableImpl +internal constructor( + callAsyncResult: AsyncResult, + private val invocationIdAsyncResult: AsyncResult +) : SingleAwaitableImpl(callAsyncResult), CallAwaitable { + override suspend fun invocationId(): String { + return invocationIdAsyncResult.poll().await() + } +} + +internal abstract class BaseInvocationHandle +internal constructor( + private val handlerContext: HandlerContext, + private val responseSerde: Serde +) : InvocationHandle { + override suspend fun cancel() { + val ignored = handlerContext.cancelInvocation(invocationId()).await() + } + + override suspend fun attach(): Awaitable = + SingleAwaitableImpl( + handlerContext.attachInvocation(invocationId()).await().map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + }) + + override suspend fun output(): Output = + SingleAwaitableImpl(handlerContext.getInvocationOutput(invocationId()).await()) + .simpleMap { it.map { responseSerde.deserialize(it) } } + .await() +} + +internal class AwakeableImpl +internal constructor(asyncResult: AsyncResult, serde: Serde, override val id: String) : + SingleAwaitableImpl( + asyncResult.map { CompletableFuture.completedFuture(serde.deserialize(it)) }), + Awakeable + +internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) : AwakeableHandle { + override suspend fun resolve(typeTag: TypeTag, payload: T) { + contextImpl.handlerContext + .resolveAwakeable(id, contextImpl.resolveAndSerialize(typeTag, payload)) + .await() + } + + override suspend fun reject(reason: String) { + return + contextImpl.handlerContext.rejectAwakeable(id, TerminalException(reason)).await() + } +} + +internal class SelectClauseImpl(override val awaitable: Awaitable) : SelectClause + +@PublishedApi +internal class SelectImplementation : SelectBuilder { + + private val clauses: MutableList, suspend (Any?) -> R>> = + mutableListOf() + + @Suppress("UNCHECKED_CAST") + override fun SelectClause.invoke(block: suspend (T) -> R) { + clauses.add(this.awaitable as BaseAwaitableImpl<*> to block as suspend (Any?) -> R) + } + + suspend fun build(): Awaitable { + return wrapAnyAwaitable(clauses.map { it.first }).map { index -> + clauses[index].let { resolved -> resolved.first.await().let { resolved.second(it) } } + } + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/EnterSideEffectSyscallCallback.java b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/endpoint/endpoint.kt similarity index 57% rename from sdk-common/src/main/java/dev/restate/sdk/common/syscalls/EnterSideEffectSyscallCallback.java rename to sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/endpoint/endpoint.kt index 9559e0b4d..d87ea1df4 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/EnterSideEffectSyscallCallback.java +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/endpoint/endpoint.kt @@ -6,9 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; +package dev.restate.sdk.kotlin.endpoint -public interface EnterSideEffectSyscallCallback extends ExitSideEffectSyscallCallback { +import dev.restate.sdk.endpoint.Endpoint - void onNotExecuted(); +/** Endpoint builder function. */ +fun endpoint(init: Endpoint.Builder.() -> Unit): Endpoint { + val builder = Endpoint.builder() + builder.init() + return builder.build() } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt deleted file mode 100644 index c61bb3541..000000000 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ingress.kt +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin - -import dev.restate.sdk.client.Client -import dev.restate.sdk.client.RequestOptions -import dev.restate.sdk.client.SendResponse -import dev.restate.sdk.common.Output -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.Target -import kotlin.time.Duration -import kotlin.time.toJavaDuration -import kotlinx.coroutines.future.await - -// Extension methods for the IngressClient - -suspend fun Client.callSuspend( - target: Target, - reqSerde: Serde, - resSerde: Serde, - req: Req, - options: RequestOptions = RequestOptions.DEFAULT -): Res { - return this.callAsync(target, reqSerde, resSerde, req, options).await() -} - -suspend fun Client.sendSuspend( - target: Target, - reqSerde: Serde, - req: Req, - delay: Duration = Duration.ZERO, - options: RequestOptions = RequestOptions.DEFAULT -): SendResponse { - return this.sendAsync(target, reqSerde, req, delay.toJavaDuration(), options).await() -} - -suspend fun Client.AwakeableHandle.resolveSuspend( - serde: Serde, - payload: T, - options: RequestOptions = RequestOptions.DEFAULT -) { - this.resolveAsync(serde, payload, options).await() -} - -suspend fun Client.AwakeableHandle.rejectSuspend( - reason: String, - options: RequestOptions = RequestOptions.DEFAULT -) { - this.rejectAsync(reason, options).await() -} - -suspend fun Client.InvocationHandle.attachSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): T { - return this.attachAsync(options).await() -} - -suspend fun Client.InvocationHandle.getOutputSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): Output { - return this.getOutputAsync(options).await() -} - -suspend fun Client.IdempotentInvocationHandle.attachSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): T { - return this.attachAsync(options).await() -} - -suspend fun Client.IdempotentInvocationHandle.getOutputSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): Output { - return this.getOutputAsync(options).await() -} - -suspend fun Client.WorkflowHandle.attachSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): T { - return this.attachAsync(options).await() -} - -suspend fun Client.WorkflowHandle.getOutputSuspend( - options: RequestOptions = RequestOptions.DEFAULT -): Output { - return this.getOutputAsync(options).await() -} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/KotlinSerializationSerdeFactory.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/KotlinSerializationSerdeFactory.kt new file mode 100644 index 000000000..7a9f6faa7 --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/KotlinSerializationSerdeFactory.kt @@ -0,0 +1,212 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.serialization + +import dev.restate.common.Slice +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeRef +import dev.restate.serde.TypeTag +import java.nio.charset.StandardCharsets +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlinx.serialization.* +import kotlinx.serialization.builtins.* +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonTransformingSerializer +import kotlinx.serialization.modules.SerializersModule + +class KotlinSerializationSerdeFactory +@JvmOverloads +constructor(private val json: Json = Json.Default) : SerdeFactory { + + @PublishedApi + internal class KtTypeTag( + internal val type: KClass<*>, + /** Reified type */ + internal val kotlinType: KType? + ) : TypeTag + + override fun create(typeTag: TypeTag): Serde { + if (typeTag is KtTypeTag) { + return create(typeTag) + } + return super.create(typeTag) + } + + @Suppress("UNCHECKED_CAST") + override fun create(typeRef: TypeRef): Serde { + if (typeRef.type == Unit::class.java) { + return UNIT as Serde + } + val serializer: KSerializer = + json.serializersModule.serializer(typeRef.type) as KSerializer + return jsonSerde(json, serializer) + } + + @Suppress("UNCHECKED_CAST") + override fun create(clazz: Class): Serde { + if (clazz == Unit::class.java) { + return UNIT as Serde + } + val serializer: KSerializer = json.serializersModule.serializer(clazz) as KSerializer + return jsonSerde(json, serializer) + } + + @Suppress("UNCHECKED_CAST") + @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) + private fun create(ktSerdeInfo: KtTypeTag): Serde { + if (ktSerdeInfo.type == Unit::class) { + return UNIT as Serde + } + val serializer: KSerializer = + json.serializersModule.serializerForKtTypeInfo(ktSerdeInfo) as KSerializer + return jsonSerde(json, serializer) + } + + companion object { + val UNIT: Serde = + object : Serde { + override fun serialize(value: Unit?): Slice { + return Slice.EMPTY + } + + override fun deserialize(value: Slice) { + return + } + + override fun contentType(): String? { + return null + } + } + + /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ + fun jsonSerde(json: Json = Json.Default, serializer: KSerializer): Serde { + return object : Serde { + @Suppress("WRONG_NULLABILITY_FOR_JAVA_OVERRIDE") + override fun serialize(value: T?): Slice { + if (value == null) { + return Slice.wrap(json.encodeToString(JsonNull.serializer(), JsonNull)) + } + + return Slice.wrap(json.encodeToString(serializer, value)) + } + + override fun deserialize(value: Slice): T { + return json.decodeFromString( + serializer, String(value.toByteArray(), StandardCharsets.UTF_8)) + } + + override fun contentType(): String { + return "application/json" + } + + override fun jsonSchema(): Serde.Schema { + val schema: JsonSchema = serializer.descriptor.jsonSchema() + return Serde.StringifiedJsonSchema(Json.encodeToString(schema)) + } + } + } + + @Serializable + @PublishedApi + internal data class JsonSchema( + @Serializable(with = StringListSerializer::class) val type: List? = null, + val format: String? = null, + ) { + companion object { + val INT = JsonSchema(type = listOf("number"), format = "int32") + + val LONG = JsonSchema(type = listOf("number"), format = "int64") + + val DOUBLE = JsonSchema(type = listOf("number"), format = "double") + + val FLOAT = JsonSchema(type = listOf("number"), format = "float") + + val STRING = JsonSchema(type = listOf("string")) + + val BOOLEAN = JsonSchema(type = listOf("boolean")) + + val OBJECT = JsonSchema(type = listOf("object")) + + val LIST = JsonSchema(type = listOf("array")) + + val ANY = JsonSchema() + } + } + + object StringListSerializer : + JsonTransformingSerializer>(ListSerializer(String.Companion.serializer())) { + override fun transformSerialize(element: JsonElement): JsonElement { + require(element is JsonArray) + return element.singleOrNull() ?: element + } + } + + /** + * Super simplistic json schema generation. We should replace this with an appropriate library. + */ + @OptIn(ExperimentalSerializationApi::class) + @PublishedApi + internal fun SerialDescriptor.jsonSchema(): JsonSchema { + var schema = + when (this.kind) { + PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN + PrimitiveKind.BYTE -> JsonSchema.INT + PrimitiveKind.CHAR -> JsonSchema.STRING + PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE + PrimitiveKind.FLOAT -> JsonSchema.FLOAT + PrimitiveKind.INT -> JsonSchema.INT + PrimitiveKind.LONG -> JsonSchema.LONG + PrimitiveKind.SHORT -> JsonSchema.INT + PrimitiveKind.STRING -> JsonSchema.STRING + StructureKind.LIST -> JsonSchema.LIST + StructureKind.MAP -> JsonSchema.OBJECT + else -> JsonSchema.ANY + } + + // Add nullability constraint + if (this.isNullable && schema.type != null) { + schema = schema.copy(type = schema.type.plus("null")) + } + + return schema + } + } + + @InternalSerializationApi + @ExperimentalSerializationApi + /** Copy-pasted from ktor! */ + private fun SerializersModule.serializerForKtTypeInfo( + ktSerdeInfoInfo: KtTypeTag<*> + ): KSerializer<*> { + val module = this + return ktSerdeInfoInfo.kotlinType?.let { type -> + if (type.arguments.isEmpty()) { + null // fallback to a simple case because of + // https://github.com/Kotlin/kotlinx.serialization/issues/1870 + } else { + module.serializerOrNull(type) + } + } + ?: module.getContextual(ktSerdeInfoInfo.type)?.maybeNullable(ktSerdeInfoInfo) + ?: ktSerdeInfoInfo.type.serializer().maybeNullable(ktSerdeInfoInfo) + } + + private fun KSerializer.maybeNullable(typeInfo: KtTypeTag<*>): KSerializer<*> { + return if (typeInfo.kotlinType?.isMarkedNullable == true) this.nullable else this + } +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/api.kt new file mode 100644 index 000000000..b44d87148 --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/serialization/api.kt @@ -0,0 +1,28 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.serialization + +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import kotlin.reflect.typeOf +import kotlinx.serialization.json.* +import kotlinx.serialization.serializer + +/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ +inline fun jsonSerde(json: Json = Json.Default): Serde { + @Suppress("UNCHECKED_CAST") + return when (typeOf()) { + typeOf() -> KotlinSerializationSerdeFactory.UNIT as Serde + else -> KotlinSerializationSerdeFactory.jsonSerde(json, serializer()) + } +} + +/** Kotlin specific [TypeTag], using Kotlin's reified generics. */ +inline fun typeTag(): TypeTag = + KotlinSerializationSerdeFactory.KtTypeTag(T::class, typeOf()) diff --git a/sdk-api/build.gradle.kts b/sdk-api/build.gradle.kts index 4a3bab446..eb3044289 100644 --- a/sdk-api/build.gradle.kts +++ b/sdk-api/build.gradle.kts @@ -1,7 +1,6 @@ plugins { `java-conventions` `java-library` - `test-jar-conventions` `library-publishing-conventions` } @@ -11,18 +10,7 @@ dependencies { compileOnly(libs.jspecify) api(project(":sdk-common")) + api(project(":sdk-serde-jackson")) implementation(libs.log4j.api) - - implementation(libs.jackson.core) - - testImplementation(project(":sdk-core")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - testImplementation(libs.protobuf.java) - testImplementation(libs.log4j.core) - testImplementation(libs.mutiny) - - // Import test suites from sdk-core - testImplementation(project(":sdk-core", "testArchive")) } diff --git a/sdk-api/src/main/java/dev/restate/sdk/AnyAwaitable.java b/sdk-api/src/main/java/dev/restate/sdk/AnyAwaitable.java deleted file mode 100644 index 654a699cd..000000000 --- a/sdk-api/src/main/java/dev/restate/sdk/AnyAwaitable.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; - -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Result; -import dev.restate.sdk.common.syscalls.Syscalls; -import java.util.List; - -public final class AnyAwaitable extends Awaitable.MappedAwaitable { - - @SuppressWarnings({"unchecked", "rawtypes"}) - AnyAwaitable(Syscalls syscalls, Deferred deferred, List> nested) { - super( - new SingleAwaitable<>(syscalls, deferred), - res -> - res.isSuccess() - ? (Result) nested.get(res.getValue()).awaitResult() - : (Result) res); - } - - /** Same as {@link #await()}, but returns the index. */ - public int awaitIndex() { - // This cast is safe b/c of the constructor - return (int) Util.blockOnResolve(this.syscalls, this.deferred()); - } -} diff --git a/sdk-api/src/main/java/dev/restate/sdk/Awaitable.java b/sdk-api/src/main/java/dev/restate/sdk/Awaitable.java index 9f0825920..674fd0b12 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Awaitable.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Awaitable.java @@ -8,19 +8,18 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.function.ThrowingFunction; -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Result; -import dev.restate.sdk.common.syscalls.Syscalls; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.TerminalException; +import dev.restate.sdk.types.TimeoutException; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeoutException; -import java.util.function.Function; +import java.util.concurrent.Executor; import java.util.stream.Collectors; /** @@ -37,15 +36,9 @@ */ public abstract class Awaitable { - protected final Syscalls syscalls; + protected abstract AsyncResult asyncResult(); - Awaitable(Syscalls syscalls) { - this.syscalls = syscalls; - } - - protected abstract Deferred deferred(); - - protected abstract Result awaitResult(); + protected abstract Executor serviceExecutor(); /** * Wait for the current awaitable to complete. Executing this method may trigger the suspension of @@ -57,42 +50,116 @@ public abstract class Awaitable { * @throws TerminalException if the awaitable is ready and contains a failure */ public final T await() throws TerminalException { - return Util.unwrapResult(this.awaitResult()); + return Util.awaitCompletableFuture(asyncResult().poll()); } /** * Same as {@link #await()}, but throws a {@link TimeoutException} if this {@link Awaitable} * doesn't complete before the provided {@code timeout}. */ - public final T await(Duration timeout) throws TerminalException, TimeoutException { - Deferred sleep = Util.blockOnSyscall(cb -> this.syscalls.sleep(timeout, cb)); - Awaitable sleepAwaitable = single(this.syscalls, sleep); - - int index = any(this, sleepAwaitable).awaitIndex(); + public final T await(Duration timeout) throws TerminalException { + return this.withTimeout(timeout).await(); + } - if (index == 1) { - throw new TimeoutException(); - } - // This await is no-op now - return this.await(); + /** + * @return an Awaitable that throws a {@link TerminalException} if this awaitable doesn't complete + * before the provided {@code timeout}. + */ + public final Awaitable withTimeout(Duration timeout) { + return any( + this, + fromAsyncResult( + Util.awaitCompletableFuture(asyncResult().ctx().timer(timeout, null)), + this.serviceExecutor())) + .mapWithoutExecutor( + i -> { + if (i == 1) { + throw new TimeoutException("Timed out waiting for awaitable after " + timeout); + } + return this.await(); + }); } /** Map the result of this {@link Awaitable}. */ public final Awaitable map(ThrowingFunction mapper) { - return new MappedAwaitable<>( - this, - result -> { - if (result.isSuccess()) { - return Result.success( - Util.executeMappingException(this.syscalls, mapper, result.getValue())); - } - //noinspection unchecked - return (Result) result; - }); + return fromAsyncResult( + asyncResult() + .map( + t -> + CompletableFuture.supplyAsync( + () -> { + try { + return mapper.apply(t); + } catch (Throwable e) { + Util.sneakyThrow(e); + return null; + } + }, + serviceExecutor()), + null), + this.serviceExecutor()); } - static Awaitable single(Syscalls syscalls, Deferred deferred) { - return new SingleAwaitable<>(syscalls, deferred); + public final Awaitable map( + ThrowingFunction successMapper, ThrowingFunction failureMapper) { + return fromAsyncResult( + asyncResult() + .map( + t -> + CompletableFuture.supplyAsync( + () -> { + try { + return successMapper.apply(t); + } catch (Throwable e) { + Util.sneakyThrow(e); + return null; + } + }, + serviceExecutor()), + t -> + CompletableFuture.supplyAsync( + () -> { + try { + return failureMapper.apply(t); + } catch (Throwable e) { + Util.sneakyThrow(e); + return null; + } + }, + serviceExecutor())), + this.serviceExecutor()); + } + + public final Awaitable mapFailure(ThrowingFunction failureMapper) { + return fromAsyncResult( + asyncResult() + .mapFailure( + t -> + CompletableFuture.supplyAsync( + () -> { + try { + return failureMapper.apply(t); + } catch (Throwable e) { + Util.sneakyThrow(e); + return null; + } + }, + serviceExecutor())), + this.serviceExecutor()); + } + + /** + * Map without executor switching. This is an optimization used only internally for operations + * safe to perform without switching executor. + */ + final Awaitable mapWithoutExecutor(ThrowingFunction mapper) { + return fromAsyncResult( + asyncResult().map(i -> CompletableFuture.completedFuture(mapper.apply(i)), null), + this.serviceExecutor()); + } + + static Awaitable fromAsyncResult(AsyncResult asyncResult, Executor serviceExecutor) { + return new SingleAwaitable<>(asyncResult, serviceExecutor); } /** @@ -101,17 +168,13 @@ static Awaitable single(Syscalls syscalls, Deferred deferred) { *

The behavior is the same as {@link * java.util.concurrent.CompletableFuture#anyOf(CompletableFuture[])}. */ - public static AnyAwaitable any(Awaitable first, Awaitable second, Awaitable... others) { + public static Awaitable any( + Awaitable first, Awaitable second, Awaitable... others) { List> awaitables = new ArrayList<>(2 + others.length); awaitables.add(first); awaitables.add(second); awaitables.addAll(Arrays.asList(others)); - - return new AnyAwaitable( - first.syscalls, - first.syscalls.createAnyDeferred( - awaitables.stream().map(Awaitable::deferred).collect(Collectors.toList())), - awaitables); + return any(awaitables); } /** @@ -122,18 +185,14 @@ public static AnyAwaitable any(Awaitable first, Awaitable second, Awaitabl *

The behavior is the same as {@link * java.util.concurrent.CompletableFuture#anyOf(CompletableFuture[])}. */ - public static AnyAwaitable any(List> awaitables) { + public static Awaitable any(List> awaitables) { if (awaitables.isEmpty()) { throw new IllegalArgumentException("Awaitable any doesn't support an empty list"); } - return new AnyAwaitable( - awaitables.get(0).syscalls, - awaitables - .get(0) - .syscalls - .createAnyDeferred( - awaitables.stream().map(Awaitable::deferred).collect(Collectors.toList())), - awaitables); + List> ars = + awaitables.stream().map(Awaitable::asyncResult).collect(Collectors.toList()); + HandlerContext ctx = ars.get(0).ctx(); + return fromAsyncResult(ctx.createAnyAsyncResult(ars), awaitables.get(0).serviceExecutor()); } /** @@ -144,12 +203,12 @@ public static AnyAwaitable any(List> awaitables) { */ public static Awaitable all( Awaitable first, Awaitable second, Awaitable... others) { - List> deferred = new ArrayList<>(2 + others.length); - deferred.add(first.deferred()); - deferred.add(second.deferred()); - Arrays.stream(others).map(Awaitable::deferred).forEach(deferred::add); + List> awaitables = new ArrayList<>(2 + others.length); + awaitables.add(first); + awaitables.add(second); + awaitables.addAll(Arrays.asList(others)); - return single(first.syscalls, first.syscalls.createAllDeferred(deferred)); + return all(awaitables); } /** @@ -165,68 +224,33 @@ public static Awaitable all(List> awaitables) { throw new IllegalArgumentException("Awaitable all doesn't support an empty list"); } if (awaitables.size() == 1) { - return awaitables.get(0).map(unused -> null); + return awaitables.get(0).mapWithoutExecutor(unused -> null); } else { - return single( - awaitables.get(0).syscalls, - awaitables - .get(0) - .syscalls - .createAllDeferred( - awaitables.stream().map(Awaitable::deferred).collect(Collectors.toList()))); - } - } - - static class SingleAwaitable extends Awaitable { - - private final Deferred deferred; - private Result result; - - SingleAwaitable(Syscalls syscalls, Deferred deferred) { - super(syscalls); - this.deferred = deferred; - } - - @Override - protected Deferred deferred() { - return this.deferred; - } - - @Override - protected Result awaitResult() { - if (!this.deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(this.deferred, cb)); - } - if (this.result == null) { - this.result = this.deferred.toResult(); - } - return this.result; + List> ars = + awaitables.stream().map(Awaitable::asyncResult).collect(Collectors.toList()); + HandlerContext ctx = ars.get(0).ctx(); + return fromAsyncResult(ctx.createAllAsyncResult(ars), awaitables.get(0).serviceExecutor()); } } - static class MappedAwaitable extends Awaitable { + static final class SingleAwaitable extends Awaitable { - private final Awaitable inner; - private final Function, Result> mapper; - private Result mappedResult; + private final AsyncResult asyncResult; + private final Executor serviceExecutor; - MappedAwaitable(Awaitable inner, Function, Result> mapper) { - super(inner.syscalls); - this.inner = inner; - this.mapper = mapper; + SingleAwaitable(AsyncResult asyncResult, Executor serviceExecutor) { + this.asyncResult = asyncResult; + this.serviceExecutor = serviceExecutor; } @Override - protected Deferred deferred() { - return inner.deferred(); + protected AsyncResult asyncResult() { + return this.asyncResult; } @Override - public Result awaitResult() throws TerminalException { - if (mappedResult == null) { - this.mappedResult = this.mapper.apply(this.inner.awaitResult()); - } - return this.mappedResult; + protected Executor serviceExecutor() { + return serviceExecutor; } } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java index 2000628ee..3ffd6ffcf 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java @@ -8,11 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Result; -import dev.restate.sdk.common.syscalls.Syscalls; -import java.nio.ByteBuffer; +import dev.restate.common.Slice; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.serde.Serde; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; /** * An {@link Awakeable} is a special type of {@link Awaitable} which can be arbitrarily completed by @@ -28,22 +28,18 @@ *

NOTE: This interface MUST NOT be accessed concurrently since it can lead to different * orderings of user actions, corrupting the execution of the invocation. */ -public final class Awakeable extends Awaitable.MappedAwaitable { +public final class Awakeable extends Awaitable { private final String identifier; + private final AsyncResult asyncResult; + private final Executor serviceExecutor; - Awakeable(Syscalls syscalls, Deferred deferred, Serde serde, String identifier) { - super( - Awaitable.single(syscalls, deferred), - res -> { - if (res.isSuccess()) { - return Result.success( - Util.deserializeWrappingException(syscalls, serde, res.getValue())); - } - //noinspection unchecked - return (Result) res; - }); + Awakeable( + AsyncResult asyncResult, Executor serviceExecutor, Serde serde, String identifier) { this.identifier = identifier; + this.asyncResult = + asyncResult.map(s -> CompletableFuture.completedFuture(serde.deserialize(s))); + this.serviceExecutor = serviceExecutor; } /** @@ -52,4 +48,14 @@ public final class Awakeable extends Awaitable.MappedAwaitable public String id() { return identifier; } + + @Override + protected AsyncResult asyncResult() { + return asyncResult; + } + + @Override + protected Executor serviceExecutor() { + return serviceExecutor; + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/AwakeableHandle.java b/sdk-api/src/main/java/dev/restate/sdk/AwakeableHandle.java index eb8a8ecfb..01dee46f9 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/AwakeableHandle.java +++ b/sdk-api/src/main/java/dev/restate/sdk/AwakeableHandle.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.Serde; +import dev.restate.serde.TypeTag; /** This class represents a handle to an {@link Awakeable} created in another service. */ public interface AwakeableHandle { @@ -16,11 +16,22 @@ public interface AwakeableHandle { /** * Complete with success the {@link Awakeable}. * - * @param serde used to serialize the {@link Awakeable} result payload. + * @param typeTag used to serialize the {@link Awakeable} result payload. * @param payload the result payload. MUST NOT be null. * @see Awakeable */ - void resolve(Serde serde, T payload); + void resolve(TypeTag typeTag, T payload); + + /** + * Complete with success the {@link Awakeable}. + * + * @param clazz used to serialize the {@link Awakeable} result payload. + * @param payload the result payload. MUST NOT be null. + * @see Awakeable + */ + default void resolve(Class clazz, T payload) { + resolve(TypeTag.of(clazz), payload); + } /** * Complete with failure the {@link Awakeable}. diff --git a/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java b/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java new file mode 100644 index 000000000..145f5fce9 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java @@ -0,0 +1,52 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import java.util.concurrent.Executor; + +/** {@link Awaitable} returned by a call to another service. */ +public final class CallAwaitable extends Awaitable { + + private final HandlerContext context; + private final AsyncResult asyncResult; + private final Awaitable invocationIdAwaitable; + + CallAwaitable( + HandlerContext context, + AsyncResult callAsyncResult, + Awaitable invocationIdAwaitable) { + this.context = context; + this.asyncResult = callAsyncResult; + this.invocationIdAwaitable = invocationIdAwaitable; + } + + /** + * @return the unique identifier of this {@link CallAwaitable} instance. + */ + public String invocationId() { + return this.invocationIdAwaitable.await(); + } + + /** Cancel this invocation */ + public void cancel() { + Util.awaitCompletableFuture(context.cancelInvocation(invocationId())); + } + + @Override + protected AsyncResult asyncResult() { + return asyncResult; + } + + @Override + protected Executor serviceExecutor() { + return invocationIdAwaitable.serviceExecutor(); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index 6161aaeab..dbde96f53 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -8,9 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.*; -import dev.restate.sdk.common.function.ThrowingRunnable; -import dev.restate.sdk.common.function.ThrowingSupplier; +import dev.restate.common.Request; +import dev.restate.common.Slice; +import dev.restate.common.function.ThrowingRunnable; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.HandlerRequest; +import dev.restate.sdk.types.RetryPolicy; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; import java.time.Duration; /** @@ -27,54 +34,50 @@ */ public interface Context { - Request request(); + HandlerRequest request(); /** * Invoke another Restate service method. * - * @param target the address of the callee - * @param inputSerde Input serde - * @param outputSerde Output serde - * @param parameter the invocation request parameter. + * @param request request * @return an {@link Awaitable} that wraps the Restate service method result. */ - Awaitable call(Target target, Serde inputSerde, Serde outputSerde, T parameter); + CallAwaitable call(Request request); - /** Like {@link #call(Target, Serde, Serde, Object)} with raw input/output. */ - default Awaitable call(Target target, byte[] parameter) { - return call(target, Serde.RAW, Serde.RAW, parameter); + /** Like {@link #call(Request)} */ + default CallAwaitable call(Request.Builder requestBuilder) { + return call(requestBuilder.build()); } /** * Invoke another Restate service without waiting for the response. * - * @param target the address of the callee - * @param inputSerde Input serde - * @param parameter the invocation request parameter. + * @param request request + * @return an {@link InvocationHandle} that can be used to retrieve the invocation id, cancel the + * invocation, attach to its result. */ - void send(Target target, Serde inputSerde, T parameter); + InvocationHandle send(Request request); - /** Like {@link #send(Target, Serde, Object)} with bytes input. */ - default void send(Target target, byte[] parameter) { - send(target, Serde.RAW, parameter); + /** Like {@link #send(Request)} */ + default InvocationHandle send(Request.Builder requestBuilder) { + return send(requestBuilder.asSend()); } + InvocationHandle invocationHandle(String invocationId, TypeTag responseTypeTag); + /** - * Invoke another Restate service without waiting for the response after the provided {@code - * delay} has elapsed. - * - *

This method returns immediately, as the timer is executed and awaited on Restate. + * Get an {@link InvocationHandle} for an already existing invocation. This will let you interact + * with a running invocation, for example to cancel it or retrieve its result. * - * @param target the address of the callee - * @param inputSerde Input serde - * @param parameter the invocation request parameter. - * @param delay time to wait before executing the call. + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. */ - void send(Target target, Serde inputSerde, T parameter, Duration delay); + default InvocationHandle invocationHandle(String invocationId, Class responseClazz) { + return invocationHandle(invocationId, TypeTag.of(responseClazz)); + } - /** Like {@link #send(Target, Serde, Object, Duration)} with bytes input. */ - default void send(Target target, byte[] parameter, Duration delay) { - send(target, Serde.RAW, parameter, delay); + default InvocationHandle invocationHandle(String invocationId) { + return invocationHandle(invocationId, Serde.SLICE); } /** @@ -92,7 +95,48 @@ default void sleep(Duration duration) { * * @param duration for which to sleep. */ - Awaitable timer(Duration duration); + default Awaitable timer(Duration duration) { + return timer(null, duration); + } + + /** + * Causes the start of a timer for the given duration. You can await on the timer end by invoking + * {@link Awaitable#await()}. + * + * @param name name used for observability + * @param duration for which to sleep. + */ + Awaitable timer(String name, Duration duration); + + /** + * Like {@link #run(String, TypeTag, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see RetryPolicy + */ + default T run( + String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return runAsync(name, typeTag, retryPolicy, action).await(); + } + + /** + * Like {@link #run(String, Class, ThrowingSupplier)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see RetryPolicy + */ + default T run( + String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return run(name, TypeTag.of(clazz), retryPolicy, action); + } /** * Execute a non-deterministic closure, recording the result value in the journal. The result @@ -105,7 +149,7 @@ default void sleep(Duration duration) { * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. You can control and limit the amount of retries using {@link - * #run(String, Serde, RetryPolicy, ThrowingSupplier)}. + * #run(String, TypeTag, RetryPolicy, ThrowingSupplier)}. * *

Error handling: Errors occurring within this closure won't be propagated to the * caller, unless they are {@link TerminalException}. Consider the following code: @@ -136,27 +180,75 @@ default void sleep(Duration duration) { * TerminalException}. * * @param name name of the side effect. - * @param serde the type tag of the return value, used to serialize/deserialize it. + * @param typeTag the type tag of the return value, used to serialize/deserialize it. * @param action closure to execute. * @param type of the return value. * @return value of the run operation. */ - default T run(String name, Serde serde, ThrowingSupplier action) + default T run(String name, TypeTag typeTag, ThrowingSupplier action) throws TerminalException { - return run(name, serde, null, action); + return run(name, typeTag, null, action); } /** - * Like {@link #run(String, Serde, ThrowingSupplier)}, but using a custom retry policy. + * Execute a non-deterministic closure, recording the result value in the journal. The result + * value will be re-played in case of re-invocation (e.g. because of failure recovery or + * suspension point) without re-executing the closure. Use this feature if you want to perform + * non-deterministic operations. * - *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, - * which by default retries indefinitely. + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. * - * @see RetryPolicy + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. You can control and limit the amount of retries using {@link + * #run(String, Class, RetryPolicy, ThrowingSupplier)}. + * + *

Error handling: Errors occurring within this closure won't be propagated to the + * caller, unless they are {@link TerminalException}. Consider the following code: + * + *

{@code
+   * // Bad usage of try-catch outside the run
+   * try {
+   *     ctx.run(() -> {
+   *         throw new IllegalStateException();
+   *     }).await();
+   * } catch (IllegalStateException e) {
+   *     // This will never be executed,
+   *     // but the error will be retried by Restate,
+   *     // following the invocation retry policy.
+   * }
+   *
+   * // Good usage of try-catch outside the run
+   * try {
+   *     ctx.run(() -> {
+   *         throw new TerminalException("my error");
+   *     }).await();
+   * } catch (TerminalException e) {
+   *     // This is invoked
+   * }
+   * }
+ * + * To propagate run failures to the call-site, make sure to wrap them in {@link + * TerminalException}. + * + * @param name name of the side effect. + * @param clazz the class of the return value, used to serialize/deserialize it. + * @param action closure to execute. + * @param type of the return value. + * @return value of the run operation. */ - T run(String name, Serde serde, RetryPolicy retryPolicy, ThrowingSupplier action) - throws TerminalException; + default T run(String name, Class clazz, ThrowingSupplier action) + throws TerminalException { + return run(name, TypeTag.of(clazz), action); + } + + default T run(TypeTag typeTag, ThrowingSupplier action) throws TerminalException { + return run(null, typeTag, null, action); + } + + default T run(Class clazz, ThrowingSupplier action) throws TerminalException { + return run(TypeTag.of(clazz), action); + } /** * Like {@link #run(String, ThrowingRunnable)}, but using a custom retry policy. @@ -179,8 +271,18 @@ default void run(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable }); } + /** Like {@link #run(String, Class, ThrowingSupplier)} without output. */ + default void run(String name, ThrowingRunnable runnable) throws TerminalException { + run(name, null, runnable); + } + + /** Like {@link #run(Class, ThrowingSupplier)} without output. */ + default void run(ThrowingRunnable runnable) throws TerminalException { + run(null, runnable); + } + /** - * Like {@link #run(Serde, ThrowingSupplier)}, but using a custom retry policy. + * Like {@link #runAsync(String, TypeTag, ThrowingSupplier)}, but using a custom retry policy. * *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, @@ -188,13 +290,12 @@ default void run(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable * * @see RetryPolicy */ - default T run(Serde serde, RetryPolicy retryPolicy, ThrowingSupplier action) - throws TerminalException { - return run(null, serde, retryPolicy, action); - } + Awaitable runAsync( + String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException; /** - * Like {@link #run(ThrowingRunnable)}, but using a custom retry policy. + * Like {@link #runAsync(String, Class, ThrowingSupplier)}, but using a custom retry policy. * *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, @@ -202,29 +303,173 @@ default T run(Serde serde, RetryPolicy retryPolicy, ThrowingSupplier a * * @see RetryPolicy */ - default void run(RetryPolicy retryPolicy, ThrowingRunnable runnable) throws TerminalException { - run(null, retryPolicy, runnable); + default Awaitable runAsync( + String name, Class clazz, RetryPolicy retryPolicy, ThrowingSupplier action) + throws TerminalException { + return runAsync(name, TypeTag.of(clazz), retryPolicy, action); } - /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without returning a value. */ - default void run(String name, ThrowingRunnable runnable) throws TerminalException { - run( + /** + * Execute a non-deterministic closure, recording the result value in the journal. The result + * value will be re-played in case of re-invocation (e.g. because of failure recovery or + * suspension point) without re-executing the closure. Use this feature if you want to perform + * non-deterministic operations. + * + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. + * + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. You can control and limit the amount of retries using {@link + * #runAsync(String, TypeTag, RetryPolicy, ThrowingSupplier)}. + * + *

Error handling: Errors occurring within this closure won't be propagated to the + * caller, unless they are {@link TerminalException}. Consider the following code: + * + *

{@code
+   * // Bad usage of try-catch outside the run
+   * try {
+   *     ctx.runAsync(() -> {
+   *         throw new IllegalStateException();
+   *     }).await();
+   * } catch (IllegalStateException e) {
+   *     // This will never be executed,
+   *     // but the error will be retried by Restate,
+   *     // following the invocation retry policy.
+   * }
+   *
+   * // Good usage of try-catch outside the run
+   * try {
+   *     ctx.runAsync(() -> {
+   *         throw new TerminalException("my error");
+   *     }).await();
+   * } catch (TerminalException e) {
+   *     // This is invoked
+   * }
+   * }
+ * + * To propagate run failures to the call-site, make sure to wrap them in {@link + * TerminalException}. + * + * @param name name of the side effect. + * @param typeTag the type tag of the return value, used to serialize/deserialize it. + * @param action closure to execute. + * @param type of the return value. + * @return value of the run operation. + */ + default Awaitable runAsync(String name, TypeTag typeTag, ThrowingSupplier action) + throws TerminalException { + return runAsync(name, typeTag, null, action); + } + + /** + * Execute a non-deterministic closure, recording the result value in the journal. The result + * value will be re-played in case of re-invocation (e.g. because of failure recovery or + * suspension point) without re-executing the closure. Use this feature if you want to perform + * non-deterministic operations. + * + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. + * + *

The closure should tolerate retries, that is Restate might re-execute the closure multiple + * times until it records a result. You can control and limit the amount of retries using {@link + * #runAsync(String, Class, RetryPolicy, ThrowingSupplier)}. + * + *

Error handling: Errors occurring within this closure won't be propagated to the + * caller, unless they are {@link TerminalException}. Consider the following code: + * + *

{@code
+   * // Bad usage of try-catch outside the run
+   * try {
+   *     ctx.runAsync(() -> {
+   *         throw new IllegalStateException();
+   *     }).await();
+   * } catch (IllegalStateException e) {
+   *     // This will never be executed,
+   *     // but the error will be retried by Restate,
+   *     // following the invocation retry policy.
+   * }
+   *
+   * // Good usage of try-catch outside the run
+   * try {
+   *     ctx.runAsync(() -> {
+   *         throw new TerminalException("my error");
+   *     }).await();
+   * } catch (TerminalException e) {
+   *     // This is invoked
+   * }
+   * }
+ * + * To propagate run failures to the call-site, make sure to wrap them in {@link + * TerminalException}. + * + * @param name name of the side effect. + * @param clazz the class of the return value, used to serialize/deserialize it. + * @param action closure to execute. + * @param type of the return value. + * @return value of the run operation. + */ + default Awaitable runAsync(String name, Class clazz, ThrowingSupplier action) + throws TerminalException { + return runAsync(name, TypeTag.of(clazz), action); + } + + default Awaitable runAsync(TypeTag typeTag, ThrowingSupplier action) + throws TerminalException { + return runAsync(null, typeTag, null, action); + } + + default Awaitable runAsync(Class clazz, ThrowingSupplier action) + throws TerminalException { + return runAsync(TypeTag.of(clazz), action); + } + + /** + * Like {@link #runAsync(String, ThrowingRunnable)}, but using a custom retry policy. + * + *

When a retry policy is not specified, the {@code run} will be retried using the Restate invoker retry policy, + * which by default retries indefinitely. + * + * @see RetryPolicy + */ + default Awaitable runAsync(String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) + throws TerminalException { + return runAsync( name, Serde.VOID, + retryPolicy, () -> { runnable.run(); return null; }); } - /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without a name. */ - default T run(Serde serde, ThrowingSupplier action) throws TerminalException { - return run(null, serde, action); + /** Like {@link #runAsync(String, Class, ThrowingSupplier)} without output. */ + default Awaitable runAsync(String name, ThrowingRunnable runnable) + throws TerminalException { + return runAsync(name, null, runnable); } - /** Like {@link #run(String, ThrowingRunnable)}, but without a name. */ - default void run(ThrowingRunnable runnable) throws TerminalException { - run((String) null, runnable); + /** Like {@link #runAsync(Class, ThrowingSupplier)} without output. */ + default Awaitable runAsync(ThrowingRunnable runnable) throws TerminalException { + return runAsync(null, runnable); + } + + /** + * Create an {@link Awakeable}, addressable through {@link Awakeable#id()}. + * + *

You can use this feature to implement external asynchronous systems interactions, for + * example you can send a Kafka record including the {@link Awakeable#id()}, and then let another + * service consume from Kafka the responses of given external system interaction by using {@link + * #awakeableHandle(String)}. + * + * @param clazz the response type to use for deserializing the {@link Awakeable} result. When + * using generic types, use {@link #awakeable(TypeTag)} instead. + * @return the {@link Awakeable} to await on. + * @see Awakeable + */ + default Awakeable awakeable(Class clazz) { + return awakeable(TypeTag.of(clazz)); } /** @@ -235,15 +480,15 @@ default void run(ThrowingRunnable runnable) throws TerminalException { * service consume from Kafka the responses of given external system interaction by using {@link * #awakeableHandle(String)}. * - * @param serde the response type tag to use for deserializing the {@link Awakeable} result. + * @param typeTag the response type tag to use for deserializing the {@link Awakeable} result. * @return the {@link Awakeable} to await on. * @see Awakeable */ - Awakeable awakeable(Serde serde); + Awakeable awakeable(TypeTag typeTag); /** * Create a new {@link AwakeableHandle} for the provided identifier. You can use it to {@link - * AwakeableHandle#resolve(Serde, Object)} or {@link AwakeableHandle#reject(String)} the linked + * AwakeableHandle#resolve(TypeTag, Object)} or {@link AwakeableHandle#reject(String)} the linked * {@link Awakeable}. * * @see Awakeable diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 3fb4ee16b..d5ad45fb5 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -8,208 +8,235 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.*; -import dev.restate.sdk.common.function.ThrowingSupplier; -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback; -import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback; -import dev.restate.sdk.common.syscalls.Syscalls; -import java.nio.ByteBuffer; +import dev.restate.common.*; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.DurablePromiseKey; +import dev.restate.sdk.types.HandlerRequest; +import dev.restate.sdk.types.RetryPolicy; +import dev.restate.sdk.types.StateKey; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.TypeTag; import java.time.Duration; import java.util.Collection; -import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import org.jspecify.annotations.NonNull; -import org.jspecify.annotations.Nullable; class ContextImpl implements ObjectContext, WorkflowContext { - final Syscalls syscalls; + private final HandlerContext handlerContext; + private final Executor serviceExecutor; + private final SerdeFactory serdeFactory; - ContextImpl(Syscalls syscalls) { - this.syscalls = syscalls; + ContextImpl(HandlerContext handlerContext, Executor serviceExecutor, SerdeFactory serdeFactory) { + this.handlerContext = handlerContext; + this.serviceExecutor = serviceExecutor; + this.serdeFactory = serdeFactory; } @Override public String key() { - return syscalls.objectKey(); + return handlerContext.objectKey(); } @Override - public Request request() { - return syscalls.request(); + public HandlerRequest request() { + return handlerContext.request(); } @Override public Optional get(StateKey key) { - Deferred deferred = Util.blockOnSyscall(cb -> syscalls.get(key.name(), cb)); - - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - return Util.unwrapOptionalReadyResult(deferred.toResult()) - .map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs)); + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.get(key.name())), serviceExecutor) + .mapWithoutExecutor(opt -> opt.map(serdeFactory.create(key.serdeInfo())::deserialize)) + .await(); } @Override public Collection stateKeys() { - Deferred> deferred = Util.blockOnSyscall(syscalls::getKeys); - - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - return Util.unwrapResult(deferred.toResult()); + return Util.awaitCompletableFuture( + Util.awaitCompletableFuture(handlerContext.getKeys()).poll()); } @Override public void clear(StateKey key) { - Util.blockOnSyscall(cb -> syscalls.clear(key.name(), cb)); + Util.awaitCompletableFuture(handlerContext.clear(key.name())); } @Override public void clearAll() { - Util.blockOnSyscall(syscalls::clearAll); + Util.awaitCompletableFuture(handlerContext.clearAll()); } @Override public void set(StateKey key, @NonNull T value) { - Util.blockOnSyscall( - cb -> - syscalls.set( - key.name(), Util.serializeWrappingException(syscalls, key.serde(), value), cb)); + Util.awaitCompletableFuture( + handlerContext.set( + key.name(), + Util.executeOrFail( + handlerContext, serdeFactory.create(key.serdeInfo())::serialize, value))); } @Override - public Awaitable timer(Duration duration) { - Deferred result = Util.blockOnSyscall(cb -> syscalls.sleep(duration, cb)); - return Awaitable.single(syscalls, result); + public Awaitable timer(String name, Duration duration) { + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.timer(duration, name)), serviceExecutor); } @Override - public Awaitable call( - Target target, Serde inputSerde, Serde outputSerde, T parameter) { - ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); - Deferred result = Util.blockOnSyscall(cb -> syscalls.call(target, input, cb)); - return Awaitable.single(syscalls, result) - .map(bs -> Util.deserializeWrappingException(syscalls, outputSerde, bs)); + public CallAwaitable call(Request request) { + Slice input = + Util.executeOrFail( + handlerContext, + serdeFactory.create(request.requestTypeTag())::serialize, + request.request()); + HandlerContext.CallResult result = + Util.awaitCompletableFuture( + handlerContext.call( + request.target(), input, request.idempotencyKey(), request.headers().entrySet())); + + return new CallAwaitable<>( + handlerContext, + result + .callAsyncResult() + .map( + s -> + CompletableFuture.completedFuture( + serdeFactory.create(request.responseTypeTag()).deserialize(s))), + Awaitable.fromAsyncResult(result.invocationIdAsyncResult(), serviceExecutor)); } @Override - public void send(Target target, Serde inputSerde, T parameter) { - ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); - Util.blockOnSyscall(cb -> syscalls.send(target, input, null, cb)); + public InvocationHandle send(Request request) { + Slice input = + Util.executeOrFail( + handlerContext, + serdeFactory.create(request.requestTypeTag())::serialize, + request.request()); + + var invocationIdAwaitable = + Awaitable.fromAsyncResult( + Util.awaitCompletableFuture( + handlerContext.send( + request.target(), + input, + request.idempotencyKey(), + request.headers().entrySet(), + (request instanceof SendRequest sendRequest) + ? sendRequest.delay() + : null)), + serviceExecutor); + + return new BaseInvocationHandle<>( + Util.executeOrFail(handlerContext, () -> serdeFactory.create(request.responseTypeTag()))) { + @Override + public String invocationId() { + return invocationIdAwaitable.await(); + } + }; } @Override - public void send(Target target, Serde inputSerde, T parameter, Duration delay) { - ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); - Util.blockOnSyscall(cb -> syscalls.send(target, input, delay, cb)); + public InvocationHandle invocationHandle(String invocationId, TypeTag responseTypeTag) { + return new BaseInvocationHandle<>( + Util.executeOrFail(handlerContext, () -> serdeFactory.create(responseTypeTag))) { + @Override + public String invocationId() { + return invocationId; + } + }; } - @Override - public T run( - String name, Serde serde, RetryPolicy retryPolicy, ThrowingSupplier action) { - CompletableFuture> enterFut = new CompletableFuture<>(); - syscalls.enterSideEffectBlock( - name, - new EnterSideEffectSyscallCallback() { - @Override - public void onNotExecuted() { - enterFut.complete(new CompletableFuture<>()); - } - - @Override - public void onSuccess(ByteBuffer result) { - enterFut.complete(CompletableFuture.completedFuture(result)); - } - - @Override - public void onFailure(TerminalException t) { - enterFut.complete(CompletableFuture.failedFuture(t)); - } - - @Override - public void onCancel(Throwable t) { - enterFut.cancel(true); - } - }); - - // If a failure was stored, it's simply thrown here - CompletableFuture exitFut = Util.awaitCompletableFuture(enterFut); - if (exitFut.isDone()) { - // We already have a result, we don't need to execute the action - return Util.deserializeWrappingException( - syscalls, serde, Util.awaitCompletableFuture(exitFut)); + abstract class BaseInvocationHandle implements InvocationHandle { + private final Serde responseSerde; + + BaseInvocationHandle(Serde responseSerde) { + this.responseSerde = responseSerde; } - ExitSideEffectSyscallCallback exitCallback = - new ExitSideEffectSyscallCallback() { - @Override - public void onSuccess(ByteBuffer result) { - exitFut.complete(result); - } - - @Override - public void onFailure(TerminalException t) { - exitFut.completeExceptionally(t); - } - - @Override - public void onCancel(@Nullable Throwable t) { - exitFut.cancel(true); - } - }; - - T res = null; - Throwable failure = null; - try { - res = action.get(); - } catch (Throwable e) { - failure = e; + @Override + public void cancel() { + Util.awaitCompletableFuture(handlerContext.cancelInvocation(invocationId())); } - if (failure != null) { - syscalls.exitSideEffectBlockWithException(failure, retryPolicy, exitCallback); - } else { - syscalls.exitSideEffectBlock( - Util.serializeWrappingException(syscalls, serde, res), exitCallback); + @Override + public Awaitable attach() { + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.attachInvocation(invocationId())) + .map(s -> CompletableFuture.completedFuture(responseSerde.deserialize(s))), + serviceExecutor); } - return Util.deserializeWrappingException(syscalls, serde, Util.awaitCompletableFuture(exitFut)); + @Override + public Output getOutput() { + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.getInvocationOutput(invocationId())) + .map(o -> CompletableFuture.completedFuture(o.map(responseSerde::deserialize))), + serviceExecutor) + .await(); + } } @Override - public Awakeable awakeable(Serde serde) throws TerminalException { - // Retrieve the awakeable - Map.Entry> awakeable = Util.blockOnSyscall(syscalls::awakeable); + public Awaitable runAsync( + String name, TypeTag typeTag, RetryPolicy retryPolicy, ThrowingSupplier action) { + Serde serde = serdeFactory.create(typeTag); + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture( + handlerContext.submitRun( + name, + runCompleter -> + serviceExecutor.execute( + () -> { + Slice result; + try { + result = serde.serialize(action.get()); + } catch (Throwable e) { + runCompleter.proposeFailure(e, retryPolicy); + return; + } + runCompleter.proposeSuccess(result); + }))), + serviceExecutor) + .mapWithoutExecutor(serde::deserialize); + } - return new Awakeable<>(syscalls, awakeable.getValue(), serde, awakeable.getKey()); + @Override + public Awakeable awakeable(TypeTag typeTag) throws TerminalException { + Serde serde = serdeFactory.create(typeTag); + // Retrieve the awakeable + HandlerContext.Awakeable awakeable = Util.awaitCompletableFuture(handlerContext.awakeable()); + return new Awakeable<>(awakeable.asyncResult(), serviceExecutor, serde, awakeable.id()); } @Override public AwakeableHandle awakeableHandle(String id) { return new AwakeableHandle() { @Override - public void resolve(Serde serde, @NonNull T payload) { - Util.blockOnSyscall( - cb -> - syscalls.resolveAwakeable( - id, Util.serializeWrappingException(syscalls, serde, payload), cb)); + public void resolve(TypeTag serde, @NonNull T payload) { + Util.awaitCompletableFuture( + handlerContext.resolveAwakeable( + id, + Util.executeOrFail( + handlerContext, serdeFactory.create(serde)::serialize, payload))); } @Override public void reject(String reason) { - Util.blockOnSyscall(cb -> syscalls.rejectAwakeable(id, reason, cb)); + Util.awaitCompletableFuture( + handlerContext.rejectAwakeable(id, new TerminalException(reason))); } }; } @Override public RestateRandom random() { - return new RestateRandom(this.request().invocationId().toRandomSeed(), this.syscalls); + return new RestateRandom(this.request().invocationId().toRandomSeed()); } @Override @@ -217,22 +244,16 @@ public DurablePromise promise(DurablePromiseKey key) { return new DurablePromise<>() { @Override public Awaitable awaitable() { - Deferred result = Util.blockOnSyscall(cb -> syscalls.promise(key.name(), cb)); - return Awaitable.single(syscalls, result) - .map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs)); + AsyncResult result = Util.awaitCompletableFuture(handlerContext.promise(key.name())); + return Awaitable.fromAsyncResult(result, serviceExecutor) + .mapWithoutExecutor(serdeFactory.create(key.serdeInfo())::deserialize); } @Override public Output peek() { - Deferred deferred = - Util.blockOnSyscall(cb -> syscalls.peekPromise(key.name(), cb)); - - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - return Util.unwrapOutputReadyResult(deferred.toResult()) - .map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs)); + return Util.awaitCompletableFuture( + Util.awaitCompletableFuture(handlerContext.peekPromise(key.name())).poll()) + .map(serdeFactory.create(key.serdeInfo())::deserialize); } }; } @@ -242,31 +263,23 @@ public DurablePromiseHandle promiseHandle(DurablePromiseKey key) { return new DurablePromiseHandle<>() { @Override public void resolve(T payload) throws IllegalStateException { - Deferred deferred = - Util.blockOnSyscall( - cb -> - syscalls.resolvePromise( + Util.awaitCompletableFuture( + Util.awaitCompletableFuture( + handlerContext.resolvePromise( key.name(), - Util.serializeWrappingException(syscalls, key.serde(), payload), - cb)); - - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - Util.unwrapResult(deferred.toResult()); + Util.executeOrFail( + handlerContext, + serdeFactory.create(key.serdeInfo())::serialize, + payload))) + .poll()); } @Override public void reject(String reason) throws IllegalStateException { - Deferred deferred = - Util.blockOnSyscall(cb -> syscalls.rejectPromise(key.name(), reason, cb)); - - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - Util.unwrapResult(deferred.toResult()); + Util.awaitCompletableFuture( + Util.awaitCompletableFuture( + handlerContext.rejectPromise(key.name(), new TerminalException(reason))) + .poll()); } }; } diff --git a/sdk-api/src/main/java/dev/restate/sdk/DurablePromise.java b/sdk-api/src/main/java/dev/restate/sdk/DurablePromise.java index 5480fa904..285454587 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/DurablePromise.java +++ b/sdk-api/src/main/java/dev/restate/sdk/DurablePromise.java @@ -8,8 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.DurablePromiseKey; -import dev.restate.sdk.common.Output; +import dev.restate.common.Output; +import dev.restate.sdk.types.DurablePromiseKey; /** * A {@link DurablePromise} is a durable, distributed version of a {@link diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index f19e2dcd1..b93b40353 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -8,70 +8,73 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.function.ThrowingBiConsumer; -import dev.restate.sdk.common.function.ThrowingBiFunction; -import dev.restate.sdk.common.function.ThrowingConsumer; -import dev.restate.sdk.common.function.ThrowingFunction; -import dev.restate.sdk.common.syscalls.HandlerSpecification; -import dev.restate.sdk.common.syscalls.SyscallCallback; -import dev.restate.sdk.common.syscalls.Syscalls; +import dev.restate.common.Slice; +import dev.restate.common.function.ThrowingBiConsumer; +import dev.restate.common.function.ThrowingBiFunction; +import dev.restate.common.function.ThrowingConsumer; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; import io.opentelemetry.context.Scope; -import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jspecify.annotations.Nullable; -/** Adapter class for {@link dev.restate.sdk.common.syscalls.HandlerRunner} to use the Java API. */ +/** + * Adapter class for {@link dev.restate.sdk.endpoint.definition.HandlerRunner} to use the Java API. + */ public class HandlerRunner - implements dev.restate.sdk.common.syscalls.HandlerRunner { + implements dev.restate.sdk.endpoint.definition.HandlerRunner { private final ThrowingBiFunction runner; + private final SerdeFactory contextSerdeFactory; + private final Options options; private static final Logger LOG = LogManager.getLogger(HandlerRunner.class); - HandlerRunner(ThrowingBiFunction runner) { + HandlerRunner( + ThrowingBiFunction runner, + SerdeFactory contextSerdeFactory, + @Nullable Options options) { //noinspection unchecked this.runner = (ThrowingBiFunction) runner; + this.contextSerdeFactory = contextSerdeFactory; + this.options = (options != null) ? options : Options.DEFAULT; } @Override - public void run( - HandlerSpecification handlerSpecification, - Syscalls syscalls, - @Nullable Options options, - SyscallCallback callback) { - if (options == null) { - options = Options.DEFAULT; - } + public CompletableFuture run( + HandlerContext handlerContext, Serde requestSerde, Serde responseSerde) { + CompletableFuture returnFuture = new CompletableFuture<>(); // Wrap the executor for setting/unsetting the thread local - Options finalOptions = options; - Executor wrapped = + Executor serviceExecutor = runnable -> - finalOptions.executor.execute( + options.executor.execute( () -> { - SYSCALLS_THREAD_LOCAL.set(syscalls); - try (Scope ignored = syscalls.request().otelContext().makeCurrent()) { + HANDLER_CONTEXT_THREAD_LOCAL.set(handlerContext); + try (Scope ignored = handlerContext.request().otelContext().makeCurrent()) { runnable.run(); } finally { - SYSCALLS_THREAD_LOCAL.remove(); + HANDLER_CONTEXT_THREAD_LOCAL.remove(); } }); - wrapped.execute( + serviceExecutor.execute( () -> { // Any context switching, if necessary, will be done by ResolvedEndpointHandler - Context ctx = new ContextImpl(syscalls); + Context ctx = new ContextImpl(handlerContext, serviceExecutor, contextSerdeFactory); // Parse input REQ req; try { - req = - handlerSpecification.getRequestSerde().deserialize(syscalls.request().bodyBuffer()); + req = requestSerde.deserialize(handlerContext.request().body()); } catch (Throwable e) { LOG.warn("Cannot deserialize input", e); - callback.onCancel( + returnFuture.completeExceptionally( new TerminalException( TerminalException.BAD_REQUEST_CODE, "Cannot deserialize input: " + e.getMessage())); @@ -83,17 +86,17 @@ public void run( try { res = this.runner.apply(ctx, req); } catch (Throwable e) { - callback.onCancel(e); + returnFuture.completeExceptionally(e); return; } // Serialize output - ByteBuffer serializedResult; + Slice serializedResult; try { - serializedResult = handlerSpecification.getResponseSerde().serializeToByteBuffer(res); + serializedResult = responseSerde.serialize(res); } catch (Throwable e) { LOG.warn("Cannot serialize output", e); - callback.onCancel( + returnFuture.completeExceptionally( new TerminalException( TerminalException.INTERNAL_SERVER_ERROR_CODE, "Cannot serialize output: " + e.getMessage())); @@ -101,38 +104,52 @@ public void run( } // Complete callback - callback.onSuccess(serializedResult); + returnFuture.complete(serializedResult); }); + + return returnFuture; } public static HandlerRunner of( - ThrowingBiFunction runner) { - return new HandlerRunner<>(runner); + ThrowingBiFunction runner, + SerdeFactory contextSerdeFactory, + @Nullable Options options) { + return new HandlerRunner<>(runner, contextSerdeFactory, options); } @SuppressWarnings("unchecked") public static HandlerRunner of( - ThrowingFunction runner) { - return new HandlerRunner<>((context, o) -> runner.apply((CTX) context)); + ThrowingFunction runner, + SerdeFactory contextSerdeFactory, + @Nullable Options options) { + return new HandlerRunner<>( + (context, o) -> runner.apply((CTX) context), contextSerdeFactory, options); } @SuppressWarnings("unchecked") public static HandlerRunner of( - ThrowingBiConsumer runner) { + ThrowingBiConsumer runner, + SerdeFactory contextSerdeFactory, + @Nullable Options options) { return new HandlerRunner<>( (context, o) -> { runner.accept((CTX) context, o); return null; - }); + }, + contextSerdeFactory, + options); } @SuppressWarnings("unchecked") - public static HandlerRunner of(ThrowingConsumer runner) { + public static HandlerRunner of( + ThrowingConsumer runner, SerdeFactory contextSerdeFactory, @Nullable Options options) { return new HandlerRunner<>( (ctx, o) -> { runner.accept((CTX) ctx); return null; - }); + }, + contextSerdeFactory, + options); } public static class Options { @@ -144,8 +161,12 @@ public static class Options { * You can run on virtual threads by using the executor {@code * Executors.newVirtualThreadPerTaskExecutor()}. */ - public Options(Executor executor) { + private Options(Executor executor) { this.executor = executor; } + + public static Options withExecutor(Executor executor) { + return new Options(executor); + } } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java new file mode 100644 index 000000000..83c62706f --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java @@ -0,0 +1,29 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.Output; + +public interface InvocationHandle { + /** + * @return the invocation id of this invocation + */ + String invocationId(); + + /** Cancel this invocation. */ + void cancel(); + + /** Attach to this invocation. This will wait for the invocation to complete */ + Awaitable attach(); + + /** + * @return the output of this invocation, if present. + */ + Output getOutput(); +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/JsonSerdes.java b/sdk-api/src/main/java/dev/restate/sdk/JsonSerdes.java deleted file mode 100644 index cb018d415..000000000 --- a/sdk-api/src/main/java/dev/restate/sdk/JsonSerdes.java +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonToken; -import dev.restate.sdk.common.RichSerde; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.function.ThrowingBiConsumer; -import dev.restate.sdk.common.function.ThrowingFunction; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.Map; -import org.jspecify.annotations.NonNull; - -/** - * Collection of common serializers/deserializers. - * - *

To ser/de POJOs using JSON, you can use the module {@code sdk-serde-jackson}. - */ -public abstract class JsonSerdes { - - private JsonSerdes() {} - - /** {@link Serde} for {@link String}. This writes and reads {@link String} as JSON value. */ - public static final Serde<@NonNull String> STRING = - usingJackson( - "string", - JsonGenerator::writeString, - p -> { - if (p.nextToken() != JsonToken.VALUE_STRING) { - throw new IllegalStateException( - "Expecting token " + JsonToken.VALUE_STRING + ", got " + p.getCurrentToken()); - } - return p.getText(); - }); - - /** {@link Serde} for {@link Boolean}. This writes and reads {@link Boolean} as JSON value. */ - public static final Serde<@NonNull Boolean> BOOLEAN = - usingJackson( - "boolean", - JsonGenerator::writeBoolean, - p -> { - p.nextToken(); - return p.getBooleanValue(); - }); - - /** {@link Serde} for {@link Byte}. This writes and reads {@link Byte} as JSON value. */ - public static final Serde<@NonNull Byte> BYTE = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getByteValue(); - }); - - /** {@link Serde} for {@link Short}. This writes and reads {@link Short} as JSON value. */ - public static final Serde<@NonNull Short> SHORT = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getShortValue(); - }); - - /** {@link Serde} for {@link Integer}. This writes and reads {@link Integer} as JSON value. */ - public static final Serde<@NonNull Integer> INT = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getIntValue(); - }); - - /** {@link Serde} for {@link Long}. This writes and reads {@link Long} as JSON value. */ - public static final Serde<@NonNull Long> LONG = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getLongValue(); - }); - - /** {@link Serde} for {@link Float}. This writes and reads {@link Float} as JSON value. */ - public static final Serde<@NonNull Float> FLOAT = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getFloatValue(); - }); - - /** {@link Serde} for {@link Double}. This writes and reads {@link Double} as JSON value. */ - public static final Serde<@NonNull Double> DOUBLE = - usingJackson( - "number", - JsonGenerator::writeNumber, - p -> { - p.nextToken(); - return p.getDoubleValue(); - }); - - // --- Helpers for jackson-core - - private static final JsonFactory JSON_FACTORY = new JsonFactory(); - - private static Serde usingJackson( - String type, - ThrowingBiConsumer serializer, - ThrowingFunction deserializer) { - return new RichSerde<>() { - - @Override - public Object jsonSchema() { - return Map.of("type", type); - } - - @Override - public byte[] serialize(T value) { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - try (JsonGenerator gen = JSON_FACTORY.createGenerator(outputStream)) { - serializer.asBiConsumer().accept(gen, value); - } catch (IOException e) { - throw new RuntimeException("Cannot create JsonGenerator", e); - } - return outputStream.toByteArray(); - } - - @Override - public T deserialize(byte[] value) { - ByteArrayInputStream inputStream = new ByteArrayInputStream(value); - try (JsonParser parser = JSON_FACTORY.createParser(inputStream)) { - return deserializer.asFunction().apply(parser); - } catch (IOException e) { - throw new RuntimeException("Cannot create JsonGenerator", e); - } - } - - @Override - public String contentType() { - return "application/json"; - } - }; - } -} diff --git a/sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java b/sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java index fdba8dc56..c2c090f82 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java @@ -8,7 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.*; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; import org.jspecify.annotations.NonNull; /** diff --git a/sdk-api/src/main/java/dev/restate/sdk/PreviewContext.java b/sdk-api/src/main/java/dev/restate/sdk/PreviewContext.java index 755aa9dce..c72274d30 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/PreviewContext.java +++ b/sdk-api/src/main/java/dev/restate/sdk/PreviewContext.java @@ -8,12 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.function.ThrowingRunnable; -import dev.restate.sdk.common.function.ThrowingSupplier; - /** * Preview of new context features. Please note that the methods in this class may break between * minor releases, use with caution! @@ -21,44 +15,4 @@ *

In order to use these methods, you MUST enable the preview context, through the * endpoint builders using {@code enablePreviewContext()}. */ -public class PreviewContext { - - /** - * @deprecated Use {@link Context#run(String, Serde, RetryPolicy, ThrowingSupplier)} - */ - @Deprecated(since = "1.2", forRemoval = true) - public static T run( - Context ctx, String name, Serde serde, RetryPolicy retryPolicy, ThrowingSupplier action) - throws TerminalException { - return ctx.run(name, serde, retryPolicy, action); - } - - /** - * @deprecated Use {@link Context#run(String, RetryPolicy, ThrowingRunnable)} - */ - @Deprecated(since = "1.2", forRemoval = true) - public static void run( - Context ctx, String name, RetryPolicy retryPolicy, ThrowingRunnable runnable) - throws TerminalException { - ctx.run(name, retryPolicy, runnable); - } - - /** - * @deprecated Use {@link Context#run(Serde, RetryPolicy, ThrowingSupplier)} - */ - @Deprecated(since = "1.2", forRemoval = true) - public static T run( - Context ctx, Serde serde, RetryPolicy retryPolicy, ThrowingSupplier action) - throws TerminalException { - return ctx.run(serde, retryPolicy, action); - } - - /** - * @deprecated Use {@link Context#run(RetryPolicy, ThrowingRunnable)} - */ - @Deprecated(since = "1.2", forRemoval = true) - public static void run(Context ctx, RetryPolicy retryPolicy, ThrowingRunnable runnable) - throws TerminalException { - ctx.run(retryPolicy, runnable); - } -} +public class PreviewContext {} diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java index 5c06ff5e0..c8656813a 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java +++ b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java @@ -8,10 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.InvocationId; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.function.ThrowingSupplier; -import dev.restate.sdk.common.syscalls.Syscalls; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.types.InvocationId; import java.util.Random; import java.util.UUID; @@ -21,18 +19,17 @@ * *

This instance is useful to generate identifiers, idempotency keys, and for uniform sampling * from a set of options. If a cryptographically secure value is needed, please generate that - * externally using {@link ObjectContext#run(Serde, ThrowingSupplier)}. + * externally using {@link Context#run(String, Class, ThrowingSupplier)}. * - *

You MUST NOT use this object inside a {@link ObjectContext#run(Serde, ThrowingSupplier)}. + *

You MUST NOT use this object inside a {@link Context#run(String, Class, + * ThrowingSupplier)}/{@link Context#runAsync(String, Class, ThrowingSupplier)}. */ public class RestateRandom extends Random { - private final Syscalls syscalls; private boolean seedInitialized = false; - RestateRandom(long randomSeed, Syscalls syscalls) { + RestateRandom(long randomSeed) { super(randomSeed); - this.syscalls = syscalls; } /** @@ -56,10 +53,6 @@ public UUID nextUUID() { @Override protected int next(int bits) { - if (this.syscalls.isInsideSideEffect()) { - throw new IllegalStateException("You can't use RestateRandom inside ctx.run!"); - } - return super.next(bits); } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Select.java b/sdk-api/src/main/java/dev/restate/sdk/Select.java new file mode 100644 index 000000000..2317fdf52 --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/Select.java @@ -0,0 +1,85 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.TerminalException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.stream.Collectors; + +public final class Select extends Awaitable { + + private final List> awaitables; + private AsyncResult asyncResult; + + Select() { + this.awaitables = new ArrayList<>(); + } + + public static Select select() { + return new Select<>(); + } + + public Select or(Awaitable awaitable) { + this.awaitables.add(awaitable); + this.asyncResult = null; + return this; + } + + public Select when(Awaitable awaitable, ThrowingFunction successMapper) { + this.awaitables.add(awaitable.map(successMapper)); + this.asyncResult = null; + return this; + } + + public Select when( + Awaitable awaitable, + ThrowingFunction successMapper, + ThrowingFunction failureMapper) { + this.awaitables.add(awaitable.map(successMapper, failureMapper)); + this.asyncResult = null; + return this; + } + + @Override + protected AsyncResult asyncResult() { + if (this.asyncResult == null) { + recreateAsyncResult(); + } + return this.asyncResult; + } + + @Override + protected Executor serviceExecutor() { + checkNonEmpty(); + return awaitables.get(0).serviceExecutor(); + } + + private void checkNonEmpty() { + if (awaitables.isEmpty()) { + throw new IllegalArgumentException("Select is empty"); + } + } + + private void recreateAsyncResult() { + checkNonEmpty(); + List> awaitables = List.copyOf(this.awaitables); + List> ars = + awaitables.stream().map(Awaitable::asyncResult).collect(Collectors.toList()); + HandlerContext ctx = ars.get(0).ctx(); + //noinspection unchecked + this.asyncResult = + ctx.createAnyAsyncResult(ars).map(i -> (CompletableFuture) ars.get(i).poll()); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/SharedObjectContext.java b/sdk-api/src/main/java/dev/restate/sdk/SharedObjectContext.java index cb5990888..7e3122acb 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/SharedObjectContext.java +++ b/sdk-api/src/main/java/dev/restate/sdk/SharedObjectContext.java @@ -8,8 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; import java.util.Collection; import java.util.Optional; diff --git a/sdk-api/src/main/java/dev/restate/sdk/SharedWorkflowContext.java b/sdk-api/src/main/java/dev/restate/sdk/SharedWorkflowContext.java index e4f266ef6..84246b3d5 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/SharedWorkflowContext.java +++ b/sdk-api/src/main/java/dev/restate/sdk/SharedWorkflowContext.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.DurablePromiseKey; +import dev.restate.sdk.types.DurablePromiseKey; /** * This interface can be used only within shared handlers of workflow. It extends {@link Context} diff --git a/sdk-api/src/main/java/dev/restate/sdk/Util.java b/sdk-api/src/main/java/dev/restate/sdk/Util.java index 3854bde4c..13fa1a9fe 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Util.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Util.java @@ -8,93 +8,54 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.Output; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.function.ThrowingFunction; -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Result; -import dev.restate.sdk.common.syscalls.SyscallCallback; -import dev.restate.sdk.common.syscalls.Syscalls; -import java.nio.ByteBuffer; -import java.util.Optional; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.AbortedExecutionException; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; -import java.util.function.Consumer; +import org.jspecify.annotations.NonNull; class Util { private Util() {} - static T blockOnResolve(Syscalls syscalls, Deferred deferred) { - if (!deferred.isCompleted()) { - Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); - } - - return Util.unwrapResult(deferred.toResult()); - } - - static T awaitCompletableFuture(CompletableFuture future) { + static R executeOrFail(HandlerContext handlerContext, ThrowingFunction fn, T t) { try { - return future.get(); - } catch (InterruptedException | CancellationException e) { + return fn.apply(t); + } catch (Throwable e) { + handlerContext.fail(e); AbortedExecutionException.sneakyThrow(); - return null; // Previous statement throws an exception - } catch (ExecutionException e) { - throw (RuntimeException) e.getCause(); - } - } - - static T blockOnSyscall(Consumer> syscallExecutor) { - CompletableFuture fut = new CompletableFuture<>(); - syscallExecutor.accept(SyscallCallback.completingFuture(fut)); - return Util.awaitCompletableFuture(fut); - } - - static T unwrapResult(Result res) { - if (res.isSuccess()) { - return res.getValue(); - } - throw res.getFailure(); - } - - static Optional unwrapOptionalReadyResult(Result res) { - if (!res.isSuccess()) { - throw res.getFailure(); - } - if (res.isEmpty()) { - return Optional.empty(); - } - return Optional.of(res.getValue()); - } - - static Output unwrapOutputReadyResult(Result res) { - if (!res.isSuccess()) { - throw res.getFailure(); - } - if (res.isEmpty()) { - return Output.notReady(); + return null; } - return Output.ready(res.getValue()); } - static R executeMappingException(Syscalls syscalls, ThrowingFunction fn, T t) { + static R executeOrFail(HandlerContext handlerContext, ThrowingSupplier fn) { try { - return fn.apply(t); + return fn.get(); } catch (Throwable e) { - syscalls.fail(e); + handlerContext.fail(e); AbortedExecutionException.sneakyThrow(); return null; } } - static ByteBuffer serializeWrappingException(Syscalls syscalls, Serde serde, T value) { - return executeMappingException(syscalls, serde::serializeToByteBuffer, value); + static @NonNull T awaitCompletableFuture(CompletableFuture future) { + try { + return future.get(); + } catch (InterruptedException | CancellationException e) { + AbortedExecutionException.sneakyThrow(); + return null; // Previous statement throws an exception + } catch (ExecutionException | CompletionException e) { + sneakyThrow(e.getCause()); + return null; // Previous statement throws an exception + } } - static T deserializeWrappingException( - Syscalls syscalls, Serde serde, ByteBuffer byteString) { - return executeMappingException(syscalls, serde::deserialize, byteString); + @SuppressWarnings("unchecked") + public static void sneakyThrow(Throwable e) throws E { + throw (E) e; } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/TestSerdesTest.java b/sdk-api/src/test/java/dev/restate/sdk/TestSerdesTest.java deleted file mode 100644 index 94d3838a8..000000000 --- a/sdk-api/src/test/java/dev/restate/sdk/TestSerdesTest.java +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import dev.restate.sdk.common.Serde; -import java.nio.charset.StandardCharsets; -import java.util.Random; -import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -class TestSerdesTest { - - private static Arguments roundtripCase(Serde serde, T value) { - return Arguments.of( - (value != null ? value.getClass().getSimpleName() : "Null") + ": " + value, serde, value); - } - - private static Stream roundtrip() { - var random = new Random(); - return Stream.of( - roundtripCase(JsonSerdes.STRING, ""), - roundtripCase(JsonSerdes.STRING, "Francesco1234"), - roundtripCase(JsonSerdes.STRING, "😀"), - roundtripCase(JsonSerdes.BOOLEAN, true), - roundtripCase(JsonSerdes.BOOLEAN, false), - roundtripCase(JsonSerdes.BYTE, Byte.MIN_VALUE), - roundtripCase(JsonSerdes.BYTE, Byte.MAX_VALUE), - roundtripCase(JsonSerdes.BYTE, (byte) random.nextInt()), - roundtripCase(JsonSerdes.SHORT, Short.MIN_VALUE), - roundtripCase(JsonSerdes.SHORT, Short.MAX_VALUE), - roundtripCase(JsonSerdes.SHORT, (short) random.nextInt()), - roundtripCase(JsonSerdes.INT, Integer.MIN_VALUE), - roundtripCase(JsonSerdes.INT, Integer.MAX_VALUE), - roundtripCase(JsonSerdes.INT, random.nextInt()), - roundtripCase(JsonSerdes.LONG, Long.MIN_VALUE), - roundtripCase(JsonSerdes.LONG, Long.MAX_VALUE), - roundtripCase(JsonSerdes.LONG, random.nextLong()), - roundtripCase(JsonSerdes.FLOAT, Float.MIN_VALUE), - roundtripCase(JsonSerdes.FLOAT, Float.MAX_VALUE), - roundtripCase(JsonSerdes.FLOAT, random.nextFloat()), - roundtripCase(JsonSerdes.DOUBLE, Double.MIN_VALUE), - roundtripCase(JsonSerdes.DOUBLE, Double.MAX_VALUE), - roundtripCase(JsonSerdes.DOUBLE, random.nextDouble())); - } - - @ParameterizedTest(name = "{0}") - @MethodSource - void roundtrip(String testName, Serde serde, T value) throws Throwable { - assertThat(serde.deserialize(serde.serialize(value))).isEqualTo(value); - } - - private static Stream failDeserialization() { - return Stream.of( - Arguments.of("String unquoted", JsonSerdes.STRING, "my string"), - Arguments.of("Not a boolean", JsonSerdes.BOOLEAN, "something"), - Arguments.of("Not a byte", JsonSerdes.BYTE, "something"), - Arguments.of("Not a short", JsonSerdes.SHORT, "something"), - Arguments.of("Not a int", JsonSerdes.INT, "something"), - Arguments.of("Not a long", JsonSerdes.LONG, "something"), - Arguments.of("Not a float", JsonSerdes.FLOAT, "something"), - Arguments.of("Not a double", JsonSerdes.DOUBLE, "something")); - } - - @ParameterizedTest(name = "{0}") - @MethodSource - void failDeserialization(String testName, Serde serde, String value) throws Throwable { - assertThatThrownBy(() -> serde.deserialize(value.getBytes(StandardCharsets.UTF_8))).isNotNull(); - } -} diff --git a/sdk-common/build.gradle.kts b/sdk-common/build.gradle.kts index da4b95e93..0d3887a08 100644 --- a/sdk-common/build.gradle.kts +++ b/sdk-common/build.gradle.kts @@ -13,8 +13,10 @@ dependencies { compileOnly(libs.jspecify) api(libs.opentelemetry.api) + api(project(":common")) implementation(libs.jackson.core) + implementation(libs.log4j.api) testImplementation(libs.junit.jupiter) testImplementation(libs.assertj) diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java new file mode 100644 index 000000000..36da5768c --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/CustomSerdeFactory.java @@ -0,0 +1,21 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.annotation; + +import dev.restate.serde.SerdeFactory; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.SOURCE) +public @interface CustomSerdeFactory { + Class value(); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java index 327dbd866..c6036a60e 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/Service.java @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; /** * Annotation to define a class/interface as Restate Service. This triggers the code generation of - * the related Client class and the {@link - * dev.restate.sdk.common.syscalls.ServiceDefinitionFactory}. + * the related Client class and the {@link ServiceDefinitionFactory}. */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java index e1c1cb609..93edb8d53 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/VirtualObject.java @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; /** * Annotation to define a class/interface as Restate VirtualObject. This triggers the code - * generation of the related Client class and the {@link - * dev.restate.sdk.common.syscalls.ServiceDefinitionFactory}. + * generation of the related Client class and the {@link ServiceDefinitionFactory}. */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java index cb20899f3..eb9b5e54c 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/Workflow.java @@ -8,13 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.annotation; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory; import java.lang.annotation.*; /** * Annotation to define a class/interface as Restate Workflow. This triggers the code generation of - * the related Client class and the {@link - * dev.restate.sdk.common.syscalls.ServiceDefinitionFactory}. When defining a class/interface as - * workflow, you must annotate one of its methods too as {@link Workflow}. + * the related Client class and the {@link ServiceDefinitionFactory}. When defining a + * class/interface as workflow, you must annotate one of its methods too as {@link Workflow}. */ @Target({ElementType.METHOD, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/CallRequestOptions.java b/sdk-common/src/main/java/dev/restate/sdk/client/CallRequestOptions.java deleted file mode 100644 index 1870de676..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/client/CallRequestOptions.java +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; - -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -public final class CallRequestOptions extends RequestOptions { - - public static final CallRequestOptions DEFAULT = new CallRequestOptions(); - - private final String idempotencyKey; - - public CallRequestOptions() { - this(new HashMap<>(), null); - } - - public CallRequestOptions(Map additionalHeaders, String idempotencyKey) { - super(additionalHeaders); - this.idempotencyKey = idempotencyKey; - } - - public CallRequestOptions withIdempotency(String idempotencyKey) { - return new CallRequestOptions(new HashMap<>(this.additionalHeaders), idempotencyKey); - } - - @Override - public CallRequestOptions withHeader(String name, String value) { - CallRequestOptions newOptions = this.copy(); - newOptions.additionalHeaders.put(name, value); - return newOptions; - } - - @Override - public CallRequestOptions withHeaders(Map additionalHeaders) { - CallRequestOptions newOptions = this.copy(); - newOptions.additionalHeaders.putAll(additionalHeaders); - return newOptions; - } - - public String getIdempotencyKey() { - return idempotencyKey; - } - - @Override - public CallRequestOptions copy() { - return new CallRequestOptions(new HashMap<>(this.additionalHeaders), this.idempotencyKey); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - if (!super.equals(o)) return false; - - CallRequestOptions that = (CallRequestOptions) o; - return Objects.equals(idempotencyKey, that.idempotencyKey); - } - - @Override - public int hashCode() { - int result = super.hashCode(); - result = 31 * result + Objects.hashCode(idempotencyKey); - return result; - } - - @Override - public String toString() { - return "CallRequestOptions{" - + "idempotencyKey='" - + idempotencyKey - + '\'' - + ", additionalHeaders=" - + additionalHeaders - + '}'; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/Client.java b/sdk-common/src/main/java/dev/restate/sdk/client/Client.java deleted file mode 100644 index 9d799571d..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/client/Client.java +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; - -import dev.restate.sdk.common.Output; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.Target; -import java.net.http.HttpClient; -import java.time.Duration; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; -import org.jspecify.annotations.NonNull; -import org.jspecify.annotations.Nullable; - -public interface Client { - - CompletableFuture callAsync( - Target target, Serde reqSerde, Serde resSerde, Req req, RequestOptions options); - - default CompletableFuture callAsync( - Target target, Serde reqSerde, Serde resSerde, Req req) { - return callAsync(target, reqSerde, resSerde, req, RequestOptions.DEFAULT); - } - - default Res call( - Target target, Serde reqSerde, Serde resSerde, Req req, RequestOptions options) - throws IngressException { - try { - return callAsync(target, reqSerde, resSerde, req, options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Res call(Target target, Serde reqSerde, Serde resSerde, Req req) - throws IngressException { - return call(target, reqSerde, resSerde, req, RequestOptions.DEFAULT); - } - - CompletableFuture sendAsync( - Target target, - Serde reqSerde, - Req req, - @Nullable Duration delay, - RequestOptions options); - - default CompletableFuture sendAsync( - Target target, Serde reqSerde, Req req, @Nullable Duration delay) { - return sendAsync(target, reqSerde, req, delay, RequestOptions.DEFAULT); - } - - default CompletableFuture sendAsync( - Target target, Serde reqSerde, Req req) { - return sendAsync(target, reqSerde, req, null, RequestOptions.DEFAULT); - } - - default SendResponse send( - Target target, Serde reqSerde, Req req, @Nullable Duration delay, RequestOptions options) - throws IngressException { - try { - return sendAsync(target, reqSerde, req, delay, options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default SendResponse send( - Target target, Serde reqSerde, Req req, @Nullable Duration delay) - throws IngressException { - return send(target, reqSerde, req, delay, RequestOptions.DEFAULT); - } - - default SendResponse send(Target target, Serde reqSerde, Req req) - throws IngressException { - return send(target, reqSerde, req, null, RequestOptions.DEFAULT); - } - - /** - * Create a new {@link AwakeableHandle} for the provided identifier. You can use it to {@link - * AwakeableHandle#resolve(Serde, Object)} or {@link AwakeableHandle#reject(String)} an Awakeable - * from the ingress. - */ - AwakeableHandle awakeableHandle(String id); - - /** - * This class represents a handle to an Awakeable. It can be used to complete awakeables from the - * ingress - */ - interface AwakeableHandle { - /** Same as {@link #resolve(Serde, Object)} but async with options. */ - CompletableFuture resolveAsync( - Serde serde, @NonNull T payload, RequestOptions options); - - /** Same as {@link #resolve(Serde, Object)} but async. */ - default CompletableFuture resolveAsync(Serde serde, @NonNull T payload) { - return resolveAsync(serde, payload, RequestOptions.DEFAULT); - } - - /** Same as {@link #resolve(Serde, Object)} with options. */ - default void resolve(Serde serde, @NonNull T payload, RequestOptions options) { - try { - resolveAsync(serde, payload, options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - /** - * Complete with success the Awakeable. - * - * @param serde used to serialize the Awakeable result payload. - * @param payload the result payload. MUST NOT be null. - */ - default void resolve(Serde serde, @NonNull T payload) { - this.resolve(serde, payload, RequestOptions.DEFAULT); - } - - /** Same as {@link #reject(String)} but async with options. */ - CompletableFuture rejectAsync(String reason, RequestOptions options); - - /** Same as {@link #reject(String)} but async. */ - default CompletableFuture rejectAsync(String reason) { - return rejectAsync(reason, RequestOptions.DEFAULT); - } - - /** Same as {@link #reject(String)} with options. */ - default void reject(String reason, RequestOptions options) { - try { - rejectAsync(reason, options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - /** - * Complete with failure the Awakeable. - * - * @param reason the rejection reason. MUST NOT be null. - */ - default void reject(String reason) { - this.reject(reason, RequestOptions.DEFAULT); - } - } - - InvocationHandle invocationHandle(String invocationId, Serde resSerde); - - interface InvocationHandle { - - String invocationId(); - - CompletableFuture attachAsync(RequestOptions options); - - default CompletableFuture attachAsync() { - return attachAsync(RequestOptions.DEFAULT); - } - - default Res attach(RequestOptions options) throws IngressException { - try { - return attachAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Res attach() throws IngressException { - return attach(RequestOptions.DEFAULT); - } - - CompletableFuture> getOutputAsync(RequestOptions options); - - default CompletableFuture> getOutputAsync() { - return getOutputAsync(RequestOptions.DEFAULT); - } - - default Output getOutput(RequestOptions options) throws IngressException { - try { - return getOutputAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Output getOutput() throws IngressException { - return getOutput(RequestOptions.DEFAULT); - } - } - - IdempotentInvocationHandle idempotentInvocationHandle( - Target target, String idempotencyKey, Serde resSerde); - - interface IdempotentInvocationHandle { - - CompletableFuture attachAsync(RequestOptions options); - - default CompletableFuture attachAsync() { - return attachAsync(RequestOptions.DEFAULT); - } - - default Res attach(RequestOptions options) throws IngressException { - try { - return attachAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Res attach() throws IngressException { - return attach(RequestOptions.DEFAULT); - } - - CompletableFuture> getOutputAsync(RequestOptions options); - - default CompletableFuture> getOutputAsync() { - return getOutputAsync(RequestOptions.DEFAULT); - } - - default Output getOutput(RequestOptions options) throws IngressException { - try { - return getOutputAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Output getOutput() throws IngressException { - return getOutput(RequestOptions.DEFAULT); - } - } - - WorkflowHandle workflowHandle( - String workflowName, String workflowId, Serde resSerde); - - interface WorkflowHandle { - CompletableFuture attachAsync(RequestOptions options); - - default CompletableFuture attachAsync() { - return attachAsync(RequestOptions.DEFAULT); - } - - default Res attach(RequestOptions options) throws IngressException { - try { - return attachAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Res attach() throws IngressException { - return attach(RequestOptions.DEFAULT); - } - - CompletableFuture> getOutputAsync(RequestOptions options); - - default CompletableFuture> getOutputAsync() { - return getOutputAsync(RequestOptions.DEFAULT); - } - - default Output getOutput(RequestOptions options) throws IngressException { - try { - return getOutputAsync(options).join(); - } catch (CompletionException e) { - if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } - throw new RuntimeException(e.getCause()); - } - } - - default Output getOutput() throws IngressException { - return getOutput(RequestOptions.DEFAULT); - } - } - - static Client connect(String baseUri) { - return connect(baseUri, Collections.emptyMap()); - } - - static Client connect(String baseUri, Map headers) { - return new DefaultClient(HttpClient.newHttpClient(), baseUri, headers); - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java b/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java deleted file mode 100644 index c1e71749b..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultClient.java +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonToken; -import dev.restate.sdk.common.Output; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.Target; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.net.URLEncoder; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.function.BiFunction; -import org.jetbrains.annotations.NotNull; -import org.jspecify.annotations.NonNull; - -public class DefaultClient implements Client { - - private static final JsonFactory JSON_FACTORY = new JsonFactory(); - - private final HttpClient httpClient; - private final URI baseUri; - private final Map headers; - - DefaultClient(HttpClient httpClient, String baseUri, Map headers) { - this.httpClient = httpClient; - this.baseUri = URI.create(baseUri); - this.headers = headers; - } - - @Override - public CompletableFuture callAsync( - Target target, - Serde reqSerde, - Serde resSerde, - Req req, - RequestOptions requestOptions) { - HttpRequest request = prepareHttpRequest(target, false, reqSerde, req, null, requestOptions); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", request, throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return resSerde.deserialize(response.body()); - } catch (Exception e) { - throw new IngressException("Cannot deserialize the response", response, e); - } - }); - } - - @Override - public CompletableFuture sendAsync( - Target target, Serde reqSerde, Req req, Duration delay, RequestOptions options) { - HttpRequest request = prepareHttpRequest(target, true, reqSerde, req, delay, options); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle( - (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", request, throwable); - } - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - Map fields; - try { - fields = - findStringFieldsInJsonObject( - new ByteArrayInputStream(response.body()), "invocationId", "status"); - } catch (Exception e) { - throw new IngressException("Cannot deserialize the response", response, e); - } - - String statusField = fields.get("status"); - SendResponse.SendStatus status; - if ("Accepted".equals(statusField)) { - status = SendResponse.SendStatus.ACCEPTED; - } else if ("PreviouslyAccepted".equals(statusField)) { - status = SendResponse.SendStatus.PREVIOUSLY_ACCEPTED; - } else { - throw new IngressException( - "Cannot deserialize the response status, got " + statusField, response); - } - - return new SendResponse(status, fields.get("invocationId")); - }); - } - - @Override - public AwakeableHandle awakeableHandle(String id) { - return new AwakeableHandle() { - private Void handleVoidResponse( - HttpRequest request, HttpResponse response, Throwable throwable) { - if (throwable != null) { - throw new IngressException("Error when executing the request", request, throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - return null; - } - - @Override - public CompletableFuture resolveAsync( - Serde serde, @NonNull T payload, RequestOptions options) { - // Prepare request - var reqBuilder = - prepareBuilder(options).uri(baseUri.resolve("/restate/awakeables/" + id + "/resolve")); - - // Add content-type - if (serde.contentType() != null) { - reqBuilder.header("content-type", serde.contentType()); - } - - // Build and Send request - HttpRequest request = - reqBuilder - .POST(HttpRequest.BodyPublishers.ofByteArray(serde.serialize(payload))) - .build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle((res, t) -> this.handleVoidResponse(request, res, t)); - } - - @Override - public CompletableFuture rejectAsync(String reason, RequestOptions options) { - // Prepare request - var reqBuilder = - HttpRequest.newBuilder() - .uri(baseUri.resolve("/restate/awakeables/" + id + "/reject")) - .header("content-type", "text-plain"); - - // Add headers - headers.forEach(reqBuilder::header); - options.getAdditionalHeaders().forEach(reqBuilder::header); - - // Build and Send request - HttpRequest request = reqBuilder.POST(HttpRequest.BodyPublishers.ofString(reason)).build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle((res, t) -> this.handleVoidResponse(request, res, t)); - } - }; - } - - @Override - public InvocationHandle invocationHandle(String invocationId, Serde resSerde) { - return new InvocationHandle<>() { - @Override - public String invocationId() { - return invocationId; - } - - @Override - public CompletableFuture attachAsync(RequestOptions options) { - // Prepare request - var reqBuilder = - prepareBuilder(options) - .uri(baseUri.resolve("/restate/invocation/" + invocationId + "/attach")); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleAttachResponse(request, resSerde)); - } - - @Override - public CompletableFuture> getOutputAsync(RequestOptions options) { - // Prepare request - var reqBuilder = - prepareBuilder(options) - .uri(baseUri.resolve("/restate/invocation/" + invocationId + "/output")); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleGetOutputResponse(request, resSerde)); - } - }; - } - - @Override - public IdempotentInvocationHandle idempotentInvocationHandle( - Target target, String idempotencyKey, Serde resSerde) { - return new IdempotentInvocationHandle<>() { - @Override - public CompletableFuture attachAsync(RequestOptions options) { - // Prepare request - var uri = - baseUri.resolve( - "/restate/invocation" - + targetToURI(target) - + "/" - + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) - + "/attach"); - var reqBuilder = prepareBuilder(options).uri(uri); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleAttachResponse(request, resSerde)); - } - - @Override - public CompletableFuture> getOutputAsync(RequestOptions options) { - // Prepare request - var uri = - baseUri.resolve( - "/restate/invocation" - + targetToURI(target) - + "/" - + URLEncoder.encode(idempotencyKey, StandardCharsets.UTF_8) - + "/output"); - var reqBuilder = prepareBuilder(options).uri(uri); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleGetOutputResponse(request, resSerde)); - } - }; - } - - @Override - public WorkflowHandle workflowHandle( - String workflowName, String workflowId, Serde resSerde) { - return new WorkflowHandle<>() { - @Override - public CompletableFuture attachAsync(RequestOptions options) { - // Prepare request - var reqBuilder = - prepareBuilder(options) - .uri( - baseUri.resolve( - "/restate/workflow/" - + workflowName - + "/" - + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) - + "/attach")); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleAttachResponse(request, resSerde)); - } - - @Override - public CompletableFuture> getOutputAsync(RequestOptions options) { - // Prepare request - var reqBuilder = - prepareBuilder(options) - .uri( - baseUri.resolve( - "/restate/workflow/" - + workflowName - + "/" - + URLEncoder.encode(workflowId, StandardCharsets.UTF_8) - + "/output")); - - // Build and Send request - HttpRequest request = reqBuilder.GET().build(); - return httpClient - .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) - .handle(handleGetOutputResponse(request, resSerde)); - } - }; - } - - private @NotNull - BiFunction, Throwable, Output> handleGetOutputResponse( - HttpRequest request, Serde resSerde) { - return (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", request, throwable); - } - - if (response.statusCode() == 470) { - return Output.notReady(); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return Output.ready(resSerde.deserialize(response.body())); - } catch (Exception e) { - throw new IngressException("Cannot deserialize the response", response, e); - } - }; - } - - private @NotNull BiFunction, Throwable, Res> handleAttachResponse( - HttpRequest request, Serde resSerde) { - return (response, throwable) -> { - if (throwable != null) { - throw new IngressException("Error when executing the request", request, throwable); - } - - if (response.statusCode() >= 300) { - handleNonSuccessResponse(response); - } - - try { - return resSerde.deserialize(response.body()); - } catch (Exception e) { - throw new IngressException("Cannot deserialize the response", response, e); - } - }; - } - - /** Contains prefix / but not postfix / */ - private String targetToURI(Target target) { - StringBuilder builder = new StringBuilder(); - builder.append("/").append(target.getService()); - if (target.getKey() != null) { - builder.append("/").append(URLEncoder.encode(target.getKey(), StandardCharsets.UTF_8)); - } - builder.append("/").append(target.getHandler()); - return builder.toString(); - } - - private URI toRequestURI(Target target, boolean isSend, Duration delay) { - StringBuilder builder = new StringBuilder(targetToURI(target)); - if (isSend) { - builder.append("/send"); - } - if (delay != null && !delay.isZero() && !delay.isNegative()) { - builder.append("?delay=").append(delay); - } - - return this.baseUri.resolve(builder.toString()); - } - - private HttpRequest.Builder prepareBuilder(RequestOptions options) { - var reqBuilder = HttpRequest.newBuilder(); - - // Add headers - this.headers.forEach(reqBuilder::header); - - // Add idempotency key and period - if (options instanceof CallRequestOptions) { - if (((CallRequestOptions) options).getIdempotencyKey() != null) { - reqBuilder.header("idempotency-key", ((CallRequestOptions) options).getIdempotencyKey()); - } - } - - // Add additional headers - options.getAdditionalHeaders().forEach(reqBuilder::header); - - return reqBuilder; - } - - private HttpRequest prepareHttpRequest( - Target target, - boolean isSend, - Serde reqSerde, - Req req, - Duration delay, - RequestOptions options) { - var reqBuilder = prepareBuilder(options).uri(toRequestURI(target, isSend, delay)); - - // Add content-type - if (reqSerde.contentType() != null) { - reqBuilder.header("content-type", reqSerde.contentType()); - } - - return reqBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(reqSerde.serialize(req))).build(); - } - - private void handleNonSuccessResponse(HttpResponse response) { - if (response.headers().firstValue("content-type").orElse("").contains("application/json")) { - String errorMessage; - // Let's try to parse the message field - try { - errorMessage = - findStringFieldInJsonObject(new ByteArrayInputStream(response.body()), "message"); - } catch (Exception e) { - throw new IngressException("Can't decode error response from ingress", response, e); - } - throw new IngressException(errorMessage, response); - } - - // Fallback error - throw new IngressException("Received non success status code", response); - } - - private static String findStringFieldInJsonObject(InputStream body, String fieldName) - throws IOException { - try (JsonParser parser = JSON_FACTORY.createParser(body)) { - if (parser.nextToken() != JsonToken.START_OBJECT) { - throw new IllegalStateException( - "Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken()); - } - for (String actualFieldName = parser.nextFieldName(); - actualFieldName != null; - actualFieldName = parser.nextFieldName()) { - if (actualFieldName.equalsIgnoreCase(fieldName)) { - return parser.nextTextValue(); - } else { - parser.nextValue(); - } - } - throw new IllegalStateException( - "Expecting field name \"" + fieldName + "\", got " + parser.getCurrentToken()); - } - } - - private static Map findStringFieldsInJsonObject( - InputStream body, String... fields) throws IOException { - Map resultMap = new HashMap<>(); - Set fieldSet = new HashSet<>(Set.of(fields)); - - try (JsonParser parser = JSON_FACTORY.createParser(body)) { - if (parser.nextToken() != JsonToken.START_OBJECT) { - throw new IllegalStateException( - "Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken()); - } - for (String actualFieldName = parser.nextFieldName(); - actualFieldName != null; - actualFieldName = parser.nextFieldName()) { - if (fieldSet.remove(actualFieldName)) { - resultMap.put(actualFieldName, parser.nextTextValue()); - } else { - parser.nextValue(); - } - } - } - - if (!fieldSet.isEmpty()) { - throw new IllegalStateException( - "Expecting fields \"" + Arrays.toString(fields) + "\", cannot find fields " + fieldSet); - } - - return resultMap; - } - - public static DefaultClient of( - HttpClient httpClient, String baseUri, Map headers) { - return new DefaultClient(httpClient, baseUri, headers); - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/RequestOptions.java b/sdk-common/src/main/java/dev/restate/sdk/client/RequestOptions.java deleted file mode 100644 index c979fd953..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/client/RequestOptions.java +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; - -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -public class RequestOptions { - public static final RequestOptions DEFAULT = new RequestOptions(); - - final Map additionalHeaders; - - public RequestOptions() { - this(new HashMap<>()); - } - - public RequestOptions(Map additionalHeaders) { - this.additionalHeaders = additionalHeaders; - } - - public RequestOptions withHeader(String name, String value) { - RequestOptions newOptions = this.copy(); - newOptions.additionalHeaders.put(name, value); - return newOptions; - } - - public RequestOptions withHeaders(Map additionalHeaders) { - RequestOptions newOptions = this.copy(); - newOptions.additionalHeaders.putAll(additionalHeaders); - return newOptions; - } - - public Map getAdditionalHeaders() { - return additionalHeaders; - } - - public RequestOptions copy() { - return new RequestOptions(new HashMap<>(this.additionalHeaders)); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof RequestOptions that)) return false; - return Objects.equals(additionalHeaders, that.additionalHeaders); - } - - @Override - public int hashCode() { - return additionalHeaders.hashCode(); - } - - @Override - public String toString() { - return "RequestOptions{" + "additionalHeaders=" + additionalHeaders + '}'; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/SendResponse.java b/sdk-common/src/main/java/dev/restate/sdk/client/SendResponse.java deleted file mode 100644 index 220399930..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/client/SendResponse.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.client; - -import java.util.Objects; - -public class SendResponse { - - public enum SendStatus { - /** The request was sent for the first time. */ - ACCEPTED, - /** The request was already sent beforehand. */ - PREVIOUSLY_ACCEPTED - } - - private final SendStatus status; - private final String invocationId; - - public SendResponse(SendStatus status, String invocationId) { - this.status = status; - this.invocationId = invocationId; - } - - public SendStatus getStatus() { - return status; - } - - public String getInvocationId() { - return invocationId; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof SendResponse)) return false; - SendResponse that = (SendResponse) o; - return status == that.status && Objects.equals(invocationId, that.invocationId); - } - - @Override - public int hashCode() { - int result = Objects.hashCode(status); - result = 31 * result + Objects.hashCode(invocationId); - return result; - } - - @Override - public String toString() { - return "SendResponse{" + "status=" + status + ", invocationId='" + invocationId + '\'' + '}'; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Request.java b/sdk-common/src/main/java/dev/restate/sdk/common/Request.java deleted file mode 100644 index ccbca4904..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Request.java +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; - -import io.opentelemetry.context.Context; -import java.nio.ByteBuffer; -import java.util.Map; -import java.util.Objects; - -/** The Request object represents the incoming request to an handler. */ -public final class Request { - - private final InvocationId invocationId; - private final Context otelContext; - private final ByteBuffer body; - private final Map headers; - - public Request( - InvocationId invocationId, - Context otelContext, - ByteBuffer body, - Map headers) { - this.invocationId = invocationId; - this.otelContext = otelContext; - this.body = body; - this.headers = headers; - } - - /** - * @return this invocation id. - */ - public InvocationId invocationId() { - return invocationId; - } - - /** - * @return the attached OpenTelemetry {@link Context}. - */ - public Context otelContext() { - return otelContext; - } - - public byte[] body() { - return Serde.BYTE_BUFFER.serialize(body); - } - - public ByteBuffer bodyBuffer() { - return body.asReadOnlyBuffer(); - } - - /** - * @return the request headers, as received at the ingress. - */ - public Map headers() { - return headers; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Request request = (Request) o; - return Objects.equals(invocationId, request.invocationId) - && Objects.equals(otelContext, request.otelContext) - && Objects.equals(body, request.body) - && Objects.equals(headers, request.headers); - } - - @Override - public int hashCode() { - int result = Objects.hashCode(invocationId); - result = 31 * result + Objects.hashCode(otelContext); - result = 31 * result + Objects.hashCode(body); - result = 31 * result + Objects.hashCode(headers); - return result; - } - - @Override - public String toString() { - return "Request{" + "invocationId=" + invocationId + ", headers=" + headers + '}'; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java b/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java deleted file mode 100644 index 2e7e0961c..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; - -import java.nio.ByteBuffer; -import org.jspecify.annotations.Nullable; - -/** - * Richer version of {@link Serde} containing schema information. - * - *

This API should be considered unstable to implement. - * - *

You can create one using {@link #withSchema(Object, Serde)}. - */ -public interface RichSerde extends Serde { - - /** - * @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to - * files. If the schema shouldn't be serialized with Jackson, return a {@link String} - */ - Object jsonSchema(); - - static RichSerde withSchema(Object jsonSchema, Serde inner) { - return new RichSerde<>() { - @Override - public byte[] serialize(T value) { - return inner.serialize(value); - } - - @Override - public ByteBuffer serializeToByteBuffer(T value) { - return inner.serializeToByteBuffer(value); - } - - @Override - public T deserialize(ByteBuffer byteBuffer) { - return inner.deserialize(byteBuffer); - } - - @Override - public T deserialize(byte[] value) { - return inner.deserialize(value); - } - - @Override - public String contentType() { - return inner.contentType(); - } - - @Override - public Object jsonSchema() { - return jsonSchema; - } - }; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Deferred.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Deferred.java deleted file mode 100644 index 9b93148fd..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Deferred.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import org.jspecify.annotations.Nullable; - -/** - * Interface to define interaction with deferred results. - * - *

Implementations of this class are provided by {@link Syscalls} and should not be - * overriden/wrapped. - * - *

To resolve a {@link Deferred}, use {@link Syscalls#resolveDeferred(Deferred, SyscallCallback)} - */ -public interface Deferred { - - boolean isCompleted(); - - /** - * @return {@code null} if {@link #isCompleted()} is false. - */ - @Nullable Result toResult(); -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java deleted file mode 100644 index 449aa51d8..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import java.util.Objects; - -public final class HandlerDefinition { - - private final HandlerSpecification spec; - private final HandlerRunner runner; - - HandlerDefinition(HandlerSpecification spec, HandlerRunner runner) { - this.spec = spec; - this.runner = runner; - } - - public HandlerSpecification getSpec() { - return spec; - } - - public HandlerRunner getRunner() { - return runner; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - HandlerDefinition that = (HandlerDefinition) o; - return Objects.equals(spec, that.spec) && Objects.equals(runner, that.runner); - } - - @Override - public int hashCode() { - int result = Objects.hashCode(spec); - result = 31 * result + Objects.hashCode(runner); - return result; - } - - @Override - public String toString() { - return "HandlerDefinition{" + "spec=" + spec + ", handler=" + runner + '}'; - } - - public static HandlerDefinition of( - HandlerSpecification spec, HandlerRunner runner) { - return new HandlerDefinition<>(spec, runner); - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java deleted file mode 100644 index 658b64a36..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import dev.restate.sdk.common.HandlerType; -import dev.restate.sdk.common.Serde; -import java.util.Collections; -import java.util.Map; -import java.util.Objects; -import org.jspecify.annotations.Nullable; - -public final class HandlerSpecification { - - private final String name; - private final HandlerType handlerType; - private final @Nullable String acceptContentType; - private final Serde requestSerde; - private final Serde responseSerde; - private final @Nullable String documentation; - private final Map metadata; - - HandlerSpecification( - String name, - HandlerType handlerType, - @Nullable String acceptContentType, - Serde requestSerde, - Serde responseSerde, - @Nullable String documentation, - Map metadata) { - this.name = name; - this.handlerType = handlerType; - this.acceptContentType = acceptContentType; - this.requestSerde = requestSerde; - this.responseSerde = responseSerde; - this.documentation = documentation; - this.metadata = metadata; - } - - public static HandlerSpecification of( - String method, HandlerType handlerType, Serde requestSerde, Serde responseSerde) { - return new HandlerSpecification<>( - method, handlerType, null, requestSerde, responseSerde, null, Collections.emptyMap()); - } - - public String getName() { - return name; - } - - public HandlerType getHandlerType() { - return handlerType; - } - - public @Nullable String getAcceptContentType() { - return acceptContentType; - } - - public Serde getRequestSerde() { - return requestSerde; - } - - public Serde getResponseSerde() { - return responseSerde; - } - - public @Nullable String getDocumentation() { - return documentation; - } - - public Map getMetadata() { - return metadata; - } - - public HandlerSpecification withAcceptContentType(String acceptContentType) { - return new HandlerSpecification<>( - name, handlerType, acceptContentType, requestSerde, responseSerde, documentation, metadata); - } - - public HandlerSpecification withDocumentation(@Nullable String documentation) { - return new HandlerSpecification<>( - name, handlerType, acceptContentType, requestSerde, responseSerde, documentation, metadata); - } - - public HandlerSpecification withMetadata(Map metadata) { - return new HandlerSpecification<>( - name, handlerType, acceptContentType, requestSerde, responseSerde, documentation, metadata); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof HandlerSpecification that)) return false; - return Objects.equals(name, that.name) - && handlerType == that.handlerType - && Objects.equals(acceptContentType, that.acceptContentType) - && Objects.equals(requestSerde, that.requestSerde) - && Objects.equals(responseSerde, that.responseSerde) - && Objects.equals(documentation, that.documentation) - && Objects.equals(metadata, that.metadata); - } - - @Override - public int hashCode() { - return Objects.hash( - name, handlerType, acceptContentType, requestSerde, responseSerde, documentation, metadata); - } - - @Override - public String toString() { - return "HandlerSpecification{" - + "name='" - + name - + '\'' - + ", handlerType=" - + handlerType - + ", acceptContentType='" - + acceptContentType - + '\'' - + ", requestContentType=" - + requestSerde.contentType() - + ", responseContentType=" - + responseSerde.contentType() - + ", documentation=" - + documentation - + ", metadata=" - + metadata - + '}'; - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java deleted file mode 100644 index a68a519d6..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Result.java +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import dev.restate.sdk.common.TerminalException; -import java.util.function.Function; -import org.jspecify.annotations.Nullable; - -/** - * Result can be 3 valued: - * - *

    - *
  • Empty - *
  • Value - *
  • Failure - *
- * - * Empty and Value are used to distinguish the logical empty with the null result. - * - *

Failure in a ready result is always a user failure, and never a syscall failure, as opposed to - * {@link SyscallCallback#onCancel(Throwable)}. - * - * @param result type - */ -public abstract class Result { - - private Result() {} - - /** - * @return true if there is no failure. - */ - public abstract boolean isSuccess(); - - public abstract boolean isEmpty(); - - /** - * @return The success value, or null in case is empty. - */ - @Nullable - public abstract T getValue(); - - @Nullable - public abstract TerminalException getFailure(); - - // --- Helper methods - - /** - * Map this result success value. If the mapper throws an exception, this exception will be - * converted to {@link TerminalException} and return a new failed {@link Result}. - */ - public Result mapSuccess(Function mapper) { - if (this.isSuccess()) { - try { - return Result.success(mapper.apply(this.getValue())); - } catch (TerminalException e) { - return Result.failure(e); - } catch (Exception e) { - return Result.failure( - new TerminalException(TerminalException.INTERNAL_SERVER_ERROR_CODE, e.getMessage())); - } - } - //noinspection unchecked - return (Result) this; - } - - // --- Factory methods - - @SuppressWarnings("unchecked") - public static Result empty() { - return (Result) Empty.INSTANCE; - } - - public static Result success(T value) { - return new Success<>(value); - } - - public static Result failure(TerminalException t) { - return new Failure<>(t); - } - - static class Empty extends Result { - - public static final Empty INSTANCE = new Empty<>(); - - private Empty() {} - - @Override - public boolean isSuccess() { - return true; - } - - @Override - public boolean isEmpty() { - return true; - } - - @Nullable - @Override - public T getValue() { - return null; - } - - @Nullable - @Override - public TerminalException getFailure() { - return null; - } - } - - static class Success extends Result { - private final T value; - - private Success(T value) { - this.value = value; - } - - @Override - public boolean isSuccess() { - return true; - } - - @Override - public boolean isEmpty() { - return false; - } - - @Nullable - @Override - public T getValue() { - return value; - } - - @Nullable - @Override - public TerminalException getFailure() { - return null; - } - } - - static class Failure extends Result { - private final TerminalException cause; - - private Failure(TerminalException cause) { - this.cause = cause; - } - - @Override - public boolean isSuccess() { - return false; - } - - @Override - public boolean isEmpty() { - return false; - } - - @Nullable - @Override - public T getValue() { - return null; - } - - @Nullable - @Override - public TerminalException getFailure() { - return cause; - } - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/SyscallCallback.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/SyscallCallback.java deleted file mode 100644 index cc882b5f3..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/SyscallCallback.java +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; -import java.util.function.Function; -import org.jspecify.annotations.Nullable; - -public interface SyscallCallback { - - void onSuccess(@Nullable T value); - - /** - * The internal state machine invokes this method when a syscall is interrupted due to a - * suspension, or a network error. - * - *

In case the user code is blocked on a lock, the implementation of this method should unblock - * it. - */ - void onCancel(Throwable t); - - static SyscallCallback of(Consumer onSuccess, Consumer onFailure) { - return new SyscallCallback<>() { - @Override - public void onSuccess(@Nullable T value) { - onSuccess.accept(value); - } - - @Override - public void onCancel(@Nullable Throwable t) { - onFailure.accept(t); - } - }; - } - - static SyscallCallback ofVoid(Runnable onSuccess, Consumer onFailure) { - return new SyscallCallback<>() { - @Override - public void onSuccess(@Nullable Void value) { - onSuccess.run(); - } - - @Override - public void onCancel(@Nullable Throwable t) { - onFailure.accept(t); - } - }; - } - - static SyscallCallback mappingTo(Function mapper, SyscallCallback callback) { - return new SyscallCallback<>() { - @Override - public void onSuccess(@Nullable T value) { - callback.onSuccess(mapper.apply(value)); - } - - @Override - public void onCancel(@Nullable Throwable t) { - callback.onCancel(t); - } - }; - } - - static SyscallCallback completingFuture(CompletableFuture fut) { - return of(fut::complete, t -> fut.cancel(true)); - } -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java deleted file mode 100644 index 4373cb313..000000000 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; - -import dev.restate.sdk.common.Request; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.Target; -import dev.restate.sdk.common.TerminalException; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.*; -import java.util.List; -import java.util.Map; -import org.jspecify.annotations.Nullable; - -/** - * Internal interface to access Restate functionalities. Users can use the ad-hoc RestateContext - * interfaces provided by the various implementations. - * - *

When using executor switching wrappers, the method's {@code callback} will be executed in the - * state machine executor. - */ -public interface Syscalls { - - String objectKey(); - - Request request(); - - /** - * @return true if it's inside a side effect block. - */ - boolean isInsideSideEffect(); - - // ----- IO - // Note: These are not supposed to be exposed to RestateContext, but they should be used through - // gRPC APIs. - - void writeOutput(ByteBuffer value, SyscallCallback callback); - - void writeOutput(TerminalException exception, SyscallCallback callback); - - // ----- State - - void get(String name, SyscallCallback> callback); - - void getKeys(SyscallCallback>> callback); - - void clear(String name, SyscallCallback callback); - - void clearAll(SyscallCallback callback); - - void set(String name, ByteBuffer value, SyscallCallback callback); - - // ----- Syscalls - - void sleep(Duration duration, SyscallCallback> callback); - - void call(Target target, ByteBuffer parameter, SyscallCallback> callback); - - void send( - Target target, - ByteBuffer parameter, - @Nullable Duration delay, - SyscallCallback requestCallback); - - void enterSideEffectBlock(@Nullable String name, EnterSideEffectSyscallCallback callback); - - void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback); - - /** - * @deprecated use {@link #exitSideEffectBlockWithException(Throwable, RetryPolicy, - * ExitSideEffectSyscallCallback)} instead. - */ - @Deprecated(since = "1.1.0", forRemoval = true) - void exitSideEffectBlockWithTerminalException( - TerminalException toWrite, ExitSideEffectSyscallCallback callback); - - void exitSideEffectBlockWithException( - Throwable toWrite, @Nullable RetryPolicy retryPolicy, ExitSideEffectSyscallCallback callback); - - void awakeable(SyscallCallback>> callback); - - void resolveAwakeable(String id, ByteBuffer payload, SyscallCallback requestCallback); - - void rejectAwakeable(String id, String reason, SyscallCallback requestCallback); - - void promise(String key, SyscallCallback> callback); - - void peekPromise(String key, SyscallCallback> callback); - - void resolvePromise(String key, ByteBuffer payload, SyscallCallback> callback); - - void rejectPromise(String key, String reason, SyscallCallback> callback); - - void fail(Throwable cause); - - // ----- Deferred - - void resolveDeferred(Deferred deferredToResolve, SyscallCallback callback); - - Deferred createAnyDeferred(List> children); - - Deferred createAllDeferred(List> children); -} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/Endpoint.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/Endpoint.java new file mode 100644 index 000000000..49c1d6353 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/Endpoint.java @@ -0,0 +1,166 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint; + +import dev.restate.sdk.endpoint.definition.HandlerRunner; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactories; +import io.opentelemetry.api.OpenTelemetry; +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Restate endpoint, encapsulating the configured services, together with additional options. */ +public final class Endpoint { + + private final Map services; + private final OpenTelemetry openTelemetry; + private final RequestIdentityVerifier requestIdentityVerifier; + private final boolean experimentalContextEnabled; + + private Endpoint( + Map services, + OpenTelemetry openTelemetry, + RequestIdentityVerifier requestIdentityVerifier, + boolean experimentalContextEnabled) { + this.services = services; + this.openTelemetry = openTelemetry; + this.requestIdentityVerifier = requestIdentityVerifier; + this.experimentalContextEnabled = experimentalContextEnabled; + } + + public static class Builder { + private final List services = new ArrayList<>(); + private RequestIdentityVerifier requestIdentityVerifier = RequestIdentityVerifier.noop(); + private OpenTelemetry openTelemetry = OpenTelemetry.noop(); + private boolean experimentalContextEnabled = false; + + /** + * Add a Restate service to the endpoint. This will automatically discover the generated factory + * based on the class name. + * + *

You can also manually instantiate the {@link ServiceDefinition} using {@link + * #bind(ServiceDefinition)}. + */ + public Builder bind(Object service) { + return this.bind(ServiceDefinitionFactories.discover(service).create(service, null)); + } + + /** + * Like {@link #bind(Object)}, but allows to provide options for the handler runner. This allows + * to configure for the Java API the executor where to run the handler code, or the Kotlin API + * the coroutine context. + * + *

Look at the respective documentations of the HandlerRunner class in the Java or in the + * Kotlin module. + * + * @see #bind(Object) + */ + public Builder bind(Object service, HandlerRunner.Options options) { + return this.bind(ServiceDefinitionFactories.discover(service).create(service, options)); + } + + /** Add a manual {@link ServiceDefinition} to the endpoint. */ + public Builder bind(ServiceDefinition serviceDefinition) { + this.services.add(serviceDefinition); + return this; + } + + /** + * Set the {@link OpenTelemetry} implementation for tracing and metrics. + * + * @see OpenTelemetry + */ + public Builder withOpenTelemetry(OpenTelemetry openTelemetry) { + this.openTelemetry = openTelemetry; + return this; + } + + /** Same as {@link #withOpenTelemetry(OpenTelemetry)}. */ + public void setOpenTelemetry(OpenTelemetry openTelemetry) { + withOpenTelemetry(openTelemetry); + } + + /** + * @return the configured {@link OpenTelemetry} + */ + public OpenTelemetry getOpenTelemetry() { + return this.openTelemetry; + } + + /** + * Set the request identity verifier for this endpoint. + * + *

For the Restate implementation to use with Restate Cloud, check the module {@code + * sdk-request-identity}. + */ + public Builder withRequestIdentityVerifier(RequestIdentityVerifier requestIdentityVerifier) { + this.requestIdentityVerifier = requestIdentityVerifier; + return this; + } + + /** Same as {@link #withRequestIdentityVerifier(RequestIdentityVerifier)}. */ + public void setRequestIdentityVerifier(RequestIdentityVerifier requestIdentityVerifier) { + this.withRequestIdentityVerifier(requestIdentityVerifier); + } + + /** + * @return the configured request identity verifier + */ + public RequestIdentityVerifier getRequestIdentityVerifier() { + return this.requestIdentityVerifier; + } + + public Builder enablePreviewContext() { + this.experimentalContextEnabled = true; + return this; + } + + public Endpoint build() { + return new Endpoint( + this.services.stream() + .collect(Collectors.toMap(ServiceDefinition::getServiceName, Function.identity())), + this.openTelemetry, + this.requestIdentityVerifier, + this.experimentalContextEnabled); + } + } + + public static Builder builder() { + return new Builder(); + } + + /** + * @see Builder#bind(Object) + */ + public static Builder bind(Object object) { + return new Builder().bind(object); + } + + public ServiceDefinition resolveService(String serviceName) { + return services.get(serviceName); + } + + public Stream getServiceDefinitions() { + return this.services.values().stream(); + } + + public OpenTelemetry getOpenTelemetry() { + return openTelemetry; + } + + public RequestIdentityVerifier getRequestIdentityVerifier() { + return requestIdentityVerifier; + } + + public boolean isExperimentalContextEnabled() { + return experimentalContextEnabled; + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/HeadersAccessor.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/HeadersAccessor.java new file mode 100644 index 000000000..0c21a4bda --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/HeadersAccessor.java @@ -0,0 +1,38 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint; + +import java.util.Map; +import org.jspecify.annotations.Nullable; + +/** Abstraction for headers map. */ +public interface HeadersAccessor { + Iterable keys(); + + @Nullable String get(String key); + + static HeadersAccessor wrap(Map input) { + return new HeadersAccessor() { + @Override + public Iterable keys() { + return input.keySet(); + } + + @Override + public String get(String key) { + for (var k : input.keySet()) { + if (k.equalsIgnoreCase(key)) { + return input.get(k); + } + } + return null; + } + }; + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/auth/RequestIdentityVerifier.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/RequestIdentityVerifier.java similarity index 66% rename from sdk-common/src/main/java/dev/restate/sdk/auth/RequestIdentityVerifier.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/RequestIdentityVerifier.java index 817c9d4eb..b34e0baa9 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/auth/RequestIdentityVerifier.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/RequestIdentityVerifier.java @@ -6,21 +6,20 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.auth; - -import org.jspecify.annotations.Nullable; +package dev.restate.sdk.endpoint; /** Interface to verify requests. */ public interface RequestIdentityVerifier { - /** Abstraction for headers map. */ - @FunctionalInterface - interface Headers { - @Nullable String get(String key); - } - /** * @throws Exception if the request cannot be verified */ - void verifyRequest(Headers headers) throws Exception; + void verifyRequest(HeadersAccessor headers) throws Exception; + + /** + * @return a noop request identity verifier + */ + static RequestIdentityVerifier noop() { + return headers -> {}; + } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/AsyncResult.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/AsyncResult.java new file mode 100644 index 000000000..c6b547612 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/AsyncResult.java @@ -0,0 +1,39 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint.definition; + +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.types.TerminalException; +import java.util.concurrent.CompletableFuture; + +/** + * Interface to define interaction with deferred results. + * + *

Implementations of this class are provided by {@link HandlerContext} and should not be + * overriden/wrapped. + */ +public interface AsyncResult { + + CompletableFuture poll(); + + HandlerContext ctx(); + + AsyncResult map( + ThrowingFunction> successMapper, + ThrowingFunction> failureMapper); + + default AsyncResult map(ThrowingFunction> successMapper) { + return map(successMapper, null); + } + + default AsyncResult mapFailure( + ThrowingFunction> failureMapper) { + return map(null, failureMapper); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java new file mode 100644 index 000000000..61a204b48 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java @@ -0,0 +1,109 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint.definition; + +import dev.restate.common.Output; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.types.*; +import java.time.Duration; +import java.util.*; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import org.jspecify.annotations.Nullable; + +/** + * Internal interface to access Restate functionalities. Users can use the ad-hoc RestateContext + * interfaces provided by the various implementations. + */ +public interface HandlerContext { + + String objectKey(); + + HandlerRequest request(); + + // ----- IO + // Note: These are not supposed to be exposed in the user's facing Context API. + + CompletableFuture writeOutput(Slice value); + + CompletableFuture writeOutput(TerminalException exception); + + // ----- State + + CompletableFuture>> get(String name); + + CompletableFuture>> getKeys(); + + CompletableFuture clear(String name); + + CompletableFuture clearAll(); + + CompletableFuture set(String name, Slice value); + + // ----- Syscalls + + CompletableFuture> timer(Duration duration, @Nullable String name); + + record CallResult( + AsyncResult invocationIdAsyncResult, AsyncResult callAsyncResult) {} + + CompletableFuture call( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers); + + CompletableFuture> send( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers, + @Nullable Duration delay); + + interface RunCompleter { + void proposeSuccess(Slice toWrite); + + void proposeFailure(Throwable toWrite, @Nullable RetryPolicy retryPolicy); + } + + CompletableFuture> submitRun( + @Nullable String name, Consumer closure); + + record Awakeable(String id, AsyncResult asyncResult) {} + + CompletableFuture awakeable(); + + CompletableFuture resolveAwakeable(String id, Slice payload); + + CompletableFuture rejectAwakeable(String id, TerminalException reason); + + CompletableFuture> promise(String key); + + CompletableFuture>> peekPromise(String key); + + CompletableFuture> resolvePromise(String key, Slice payload); + + CompletableFuture> rejectPromise(String key, TerminalException reason); + + CompletableFuture cancelInvocation(String invocationId); + + CompletableFuture> attachInvocation(String invocationId); + + CompletableFuture>> getInvocationOutput(String invocationId); + + void fail(Throwable cause); + + // ----- Deferred + + AsyncResult createAnyAsyncResult(List> children); + + AsyncResult createAllAsyncResult(List> children); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java new file mode 100644 index 000000000..82682f451 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerDefinition.java @@ -0,0 +1,130 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint.definition; + +import dev.restate.serde.Serde; +import java.util.Collections; +import java.util.Map; +import org.jspecify.annotations.Nullable; + +public final class HandlerDefinition { + + private final String name; + private final HandlerType handlerType; + private final @Nullable String acceptContentType; + private final Serde requestSerde; + private final Serde responseSerde; + private final @Nullable String documentation; + private final Map metadata; + private final HandlerRunner runner; + + HandlerDefinition( + String name, + HandlerType handlerType, + @Nullable String acceptContentType, + Serde requestSerde, + Serde responseSerde, + @Nullable String documentation, + Map metadata, + HandlerRunner runner) { + this.name = name; + this.handlerType = handlerType; + this.acceptContentType = acceptContentType; + this.requestSerde = requestSerde; + this.responseSerde = responseSerde; + this.documentation = documentation; + this.metadata = metadata; + this.runner = runner; + } + + public String getName() { + return name; + } + + public HandlerType getHandlerType() { + return handlerType; + } + + public @Nullable String getAcceptContentType() { + return acceptContentType; + } + + public Serde getRequestSerde() { + return requestSerde; + } + + public Serde getResponseSerde() { + return responseSerde; + } + + public @Nullable String getDocumentation() { + return documentation; + } + + public Map getMetadata() { + return metadata; + } + + public HandlerRunner getRunner() { + return runner; + } + + public HandlerDefinition withAcceptContentType(String acceptContentType) { + return new HandlerDefinition<>( + name, + handlerType, + acceptContentType, + requestSerde, + responseSerde, + documentation, + metadata, + runner); + } + + public HandlerDefinition withDocumentation(@Nullable String documentation) { + return new HandlerDefinition<>( + name, + handlerType, + acceptContentType, + requestSerde, + responseSerde, + documentation, + metadata, + runner); + } + + public HandlerDefinition withMetadata(Map metadata) { + return new HandlerDefinition<>( + name, + handlerType, + acceptContentType, + requestSerde, + responseSerde, + documentation, + metadata, + runner); + } + + public static HandlerDefinition of( + String handler, + HandlerType handlerType, + Serde requestSerde, + Serde responseSerde, + HandlerRunner runner) { + return new HandlerDefinition<>( + handler, + handlerType, + null, + requestSerde, + responseSerde, + null, + Collections.emptyMap(), + runner); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerRunner.java similarity index 54% rename from sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerRunner.java index 95915aa65..8084eb394 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerRunner.java @@ -6,25 +6,26 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; +package dev.restate.sdk.endpoint.definition; -import java.nio.ByteBuffer; -import org.jspecify.annotations.Nullable; +import dev.restate.common.Slice; +import dev.restate.serde.Serde; +import java.util.concurrent.CompletableFuture; -public interface HandlerRunner { +public interface HandlerRunner { /** - * Thread local to store {@link Syscalls}. + * Thread local to store {@link HandlerContext}. * *

Implementations of {@link HandlerRunner} should correctly propagate this thread local in * order for logging to work correctly. Could be improved if ScopedContext will ever be introduced in * log4j2. */ - ThreadLocal SYSCALLS_THREAD_LOCAL = new ThreadLocal<>(); + ThreadLocal HANDLER_CONTEXT_THREAD_LOCAL = new ThreadLocal<>(); - void run( - HandlerSpecification handlerSpecification, - Syscalls syscalls, - @Nullable O options, - SyscallCallback callback); + /** Marker interface of runner options. */ + interface Options {} + + CompletableFuture run( + HandlerContext handlerContext, Serde requestSerde, Serde responseSerde); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/HandlerType.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerType.java similarity index 89% rename from sdk-common/src/main/java/dev/restate/sdk/common/HandlerType.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerType.java index 37ed359d7..e1cee4298 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/HandlerType.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerType.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.endpoint.definition; public enum HandlerType { SHARED, diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinition.java similarity index 67% rename from sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinition.java index e22da2110..7f6d45221 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinition.java @@ -6,26 +6,25 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; +package dev.restate.sdk.endpoint.definition; -import dev.restate.sdk.common.ServiceType; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import org.jspecify.annotations.Nullable; -public final class ServiceDefinition { +public final class ServiceDefinition { private final String serviceName; private final ServiceType serviceType; - private final Map> handlers; + private final Map> handlers; private final @Nullable String documentation; private final Map metadata; private ServiceDefinition( String serviceName, ServiceType serviceType, - Map> handlers, + Map> handlers, @Nullable String documentation, Map metadata) { this.serviceName = serviceName; @@ -43,11 +42,11 @@ public ServiceType getServiceType() { return serviceType; } - public Collection> getHandlers() { + public Collection> getHandlers() { return handlers.values(); } - public HandlerDefinition getHandler(String name) { + public HandlerDefinition getHandler(String name) { return handlers.get(name); } @@ -59,19 +58,19 @@ public Map getMetadata() { return metadata; } - public ServiceDefinition withDocumentation(@Nullable String documentation) { - return new ServiceDefinition<>(serviceName, serviceType, handlers, documentation, metadata); + public ServiceDefinition withDocumentation(@Nullable String documentation) { + return new ServiceDefinition(serviceName, serviceType, handlers, documentation, metadata); } - public ServiceDefinition withMetadata(Map metadata) { - return new ServiceDefinition<>(serviceName, serviceType, handlers, documentation, metadata); + public ServiceDefinition withMetadata(Map metadata) { + return new ServiceDefinition(serviceName, serviceType, handlers, documentation, metadata); } @Override public boolean equals(Object object) { if (this == object) return true; if (object == null || getClass() != object.getClass()) return false; - ServiceDefinition that = (ServiceDefinition) object; + ServiceDefinition that = (ServiceDefinition) object; return Objects.equals(serviceName, that.serviceName) && serviceType == that.serviceType && Objects.equals(handlers, that.handlers); @@ -82,13 +81,13 @@ public int hashCode() { return Objects.hash(serviceName, serviceType, handlers); } - public static ServiceDefinition of( - String name, ServiceType ty, Collection> handlers) { - return new ServiceDefinition<>( + public static ServiceDefinition of( + String name, ServiceType ty, Collection> handlers) { + return new ServiceDefinition( name, ty, handlers.stream() - .collect(Collectors.toMap(h -> h.getSpec().getName(), Function.identity())), + .collect(Collectors.toMap(HandlerDefinition::getName, Function.identity())), null, Collections.emptyMap()); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java new file mode 100644 index 000000000..5c2da56e5 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactories.java @@ -0,0 +1,87 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.endpoint.definition; + +import java.util.*; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +@SuppressWarnings("rawtypes") +public final class ServiceDefinitionFactories { + + private static class ServiceDefinitionFactorySingleton { + private static final ServiceDefinitionFactories INSTANCE = new ServiceDefinitionFactories(); + } + + private static final Logger LOG = LogManager.getLogger(ServiceDefinitionFactories.class); + + private final List factories; + + public ServiceDefinitionFactories() { + this.factories = new ArrayList<>(); + + var serviceLoaderIterator = ServiceLoader.load(ServiceDefinitionFactory.class).iterator(); + while (serviceLoaderIterator.hasNext()) { + try { + this.factories.add(serviceLoaderIterator.next()); + } catch (ServiceConfigurationError | Exception e) { + LOG.debug( + "Found service that cannot be loaded using service provider. " + + "You can ignore this message during development.\n" + + "This might be the result of using a compiler with incremental builds (e.g. IntelliJ IDEA) " + + "that updated a dirty META-INF file after removing/renaming an annotated service.", + e); + } + } + } + + /** Resolve the code generated {@link ServiceDefinitionFactory} */ + @SuppressWarnings("unchecked") + public static ServiceDefinitionFactory discover(Object service) { + Objects.requireNonNull(service, "service is null"); + if (service instanceof ServiceDefinitionFactory) { + // We got this already + return (ServiceDefinitionFactory) service; + } + if (service instanceof ServiceDefinition) { + // We got this already + return new ServiceDefinitionFactory<>() { + @Override + public ServiceDefinition create( + Object serviceObject, + @org.jetbrains.annotations.Nullable HandlerRunner.Options overrideHandlerOptions) { + return (ServiceDefinition) serviceObject; + } + + @Override + public boolean supports(Object serviceObject) { + return serviceObject == service; + } + }; + } + return Objects.requireNonNull( + ServiceDefinitionFactorySingleton.INSTANCE.discoverFactory(service), + () -> + "ServiceDefinitionFactory class not found for service " + + service.getClass().getCanonicalName() + + ". " + + "Make sure the annotation processor is correctly configured to generate the ServiceDefinitionFactory, " + + "and it generates the META-INF/services/" + + ServiceDefinitionFactory.class.getCanonicalName() + + " file containing the generated class. " + + "If you're using fat jars, make sure the jar plugin correctly squashes all the META-INF/services files. " + + "Found ServiceAdapter: " + + ServiceDefinitionFactorySingleton.INSTANCE.factories); + } + + private @Nullable ServiceDefinitionFactory discoverFactory(Object service) { + return this.factories.stream().filter(sa -> sa.supports(service)).findFirst().orElse(null); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinitionFactory.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java similarity index 61% rename from sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinitionFactory.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java index 53fa92a26..fe457eda0 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinitionFactory.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceDefinitionFactory.java @@ -6,11 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; +package dev.restate.sdk.endpoint.definition; -public interface ServiceDefinitionFactory { +import org.jspecify.annotations.Nullable; - ServiceDefinition create(T serviceObject); +public interface ServiceDefinitionFactory { + + ServiceDefinition create(T serviceObject, HandlerRunner.@Nullable Options overrideHandlerOptions); boolean supports(Object serviceObject); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/ServiceType.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceType.java similarity index 89% rename from sdk-common/src/main/java/dev/restate/sdk/common/ServiceType.java rename to sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceType.java index 10a9b7768..0fd45efa8 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/ServiceType.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/ServiceType.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.endpoint.definition; public enum ServiceType { SERVICE, diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/AbortedExecutionException.java b/sdk-common/src/main/java/dev/restate/sdk/types/AbortedExecutionException.java similarity index 96% rename from sdk-common/src/main/java/dev/restate/sdk/common/AbortedExecutionException.java rename to sdk-common/src/main/java/dev/restate/sdk/types/AbortedExecutionException.java index 80167f3df..60e877e85 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/AbortedExecutionException.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/AbortedExecutionException.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; /** You MUST NOT catch this exception. */ public final class AbortedExecutionException extends Throwable { diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/DurablePromiseKey.java b/sdk-common/src/main/java/dev/restate/sdk/types/DurablePromiseKey.java similarity index 60% rename from sdk-common/src/main/java/dev/restate/sdk/common/DurablePromiseKey.java rename to sdk-common/src/main/java/dev/restate/sdk/types/DurablePromiseKey.java index 24269ef3c..5d865cfcd 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/DurablePromiseKey.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/DurablePromiseKey.java @@ -6,7 +6,10 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; + +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; /** * This class holds information about durable promise's name and its type tag to be used for @@ -17,16 +20,21 @@ public final class DurablePromiseKey { private final String name; - private final Serde serde; + private final TypeTag typeTag; - private DurablePromiseKey(String name, Serde serde) { + private DurablePromiseKey(String name, TypeTag typeTag) { this.name = name; - this.serde = serde; + this.typeTag = typeTag; + } + + /** Create a new {@link DurablePromiseKey}. */ + public static DurablePromiseKey of(String name, TypeTag typeTag) { + return new DurablePromiseKey<>(name, typeTag); } /** Create a new {@link DurablePromiseKey}. */ - public static DurablePromiseKey of(String name, Serde serde) { - return new DurablePromiseKey<>(name, serde); + public static DurablePromiseKey of(String name, Class clazz) { + return new DurablePromiseKey<>(name, TypeTag.of(clazz)); } /** Create a new {@link DurablePromiseKey} for bytes state. */ @@ -38,7 +46,7 @@ public String name() { return name; } - public Serde serde() { - return serde; + public TypeTag serdeInfo() { + return typeTag; } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java b/sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java new file mode 100644 index 000000000..07ad19e2f --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.types; + +import dev.restate.common.Slice; +import io.opentelemetry.context.Context; +import java.nio.ByteBuffer; +import java.util.Map; + +/** This record encapsulates the inputs to a handler. */ +public record HandlerRequest( + InvocationId invocationId, Context otelContext, Slice body, Map headers) { + public byte[] bodyAsByteArray() { + return body.toByteArray(); + } + + public ByteBuffer bodyAsBodyBuffer() { + return body.asReadOnlyByteBuffer(); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/InvocationId.java b/sdk-common/src/main/java/dev/restate/sdk/types/InvocationId.java similarity index 95% rename from sdk-common/src/main/java/dev/restate/sdk/common/InvocationId.java rename to sdk-common/src/main/java/dev/restate/sdk/types/InvocationId.java index 8b9b409b0..8422e9ce4 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/InvocationId.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/InvocationId.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; /** * This represents a stable identifier created by Restate for this invocation. It can be used as diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/RetryPolicy.java b/sdk-common/src/main/java/dev/restate/sdk/types/RetryPolicy.java similarity index 99% rename from sdk-common/src/main/java/dev/restate/sdk/common/RetryPolicy.java rename to sdk-common/src/main/java/dev/restate/sdk/types/RetryPolicy.java index 554ca53a0..b920ce99d 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/RetryPolicy.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/RetryPolicy.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; import java.time.Duration; import java.util.Objects; diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/StateKey.java b/sdk-common/src/main/java/dev/restate/sdk/types/StateKey.java similarity index 67% rename from sdk-common/src/main/java/dev/restate/sdk/common/StateKey.java rename to sdk-common/src/main/java/dev/restate/sdk/types/StateKey.java index e181eb856..946dd8b8b 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/StateKey.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/StateKey.java @@ -6,7 +6,10 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; + +import dev.restate.serde.Serde; +import dev.restate.serde.TypeTag; /** * This class holds information about state's name and its type tag to be used for serializing and @@ -17,18 +20,23 @@ public final class StateKey { private final String name; - private final Serde serde; + private final TypeTag serde; - private StateKey(String name, Serde serde) { + private StateKey(String name, TypeTag serde) { this.name = name; this.serde = serde; } /** Create a new {@link StateKey}. */ - public static StateKey of(String name, Serde serde) { + public static StateKey of(String name, TypeTag serde) { return new StateKey<>(name, serde); } + /** Create a new {@link StateKey}. */ + public static StateKey of(String name, Class clazz) { + return new StateKey<>(name, TypeTag.of(clazz)); + } + /** Create a new {@link StateKey} for bytes state. */ public static StateKey bytes(String name) { return new StateKey<>(name, Serde.RAW); @@ -38,7 +46,7 @@ public String name() { return name; } - public Serde serde() { + public TypeTag serdeInfo() { return serde; } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/TerminalException.java b/sdk-common/src/main/java/dev/restate/sdk/types/TerminalException.java similarity index 97% rename from sdk-common/src/main/java/dev/restate/sdk/common/TerminalException.java rename to sdk-common/src/main/java/dev/restate/sdk/types/TerminalException.java index 7dd5a735a..ee6a246ee 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/TerminalException.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/TerminalException.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common; +package dev.restate.sdk.types; /** When thrown in a Restate service method, it will complete the invocation with an error. */ public class TerminalException extends RuntimeException { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SuspendableCallback.java b/sdk-common/src/main/java/dev/restate/sdk/types/TimeoutException.java similarity index 67% rename from sdk-core/src/main/java/dev/restate/sdk/core/SuspendableCallback.java rename to sdk-common/src/main/java/dev/restate/sdk/types/TimeoutException.java index 8939bb265..557d8d87d 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SuspendableCallback.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/TimeoutException.java @@ -6,11 +6,11 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.types; -interface SuspendableCallback { +public class TimeoutException extends TerminalException { - void onSuspend(); - - void onError(Throwable e); + public TimeoutException(String message) { + super(409, message); + } } diff --git a/sdk-core/build.gradle.kts b/sdk-core/build.gradle.kts index bbd6be8d2..319913237 100644 --- a/sdk-core/build.gradle.kts +++ b/sdk-core/build.gradle.kts @@ -4,10 +4,11 @@ plugins { `java-library` `java-conventions` `kotlin-conventions` - `test-jar-conventions` `library-publishing-conventions` alias(libs.plugins.jsonschema2pojo) alias(libs.plugins.protobuf) + alias(libs.plugins.shadow) + alias(libs.plugins.ksp) // https://github.com/gradle/gradle/issues/20084#issuecomment-1060822638 id(libs.plugins.spotless.get().pluginId) apply false @@ -15,28 +16,56 @@ plugins { description = "Restate SDK Core" +val shade by configurations.creating +val implementation by configurations.getting + +implementation.extendsFrom(shade) + +val api by configurations.getting + +api.extendsFrom(shade) + dependencies { compileOnly(libs.jspecify) - implementation(project(":sdk-common")) + shadow(project(":sdk-common")) - implementation(libs.protobuf.java) - implementation(libs.log4j.api) + shadow(libs.log4j.api) + shadow(libs.opentelemetry.api) // We need this for the manifest - implementation(libs.jackson.annotations) - implementation(libs.jackson.databind) + shadow(libs.jackson.annotations) + shadow(libs.jackson.databind) + + // We shade protobuf java + shade(libs.protobuf.java) // We don't want a hard-dependency on it compileOnly(libs.log4j.core) - implementation(libs.opentelemetry.api) - testCompileOnly(libs.jspecify) + testAnnotationProcessor(project(":sdk-api-gen")) + kspTest(project(":sdk-api-kotlin-gen")) + testImplementation(libs.log4j.api) + testImplementation(project(":sdk-common")) + testImplementation(project(":client")) + testImplementation(project(":client-kotlin")) + testImplementation(project(":sdk-api")) + testImplementation(project(":sdk-api-kotlin")) + testImplementation(project(":sdk-http-vertx")) + testImplementation(project(":sdk-lambda")) + testImplementation(libs.jackson.annotations) + testImplementation(libs.jackson.databind) + testImplementation(libs.opentelemetry.api) + testImplementation(libs.protobuf.java) testImplementation(libs.mutiny) testImplementation(libs.junit.jupiter) testImplementation(libs.assertj) testImplementation(libs.log4j.core) + testImplementation(libs.kotlinx.coroutines.core) + testImplementation(libs.kotlinx.serialization.core) + testImplementation(libs.vertx.junit5) + testImplementation(libs.vertx.kotlin.coroutines) } // Configure source sets for protobuf plugin and jsonschema2pojo @@ -45,14 +74,14 @@ val generatedJ2SPDir = layout.buildDirectory.dir("generated/j2sp") sourceSets { main { java.srcDir(generatedJ2SPDir) - proto { srcDirs("src/main/sdk-proto", "src/main/service-protocol") } + proto { srcDirs("src/main/service-protocol") } } } // Configure jsonSchema2Pojo jsonSchema2Pojo { setSource(files("$projectDir/src/main/service-protocol/endpoint_manifest_schema.json")) - targetPackage = "dev.restate.sdk.core.manifest" + targetPackage = "dev.restate.sdk.core.generated.manifest" targetDirectory = generatedJ2SPDir.get().asFile useLongIntegers = false @@ -75,6 +104,22 @@ tasks { dependsOn(generateJsonSchema2Pojo, generateProto) } withType().configureEach { dependsOn(generateJsonSchema2Pojo, generateProto) } + + getByName("jar") { + enabled = false + dependsOn(shadowJar) + } + + shadowJar { + configurations = listOf(shade) + enableRelocation = true + archiveClassifier = null + relocate("com.google.protobuf", "dev.restate.shaded.com.google.protobuf") + dependencies { + project.configurations["shadow"].allDependencies.forEach { exclude(dependency(it)) } + exclude("**/google/protobuf/*.proto") + } + } } // spotless configuration for protobuf diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java deleted file mode 100644 index c31f3f479..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -/** State machine tracking acks */ -class AckStateMachine extends BaseSuspendableCallbackStateMachine { - - interface AckCallback extends SuspendableCallback { - void onAck(); - } - - private int lastAcknowledgedEntry = -1; - - /** -1 means no side effect waiting to be acked. */ - private int lastEntryToAck = -1; - - void waitLastAck(AckCallback callback) { - if (lastEntryIsAcked()) { - callback.onAck(); - } else { - this.setCallback(callback); - } - } - - void tryHandleAck(int entryIndex) { - this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry); - if (lastEntryIsAcked()) { - this.consumeCallback(AckCallback::onAck); - } - } - - void registerEntryToAck(int entryIndex) { - this.lastEntryToAck = Math.max(entryIndex, this.lastEntryToAck); - } - - private boolean lastEntryIsAcked() { - return this.lastEntryToAck <= this.lastAcknowledgedEntry; - } - - public int getLastEntryToAck() { - return lastEntryToAck; - } - - @Override - void abort(Throwable cause) { - super.abort(cause); - // We can't do anything else if the input stream is closed, so we just fail the callback, if any - this.tryFailCallback(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java new file mode 100644 index 000000000..e0d06c828 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/AsyncResults.java @@ -0,0 +1,353 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.function.ThrowingFunction; +import dev.restate.sdk.core.statemachine.NotificationValue; +import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.TerminalException; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.stream.Stream; + +abstract class AsyncResults { + + @FunctionalInterface + interface Completer { + void complete(NotificationValue value, CompletableFuture future); + } + + private AsyncResults() {} + + static AsyncResultInternal single( + HandlerContextInternal contextInternal, int handle, Completer completer) { + return new SingleAsyncResultInternal<>(handle, completer, contextInternal); + } + + static AsyncResultInternal any( + HandlerContextInternal contextInternal, List> any) { + return new AnyAsyncResult(contextInternal, any); + } + + static AsyncResultInternal all( + HandlerContextInternal contextInternal, List> all) { + return new AllAsyncResult(contextInternal, all); + } + + interface AsyncResultInternal extends AsyncResult { + boolean isDone(); + + void tryCancel(); + + void tryComplete(StateMachine stateMachine); + + CompletableFuture publicFuture(); + + Stream uncompletedLeaves(); + + HandlerContextInternal ctx(); + } + + abstract static class BaseAsyncResultInternal implements AsyncResultInternal { + protected final CompletableFuture publicFuture; + + BaseAsyncResultInternal(CompletableFuture publicFuture) { + this.publicFuture = publicFuture; + } + + @Override + public CompletableFuture poll() { + if (!this.isDone()) { + ctx().pollAsyncResult(this); + } + return this.publicFuture; + } + + @Override + public boolean isDone() { + return this.publicFuture.isDone(); + } + + @Override + public CompletableFuture publicFuture() { + return publicFuture; + } + + @Override + public AsyncResult map( + ThrowingFunction> successMapper, + ThrowingFunction> failureMapper) { + return new MappedSingleAsyncResultInternal<>(this, successMapper, failureMapper); + } + } + + static class SingleAsyncResultInternal extends BaseAsyncResultInternal { + + private final int handle; + private final Completer completer; + private final HandlerContextInternal contextInternal; + + private SingleAsyncResultInternal( + int handle, Completer completer, HandlerContextInternal contextInternal) { + super(new CompletableFuture<>()); + this.handle = handle; + this.completer = completer; + this.contextInternal = contextInternal; + } + + @Override + public void tryCancel() { + this.publicFuture.completeExceptionally( + new TerminalException(TerminalException.CANCELLED_CODE)); + } + + @Override + public void tryComplete(StateMachine stateMachine) { + stateMachine + .takeNotification(handle) + .ifPresent( + value -> { + try { + completer.complete(value, publicFuture); + } catch (Throwable e) { + contextInternal.fail(e); + publicFuture.completeExceptionally(AbortedExecutionException.INSTANCE); + } + }); + } + + @Override + public Stream uncompletedLeaves() { + if (publicFuture.isDone()) { + return Stream.empty(); + } + return Stream.of(handle); + } + + @Override + public HandlerContextInternal ctx() { + return this.contextInternal; + } + } + + static class MappedSingleAsyncResultInternal extends BaseAsyncResultInternal { + private final AsyncResultInternal asyncResult; + + MappedSingleAsyncResultInternal( + AsyncResultInternal asyncResult, + ThrowingFunction> successMapper, + ThrowingFunction> failureMapper) { + super(compose(asyncResult.ctx(), asyncResult.publicFuture(), successMapper, failureMapper)); + this.asyncResult = asyncResult; + } + + @Override + public boolean isDone() { + return asyncResult.isDone(); + } + + @Override + public void tryCancel() { + asyncResult.tryCancel(); + } + + @Override + public void tryComplete(StateMachine stateMachine) { + asyncResult.tryComplete(stateMachine); + } + + @Override + public Stream uncompletedLeaves() { + return asyncResult.uncompletedLeaves(); + } + + @Override + public HandlerContextInternal ctx() { + return asyncResult.ctx(); + } + + private static CompletableFuture compose( + HandlerContextInternal ctx, + CompletableFuture upstreamFuture, + ThrowingFunction> successMapper, + ThrowingFunction> failureMapper) { + CompletableFuture downstreamFuture = new CompletableFuture<>(); + + upstreamFuture.whenComplete( + (t, throwable) -> { + if (ExceptionUtils.isTerminalException(throwable)) { + // Upstream future failed with Terminal exception + if (failureMapper != null) { + try { + failureMapper + .apply((TerminalException) throwable) + .whenCompleteAsync( + (u, mapperT) -> { + if (ExceptionUtils.isTerminalException(mapperT)) { + downstreamFuture.completeExceptionally(mapperT); + } else if (mapperT != null) { + ctx.failWithoutContextSwitch(mapperT); + downstreamFuture.completeExceptionally( + AbortedExecutionException.INSTANCE); + } else { + downstreamFuture.complete(u); + } + }, + ctx.stateMachineExecutor()); + } catch (Throwable mapperT) { + if (ExceptionUtils.isTerminalException(mapperT)) { + downstreamFuture.completeExceptionally(mapperT); + } else { + ctx.failWithoutContextSwitch(mapperT); + downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); + } + } + } else { + downstreamFuture.completeExceptionally(throwable); + } + } else if (throwable != null) { + // Aborted exception/some other exception. Just propagate it through + downstreamFuture.completeExceptionally(throwable); + } else { + // Success case! + if (successMapper != null) { + try { + successMapper + .apply(t) + .whenCompleteAsync( + (u, mapperT) -> { + if (ExceptionUtils.isTerminalException(mapperT)) { + downstreamFuture.completeExceptionally(mapperT); + } else if (mapperT != null) { + ctx.failWithoutContextSwitch(mapperT); + downstreamFuture.completeExceptionally( + AbortedExecutionException.INSTANCE); + } else { + downstreamFuture.complete(u); + } + }, + ctx.stateMachineExecutor()); + } catch (Throwable mapperT) { + if (ExceptionUtils.isTerminalException(mapperT)) { + downstreamFuture.completeExceptionally(mapperT); + } else { + ctx.failWithoutContextSwitch(mapperT); + downstreamFuture.completeExceptionally(AbortedExecutionException.INSTANCE); + } + } + } else { + // Type checked by the API itself + //noinspection unchecked + downstreamFuture.complete((U) t); + } + } + }); + + return downstreamFuture; + } + } + + static class AnyAsyncResult extends BaseAsyncResultInternal { + + private final HandlerContextInternal handlerContextInternal; + private final List> asyncResults; + + AnyAsyncResult( + HandlerContextInternal handlerContextInternal, List> asyncResults) { + super(new CompletableFuture<>()); + this.handlerContextInternal = handlerContextInternal; + this.asyncResults = asyncResults; + } + + @Override + public void tryCancel() { + this.publicFuture.completeExceptionally( + new TerminalException(TerminalException.CANCELLED_CODE)); + } + + @Override + public void tryComplete(StateMachine stateMachine) { + asyncResults.forEach(ar -> ar.tryComplete(stateMachine)); + for (int i = 0; i < asyncResults.size(); i++) { + if (asyncResults.get(i).isDone()) { + publicFuture.complete(i); + return; + } + } + } + + @Override + public Stream uncompletedLeaves() { + if (isDone()) { + return Stream.empty(); + } + return asyncResults.stream().flatMap(AsyncResultInternal::uncompletedLeaves); + } + + @Override + public HandlerContextInternal ctx() { + return handlerContextInternal; + } + } + + static class AllAsyncResult extends BaseAsyncResultInternal { + + private final HandlerContextInternal handlerContextInternal; + private final List> asyncResults; + + AllAsyncResult( + HandlerContextInternal handlerContextInternal, List> asyncResults) { + super( + CompletableFuture.allOf( + asyncResults.stream() + .map(AsyncResultInternal::publicFuture) + .toArray(CompletableFuture[]::new))); + this.handlerContextInternal = handlerContextInternal; + this.asyncResults = asyncResults; + } + + @Override + public void tryCancel() { + this.publicFuture.completeExceptionally( + new TerminalException(TerminalException.CANCELLED_CODE)); + } + + @Override + public void tryComplete(StateMachine stateMachine) { + asyncResults.forEach(ar -> ar.tryComplete(stateMachine)); + asyncResults.stream() + .filter(ar -> ar.publicFuture().isCompletedExceptionally()) + .findFirst() + .ifPresent( + ar -> { + try { + ar.publicFuture().getNow(null); + } catch (CompletionException e) { + this.publicFuture.completeExceptionally(e.getCause()); + } + }); + } + + @Override + public Stream uncompletedLeaves() { + if (isDone()) { + return Stream.empty(); + } + return asyncResults.stream().flatMap(AsyncResultInternal::uncompletedLeaves); + } + + @Override + public HandlerContextInternal ctx() { + return handlerContextInternal; + } + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/BaseSuspendableCallbackStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/BaseSuspendableCallbackStateMachine.java deleted file mode 100644 index 0c767aaee..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/BaseSuspendableCallbackStateMachine.java +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import java.util.function.Consumer; - -// Implements the base logic for state machines containing suspensable callbacks. -abstract class BaseSuspendableCallbackStateMachine { - - private final CallbackHandle callbackHandle; - private final InputPublisherState inputPublisherState; - - BaseSuspendableCallbackStateMachine() { - this.callbackHandle = new CallbackHandle<>(); - this.inputPublisherState = new InputPublisherState(); - } - - void abort(Throwable cause) { - this.inputPublisherState.notifyClosed(cause); - } - - public void tryFailCallback() { - callbackHandle.consume( - cb -> { - if (inputPublisherState.isSuspended()) { - cb.onSuspend(); - } else if (inputPublisherState.isClosed()) { - cb.onError(inputPublisherState.getCloseCause()); - } - }); - } - - public void consumeCallback(Consumer consumer) { - this.callbackHandle.consume(consumer); - } - - public void consumeCallbackOrElse(Consumer consumer, Runnable elseRunnable) { - this.callbackHandle.consumeOrElse(consumer, elseRunnable); - } - - public void assertCallbackNotSet(String reason) { - if (!this.callbackHandle.isEmpty()) { - throw new IllegalStateException(reason); - } - } - - void setCallback(CB callback) { - if (inputPublisherState.isSuspended()) { - callback.onSuspend(); - } else if (inputPublisherState.isClosed()) { - callback.onError(inputPublisherState.getCloseCause()); - } else { - callbackHandle.set(callback); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/CallbackHandle.java b/sdk-core/src/main/java/dev/restate/sdk/core/CallbackHandle.java deleted file mode 100644 index 6f4fadf89..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/CallbackHandle.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import java.util.function.Consumer; -import org.jspecify.annotations.Nullable; - -/** Handle for callbacks. */ -final class CallbackHandle { - - private @Nullable T cb = null; - - public void set(T t) { - this.cb = t; - } - - public boolean isEmpty() { - return this.cb == null; - } - - public void consume(Consumer consumer) { - if (this.cb != null) { - consumer.accept(pop()); - } - } - - public void consumeOrElse(Consumer consumer, Runnable elseRunnable) { - if (this.cb != null) { - consumer.accept(pop()); - } else { - elseRunnable.run(); - } - } - - private @Nullable T pop() { - T temp = this.cb; - this.cb = null; - return temp; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/DeferredResults.java b/sdk-core/src/main/java/dev/restate/sdk/core/DeferredResults.java deleted file mode 100644 index 294bc2b4d..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/DeferredResults.java +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Result; -import java.util.*; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.jspecify.annotations.Nullable; - -abstract class DeferredResults { - - private DeferredResults() {} - - static DeferredInternal single(int entryIndex) { - return new ResolvableSingleDeferred<>(null, entryIndex); - } - - static DeferredInternal completedSingle(int entryIndex, Result result) { - return new ResolvableSingleDeferred<>(result, entryIndex); - } - - static DeferredInternal any(List> any) { - return new AnyDeferred(any); - } - - static DeferredInternal all(List> all) { - return new AllDeferred(all); - } - - interface DeferredInternal extends Deferred { - - @Nullable - @Override - Result toResult(); - - /** - * Look at the implementation of all and any for more details. - * - * @see AllDeferred#tryResolve(int) - * @see AnyDeferred#tryResolve(int) - */ - Stream> unprocessedLeafs(); - } - - interface SingleDeferredInternal extends DeferredInternal { - - int entryIndex(); - } - - private abstract static class BaseDeferred implements DeferredInternal { - - @Nullable private Result readyResult; - - BaseDeferred(@Nullable Result result) { - this.readyResult = result; - } - - @Override - public boolean isCompleted() { - return readyResult != null; - } - - public void resolve(Result result) { - this.readyResult = result; - } - - @Override - @Nullable - public Result toResult() { - return readyResult; - } - } - - static class ResolvableSingleDeferred extends BaseDeferred - implements SingleDeferredInternal { - - private final int entryIndex; - - private ResolvableSingleDeferred(@Nullable Result result, int entryIndex) { - super(result); - this.entryIndex = entryIndex; - } - - @Override - public int entryIndex() { - return entryIndex; - } - - @Override - public Stream> unprocessedLeafs() { - return Stream.of(this); - } - } - - abstract static class CombinatorDeferred extends BaseDeferred { - - // The reason to have these two data structures is to optimize the best case where we have a - // combinator with a large number of single deferred (which can be addressed by entry index), - // but little number of nested combinators (which cannot be addressed by an index, but needs to - // be iterated through). - protected final Map> unresolvedSingles; - protected final Set> unresolvedCombinators; - - CombinatorDeferred( - Map> unresolvedSingles, - Set> unresolvedCombinators) { - super(null); - - this.unresolvedSingles = unresolvedSingles; - this.unresolvedCombinators = unresolvedCombinators; - } - - /** - * This method implements the resolution logic, by trying to solve its leafs and inner - * combinator nodes. - * - *

In case the {@code newResolvedSingle} is unknown/invalid, this method will still try to - * walk through the inner combinator nodes in order to try resolve them. - * - * @return true if it's resolved, that is subsequent calls to {@link #isCompleted()} return - * true. - */ - abstract boolean tryResolve(int newResolvedSingle); - - /** Like {@link #tryResolve(int)}, but iteratively on the provided list. */ - boolean tryResolve(List resolvedSingle) { - boolean resolved = false; - for (int newResolvedSingle : resolvedSingle) { - resolved = tryResolve(newResolvedSingle); - } - return resolved; - } - - @Override - public Stream> unprocessedLeafs() { - return Stream.concat( - this.unresolvedSingles.values().stream(), - this.unresolvedCombinators.stream().flatMap(CombinatorDeferred::unprocessedLeafs)); - } - } - - static class AnyDeferred extends CombinatorDeferred implements Deferred { - - private final IdentityHashMap, Integer> indexMapping; - - private AnyDeferred(List> children) { - super( - children.stream() - .filter(d -> d instanceof SingleDeferredInternal) - .map(d -> (SingleDeferredInternal) d) - .collect(Collectors.toMap(SingleDeferredInternal::entryIndex, Function.identity())), - children.stream() - .filter(d -> d instanceof CombinatorDeferred) - .map(d -> (CombinatorDeferred) d) - .collect(Collectors.toSet())); - - // The index mapping relies on instance hashing - this.indexMapping = new IdentityHashMap<>(); - for (int i = 0; i < children.size(); i++) { - this.indexMapping.put(children.get(i), i); - } - } - - @SuppressWarnings("unchecked") - @Override - boolean tryResolve(int newResolvedSingle) { - if (this.isCompleted()) { - return true; - } - - SingleDeferredInternal resolvedSingle = this.unresolvedSingles.get(newResolvedSingle); - if (resolvedSingle != null) { - // Resolved - this.resolve(Result.success(this.indexMapping.get(resolvedSingle))); - return true; - } - - for (CombinatorDeferred combinator : this.unresolvedCombinators) { - if (combinator.tryResolve(newResolvedSingle)) { - // Resolved - this.resolve(Result.success(this.indexMapping.get(combinator))); - return true; - } - } - - return false; - } - } - - static class AllDeferred extends CombinatorDeferred { - - private AllDeferred(List> children) { - super( - children.stream() - .filter(d -> d instanceof SingleDeferredInternal) - .map(d -> (SingleDeferredInternal) d) - .collect( - Collectors.toMap( - SingleDeferredInternal::entryIndex, - Function.identity(), - (v1, v2) -> v1, - HashMap::new)), - children.stream() - .filter(d -> d instanceof CombinatorDeferred) - .map(d -> (CombinatorDeferred) d) - .collect(Collectors.toCollection(HashSet::new))); - } - - @SuppressWarnings("unchecked") - @Override - boolean tryResolve(int newResolvedSingle) { - if (this.isCompleted()) { - return true; - } - - SingleDeferredInternal resolvedSingle = this.unresolvedSingles.remove(newResolvedSingle); - if (resolvedSingle != null) { - if (!resolvedSingle.toResult().isSuccess()) { - this.resolve((Result) resolvedSingle.toResult()); - return true; - } - } - - Iterator> it = this.unresolvedCombinators.iterator(); - while (it.hasNext()) { - CombinatorDeferred combinator = it.next(); - if (combinator.tryResolve(newResolvedSingle)) { - // Resolved - it.remove(); - - if (!combinator.toResult().isSuccess()) { - this.resolve((Result) combinator.toResult()); - return true; - } - } - } - - if (this.unresolvedSingles.isEmpty() && this.unresolvedCombinators.isEmpty()) { - this.resolve(Result.empty()); - return true; - } - - return false; - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java similarity index 71% rename from sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java rename to sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java index f410a6384..c894123d0 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/DiscoveryProtocol.java @@ -13,60 +13,19 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ser.impl.SimpleBeanPropertyFilter; import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider; -import dev.restate.generated.service.discovery.Discovery; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import dev.restate.sdk.core.manifest.Handler; -import dev.restate.sdk.core.manifest.Service; +import dev.restate.sdk.core.generated.discovery.Discovery; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.Handler; +import dev.restate.sdk.core.generated.manifest.Service; import java.util.Objects; import java.util.Optional; -class ServiceProtocol { - static final Protocol.ServiceProtocolVersion MIN_SERVICE_PROTOCOL_VERSION = - Protocol.ServiceProtocolVersion.V2; - private static final Protocol.ServiceProtocolVersion MAX_SERVICE_PROTOCOL_VERSION = - Protocol.ServiceProtocolVersion.V2; - +class DiscoveryProtocol { static final Discovery.ServiceDiscoveryProtocolVersion MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION = Discovery.ServiceDiscoveryProtocolVersion.V1; static final Discovery.ServiceDiscoveryProtocolVersion MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION = Discovery.ServiceDiscoveryProtocolVersion.V2; - static Protocol.ServiceProtocolVersion parseServiceProtocolVersion(String version) { - version = version.trim(); - - if (version.equals("application/vnd.restate.invocation.v1")) { - return Protocol.ServiceProtocolVersion.V1; - } - if (version.equals("application/vnd.restate.invocation.v2")) { - return Protocol.ServiceProtocolVersion.V2; - } - return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; - } - - static String serviceProtocolVersionToHeaderValue(Protocol.ServiceProtocolVersion version) { - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V1) { - return "application/vnd.restate.invocation.v1"; - } - if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V2) { - return "application/vnd.restate.invocation.v2"; - } - throw new IllegalArgumentException( - String.format("Service protocol version '%s' has no header value", version.getNumber())); - } - - static Protocol.ServiceProtocolVersion maxServiceProtocolVersion( - boolean ignoredExperimentalContextEnabled) { - return Protocol.ServiceProtocolVersion.V2; - } - - static boolean isSupported( - Protocol.ServiceProtocolVersion serviceProtocolVersion, boolean experimentalContextEnabled) { - return MIN_SERVICE_PROTOCOL_VERSION.getNumber() <= serviceProtocolVersion.getNumber() - && serviceProtocolVersion.getNumber() - <= maxServiceProtocolVersion(experimentalContextEnabled).getNumber(); - } - static boolean isSupported( Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion) { return MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION.getNumber() diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java index 782e0c6e0..da52546c7 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java @@ -8,17 +8,17 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ServiceProtocol.*; +import static dev.restate.sdk.core.DiscoveryProtocol.MANIFEST_OBJECT_MAPPER; +import static dev.restate.sdk.core.statemachine.ServiceProtocol.MAX_SERVICE_PROTOCOL_VERSION; +import static dev.restate.sdk.core.statemachine.ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION; import com.fasterxml.jackson.core.JsonProcessingException; -import dev.restate.sdk.common.HandlerType; -import dev.restate.sdk.common.RichSerde; -import dev.restate.sdk.common.ServiceType; -import dev.restate.sdk.common.syscalls.HandlerDefinition; -import dev.restate.sdk.common.syscalls.HandlerSpecification; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.manifest.*; -import java.util.Objects; +import dev.restate.sdk.core.generated.manifest.*; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; +import dev.restate.serde.Serde; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,15 +29,14 @@ final class EndpointManifest { private final EndpointManifestSchema manifest; - public EndpointManifest( + EndpointManifest( EndpointManifestSchema.ProtocolMode protocolMode, - Stream> components, + Stream components, boolean experimentalContextEnabled) { this.manifest = new EndpointManifestSchema() .withMinProtocolVersion(MIN_SERVICE_PROTOCOL_VERSION.getNumber()) - .withMaxProtocolVersion( - maxServiceProtocolVersion(experimentalContextEnabled).getNumber()) + .withMaxProtocolVersion(MAX_SERVICE_PROTOCOL_VERSION.getNumber()) .withProtocolMode(protocolMode) .withServices( components @@ -66,7 +65,7 @@ public EndpointManifest( .collect(Collectors.toList())); } - public EndpointManifestSchema manifest() { + EndpointManifestSchema manifest() { return this.manifest; } @@ -78,16 +77,15 @@ private static Service.Ty convertServiceType(ServiceType serviceType) { }; } - private static Handler convertHandler(HandlerDefinition handler) { - HandlerSpecification spec = handler.getSpec(); + private static Handler convertHandler(HandlerDefinition handler) { return new Handler() - .withName(spec.getName()) - .withTy(convertHandlerType(spec.getHandlerType())) - .withInput(convertHandlerInput(spec)) - .withOutput(convertHandlerOutput(spec)) - .withDocumentation(spec.getDocumentation()) + .withName(handler.getName()) + .withTy(convertHandlerType(handler.getHandlerType())) + .withInput(convertHandlerInput(handler)) + .withOutput(convertHandlerOutput(handler)) + .withDocumentation(handler.getDocumentation()) .withMetadata( - spec.getMetadata().entrySet().stream() + handler.getMetadata().entrySet().stream() .reduce( new Metadata(), (meta, entry) -> meta.withAdditionalProperty(entry.getKey(), entry.getValue()), @@ -97,55 +95,52 @@ private static Handler convertHandler(HandlerDefinition handler) { })); } - private static Input convertHandlerInput(HandlerSpecification spec) { + private static Input convertHandlerInput(HandlerDefinition def) { String acceptContentType = - spec.getAcceptContentType() != null - ? spec.getAcceptContentType() - : spec.getRequestSerde().contentType(); + def.getAcceptContentType() != null + ? def.getAcceptContentType() + : def.getRequestSerde().contentType(); Input input = acceptContentType == null ? EMPTY_INPUT : new Input().withRequired(true).withContentType(acceptContentType); - if (spec.getRequestSerde() instanceof RichSerde) { - Object jsonSchema = - Objects.requireNonNull(((RichSerde) spec.getRequestSerde()).jsonSchema()); - if (jsonSchema instanceof String) { - // We need to convert it to databind JSON value - try { - jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); - } catch (JsonProcessingException e) { - throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); - } + Serde.Schema jsonSchema = def.getRequestSerde().jsonSchema(); + if (jsonSchema instanceof Serde.JsonSchema schema) { + input.setJsonSchema(schema.schema()); + } else if (jsonSchema instanceof Serde.StringifiedJsonSchema schema) { + // We need to convert it to databind JSON value + try { + input.setJsonSchema(MANIFEST_OBJECT_MAPPER.readTree(schema.schema())); + } catch (JsonProcessingException e) { + throw new RuntimeException( + "The schema generated by " + def.getRequestSerde() + " is not a valid JSON", e); } - input.setJsonSchema(jsonSchema); } return input; } - private static Output convertHandlerOutput(HandlerSpecification spec) { + private static Output convertHandlerOutput(HandlerDefinition def) { Output output = - spec.getResponseSerde().contentType() == null + def.getResponseSerde().contentType() == null ? EMPTY_OUTPUT : new Output() - .withContentType(spec.getResponseSerde().contentType()) + .withContentType(def.getResponseSerde().contentType()) .withSetContentTypeIfEmpty(false); - if (spec.getResponseSerde() instanceof RichSerde) { - Object jsonSchema = - Objects.requireNonNull(((RichSerde) spec.getResponseSerde()).jsonSchema()); - if (jsonSchema instanceof String) { - // We need to convert it to databind JSON value - try { - jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); - } catch (JsonProcessingException e) { - throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); - } + Serde.Schema jsonSchema = def.getResponseSerde().jsonSchema(); + if (jsonSchema instanceof Serde.JsonSchema schema) { + output.setJsonSchema(schema.schema()); + } else if (jsonSchema instanceof Serde.StringifiedJsonSchema schema) { + // We need to convert it to databind JSON value + try { + output.setJsonSchema(MANIFEST_OBJECT_MAPPER.readTree(schema.schema())); + } catch (JsonProcessingException e) { + throw new RuntimeException( + "The schema generated by " + def.getResponseSerde() + " is not a valid JSON", e); } - output.setJsonSchema(jsonSchema); } - return output; } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java new file mode 100644 index 000000000..5f6ed1f44 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointRequestHandler.java @@ -0,0 +1,217 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Slice; +import dev.restate.sdk.core.generated.discovery.Discovery; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.Service; +import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import io.opentelemetry.context.propagation.TextMapGetter; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.ThreadContext; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +public final class EndpointRequestHandler { + + private static final Logger LOG = LogManager.getLogger(EndpointRequestHandler.class); + private static final String DISCOVER_PATH = "/discover"; + private static final String HEALTH_PATH = "/health"; + private static final Pattern SLASH = Pattern.compile(Pattern.quote("/")); + private static final String ACCEPT = "accept"; + private static final TextMapGetter OTEL_HEADERS_GETTER = + new TextMapGetter<>() { + @Override + public Iterable keys(HeadersAccessor carrier) { + return carrier.keys(); + } + + @Nullable + @Override + public String get(@Nullable HeadersAccessor carrier, @NonNull String key) { + if (carrier == null) { + return null; + } + return carrier.get(key); + } + }; + + private final Endpoint endpoint; + private final EndpointManifest deploymentManifest; + + private EndpointRequestHandler( + EndpointManifestSchema.ProtocolMode protocolMode, Endpoint endpoint) { + this.endpoint = endpoint; + this.deploymentManifest = + new EndpointManifest( + protocolMode, + this.endpoint.getServiceDefinitions(), + this.endpoint.isExperimentalContextEnabled()); + } + + public static EndpointRequestHandler forBidiStream(Endpoint endpoint) { + return new EndpointRequestHandler(EndpointManifestSchema.ProtocolMode.BIDI_STREAM, endpoint); + } + + public static EndpointRequestHandler forRequestResponse(Endpoint endpoint) { + return new EndpointRequestHandler( + EndpointManifestSchema.ProtocolMode.REQUEST_RESPONSE, endpoint); + } + + /** + * Interface to abstract setting the logging context variables. + * + *

In classic multithreaded environments, you can just use {@link + * LoggingContextSetter#THREAD_LOCAL_INSTANCE}, though the caller of {@link + * EndpointRequestHandler} must take care of the cleanup of the thread local map. + */ + @FunctionalInterface + public interface LoggingContextSetter { + + String INVOCATION_ID_KEY = "restateInvocationId"; + String INVOCATION_TARGET_KEY = "restateInvocationTarget"; + String INVOCATION_STATUS_KEY = "restateInvocationStatus"; + + LoggingContextSetter THREAD_LOCAL_INSTANCE = ThreadContext::put; + + void set(String key, String value); + } + + /** + * @param coreExecutor This executor MUST serialize the execution of all scheduled tasks. For + * example {@link Executors#newSingleThreadExecutor()} can be used. + * @return The request processor + * @throws ProtocolException in + */ + public RequestProcessor processorForRequest( + String path, + HeadersAccessor headersAccessor, + LoggingContextSetter loggingContextSetter, + Executor coreExecutor) + throws ProtocolException { + // Discovery request + if (path.endsWith(DISCOVER_PATH)) { + return this.handleDiscoveryRequest(headersAccessor); + } + + if (path.endsWith(HEALTH_PATH)) { + return new StaticResponseRequestProcessor( + 200, + "text/plain", + Slice.wrap( + "Serving services [" + + this.endpoint + .getServiceDefinitions() + .map(ServiceDefinition::getServiceName) + .collect(Collectors.joining(", ")) + + "]")); + } + + // Parse request + String[] pathSegments = SLASH.split(path); + if (pathSegments.length < 3) { + LOG.warn( + "Path doesn't match the pattern /invoke/ServiceName/HandlerName nor /discover nor /health: '{}'", + path); + throw new ProtocolException( + "Path doesn't match the pattern /invoke/ServiceName/HandlerName nor /discover nor /health", + 404); + } + String serviceName = pathSegments[pathSegments.length - 2]; + String handlerName = pathSegments[pathSegments.length - 1]; + + String fullyQualifiedServiceMethod = serviceName + "/" + handlerName; + + // Instantiate state machine + StateMachine stateMachine = StateMachine.init(headersAccessor, loggingContextSetter); + + // Resolve the service method definition + @SuppressWarnings("unchecked") + ServiceDefinition svc = this.endpoint.resolveService(serviceName); + if (svc == null) { + throw ProtocolException.methodNotFound(serviceName, handlerName); + } + HandlerDefinition handler = svc.getHandler(handlerName); + if (handler == null) { + throw ProtocolException.methodNotFound(serviceName, handlerName); + } + + // Verify request + if (endpoint.getRequestIdentityVerifier() != null) { + try { + endpoint.getRequestIdentityVerifier().verifyRequest(headersAccessor); + } catch (Exception e) { + throw ProtocolException.unauthorized(e); + } + } + + // Parse OTEL context and generate span + final io.opentelemetry.context.Context otelContext = + this.endpoint + .getOpenTelemetry() + .getPropagators() + .getTextMapPropagator() + .extract( + io.opentelemetry.context.Context.current(), headersAccessor, OTEL_HEADERS_GETTER); + + // Generate the span + // Span span = + // tracer + // .spanBuilder("Invoke handler") + // .setSpanKind(SpanKind.SERVER) + // .setParent(otelContext) + // .startSpan(); + + // Setup logging context + loggingContextSetter.set( + LoggingContextSetter.INVOCATION_TARGET_KEY, fullyQualifiedServiceMethod); + + return new RequestProcessorImpl( + fullyQualifiedServiceMethod, + stateMachine, + handler, + otelContext, + loggingContextSetter, + coreExecutor); + } + + StaticResponseRequestProcessor handleDiscoveryRequest(HeadersAccessor headersAccessor) + throws ProtocolException { + String acceptContentType = headersAccessor.get(ACCEPT); + + Discovery.ServiceDiscoveryProtocolVersion version = + DiscoveryProtocol.selectSupportedServiceDiscoveryProtocolVersion(acceptContentType); + if (!DiscoveryProtocol.isSupported(version)) { + throw new ProtocolException( + String.format( + "Unsupported Discovery version in the Accept header '%s'", acceptContentType), + ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); + } + + EndpointManifestSchema response = this.deploymentManifest.manifest(); + LOG.info( + "Replying to discovery request with services [{}]", + response.getServices().stream().map(Service::getName).collect(Collectors.joining(","))); + + return new StaticResponseRequestProcessor( + 200, + DiscoveryProtocol.serviceDiscoveryProtocolVersionToHeaderValue(version), + Slice.wrap(DiscoveryProtocol.serializeManifest(version, response))); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java deleted file mode 100644 index 69d2cade9..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java +++ /dev/null @@ -1,728 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.MessageLite; -import com.google.protobuf.UnsafeByteOperations; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.generated.service.protocol.Protocol.*; -import dev.restate.sdk.common.syscalls.Result; -import io.opentelemetry.api.common.Attributes; -import io.opentelemetry.api.trace.Span; -import java.nio.ByteBuffer; -import java.util.Collection; -import java.util.Objects; -import java.util.function.Function; -import java.util.stream.Collectors; - -final class Entries { - static final String AWAKEABLE_IDENTIFIER_PREFIX = "prom_1"; - - private Entries() {} - - abstract static class JournalEntry { - abstract String getName(E expected); - - void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException {} - - abstract void trace(E expected, Span span); - - void updateUserStateStoreWithEntry(E expected, UserStateStore userStateStore) {} - } - - abstract static class CompletableJournalEntry extends JournalEntry { - abstract boolean hasResult(E actual); - - abstract Result parseEntryResult(E actual); - - Result parseCompletionResult(CompletionMessage actual) { - throw ProtocolException.completionDoesNotMatch( - this.getClass().getName(), actual.getResultCase()); - } - - E tryCompleteWithUserStateStorage(E expected, UserStateStore userStateStore) { - return expected; - } - - void updateUserStateStorageWithCompletion( - E expected, CompletionMessage actual, UserStateStore userStateStore) {} - } - - static final class OutputEntry extends JournalEntry { - - static final OutputEntry INSTANCE = new OutputEntry(); - - private OutputEntry() {} - - @Override - String getName(OutputEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(OutputEntryMessage expected, MessageLite actual) - throws ProtocolException { - Util.assertEntryEquals(expected, actual); - } - - @Override - public void trace(OutputEntryMessage expected, Span span) { - span.addEvent("Output"); - } - } - - static final class GetStateEntry - extends CompletableJournalEntry { - - static final GetStateEntry INSTANCE = new GetStateEntry(); - - private GetStateEntry() {} - - @Override - void trace(GetStateEntryMessage expected, Span span) { - span.addEvent( - "GetState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); - } - - @Override - public boolean hasResult(GetStateEntryMessage actual) { - return actual.getResultCase() != GetStateEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - String getName(GetStateEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof GetStateEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected.getKey().equals(((GetStateEntryMessage) actual).getKey())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - public Result parseEntryResult(GetStateEntryMessage actual) { - if (actual.getResultCase() == GetStateEntryMessage.ResultCase.VALUE) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.EMPTY) { - return Result.empty(); - } else { - throw new IllegalStateException("GetStateEntry has not been completed."); - } - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { - return Result.empty(); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - - @Override - void updateUserStateStoreWithEntry( - GetStateEntryMessage expected, UserStateStore userStateStore) { - if (expected.hasEmpty()) { - userStateStore.clear(expected.getKey()); - } else { - userStateStore.set(expected.getKey(), expected.getValue().asReadOnlyByteBuffer()); - } - } - - @Override - GetStateEntryMessage tryCompleteWithUserStateStorage( - GetStateEntryMessage expected, UserStateStore userStateStore) { - UserStateStore.State value = userStateStore.get(expected.getKey()); - if (value instanceof UserStateStore.Value) { - return expected.toBuilder() - .setValue(UnsafeByteOperations.unsafeWrap(((UserStateStore.Value) value).getValue())) - .build(); - } else if (value instanceof UserStateStore.Empty) { - return expected.toBuilder().setEmpty(Empty.getDefaultInstance()).build(); - } - return expected; - } - - @Override - void updateUserStateStorageWithCompletion( - GetStateEntryMessage expected, CompletionMessage actual, UserStateStore userStateStore) { - if (actual.hasEmpty()) { - userStateStore.clear(expected.getKey()); - } else { - userStateStore.set(expected.getKey(), actual.getValue().asReadOnlyByteBuffer()); - } - } - } - - static final class GetStateKeysEntry - extends CompletableJournalEntry> { - - static final GetStateKeysEntry INSTANCE = new GetStateKeysEntry(); - - private GetStateKeysEntry() {} - - @Override - void trace(GetStateKeysEntryMessage expected, Span span) { - span.addEvent("GetStateKeys"); - } - - @Override - public boolean hasResult(GetStateKeysEntryMessage actual) { - return actual.getResultCase() != GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - String getName(GetStateKeysEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(GetStateKeysEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof GetStateKeysEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - public Result> parseEntryResult(GetStateKeysEntryMessage actual) { - if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.VALUE) { - return Result.success( - actual.getValue().getKeysList().stream() - .map(ByteString::toStringUtf8) - .collect(Collectors.toUnmodifiableList())); - } else if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } else { - throw new IllegalStateException("GetStateKeysEntryMessage has not been completed."); - } - } - - @Override - public Result> parseCompletionResult(CompletionMessage actual) { - if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - GetStateKeysEntryMessage.StateKeys stateKeys; - try { - stateKeys = GetStateKeysEntryMessage.StateKeys.parseFrom(actual.getValue()); - } catch (InvalidProtocolBufferException e) { - throw new ProtocolException( - "Cannot parse get state keys completion", - ProtocolException.PROTOCOL_VIOLATION_CODE, - e); - } - return Result.success( - stateKeys.getKeysList().stream() - .map(ByteString::toStringUtf8) - .collect(Collectors.toUnmodifiableList())); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - - @Override - GetStateKeysEntryMessage tryCompleteWithUserStateStorage( - GetStateKeysEntryMessage expected, UserStateStore userStateStore) { - if (userStateStore.isComplete()) { - return expected.toBuilder() - .setValue( - GetStateKeysEntryMessage.StateKeys.newBuilder().addAllKeys(userStateStore.keys())) - .build(); - } - return expected; - } - } - - static final class ClearStateEntry extends JournalEntry { - - static final ClearStateEntry INSTANCE = new ClearStateEntry(); - - private ClearStateEntry() {} - - @Override - public void trace(ClearStateEntryMessage expected, Span span) { - span.addEvent( - "ClearState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); - } - - @Override - String getName(ClearStateEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(ClearStateEntryMessage expected, MessageLite actual) - throws ProtocolException { - Util.assertEntryEquals(expected, actual); - } - - @Override - void updateUserStateStoreWithEntry( - ClearStateEntryMessage expected, UserStateStore userStateStore) { - userStateStore.clear(expected.getKey()); - } - } - - static final class ClearAllStateEntry extends JournalEntry { - - static final ClearAllStateEntry INSTANCE = new ClearAllStateEntry(); - - private ClearAllStateEntry() {} - - @Override - public void trace(ClearAllStateEntryMessage expected, Span span) { - span.addEvent("ClearAllState"); - } - - @Override - String getName(ClearAllStateEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(ClearAllStateEntryMessage expected, MessageLite actual) - throws ProtocolException { - Util.assertEntryEquals(expected, actual); - } - - @Override - void updateUserStateStoreWithEntry( - ClearAllStateEntryMessage expected, UserStateStore userStateStore) { - userStateStore.clearAll(); - } - } - - static final class SetStateEntry extends JournalEntry { - - static final SetStateEntry INSTANCE = new SetStateEntry(); - - private SetStateEntry() {} - - @Override - public void trace(SetStateEntryMessage expected, Span span) { - span.addEvent( - "SetState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); - } - - @Override - String getName(SetStateEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof SetStateEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected.getKey().equals(((SetStateEntryMessage) actual).getKey())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - void updateUserStateStoreWithEntry( - SetStateEntryMessage expected, UserStateStore userStateStore) { - userStateStore.set(expected.getKey(), expected.getValue().asReadOnlyByteBuffer()); - } - } - - static final class SleepEntry extends CompletableJournalEntry { - - static final SleepEntry INSTANCE = new SleepEntry(); - - private SleepEntry() {} - - @Override - String getName(SleepEntryMessage expected) { - return expected.getName(); - } - - @Override - void trace(SleepEntryMessage expected, Span span) { - span.addEvent( - "Sleep", Attributes.of(Tracing.RESTATE_SLEEP_WAKE_UP_TIME, expected.getWakeUpTime())); - } - - @Override - public boolean hasResult(SleepEntryMessage actual) { - return actual.getResultCase() != Protocol.SleepEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - public Result parseEntryResult(SleepEntryMessage actual) { - if (actual.getResultCase() == SleepEntryMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } else if (actual.getResultCase() == SleepEntryMessage.ResultCase.EMPTY) { - return Result.empty(); - } else { - throw new IllegalStateException("SleepEntry has not been completed."); - } - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { - return Result.empty(); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class InvokeEntry extends CompletableJournalEntry { - - private final Function> valueParser; - - InvokeEntry(Function> valueParser) { - this.valueParser = valueParser; - } - - @Override - void trace(CallEntryMessage expected, Span span) { - span.addEvent( - "Invoke", - Attributes.of( - Tracing.RESTATE_COORDINATION_CALL_SERVICE, - expected.getServiceName(), - Tracing.RESTATE_COORDINATION_CALL_METHOD, - expected.getHandlerName())); - } - - @Override - public boolean hasResult(CallEntryMessage actual) { - return actual.getResultCase() != Protocol.CallEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - String getName(CallEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(CallEntryMessage expected, MessageLite actual) throws ProtocolException { - if (!(actual instanceof CallEntryMessage actualInvoke)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - - if (!(Objects.equals(expected.getServiceName(), actualInvoke.getServiceName()) - && Objects.equals(expected.getHandlerName(), actualInvoke.getHandlerName()) - && Objects.equals(expected.getParameter(), actualInvoke.getParameter()) - && Objects.equals(expected.getKey(), actualInvoke.getKey()))) { - throw ProtocolException.entryDoesNotMatch(expected, actualInvoke); - } - } - - @Override - public Result parseEntryResult(CallEntryMessage actual) { - if (actual.hasValue()) { - return valueParser.apply(actual.getValue().asReadOnlyByteBuffer()); - } - return Result.failure(Util.toRestateException(actual.getFailure())); - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.hasValue()) { - return valueParser.apply(actual.getValue().asReadOnlyByteBuffer()); - } - if (actual.hasFailure()) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class OneWayCallEntry extends JournalEntry { - - static final OneWayCallEntry INSTANCE = new OneWayCallEntry(); - - private OneWayCallEntry() {} - - @Override - public void trace(OneWayCallEntryMessage expected, Span span) { - span.addEvent( - "BackgroundInvoke", - Attributes.of( - Tracing.RESTATE_COORDINATION_CALL_SERVICE, - expected.getServiceName(), - Tracing.RESTATE_COORDINATION_CALL_METHOD, - expected.getHandlerName())); - } - - @Override - String getName(OneWayCallEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(OneWayCallEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof OneWayCallEntryMessage actualInvoke)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - - if (!(Objects.equals(expected.getServiceName(), actualInvoke.getServiceName()) - && Objects.equals(expected.getHandlerName(), actualInvoke.getHandlerName()) - && Objects.equals(expected.getParameter(), actualInvoke.getParameter()) - && Objects.equals(expected.getKey(), actualInvoke.getKey()))) { - throw ProtocolException.entryDoesNotMatch(expected, actualInvoke); - } - } - } - - static final class AwakeableEntry - extends CompletableJournalEntry { - static final AwakeableEntry INSTANCE = new AwakeableEntry(); - - private AwakeableEntry() {} - - @Override - String getName(AwakeableEntryMessage expected) { - return expected.getName(); - } - - @Override - void trace(AwakeableEntryMessage expected, Span span) { - span.addEvent("Awakeable"); - } - - @Override - public boolean hasResult(AwakeableEntryMessage actual) { - return actual.getResultCase() != Protocol.AwakeableEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - public Result parseEntryResult(AwakeableEntryMessage actual) { - if (actual.hasValue()) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } - return Result.failure(Util.toRestateException(actual.getFailure())); - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.hasValue()) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } - if (actual.hasFailure()) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class GetPromiseEntry - extends CompletableJournalEntry { - static final GetPromiseEntry INSTANCE = new GetPromiseEntry(); - - private GetPromiseEntry() {} - - @Override - String getName(GetPromiseEntryMessage expected) { - return expected.getName(); - } - - @Override - void trace(GetPromiseEntryMessage expected, Span span) { - span.addEvent("Promise"); - } - - @Override - void checkEntryHeader(GetPromiseEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof GetPromiseEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected.getKey().equals(((GetPromiseEntryMessage) actual).getKey())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - public boolean hasResult(GetPromiseEntryMessage actual) { - return actual.getResultCase() != Protocol.GetPromiseEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - public Result parseEntryResult(GetPromiseEntryMessage actual) { - if (actual.hasValue()) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } - return Result.failure(Util.toRestateException(actual.getFailure())); - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.hasValue()) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } - if (actual.hasFailure()) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class PeekPromiseEntry - extends CompletableJournalEntry { - static final PeekPromiseEntry INSTANCE = new PeekPromiseEntry(); - - private PeekPromiseEntry() {} - - @Override - String getName(PeekPromiseEntryMessage expected) { - return expected.getName(); - } - - @Override - void trace(PeekPromiseEntryMessage expected, Span span) { - span.addEvent("PeekPromise"); - } - - @Override - void checkEntryHeader(PeekPromiseEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof PeekPromiseEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected.getKey().equals(((PeekPromiseEntryMessage) actual).getKey())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - public boolean hasResult(PeekPromiseEntryMessage actual) { - return actual.getResultCase() != Protocol.PeekPromiseEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - public Result parseEntryResult(PeekPromiseEntryMessage actual) { - if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.VALUE) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } else if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } else if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.EMPTY) { - return Result.empty(); - } else { - throw new IllegalStateException("PeekPromiseEntry has not been completed."); - } - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - return Result.success(actual.getValue().asReadOnlyByteBuffer()); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { - return Result.empty(); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class CompletePromiseEntry - extends CompletableJournalEntry { - - static final CompletePromiseEntry INSTANCE = new CompletePromiseEntry(); - - private CompletePromiseEntry() {} - - @Override - String getName(CompletePromiseEntryMessage expected) { - return expected.getName(); - } - - @Override - void trace(CompletePromiseEntryMessage expected, Span span) { - span.addEvent("CompletePromise"); - } - - @Override - void checkEntryHeader(CompletePromiseEntryMessage expected, MessageLite actual) - throws ProtocolException { - if (!(actual instanceof CompletePromiseEntryMessage)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected.getKey().equals(((CompletePromiseEntryMessage) actual).getKey())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - if (!expected - .getCompletionCase() - .equals(((CompletePromiseEntryMessage) actual).getCompletionCase())) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - @Override - public boolean hasResult(CompletePromiseEntryMessage actual) { - return actual.getResultCase() - != Protocol.CompletePromiseEntryMessage.ResultCase.RESULT_NOT_SET; - } - - @Override - public Result parseEntryResult(CompletePromiseEntryMessage actual) { - if (actual.getResultCase() == CompletePromiseEntryMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } else if (actual.getResultCase() == CompletePromiseEntryMessage.ResultCase.EMPTY) { - return Result.empty(); - } else { - throw new IllegalStateException("CompletePromiseEntry has not been completed."); - } - } - - @Override - public Result parseCompletionResult(CompletionMessage actual) { - if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { - return Result.empty(); - } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { - return Result.failure(Util.toRestateException(actual.getFailure())); - } - return super.parseCompletionResult(actual); - } - } - - static final class CompleteAwakeableEntry extends JournalEntry { - - static final CompleteAwakeableEntry INSTANCE = new CompleteAwakeableEntry(); - - private CompleteAwakeableEntry() {} - - @Override - public void trace(CompleteAwakeableEntryMessage expected, Span span) { - span.addEvent("CompleteAwakeable"); - } - - @Override - String getName(CompleteAwakeableEntryMessage expected) { - return expected.getName(); - } - - @Override - void checkEntryHeader(CompleteAwakeableEntryMessage expected, MessageLite actual) - throws ProtocolException { - Util.assertEntryEquals(expected, actual); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionCatchingSubscriber.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionCatchingSubscriber.java deleted file mode 100644 index e59f533b9..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionCatchingSubscriber.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import java.util.concurrent.Flow; - -class ExceptionCatchingSubscriber implements Flow.Subscriber { - - final Flow.Subscriber invocationInputSubscriber; - - public ExceptionCatchingSubscriber(Flow.Subscriber invocationInputSubscriber) { - this.invocationInputSubscriber = invocationInputSubscriber; - } - - @Override - public void onSubscribe(Flow.Subscription subscription) { - try { - invocationInputSubscriber.onSubscribe(subscription); - } catch (Throwable throwable) { - invocationInputSubscriber.onError(throwable); - throw throwable; - } - } - - @Override - public void onNext(T t) { - try { - invocationInputSubscriber.onNext(t); - } catch (Throwable throwable) { - invocationInputSubscriber.onError(throwable); - throw throwable; - } - } - - @Override - public void onError(Throwable throwable) { - invocationInputSubscriber.onError(throwable); - } - - @Override - public void onComplete() { - invocationInputSubscriber.onComplete(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java new file mode 100644 index 000000000..99ad04a93 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExceptionUtils.java @@ -0,0 +1,63 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.TerminalException; +import java.util.Optional; +import java.util.function.Predicate; + +public final class ExceptionUtils { + private ExceptionUtils() {} + + @SuppressWarnings("unchecked") + public static void sneakyThrow(Throwable e) throws E { + throw (E) e; + } + + /** + * Finds a throwable fulfilling the condition in the cause chain of the given throwable. If there + * is none, then the method returns an empty optional. + * + * @param throwable to check for the given condition + * @param condition condition that a cause needs to fulfill + * @return Some cause that fulfills the condition; otherwise an empty optional + */ + @SuppressWarnings("unchecked") + static Optional findCause( + Throwable throwable, Predicate condition) { + Throwable currentThrowable = throwable; + + while (currentThrowable != null) { + if (condition.test(currentThrowable)) { + return (Optional) Optional.of(currentThrowable); + } + + if (currentThrowable == currentThrowable.getCause()) { + break; + } else { + currentThrowable = currentThrowable.getCause(); + } + } + + return Optional.empty(); + } + + public static Optional findProtocolException(Throwable throwable) { + return findCause(throwable, t -> t instanceof ProtocolException); + } + + public static boolean containsSuspendedException(Throwable throwable) { + return findCause(throwable, t -> t == AbortedExecutionException.INSTANCE).isPresent(); + } + + public static boolean isTerminalException(Throwable throwable) { + return throwable instanceof TerminalException; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java new file mode 100644 index 000000000..ab1cdbd81 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingHandlerContextImpl.java @@ -0,0 +1,191 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Output; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.types.*; +import io.opentelemetry.context.Context; +import java.time.Duration; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.function.Consumer; +import java.util.function.Function; +import org.jspecify.annotations.Nullable; + +final class ExecutorSwitchingHandlerContextImpl extends HandlerContextImpl { + + private final Executor coreExecutor; + + ExecutorSwitchingHandlerContextImpl( + String fullyQualifiedHandlerName, + StateMachine stateMachine, + Context otelContext, + StateMachine.Input input, + Executor coreExecutor) { + super(fullyQualifiedHandlerName, stateMachine, otelContext, input); + this.coreExecutor = coreExecutor; + } + + @Override + public CompletableFuture>> get(String name) { + return CompletableFuture.supplyAsync(() -> super.get(name), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture>> getKeys() { + return CompletableFuture.supplyAsync(super::getKeys, coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture clear(String name) { + return CompletableFuture.supplyAsync(() -> super.clear(name), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture clearAll() { + return CompletableFuture.supplyAsync(super::clearAll, coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture set(String name, Slice value) { + return CompletableFuture.supplyAsync(() -> super.set(name, value), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> timer(Duration duration, String name) { + return CompletableFuture.supplyAsync(() -> super.timer(duration, name), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture call( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers) { + return CompletableFuture.supplyAsync( + () -> super.call(target, parameter, idempotencyKey, headers), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> send( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers, + @Nullable Duration delay) { + return CompletableFuture.supplyAsync( + () -> super.send(target, parameter, idempotencyKey, headers, delay), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> submitRun( + @Nullable String name, Consumer closure) { + return CompletableFuture.supplyAsync(() -> super.submitRun(name, closure), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture awakeable() { + return CompletableFuture.supplyAsync(super::awakeable, coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture resolveAwakeable(String id, Slice payload) { + return CompletableFuture.supplyAsync(() -> super.resolveAwakeable(id, payload), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture rejectAwakeable(String id, TerminalException reason) { + return CompletableFuture.supplyAsync(() -> super.rejectAwakeable(id, reason), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> promise(String key) { + return CompletableFuture.supplyAsync(() -> super.promise(key), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture>> peekPromise(String key) { + return CompletableFuture.supplyAsync(() -> super.peekPromise(key), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> resolvePromise(String key, Slice payload) { + return CompletableFuture.supplyAsync(() -> super.resolvePromise(key, payload), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture> rejectPromise(String key, TerminalException reason) { + return CompletableFuture.supplyAsync(() -> super.rejectPromise(key, reason), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public void proposeRunSuccess(int runHandle, Slice toWrite) { + coreExecutor.execute(() -> super.proposeRunSuccess(runHandle, toWrite)); + } + + @Override + public void proposeRunFailure( + int runHandle, + Throwable toWrite, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy) { + coreExecutor.execute( + () -> super.proposeRunFailure(runHandle, toWrite, attemptDuration, retryPolicy)); + } + + @Override + public void pollAsyncResult(AsyncResults.AsyncResultInternal asyncResult) { + coreExecutor.execute(() -> super.pollAsyncResult(asyncResult)); + } + + @Override + public CompletableFuture writeOutput(Slice value) { + return CompletableFuture.supplyAsync(() -> super.writeOutput(value), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public CompletableFuture writeOutput(TerminalException throwable) { + return CompletableFuture.supplyAsync(() -> super.writeOutput(throwable), coreExecutor) + .thenCompose(Function.identity()); + } + + @Override + public void close() { + coreExecutor.execute(super::close); + } + + @Override + public void fail(Throwable cause) { + coreExecutor.execute(() -> super.fail(cause)); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java deleted file mode 100644 index cffc78c46..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.sdk.common.Request; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.Target; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback; -import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback; -import dev.restate.sdk.common.syscalls.SyscallCallback; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.Collection; -import java.util.Map; -import java.util.concurrent.Executor; -import org.jspecify.annotations.Nullable; - -class ExecutorSwitchingSyscalls implements SyscallsInternal { - - private final SyscallsInternal syscalls; - private final Executor syscallsExecutor; - - ExecutorSwitchingSyscalls(SyscallsInternal syscalls, Executor syscallsExecutor) { - this.syscalls = syscalls; - this.syscallsExecutor = syscallsExecutor; - } - - @Override - public void writeOutput(ByteBuffer value, SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.writeOutput(value, callback)); - } - - @Override - public void writeOutput(TerminalException throwable, SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.writeOutput(throwable, callback)); - } - - @Override - public void get(String name, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.get(name, callback)); - } - - @Override - public void getKeys(SyscallCallback>> callback) { - syscallsExecutor.execute(() -> syscalls.getKeys(callback)); - } - - @Override - public void clear(String name, SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.clear(name, callback)); - } - - @Override - public void clearAll(SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.clearAll(callback)); - } - - @Override - public void set(String name, ByteBuffer value, SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.set(name, value, callback)); - } - - @Override - public void sleep(Duration duration, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.sleep(duration, callback)); - } - - @Override - public void call( - Target target, ByteBuffer parameter, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.call(target, parameter, callback)); - } - - @Override - public void send( - Target target, - ByteBuffer parameter, - @Nullable Duration delay, - SyscallCallback requestCallback) { - syscallsExecutor.execute(() -> syscalls.send(target, parameter, delay, requestCallback)); - } - - @Override - public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.enterSideEffectBlock(name, callback)); - } - - @Override - public void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.exitSideEffectBlock(toWrite, callback)); - } - - @Override - public void exitSideEffectBlockWithTerminalException( - TerminalException toWrite, ExitSideEffectSyscallCallback callback) { - syscallsExecutor.execute( - () -> syscalls.exitSideEffectBlockWithTerminalException(toWrite, callback)); - } - - @Override - public void exitSideEffectBlockWithException( - Throwable toWrite, - @Nullable RetryPolicy retryPolicy, - ExitSideEffectSyscallCallback callback) { - syscallsExecutor.execute( - () -> syscalls.exitSideEffectBlockWithException(toWrite, retryPolicy, callback)); - } - - @Override - public void awakeable(SyscallCallback>> callback) { - syscallsExecutor.execute(() -> syscalls.awakeable(callback)); - } - - @Override - public void resolveAwakeable( - String id, ByteBuffer payload, SyscallCallback requestCallback) { - syscallsExecutor.execute(() -> syscalls.resolveAwakeable(id, payload, requestCallback)); - } - - @Override - public void rejectAwakeable(String id, String reason, SyscallCallback requestCallback) { - syscallsExecutor.execute(() -> syscalls.rejectAwakeable(id, reason, requestCallback)); - } - - @Override - public void promise(String key, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.promise(key, callback)); - } - - @Override - public void peekPromise(String key, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.peekPromise(key, callback)); - } - - @Override - public void resolvePromise( - String key, ByteBuffer payload, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.resolvePromise(key, payload, callback)); - } - - @Override - public void rejectPromise(String key, String reason, SyscallCallback> callback) { - syscallsExecutor.execute(() -> syscalls.rejectPromise(key, reason, callback)); - } - - @Override - public void resolveDeferred(Deferred deferredToResolve, SyscallCallback callback) { - syscallsExecutor.execute(() -> syscalls.resolveDeferred(deferredToResolve, callback)); - } - - @Override - public String getFullyQualifiedMethodName() { - // We can read this from another thread - return syscalls.getFullyQualifiedMethodName(); - } - - @Override - public InvocationState getInvocationState() { - // We can read this from another thread - return syscalls.getInvocationState(); - } - - @Override - public String objectKey() { - // This is immutable once set - return syscalls.objectKey(); - } - - @Override - public Request request() { - // This is immutable once set - return syscalls.request(); - } - - @Override - public boolean isInsideSideEffect() { - // We can read this from another thread - return syscalls.isInsideSideEffect(); - } - - @Override - public void close() { - syscallsExecutor.execute(syscalls::close); - } - - @Override - public void fail(Throwable cause) { - syscallsExecutor.execute(() -> syscalls.fail(cause)); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java new file mode 100644 index 000000000..529491604 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -0,0 +1,519 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Output; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.common.function.ThrowingRunnable; +import dev.restate.common.function.ThrowingSupplier; +import dev.restate.sdk.core.AsyncResults.AsyncResultInternal; +import dev.restate.sdk.core.statemachine.InvocationState; +import dev.restate.sdk.core.statemachine.NotificationValue; +import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.types.*; +import io.opentelemetry.context.Context; +import java.time.Duration; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +class HandlerContextImpl implements HandlerContextInternal { + + private static final Logger LOG = LogManager.getLogger(HandlerContextImpl.class); + + private static final int CANCEL_HANDLE = 1; + + private final HandlerRequest handlerRequest; + private final StateMachine stateMachine; + private final @Nullable String objectKey; + private final String fullyQualifiedHandlerName; + + private CompletableFuture nextProcessedRun; + private final List> invocationIdsToCancel; + private final HashMap> scheduledRuns; + + HandlerContextImpl( + String fullyQualifiedHandlerName, + StateMachine stateMachine, + Context otelContext, + StateMachine.Input input) { + this.handlerRequest = + new HandlerRequest(input.invocationId(), otelContext, input.body(), input.headers()); + this.objectKey = input.key(); + this.stateMachine = stateMachine; + this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; + this.invocationIdsToCancel = new ArrayList<>(); + this.scheduledRuns = new HashMap<>(); + } + + private static void parseSuccessOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Success success) { + cf.complete(success.slice()); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + + private static void parseEmptyOrSuccessOrFailure( + NotificationValue s, CompletableFuture> cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(Output.notReady()); + } else if (s instanceof NotificationValue.Success success) { + cf.complete(Output.ready(success.slice())); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + + private static void parseEmptyOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(null); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + + @Override + public String objectKey() { + return this.objectKey; + } + + @Override + public HandlerRequest request() { + return this.handlerRequest; + } + + @Override + public String getFullyQualifiedMethodName() { + return this.fullyQualifiedHandlerName; + } + + @Override + public InvocationState getInvocationState() { + return this.stateMachine.state(); + } + + @Override + public Executor stateMachineExecutor() { + return Runnable::run; + } + + @Override + public CompletableFuture>> get(String name) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.stateGet(name), + (s, cf) -> { + if (s instanceof NotificationValue.Empty) { + cf.complete(Optional.empty()); + } else if (s instanceof NotificationValue.Success success) { + cf.complete(Optional.of(success.slice())); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + })); + } + + @Override + public CompletableFuture>> getKeys() { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.stateGetKeys(), + (s, cf) -> { + if (s instanceof NotificationValue.StateKeys stateKeys) { + cf.complete(stateKeys.stateKeys()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + })); + } + + @Override + public CompletableFuture clear(String name) { + return this.catchExceptions(() -> this.stateMachine.stateClear(name)); + } + + @Override + public CompletableFuture clearAll() { + return this.catchExceptions(this.stateMachine::stateClearAll); + } + + @Override + public CompletableFuture set(String name, Slice value) { + return this.catchExceptions(() -> this.stateMachine.stateSet(name, value)); + } + + @Override + public CompletableFuture> timer(Duration duration, String name) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.sleep(duration, name), + (s, cf) -> { + if (s instanceof NotificationValue.Empty) { + cf.complete(null); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + })); + } + + @Override + public CompletableFuture call( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers) { + return catchExceptions( + () -> { + StateMachine.CallHandle callHandle = + this.stateMachine.call(target, parameter, idempotencyKey, headers); + + AsyncResultInternal invocationIdAsyncResult = + AsyncResults.single(this, callHandle.invocationIdHandle(), invocationIdCompleter()); + this.invocationIdsToCancel.add(invocationIdAsyncResult); + + AsyncResult callAsyncResult = + AsyncResults.single( + this, callHandle.resultHandle(), HandlerContextImpl::parseSuccessOrFailure); + + return new CallResult(invocationIdAsyncResult, callAsyncResult); + }); + } + + @Override + public CompletableFuture> send( + Target target, + Slice parameter, + @Nullable String idempotencyKey, + @Nullable Collection> headers, + @Nullable Duration delay) { + return catchExceptions( + () -> { + int sendHandle = + this.stateMachine.send(target, parameter, idempotencyKey, headers, delay); + + AsyncResultInternal invocationIdAsyncResult = + AsyncResults.single(this, sendHandle, invocationIdCompleter()); + this.invocationIdsToCancel.add(invocationIdAsyncResult); + + return invocationIdAsyncResult; + }); + } + + private static AsyncResults.Completer invocationIdCompleter() { + return (s, cf) -> { + if (s instanceof NotificationValue.InvocationId invocationId) { + cf.complete(invocationId.invocationId()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + }; + } + + @Override + public CompletableFuture> submitRun( + @Nullable String name, Consumer closure) { + return catchExceptions( + () -> { + int runHandle = this.stateMachine.run(name); + this.scheduledRuns.put(runHandle, closure); + return AsyncResults.single(this, runHandle, HandlerContextImpl::parseSuccessOrFailure); + }); + } + + @Override + public CompletableFuture awakeable() { + return catchExceptions( + () -> { + StateMachine.Awakeable awakeable = this.stateMachine.awakeable(); + return new Awakeable( + awakeable.awakeableId(), + AsyncResults.single( + this, awakeable.handle(), HandlerContextImpl::parseSuccessOrFailure)); + }); + } + + @Override + public CompletableFuture resolveAwakeable(String id, Slice payload) { + return this.catchExceptions(() -> this.stateMachine.completeAwakeable(id, payload)); + } + + @Override + public CompletableFuture rejectAwakeable(String id, TerminalException reason) { + return this.catchExceptions(() -> this.stateMachine.completeAwakeable(id, reason)); + } + + @Override + public CompletableFuture> promise(String key) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.promiseGet(key), + HandlerContextImpl::parseSuccessOrFailure)); + } + + @Override + public CompletableFuture>> peekPromise(String key) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.promisePeek(key), + HandlerContextImpl::parseEmptyOrSuccessOrFailure)); + } + + @Override + public CompletableFuture> resolvePromise(String key, Slice payload) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.promiseComplete(key, payload), + HandlerContextImpl::parseEmptyOrFailure)); + } + + @Override + public CompletableFuture> rejectPromise(String key, TerminalException reason) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.promiseComplete(key, reason), + HandlerContextImpl::parseEmptyOrFailure)); + } + + @Override + public CompletableFuture cancelInvocation(String invocationId) { + return this.catchExceptions(() -> this.stateMachine.cancelInvocation(invocationId)); + } + + @Override + public CompletableFuture> attachInvocation(String invocationId) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.attachInvocation(invocationId), + HandlerContextImpl::parseSuccessOrFailure)); + } + + @Override + public CompletableFuture>> getInvocationOutput(String invocationId) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.getInvocationOutput(invocationId), + HandlerContextImpl::parseEmptyOrSuccessOrFailure)); + } + + @Override + public CompletableFuture writeOutput(Slice value) { + return this.catchExceptions(() -> this.stateMachine.writeOutput(value)); + } + + @Override + public CompletableFuture writeOutput(TerminalException throwable) { + return this.catchExceptions(() -> this.stateMachine.writeOutput(throwable)); + } + + @Override + public void pollAsyncResult(AsyncResultInternal asyncResult) { + // We use the separate function for the recursion, + // as there's no need to jump back and forth between threads again. + this.pollAsyncResultInner(asyncResult); + } + + private void pollAsyncResultInner(AsyncResultInternal asyncResult) { + while (true) { + if (asyncResult.isDone()) { + return; + } + + // Let's look for the cancellation notification + var cancellationNotification = this.stateMachine.takeNotification(CANCEL_HANDLE); + if (cancellationNotification.isPresent()) { + LOG.info("Detected cancellation signal! Will start cancelling child invocations"); + + // Let's wait to cancel all + @SuppressWarnings({"rawtypes", "unchecked"}) + AsyncResultInternal allInvocationIds = + AsyncResults.all(this, (List) this.invocationIdsToCancel); + allInvocationIds + .publicFuture() + .whenComplete( + (ignored, throwable) -> { + if (throwable != null) { + // Already handled + return; + } + LOG.info("All child invocation ids retrieved"); + try { + for (var invocationIdAr : this.invocationIdsToCancel) { + this.stateMachine.cancelInvocation( + Objects.requireNonNull(invocationIdAr.publicFuture().getNow(null))); + } + asyncResult.tryCancel(); + } catch (Throwable e) { + // Not good! + this.failWithoutContextSwitch(e); + } + }); + // Let's resolve all the invocation IDs + pollAsyncResultInner(allInvocationIds); + return; + } + + // Let's start by trying to complete it + asyncResult.tryComplete(this.stateMachine); + + // Now let's take the unprocessed leaves + List uncompletedLeaves = + Stream.concat(asyncResult.uncompletedLeaves(), Stream.of(CANCEL_HANDLE)).toList(); + if (uncompletedLeaves.size() == 1) { + // Nothing else to do! + return; + } + + // Not ready yet, let's try to do some progress + StateMachine.DoProgressResponse response = this.stateMachine.doProgress(uncompletedLeaves); + + if (response instanceof StateMachine.DoProgressResponse.AnyCompleted) { + // Let it loop now + } else if (response instanceof StateMachine.DoProgressResponse.ReadFromInput) { + this.stateMachine + .waitNextInputSignal() + .thenAccept(v -> this.pollAsyncResultInner(asyncResult)); + return; + } else if (response instanceof StateMachine.DoProgressResponse.ExecuteRun) { + triggerScheduledRun(((StateMachine.DoProgressResponse.ExecuteRun) response).handle()); + // Let it loop now + } else if (response instanceof StateMachine.DoProgressResponse.WaitingPendingRun) { + this.waitNextProcessedRun().thenAccept(v -> this.pollAsyncResultInner(asyncResult)); + return; + } + } + } + + @Override + public void proposeRunSuccess(int runHandle, Slice toWrite) { + try { + this.stateMachine.proposeRunCompletion(runHandle, toWrite); + } catch (Exception e) { + this.failWithoutContextSwitch(e); + } + triggerNextProcessedRun(); + } + + @Override + public void proposeRunFailure( + int runHandle, + Throwable toWrite, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy) { + try { + this.stateMachine.proposeRunCompletion(runHandle, toWrite, attemptDuration, retryPolicy); + } catch (Exception e) { + this.failWithoutContextSwitch(e); + } + triggerNextProcessedRun(); + } + + private void triggerNextProcessedRun() { + if (this.nextProcessedRun != null) { + var fut = this.nextProcessedRun; + this.nextProcessedRun = null; + fut.complete(null); + } + } + + private void triggerScheduledRun(int handle) { + var consumer = + Objects.requireNonNull( + this.scheduledRuns.get(handle), "The given handle doesn't exist, this is an SDK bug"); + var startTime = Instant.now(); + consumer.accept( + new RunCompleter() { + @Override + public void proposeSuccess(Slice toWrite) { + proposeRunSuccess(handle, toWrite); + } + + @Override + public void proposeFailure(Throwable toWrite, @Nullable RetryPolicy retryPolicy) { + proposeRunFailure( + handle, toWrite, Duration.between(startTime, Instant.now()), retryPolicy); + } + }); + } + + private CompletableFuture waitNextProcessedRun() { + if (this.nextProcessedRun == null) { + this.nextProcessedRun = new CompletableFuture<>(); + } + return this.nextProcessedRun; + } + + @Override + public void close() { + this.stateMachine.end(); + } + + @Override + public void fail(Throwable cause) { + this.failWithoutContextSwitch(cause); + } + + @Override + public void failWithoutContextSwitch(Throwable cause) { + this.stateMachine.onError(cause); + } + + // -- Wrapper for failure propagation + + private CompletableFuture catchExceptions(ThrowingRunnable r) { + try { + r.run(); + return CompletableFuture.completedFuture(null); + } catch (Throwable e) { + this.failWithoutContextSwitch(e); + return CompletableFuture.failedFuture(AbortedExecutionException.INSTANCE); + } + } + + private CompletableFuture catchExceptions(ThrowingSupplier r) { + try { + return CompletableFuture.completedFuture(r.get()); + } catch (Throwable e) { + this.failWithoutContextSwitch(e); + return CompletableFuture.failedFuture(AbortedExecutionException.INSTANCE); + } + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java new file mode 100644 index 000000000..266daf17a --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextInternal.java @@ -0,0 +1,65 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Slice; +import dev.restate.sdk.core.AsyncResults.AsyncResultInternal; +import dev.restate.sdk.core.statemachine.InvocationState; +import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; +import dev.restate.sdk.types.RetryPolicy; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; + +interface HandlerContextInternal extends HandlerContext { + + @Override + default AsyncResult createAnyAsyncResult(List> children) { + return AsyncResults.any( + this, + children.stream().map(dr -> (AsyncResultInternal) dr).collect(Collectors.toList())); + } + + @Override + default AsyncResult createAllAsyncResult(List> children) { + return AsyncResults.all( + this, + children.stream().map(dr -> (AsyncResultInternal) dr).collect(Collectors.toList())); + } + + void proposeRunSuccess(int runHandle, Slice toWrite); + + void proposeRunFailure( + int runHandle, + Throwable toWrite, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy); + + void pollAsyncResult(AsyncResultInternal asyncResult); + + // -- Lifecycle methods + + void close(); + + // -- State machine introspection (used by logging propagator) + + /** + * @return fully qualified method name in the form {fullyQualifiedServiceName}/{methodName} + */ + String getFullyQualifiedMethodName(); + + InvocationState getInvocationState(); + + Executor stateMachineExecutor(); + + void failWithoutContextSwitch(Throwable throwable); +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/IncomingEntriesStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/IncomingEntriesStateMachine.java deleted file mode 100644 index 29ace7ad4..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/IncomingEntriesStateMachine.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.MessageLite; -import java.util.ArrayDeque; -import java.util.Queue; - -class IncomingEntriesStateMachine - extends BaseSuspendableCallbackStateMachine { - - interface OnEntryCallback extends SuspendableCallback { - void onEntry(MessageLite msg); - } - - private final Queue unprocessedMessages; - - IncomingEntriesStateMachine() { - this.unprocessedMessages = new ArrayDeque<>(); - } - - void offer(MessageLite msg) { - Util.assertIsEntry(msg); - this.consumeCallbackOrElse(cb -> cb.onEntry(msg), () -> this.unprocessedMessages.offer(msg)); - } - - void read(OnEntryCallback msgCallback) { - this.assertCallbackNotSet("Two concurrent reads were requested."); - - MessageLite popped = this.unprocessedMessages.poll(); - if (popped != null) { - msgCallback.onEntry(popped); - } else { - this.setCallback(msgCallback); - } - } - - boolean isEmpty() { - return this.unprocessedMessages.isEmpty(); - } - - @Override - void abort(Throwable cause) { - super.abort(cause); - // We can't do anything else if the input stream is closed, so we just fail the callback, if any - this.tryFailCallback(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InputPublisherState.java b/sdk-core/src/main/java/dev/restate/sdk/core/InputPublisherState.java deleted file mode 100644 index e71f66500..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InputPublisherState.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.sdk.common.AbortedExecutionException; -import org.jspecify.annotations.Nullable; - -class InputPublisherState { - - private @Nullable Throwable closeCause = null; - - void notifyClosed(Throwable cause) { - closeCause = cause; - } - - boolean isSuspended() { - return this.closeCause == AbortedExecutionException.INSTANCE; - } - - boolean isClosed() { - return this.closeCause != null; - } - - public Throwable getCloseCause() { - return closeCause; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationFlow.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationFlow.java deleted file mode 100644 index 7315753f4..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationFlow.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import java.nio.ByteBuffer; -import java.util.concurrent.Flow; - -public interface InvocationFlow { - - interface InvocationInputPublisher extends Flow.Publisher {} - - interface InvocationOutputPublisher extends Flow.Publisher {} - - interface InvocationInputSubscriber extends Flow.Subscriber {} - - interface InvocationOutputSubscriber extends Flow.Subscriber {} -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java deleted file mode 100644 index faf770067..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ /dev/null @@ -1,904 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.Util.durationMin; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.*; -import dev.restate.sdk.common.syscalls.*; -import io.opentelemetry.api.common.Attributes; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.context.Context; -import java.time.Duration; -import java.time.Instant; -import java.util.*; -import java.util.concurrent.Flow; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -class InvocationStateMachine implements Flow.Processor { - - private static final Logger LOG = LogManager.getLogger(InvocationStateMachine.class); - - private final String serviceName; - private final String fullyQualifiedHandlerName; - private final Span span; - private final RestateEndpoint.LoggingContextSetter loggingContextSetter; - private final Protocol.ServiceProtocolVersion negotiatedProtocolVersion; - - private volatile InvocationState invocationState = InvocationState.WAITING_START; - - // Used for the side effect guard - private Long sideEffectStart; - private volatile boolean insideSideEffect = false; - - // Obtained after WAITING_START - private ByteString id; - private String debugId; - private String key; - private int entriesToReplay; - private UserStateStore userStateStore; - - // Used inside syscalls.shouldRetry, which doesn't run on the syscalls executor sometimes - private Duration startMessageDurationSinceLastStoredEntry; - private int startMessageRetryCountSinceLastStoredEntry; - - // Those values track the progress in the journal - private int currentJournalEntryIndex = -1; - private String currentJournalEntryName = null; - private MessageType currentJournalEntryType = null; - - // Buffering of messages and completions - private final IncomingEntriesStateMachine incomingEntriesStateMachine; - private final AckStateMachine ackStateMachine; - private final ReadyResultStateMachine readyResultStateMachine; - - // Flow sub/pub - private Flow.Subscriber outputSubscriber; - private Flow.Subscription inputSubscription; - private final CallbackHandle> afterStartCallback; - - InvocationStateMachine( - String serviceName, - String fullyQualifiedHandlerName, - Span span, - RestateEndpoint.LoggingContextSetter loggingContextSetter, - Protocol.ServiceProtocolVersion negotiatedProtocolVersion) { - this.serviceName = serviceName; - this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; - this.span = span; - this.loggingContextSetter = loggingContextSetter; - this.negotiatedProtocolVersion = negotiatedProtocolVersion; - - this.incomingEntriesStateMachine = new IncomingEntriesStateMachine(); - this.readyResultStateMachine = new ReadyResultStateMachine(); - this.ackStateMachine = new AckStateMachine(); - - this.afterStartCallback = new CallbackHandle<>(); - } - - // --- Getters - - public String getServiceName() { - return serviceName; - } - - public ByteString id() { - return id; - } - - public String objectKey() { - return key; - } - - public InvocationState getInvocationState() { - return this.invocationState; - } - - public boolean isInsideSideEffect() { - return this.insideSideEffect; - } - - public String getFullyQualifiedHandlerName() { - return this.fullyQualifiedHandlerName; - } - - // --- Output Publisher impl - - @Override - public void subscribe(Flow.Subscriber subscriber) { - this.outputSubscriber = subscriber; - this.outputSubscriber.onSubscribe( - new Flow.Subscription() { - @Override - public void request(long l) {} - - @Override - public void cancel() { - end(); - } - }); - } - - // --- Input Subscriber impl - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.inputSubscription = subscription; - } - - @Override - public void onNext(InvocationInput invocationInput) { - MessageLite msg = invocationInput.message(); - LOG.trace("Received input message {} {}", msg.getClass(), msg); - if (this.invocationState == InvocationState.WAITING_START) { - this.onStartMessage(msg); - } else if (msg instanceof Protocol.CompletionMessage) { - // We check the instance rather than the state, because the user code might still be - // replaying, but the network layer is already past it and is receiving completions from the - // runtime. - this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg); - } else if (msg instanceof Protocol.EntryAckMessage) { - this.ackStateMachine.tryHandleAck(((Protocol.EntryAckMessage) msg).getEntryIndex()); - } else { - this.incomingEntriesStateMachine.offer(msg); - } - } - - @Override - public void onError(Throwable throwable) { - LOG.trace("Got failure from input publisher", throwable); - this.fail(throwable); - } - - @Override - public void onComplete() { - LOG.trace("Input publisher closed"); - this.readyResultStateMachine.abort(AbortedExecutionException.INSTANCE); - this.ackStateMachine.abort(AbortedExecutionException.INSTANCE); - } - - // --- Init routine to wait for the start message - - void startAndConsumeInput(SyscallCallback afterStartCallback) { - this.afterStartCallback.set(afterStartCallback); - this.inputSubscription.request(1); - } - - void onStartMessage(MessageLite msg) { - if (!(msg instanceof Protocol.StartMessage startMessage)) { - this.fail(ProtocolException.unexpectedMessage(Protocol.StartMessage.class, msg)); - return; - } - - // Unpack the StartMessage - this.id = startMessage.getId(); - this.debugId = startMessage.getDebugId(); - InvocationId invocationId = new InvocationIdImpl(startMessage.getDebugId()); - this.key = startMessage.getKey(); - this.entriesToReplay = startMessage.getKnownEntries(); - this.startMessageDurationSinceLastStoredEntry = - Duration.ofMillis(startMessage.getDurationSinceLastStoredEntry()); - this.startMessageRetryCountSinceLastStoredEntry = - startMessage.getRetryCountSinceLastStoredEntry(); - - // Set up the state cache - this.userStateStore = new UserStateStore(startMessage); - - // Tracing and logging setup - this.loggingContextSetter.set( - RestateEndpoint.LoggingContextSetter.INVOCATION_ID_KEY, startMessage.getDebugId()); - if (this.span.isRecording()) { - span.addEvent( - "Start", Attributes.of(Tracing.RESTATE_INVOCATION_ID, startMessage.getDebugId())); - } - LOG.info("Start invocation"); - - // Execute state transition - this.transitionState(InvocationState.REPLAYING); - if (this.entriesToReplay == 0) { - this.fail( - new ProtocolException( - "Expected at least one entry with Input, got " + this.entriesToReplay + " entries", - TerminalException.INTERNAL_SERVER_ERROR_CODE, - null)); - return; - } - - this.inputSubscription.request(Long.MAX_VALUE); - - // Now wait input entry - this.nextJournalEntry(null, MessageType.InputEntryMessage); - this.readEntry( - inputMsg -> { - if (!(inputMsg instanceof Protocol.InputEntryMessage inputEntry)) { - throw ProtocolException.unexpectedMessage(Protocol.InputEntryMessage.class, inputMsg); - } - - Request request = - new Request( - invocationId, - Context.root().with(span), - inputEntry.getValue().asReadOnlyByteBuffer(), - inputEntry.getHeadersList().stream() - .collect( - Collectors.toUnmodifiableMap( - Protocol.Header::getKey, Protocol.Header::getValue))); - - this.afterStartCallback.consume(cb -> cb.onSuccess(request)); - }, - this::fail); - } - - // --- Close state machine - - void end() { - if (this.invocationState != InvocationState.CLOSED) { - LOG.info("End invocation"); - this.closeWithMessage(Protocol.EndMessage.getDefaultInstance(), ProtocolException.CLOSED); - } - } - - void suspend(Collection suspensionIndexes) { - assert !suspensionIndexes.isEmpty() - : "Suspension indexes MUST be a non-empty collection, per protocol specification"; - LOG.info("Suspend invocation"); - this.closeWithMessage( - Protocol.SuspensionMessage.newBuilder().addAllEntryIndexes(suspensionIndexes).build(), - ProtocolException.CLOSED); - } - - void fail(Throwable cause) { - if (this.invocationState != InvocationState.CLOSED) { - LOG.warn("Invocation failed", cause); - this.closeWithMessage( - Util.toErrorMessage( - cause, - this.currentJournalEntryIndex, - this.currentJournalEntryName, - this.currentJournalEntryType), - cause); - } - } - - void failWithNextRetryDelay(Throwable cause, Duration nextRetryDelay) { - if (this.invocationState != InvocationState.CLOSED) { - LOG.warn("Invocation failed, will retry in {}", nextRetryDelay, cause); - this.closeWithMessage( - Util.toErrorMessage( - cause, - this.currentJournalEntryIndex, - this.currentJournalEntryName, - this.currentJournalEntryType) - .toBuilder() - .setNextRetryDelay(nextRetryDelay.toMillis()) - .build(), - cause); - } - } - - private void closeWithMessage(MessageLite closeMessage, Throwable cause) { - if (this.invocationState != InvocationState.CLOSED) { - this.transitionState(InvocationState.CLOSED); - - // Cancel inputSubscription and complete outputSubscriber - if (inputSubscription != null) { - this.inputSubscription.cancel(); - } - if (this.outputSubscriber != null) { - this.outputSubscriber.onNext(closeMessage); - this.outputSubscriber.onComplete(); - this.outputSubscriber = null; - } - - // Unblock any eventual waiting callbacks - this.afterStartCallback.consume(cb -> cb.onCancel(cause)); - this.readyResultStateMachine.abort(cause); - this.ackStateMachine.abort(cause); - this.incomingEntriesStateMachine.abort(cause); - this.span.end(); - } - } - - // --- Methods to implement Syscalls - - @SuppressWarnings("unchecked") - void processCompletableJournalEntry( - E expectedEntryMessage, - Entries.CompletableJournalEntry journalEntry, - SyscallCallback> callback) { - checkInsideSideEffectGuard(); - this.nextJournalEntry( - journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); - - if (this.invocationState == InvocationState.CLOSED) { - callback.onCancel(AbortedExecutionException.INSTANCE); - } else if (this.invocationState == InvocationState.REPLAYING) { - // Retrieve the entry - this.readEntry( - actualEntryMessage -> { - journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); - - if (journalEntry.hasResult((E) actualEntryMessage)) { - // Entry is already completed - journalEntry.updateUserStateStoreWithEntry( - (E) actualEntryMessage, this.userStateStore); - Result readyResultInternal = journalEntry.parseEntryResult((E) actualEntryMessage); - callback.onSuccess( - DeferredResults.completedSingle( - this.currentJournalEntryIndex, readyResultInternal)); - } else { - // Entry is not completed yet - this.readyResultStateMachine.offerCompletionParser( - this.currentJournalEntryIndex, - completionMessage -> { - journalEntry.updateUserStateStorageWithCompletion( - (E) actualEntryMessage, completionMessage, this.userStateStore); - return journalEntry.parseCompletionResult(completionMessage); - }); - callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); - } - }, - callback::onCancel); - } else if (this.invocationState == InvocationState.PROCESSING) { - // Try complete with local storage - E entryToWrite = - journalEntry.tryCompleteWithUserStateStorage(expectedEntryMessage, userStateStore); - - if (span.isRecording()) { - journalEntry.trace(entryToWrite, span); - } - - // Write out the input entry - this.writeEntry(entryToWrite); - - if (journalEntry.hasResult(entryToWrite)) { - // Complete it with the result, as we already have it - callback.onSuccess( - DeferredResults.completedSingle( - this.currentJournalEntryIndex, journalEntry.parseEntryResult(entryToWrite))); - } else { - // Register the completion parser - this.readyResultStateMachine.offerCompletionParser( - this.currentJournalEntryIndex, - completionMessage -> { - journalEntry.updateUserStateStorageWithCompletion( - entryToWrite, completionMessage, this.userStateStore); - return journalEntry.parseCompletionResult(completionMessage); - }); - - // Call the onSuccess - callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); - } - } else { - throw new IllegalStateException( - "This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug"); - } - } - - @SuppressWarnings("unchecked") - void processJournalEntry( - E expectedEntryMessage, - Entries.JournalEntry journalEntry, - SyscallCallback callback) { - checkInsideSideEffectGuard(); - this.nextJournalEntry( - journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); - - if (this.invocationState == InvocationState.CLOSED) { - callback.onCancel(AbortedExecutionException.INSTANCE); - } else if (this.invocationState == InvocationState.REPLAYING) { - // Retrieve the entry - this.readEntry( - actualEntryMessage -> { - journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); - journalEntry.updateUserStateStoreWithEntry((E) actualEntryMessage, this.userStateStore); - callback.onSuccess(null); - }, - callback::onCancel); - } else if (this.invocationState == InvocationState.PROCESSING) { - if (span.isRecording()) { - journalEntry.trace(expectedEntryMessage, span); - } - - // Write new entry - this.writeEntry(expectedEntryMessage); - - // Update local storage - journalEntry.updateUserStateStoreWithEntry(expectedEntryMessage, this.userStateStore); - - // Invoke the ok callback - callback.onSuccess(null); - } else { - throw new IllegalStateException( - "This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug"); - } - } - - void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { - checkInsideSideEffectGuard(); - this.nextJournalEntry(name, MessageType.RunEntryMessage); - - if (this.invocationState == InvocationState.CLOSED) { - callback.onCancel(AbortedExecutionException.INSTANCE); - } else if (this.invocationState == InvocationState.REPLAYING) { - // Retrieve the entry - this.readEntry( - msg -> { - Util.assertEntryClass(Protocol.RunEntryMessage.class, msg); - - // We have a result already, complete the callback - completeSideEffectCallbackWithEntry((Protocol.RunEntryMessage) msg, callback); - }, - callback::onCancel); - } else if (this.invocationState == InvocationState.PROCESSING) { - insideSideEffect = true; - sideEffectStart = System.currentTimeMillis(); - if (span.isRecording()) { - span.addEvent("Enter SideEffect"); - } - callback.onNotExecuted(); - } else { - throw new IllegalStateException( - "This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug"); - } - } - - void exitSideEffectBlock( - Protocol.RunEntryMessage sideEffectEntry, ExitSideEffectSyscallCallback callback) { - this.insideSideEffect = false; - this.sideEffectStart = null; - if (this.invocationState == InvocationState.CLOSED) { - callback.onCancel(AbortedExecutionException.INSTANCE); - } else if (this.invocationState == InvocationState.REPLAYING) { - throw new IllegalStateException( - "exitSideEffect has been invoked when the state machine is in replaying mode. " - + "This is probably an SDK bug and might be caused by a missing enterSideEffectBlock invocation before exitSideEffectBlock."); - } else if (this.invocationState == InvocationState.PROCESSING) { - if (span.isRecording()) { - span.addEvent("Exit SideEffect"); - } - - // For side effects, let's write out the name too, if available - if (this.currentJournalEntryName != null) { - sideEffectEntry = sideEffectEntry.toBuilder().setName(this.currentJournalEntryName).build(); - } - - // Write new entry - this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex); - this.writeEntry(sideEffectEntry); - - // Wait for entry to be acked - Protocol.RunEntryMessage finalSideEffectEntry = sideEffectEntry; - this.ackStateMachine.waitLastAck( - new AckStateMachine.AckCallback() { - @Override - public void onAck() { - completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback); - } - - @Override - public void onSuspend() { - suspend(List.of(ackStateMachine.getLastEntryToAck())); - callback.onCancel(AbortedExecutionException.INSTANCE); - } - - @Override - public void onError(Throwable e) { - callback.onCancel(e); - } - }); - } else { - throw new IllegalStateException( - "This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug"); - } - } - - void exitSideEffectBlockWithThrowable( - Throwable runException, - @Nullable RetryPolicy retryPolicy, - ExitSideEffectSyscallCallback callback) - throws Throwable { - TerminalException toWrite; - if (runException instanceof TerminalException) { - LOG.trace("The run completed with a terminal exception"); - toWrite = (TerminalException) runException; - } else { - toWrite = this.rethrowOrConvertToTerminal(retryPolicy, runException); - } - - LOG.trace("exitSideEffectBlock with exception"); - this.exitSideEffectBlock( - Protocol.RunEntryMessage.newBuilder().setFailure(Util.toProtocolFailure(toWrite)).build(), - callback); - } - - private Duration getDurationSinceLastStoredEntry() { - // We need to check if this is the first entry we try to commit after replay, and only in this - // case we need to return the info we got from the start message - // - // Moreover, when the retry count is == 0, the durationSinceLastStoredEntry might not be zero. - // In fact, in that case the duration is the interval between the previously stored entry and - // the time to start/resume the invocation. - // For the sake of entry retries though, we're not interested in that time elapsed, so we 0 it - // here for simplicity of the downstream consumer (the retry policy). - return this.currentJournalEntryIndex == this.entriesToReplay - && startMessageRetryCountSinceLastStoredEntry > 0 - ? this.startMessageDurationSinceLastStoredEntry - : Duration.ZERO; - } - - private int getRetryCountSinceLastStoredEntry() { - // We need to check if this is the first entry we try to commit after replay, and only in this - // case we need to return the info we got from the start message - return this.currentJournalEntryIndex == this.entriesToReplay - ? this.startMessageRetryCountSinceLastStoredEntry - : 0; - } - - // This function rethrows the exception if a retry needs to happen. - private TerminalException rethrowOrConvertToTerminal( - @Nullable RetryPolicy retryPolicy, Throwable t) throws Throwable { - if (retryPolicy != null - && this.negotiatedProtocolVersion.getNumber() - < Protocol.ServiceProtocolVersion.V2.getNumber()) { - throw ProtocolException.unsupportedFeature( - this.negotiatedProtocolVersion, "run retry policy"); - } - - if (retryPolicy == null) { - LOG.trace("The run completed with an exception and no retry policy was provided"); - // Default behavior is always retry - throw t; - } - - Duration retryLoopDuration = - this.getDurationSinceLastStoredEntry() - .plus(Duration.between(Instant.ofEpochMilli(this.sideEffectStart), Instant.now())); - int retryCount = this.getRetryCountSinceLastStoredEntry() + 1; - - if ((retryPolicy.getMaxAttempts() != null && retryPolicy.getMaxAttempts() <= retryCount) - || (retryPolicy.getMaxDuration() != null - && retryPolicy.getMaxDuration().compareTo(retryLoopDuration) <= 0)) { - LOG.trace("The run completed with a retryable exception, but all attempts were exhausted"); - // We need to convert it to TerminalException - return new TerminalException(t.toString()); - } - - // Compute next retry delay and throw it! - Duration nextComputedDelay = - retryPolicy - .getInitialDelay() - .multipliedBy((long) Math.pow(retryPolicy.getExponentiationFactor(), retryCount)); - Duration nextRetryDelay = - retryPolicy.getMaxDelay() != null - ? durationMin(retryPolicy.getMaxDelay(), nextComputedDelay) - : nextComputedDelay; - - this.failWithNextRetryDelay(t, nextRetryDelay); - throw t; - } - - void completeSideEffectCallbackWithEntry( - Protocol.RunEntryMessage sideEffectEntry, ExitSideEffectSyscallCallback callback) { - if (sideEffectEntry.hasFailure()) { - callback.onFailure(Util.toRestateException(sideEffectEntry.getFailure())); - } else { - callback.onSuccess(sideEffectEntry.getValue().asReadOnlyByteBuffer()); - } - } - - // --- Deferred - - void resolveDeferred(Deferred deferredToResolve, SyscallCallback callback) { - if (deferredToResolve.isCompleted()) { - callback.onSuccess(null); - return; - } - - if (deferredToResolve instanceof DeferredResults.ResolvableSingleDeferred) { - this.resolveSingleDeferred( - (DeferredResults.ResolvableSingleDeferred) deferredToResolve, callback); - return; - } - - if (deferredToResolve instanceof DeferredResults.CombinatorDeferred) { - this.resolveCombinatorDeferred( - (DeferredResults.CombinatorDeferred) deferredToResolve, callback); - return; - } - - throw new IllegalArgumentException("Unexpected deferred class " + deferredToResolve.getClass()); - } - - void resolveSingleDeferred( - DeferredResults.ResolvableSingleDeferred deferred, SyscallCallback callback) { - this.readyResultStateMachine.onNewReadyResult( - new ReadyResultStateMachine.OnNewReadyResultCallback() { - @SuppressWarnings("unchecked") - @Override - public boolean onNewResult(Map> resultMap) { - Result resolved = (Result) resultMap.remove(deferred.entryIndex()); - if (resolved != null) { - deferred.resolve(resolved); - callback.onSuccess(null); - return true; - } - return false; - } - - @Override - public void onSuspend() { - suspend(List.of(deferred.entryIndex())); - callback.onCancel(AbortedExecutionException.INSTANCE); - } - - @Override - public void onError(Throwable e) { - callback.onCancel(e); - } - }); - } - - /** - * This method implements the algorithm to resolve deferred combinator trees, where inner nodes of - * the tree are ANY or ALL combinators, and leafs are {@link - * DeferredResults.ResolvableSingleDeferred}, created as result of completable syscalls. - * - *

The idea of the algorithm is the following: {@code rootDeferred} is the root of this tree, - * and has internal state that can be mutated through {@link - * DeferredResults.CombinatorDeferred#tryResolve(int)} to flag the tree as resolved. Every time a - * new leaf is resolved through {@link DeferredResults.ResolvableSingleDeferred#resolve(Result)}, - * we try to resolve the tree again. We start by checking if we have enough resolved leafs in the - * combinator tree to resolve it. If not, we register a callback to the {@link - * ReadyResultStateMachine} to wait on future completions. As soon as the tree is resolved, we - * record in the journal the order of the leafs we've seen so far, and we finish by calling the - * {@code callback}, giving back control to user code. - * - *

An important property of this algorithm is that we don't write multiple {@link - * Java.CombinatorAwaitableEntryMessage} per combinator nodes composing the tree, but we write one - * of them for the whole tree. Moreover, we write only when we resolve the combinator tree, and - * not beforehand when the user creates the combinator tree. The main reason for this property is - * that the Restate protocol doesn't allow the SDK to mutate Journal Entries after they're sent to - * the runtime, and the index of entries is enforced by their send order, meaning you cannot send - * entry 2 and then entry 1. The consequence of this property is that any algorithm recording - * combinator nodes one-by-one would require non-trivial replay logic, in order to handle the - * resolution order of the combinator nodes, and partially resolved trees (e.g. in case a - * suspension happens while we have recorded only a part of the combinator nodes). - * - *

There are some special cases: - * - *

    - *
  • In case of replay, we don't need to wait for any leaf to be resolved, because we write - * the combinator journal entry only when there is a subset of resolved leafs which - * completes the combinator tree. Moreover, the leaf journal entries precede the combinator - * entry because they are created first. - *
  • In case there are no {@link DeferredResults.SingleDeferredInternal - * SingleDeferredResultInternals}, it means every leaf has been resolved beforehand. In this - * case, we must be able to flag this combinator tree as resolved as well. - *
- */ - private void resolveCombinatorDeferred( - DeferredResults.CombinatorDeferred rootDeferred, SyscallCallback callback) { - // Calling .await() on a combinator deferred within a side effect is not allowed - // as resolving it creates or read a journal entry. - checkInsideSideEffectGuard(); - this.nextJournalEntry(null, MessageType.CombinatorAwaitableEntryMessage); - - if (Objects.equals(this.invocationState, InvocationState.REPLAYING)) { - // Retrieve the CombinatorAwaitableEntryMessage - this.readEntry( - actualMsg -> { - Util.assertEntryClass(Java.CombinatorAwaitableEntryMessage.class, actualMsg); - - if (!rootDeferred.tryResolve( - ((Java.CombinatorAwaitableEntryMessage) actualMsg).getEntryIndexList())) { - throw new IllegalStateException("Combinator message cannot be resolved."); - } - callback.onSuccess(null); - }, - callback::onCancel); - } else if (this.invocationState == InvocationState.PROCESSING) { - // Create map of singles to resolve - Map> resolvableSingles = new HashMap<>(); - - Set> unprocessedLeafs = - rootDeferred.unprocessedLeafs().collect(Collectors.toSet()); - - // If there are no leafs, it means the combinator must be resolvable - if (unprocessedLeafs.isEmpty()) { - // We don't need to provide a valid entry index, - // we just need to walk through the tree and mark all the combinators as completed. - if (!rootDeferred.tryResolve(-1)) { - throw new IllegalStateException( - "Combinator cannot be resolved, but every children have been resolved already. " - + "This is a symptom of an SDK bug, please contact the developers."); - } - - writeCombinatorEntry(Collections.emptyList(), callback); - return; - } - - List resolvedOrder = new ArrayList<>(); - - // Walk the tree and populate the resolvable singles, and keep the already known ready results - for (DeferredResults.SingleDeferredInternal singleDeferred : unprocessedLeafs) { - int entryIndex = singleDeferred.entryIndex(); - if (singleDeferred.isCompleted()) { - resolvedOrder.add(entryIndex); - - // Try to resolve the combinator now - if (rootDeferred.tryResolve(entryIndex)) { - writeCombinatorEntry(resolvedOrder, callback); - return; - } - } else { - // If not completed, then it's a ResolvableSingleDeferredResult - resolvableSingles.put( - entryIndex, (DeferredResults.ResolvableSingleDeferred) singleDeferred); - } - } - - // Not completed yet, we need to wait on the ReadyResultPublisher - this.readyResultStateMachine.onNewReadyResult( - new ReadyResultStateMachine.OnNewReadyResultCallback() { - @SuppressWarnings({"unchecked", "rawtypes"}) - @Override - public boolean onNewResult(Map> resultMap) { - Iterator>> it = - resolvableSingles.entrySet().iterator(); - while (it.hasNext()) { - Map.Entry> entry = it.next(); - int entryIndex = entry.getKey(); - - Result result = resultMap.remove(entryIndex); - if (result != null) { - resolvedOrder.add(entryIndex); - entry.getValue().resolve((Result) result); - it.remove(); - - // Try to resolve the combinator now - if (rootDeferred.tryResolve(entryIndex)) { - writeCombinatorEntry(resolvedOrder, callback); - return true; - } - } - } - - return false; - } - - @Override - public void onSuspend() { - suspend(resolvableSingles.keySet()); - callback.onCancel(AbortedExecutionException.INSTANCE); - } - - @Override - public void onError(Throwable e) { - callback.onCancel(e); - } - }); - } else { - throw new IllegalStateException( - "This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug"); - } - } - - private void writeCombinatorEntry(List resolvedList, SyscallCallback callback) { - // Create and write the entry - Java.CombinatorAwaitableEntryMessage entry = - Java.CombinatorAwaitableEntryMessage.newBuilder().addAllEntryIndex(resolvedList).build(); - span.addEvent("Combinator"); - - // We register the combinator entry to wait for an ack - this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex); - writeEntry(entry); - - // Let's wait for the ack - this.ackStateMachine.waitLastAck( - new AckStateMachine.AckCallback() { - @Override - public void onAck() { - callback.onSuccess(null); - } - - @Override - public void onSuspend() { - suspend(List.of(ackStateMachine.getLastEntryToAck())); - callback.onCancel(AbortedExecutionException.INSTANCE); - } - - @Override - public void onError(Throwable e) { - callback.onCancel(e); - } - }); - } - - // --- Internal callback - - private void transitionState(InvocationState newInvocationState) { - if (this.invocationState == InvocationState.CLOSED) { - // Cannot move out of the closed state - return; - } - LOG.debug("Transitioning state machine to {}", newInvocationState); - this.invocationState = newInvocationState; - this.loggingContextSetter.set( - RestateEndpoint.LoggingContextSetter.INVOCATION_STATUS_KEY, newInvocationState.toString()); - } - - private void tryTransitionProcessing() { - if (currentJournalEntryIndex == entriesToReplay - 1 - && this.invocationState == InvocationState.REPLAYING) { - if (!this.incomingEntriesStateMachine.isEmpty()) { - throw new IllegalStateException("Entries queue should be empty at this point"); - } - this.transitionState(InvocationState.PROCESSING); - } - } - - private void nextJournalEntry(String entryName, MessageType entryType) { - this.currentJournalEntryIndex++; - this.currentJournalEntryName = entryName; - this.currentJournalEntryType = entryType; - - LOG.debug( - "Current journal entry [{}]({}): {}", - this.currentJournalEntryIndex, - this.currentJournalEntryName, - this.currentJournalEntryType); - } - - private void checkInsideSideEffectGuard() { - if (this.insideSideEffect) { - throw ProtocolException.invalidSideEffectCall(); - } - } - - void readEntry(Consumer msgCallback, Consumer errorCallback) { - this.incomingEntriesStateMachine.read( - new IncomingEntriesStateMachine.OnEntryCallback() { - @Override - public void onEntry(MessageLite msg) { - tryTransitionProcessing(); - msgCallback.accept(msg); - } - - @Override - public void onSuspend() { - // This is not expected to happen, so we treat this case as closed - errorCallback.accept(ProtocolException.CLOSED); - } - - @Override - public void onError(Throwable e) { - errorCallback.accept(e); - } - }); - } - - private void writeEntry(MessageLite message) { - LOG.trace("Writing to output message {} {}", message.getClass(), message); - Objects.requireNonNull(this.outputSubscriber).onNext(message); - } - - @Override - public String toString() { - return "InvocationStateMachine[" + debugId + ']'; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageEncoder.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageEncoder.java deleted file mode 100644 index 03ca339ce..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageEncoder.java +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.MessageLite; -import java.nio.ByteBuffer; -import java.util.concurrent.Flow; - -class MessageEncoder implements InvocationFlow.InvocationOutputPublisher { - - private final Flow.Publisher inner; - - MessageEncoder(Flow.Publisher inner) { - this.inner = inner; - } - - @Override - public void subscribe(Flow.Subscriber subscriber) { - this.inner.subscribe( - new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriber.onSubscribe(subscription); - } - - @Override - public void onNext(MessageLite item) { - // We could pool those buffers somehow? - ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(item)); - MessageEncoder.encode(buffer, item); - subscriber.onNext(buffer); - } - - @Override - public void onError(Throwable throwable) { - subscriber.onError(throwable); - } - - @Override - public void onComplete() { - subscriber.onComplete(); - } - }); - } - - static int encodeLength(MessageLite msg) { - return 8 + msg.getSerializedSize(); - } - - static ByteBuffer encode(ByteBuffer buffer, MessageLite msg) { - MessageHeader header = MessageHeader.fromMessage(msg); - - buffer.putLong(header.encode()); - buffer.put(msg.toByteString().asReadOnlyByteBuffer()); - - buffer.flip(); - - return buffer; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java deleted file mode 100644 index ec92dff07..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.MessageLite; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; - -public class MessageHeader { - - static final short DONE_FLAG = 0x0001; - static final int REQUIRES_ACK_FLAG = 0x8000; - - private final MessageType type; - private final int flags; - private final int length; - - public MessageHeader(MessageType type, int flags, int length) { - this.type = type; - this.flags = flags; - this.length = length; - } - - public MessageType getType() { - return type; - } - - public int getLength() { - return length; - } - - public long encode() { - long res = 0L; - res |= ((long) type.encode() << 48); - res |= ((long) flags << 32); - res |= length; - return res; - } - - public static MessageHeader parse(long encoded) throws ProtocolException { - var ty_code = (short) (encoded >> 48); - var flags = (short) (encoded >> 32); - var len = (int) encoded; - - return new MessageHeader(MessageType.decode(ty_code), flags, len); - } - - public static MessageHeader fromMessage(MessageLite msg) { - if (msg instanceof Protocol.GetStateEntryMessage) { - return fromCompletableMessage( - (Protocol.GetStateEntryMessage) msg, Entries.GetStateEntry.INSTANCE); - } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { - return fromCompletableMessage( - (Protocol.GetStateKeysEntryMessage) msg, Entries.GetStateKeysEntry.INSTANCE); - } else if (msg instanceof Protocol.GetPromiseEntryMessage) { - return fromCompletableMessage( - (Protocol.GetPromiseEntryMessage) msg, Entries.GetPromiseEntry.INSTANCE); - } else if (msg instanceof Protocol.PeekPromiseEntryMessage) { - return fromCompletableMessage( - (Protocol.PeekPromiseEntryMessage) msg, Entries.PeekPromiseEntry.INSTANCE); - } else if (msg instanceof Protocol.CompletePromiseEntryMessage) { - return fromCompletableMessage( - (Protocol.CompletePromiseEntryMessage) msg, Entries.CompletePromiseEntry.INSTANCE); - } else if (msg instanceof Protocol.SleepEntryMessage) { - return fromCompletableMessage((Protocol.SleepEntryMessage) msg, Entries.SleepEntry.INSTANCE); - } else if (msg instanceof Protocol.CallEntryMessage) { - return new MessageHeader( - MessageType.CallEntryMessage, - ((Protocol.CallEntryMessage) msg).getResultCase() - != Protocol.CallEntryMessage.ResultCase.RESULT_NOT_SET - ? DONE_FLAG - : 0, - msg.getSerializedSize()); - } else if (msg instanceof Protocol.AwakeableEntryMessage) { - return fromCompletableMessage( - (Protocol.AwakeableEntryMessage) msg, Entries.AwakeableEntry.INSTANCE); - } else if (msg instanceof Protocol.RunEntryMessage) { - return new MessageHeader( - MessageType.RunEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); - } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { - return new MessageHeader( - MessageType.CombinatorAwaitableEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); - } - // Messages with no flags - return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); - } - - private static > - MessageHeader fromCompletableMessage(MSG msg, E entry) { - return new MessageHeader( - MessageType.fromMessage(msg), - entry.hasResult(msg) ? DONE_FLAG : 0, - msg.getSerializedSize()); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java deleted file mode 100644 index d61e6164c..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.MessageLite; -import com.google.protobuf.Parser; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; - -public enum MessageType { - StartMessage, - CompletionMessage, - SuspensionMessage, - ErrorMessage, - EndMessage, - EntryAckMessage, - - // IO - InputEntryMessage, - OutputEntryMessage, - - // State access - GetStateEntryMessage, - SetStateEntryMessage, - ClearStateEntryMessage, - ClearAllStateEntryMessage, - GetStateKeysEntryMessage, - GetPromiseEntryMessage, - PeekPromiseEntryMessage, - CompletePromiseEntryMessage, - - // Syscalls - SleepEntryMessage, - CallEntryMessage, - OneWayCallEntryMessage, - AwakeableEntryMessage, - CompleteAwakeableEntryMessage, - RunEntryMessage, - - // SDK specific - CombinatorAwaitableEntryMessage; - - public static final short START_MESSAGE_TYPE = 0x0000; - public static final short COMPLETION_MESSAGE_TYPE = 0x0001; - public static final short SUSPENSION_MESSAGE_TYPE = 0x0002; - public static final short ERROR_MESSAGE_TYPE = 0x0003; - public static final short ENTRY_ACK_MESSAGE_TYPE = 0x0004; - public static final short END_MESSAGE_TYPE = 0x0005; - public static final short INPUT_ENTRY_MESSAGE_TYPE = 0x0400; - public static final short OUTPUT_ENTRY_MESSAGE_TYPE = 0x0401; - public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800; - public static final short SET_STATE_ENTRY_MESSAGE_TYPE = 0x0801; - public static final short CLEAR_STATE_ENTRY_MESSAGE_TYPE = 0x0802; - public static final short CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE = 0x0803; - public static final short GET_STATE_KEYS_ENTRY_MESSAGE_TYPE = 0x0804; - public static final short GET_PROMISE_ENTRY_MESSAGE_TYPE = 0x0808; - public static final short PEEK_PROMISE_ENTRY_MESSAGE_TYPE = 0x0809; - public static final short COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE = 0x080A; - public static final short SLEEP_ENTRY_MESSAGE_TYPE = 0x0C00; - public static final short INVOKE_ENTRY_MESSAGE_TYPE = 0x0C01; - public static final short BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE = 0x0C02; - public static final short AWAKEABLE_ENTRY_MESSAGE_TYPE = 0x0C03; - public static final short COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE = 0x0C04; - public static final short COMBINATOR_AWAITABLE_ENTRY_MESSAGE_TYPE = (short) 0xFC00; - public static final short SIDE_EFFECT_ENTRY_MESSAGE_TYPE = (short) 0x0C05; - - public Parser messageParser() { - return switch (this) { - case StartMessage -> Protocol.StartMessage.parser(); - case CompletionMessage -> Protocol.CompletionMessage.parser(); - case SuspensionMessage -> Protocol.SuspensionMessage.parser(); - case EndMessage -> Protocol.EndMessage.parser(); - case ErrorMessage -> Protocol.ErrorMessage.parser(); - case EntryAckMessage -> Protocol.EntryAckMessage.parser(); - case InputEntryMessage -> Protocol.InputEntryMessage.parser(); - case OutputEntryMessage -> Protocol.OutputEntryMessage.parser(); - case GetStateEntryMessage -> Protocol.GetStateEntryMessage.parser(); - case SetStateEntryMessage -> Protocol.SetStateEntryMessage.parser(); - case ClearStateEntryMessage -> Protocol.ClearStateEntryMessage.parser(); - case ClearAllStateEntryMessage -> Protocol.ClearAllStateEntryMessage.parser(); - case GetStateKeysEntryMessage -> Protocol.GetStateKeysEntryMessage.parser(); - case GetPromiseEntryMessage -> Protocol.GetPromiseEntryMessage.parser(); - case PeekPromiseEntryMessage -> Protocol.PeekPromiseEntryMessage.parser(); - case CompletePromiseEntryMessage -> Protocol.CompletePromiseEntryMessage.parser(); - case SleepEntryMessage -> Protocol.SleepEntryMessage.parser(); - case CallEntryMessage -> Protocol.CallEntryMessage.parser(); - case OneWayCallEntryMessage -> Protocol.OneWayCallEntryMessage.parser(); - case AwakeableEntryMessage -> Protocol.AwakeableEntryMessage.parser(); - case CompleteAwakeableEntryMessage -> Protocol.CompleteAwakeableEntryMessage.parser(); - case CombinatorAwaitableEntryMessage -> Java.CombinatorAwaitableEntryMessage.parser(); - case RunEntryMessage -> Protocol.RunEntryMessage.parser(); - }; - } - - public short encode() { - return switch (this) { - case StartMessage -> START_MESSAGE_TYPE; - case CompletionMessage -> COMPLETION_MESSAGE_TYPE; - case SuspensionMessage -> SUSPENSION_MESSAGE_TYPE; - case EndMessage -> END_MESSAGE_TYPE; - case ErrorMessage -> ERROR_MESSAGE_TYPE; - case EntryAckMessage -> ENTRY_ACK_MESSAGE_TYPE; - case InputEntryMessage -> INPUT_ENTRY_MESSAGE_TYPE; - case OutputEntryMessage -> OUTPUT_ENTRY_MESSAGE_TYPE; - case GetStateEntryMessage -> GET_STATE_ENTRY_MESSAGE_TYPE; - case SetStateEntryMessage -> SET_STATE_ENTRY_MESSAGE_TYPE; - case ClearStateEntryMessage -> CLEAR_STATE_ENTRY_MESSAGE_TYPE; - case ClearAllStateEntryMessage -> CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE; - case GetStateKeysEntryMessage -> GET_STATE_KEYS_ENTRY_MESSAGE_TYPE; - case GetPromiseEntryMessage -> GET_PROMISE_ENTRY_MESSAGE_TYPE; - case PeekPromiseEntryMessage -> PEEK_PROMISE_ENTRY_MESSAGE_TYPE; - case CompletePromiseEntryMessage -> COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE; - case SleepEntryMessage -> SLEEP_ENTRY_MESSAGE_TYPE; - case CallEntryMessage -> INVOKE_ENTRY_MESSAGE_TYPE; - case OneWayCallEntryMessage -> BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE; - case AwakeableEntryMessage -> AWAKEABLE_ENTRY_MESSAGE_TYPE; - case CompleteAwakeableEntryMessage -> COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE; - case CombinatorAwaitableEntryMessage -> COMBINATOR_AWAITABLE_ENTRY_MESSAGE_TYPE; - case RunEntryMessage -> SIDE_EFFECT_ENTRY_MESSAGE_TYPE; - }; - } - - public static MessageType decode(short value) throws ProtocolException { - return switch (value) { - case START_MESSAGE_TYPE -> StartMessage; - case COMPLETION_MESSAGE_TYPE -> CompletionMessage; - case SUSPENSION_MESSAGE_TYPE -> SuspensionMessage; - case END_MESSAGE_TYPE -> EndMessage; - case ERROR_MESSAGE_TYPE -> ErrorMessage; - case ENTRY_ACK_MESSAGE_TYPE -> EntryAckMessage; - case INPUT_ENTRY_MESSAGE_TYPE -> InputEntryMessage; - case OUTPUT_ENTRY_MESSAGE_TYPE -> OutputEntryMessage; - case GET_STATE_ENTRY_MESSAGE_TYPE -> GetStateEntryMessage; - case SET_STATE_ENTRY_MESSAGE_TYPE -> SetStateEntryMessage; - case CLEAR_STATE_ENTRY_MESSAGE_TYPE -> ClearStateEntryMessage; - case CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE -> ClearAllStateEntryMessage; - case GET_STATE_KEYS_ENTRY_MESSAGE_TYPE -> GetStateKeysEntryMessage; - case GET_PROMISE_ENTRY_MESSAGE_TYPE -> GetPromiseEntryMessage; - case PEEK_PROMISE_ENTRY_MESSAGE_TYPE -> PeekPromiseEntryMessage; - case COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE -> CompletePromiseEntryMessage; - case SLEEP_ENTRY_MESSAGE_TYPE -> SleepEntryMessage; - case INVOKE_ENTRY_MESSAGE_TYPE -> CallEntryMessage; - case BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE -> OneWayCallEntryMessage; - case AWAKEABLE_ENTRY_MESSAGE_TYPE -> AwakeableEntryMessage; - case COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE -> CompleteAwakeableEntryMessage; - case COMBINATOR_AWAITABLE_ENTRY_MESSAGE_TYPE -> CombinatorAwaitableEntryMessage; - case SIDE_EFFECT_ENTRY_MESSAGE_TYPE -> RunEntryMessage; - default -> throw ProtocolException.unknownMessageType(value); - }; - } - - public static MessageType fromMessage(MessageLite msg) { - if (msg instanceof Protocol.SuspensionMessage) { - return MessageType.SuspensionMessage; - } else if (msg instanceof Protocol.ErrorMessage) { - return MessageType.ErrorMessage; - } else if (msg instanceof Protocol.EndMessage) { - return MessageType.EndMessage; - } else if (msg instanceof Protocol.EntryAckMessage) { - return MessageType.EntryAckMessage; - } else if (msg instanceof Protocol.InputEntryMessage) { - return MessageType.InputEntryMessage; - } else if (msg instanceof Protocol.OutputEntryMessage) { - return MessageType.OutputEntryMessage; - } else if (msg instanceof Protocol.GetStateEntryMessage) { - return MessageType.GetStateEntryMessage; - } else if (msg instanceof Protocol.SetStateEntryMessage) { - return MessageType.SetStateEntryMessage; - } else if (msg instanceof Protocol.ClearStateEntryMessage) { - return MessageType.ClearStateEntryMessage; - } else if (msg instanceof Protocol.ClearAllStateEntryMessage) { - return MessageType.ClearAllStateEntryMessage; - } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { - return MessageType.GetStateKeysEntryMessage; - } else if (msg instanceof Protocol.GetPromiseEntryMessage) { - return MessageType.GetPromiseEntryMessage; - } else if (msg instanceof Protocol.PeekPromiseEntryMessage) { - return MessageType.PeekPromiseEntryMessage; - } else if (msg instanceof Protocol.CompletePromiseEntryMessage) { - return MessageType.CompletePromiseEntryMessage; - } else if (msg instanceof Protocol.SleepEntryMessage) { - return MessageType.SleepEntryMessage; - } else if (msg instanceof Protocol.CallEntryMessage) { - return MessageType.CallEntryMessage; - } else if (msg instanceof Protocol.OneWayCallEntryMessage) { - return MessageType.OneWayCallEntryMessage; - } else if (msg instanceof Protocol.AwakeableEntryMessage) { - return MessageType.AwakeableEntryMessage; - } else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) { - return MessageType.CompleteAwakeableEntryMessage; - } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { - return MessageType.CombinatorAwaitableEntryMessage; - } else if (msg instanceof Protocol.RunEntryMessage) { - return MessageType.RunEntryMessage; - } else if (msg instanceof Protocol.CompletionMessage) { - throw new IllegalArgumentException("SDK should never send a CompletionMessage"); - } - throw new IllegalStateException(); - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 72ab63f5b..ee2ddd84d 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -9,15 +9,15 @@ package dev.restate.sdk.core; import com.google.protobuf.MessageLite; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.statemachine.NotificationId; +import dev.restate.sdk.types.TerminalException; public class ProtocolException extends RuntimeException { static final int UNAUTHORIZED_CODE = 401; static final int NOT_FOUND_CODE = 404; - static final int UNSUPPORTED_MEDIA_TYPE_CODE = 415; - static final int INTERNAL_CODE = 500; + public static final int UNSUPPORTED_MEDIA_TYPE_CODE = 415; + public static final int INTERNAL_CODE = 500; static final int JOURNAL_MISMATCH_CODE = 570; static final int PROTOCOL_VIOLATION_CODE = 571; @@ -30,7 +30,7 @@ private ProtocolException(String message) { this(message, TerminalException.INTERNAL_SERVER_ERROR_CODE); } - ProtocolException(String message, int code) { + public ProtocolException(String message, int code) { this(message, code, null); } @@ -43,7 +43,7 @@ public int getCode() { return code; } - static ProtocolException unexpectedMessage( + public static ProtocolException unexpectedMessage( Class expected, MessageLite actual) { return new ProtocolException( "Unexpected message type received from the runtime. Expected: '" @@ -54,29 +54,90 @@ static ProtocolException unexpectedMessage( PROTOCOL_VIOLATION_CODE); } - static ProtocolException entryDoesNotMatch(MessageLite expected, MessageLite actual) { + public static ProtocolException unexpectedMessage(String type, MessageLite actual) { return new ProtocolException( - "Journal entry " + expected.getClass() + " does not match: " + expected + " != " + actual, + "Unexpected message type received from the runtime. Expected: '" + + type + + "', Actual: '" + + actual.getClass().getCanonicalName() + + "'", + PROTOCOL_VIOLATION_CODE); + } + + static ProtocolException unexpectedNotificationVariant(Class clazz) { + return new ProtocolException( + "Unexpected notification variant " + clazz.getName(), PROTOCOL_VIOLATION_CODE); + } + + public static ProtocolException commandDoesNotMatch(MessageLite expected, MessageLite actual) { + return new ProtocolException( + "Replayed journal doesn't match the handler code.\nThe handler code generated: " + + expected + + "\nwhile the replayed entry is: " + + actual, JOURNAL_MISMATCH_CODE); } - static ProtocolException completionDoesNotMatch( - String entry, Protocol.CompletionMessage.ResultCase actual) { + public static ProtocolException commandClassDoesNotMatch( + Class expectedClazz, MessageLite actual) { return new ProtocolException( - "Completion for entry " + entry + " doesn't expect completion variant " + actual, + "Replayed journal doesn't match the handler code.\nThe handler code generated: " + + expectedClazz.getName() + + "\nwhile the replayed entry is: " + + actual, JOURNAL_MISMATCH_CODE); } - static ProtocolException unknownMessageType(short type) { + public static ProtocolException commandsToProcessIsEmpty() { + return new ProtocolException("Expecting command queue to be non empty", JOURNAL_MISMATCH_CODE); + } + + public static ProtocolException unknownMessageType(short type) { return new ProtocolException( "MessageType " + Integer.toHexString(type) + " unknown", PROTOCOL_VIOLATION_CODE); } - static ProtocolException methodNotFound(String serviceName, String handlerName) { + public static ProtocolException methodNotFound(String serviceName, String handlerName) { return new ProtocolException( "Cannot find handler '" + serviceName + "/" + handlerName + "'", NOT_FOUND_CODE); } + public static ProtocolException badState(Object thisState) { + return new ProtocolException( + "Cannot process operation because the handler is in unexpected state: " + thisState, + INTERNAL_CODE); + } + + public static ProtocolException badNotificationMessage(String missingField) { + return new ProtocolException( + "Bad notification message, missing field " + missingField, PROTOCOL_VIOLATION_CODE); + } + + public static ProtocolException badRunNotificationId(NotificationId notificationId) { + return new ProtocolException( + "Bad run handle, should be mapped to a completion notification id, but was " + + notificationId, + PROTOCOL_VIOLATION_CODE); + } + + public static ProtocolException commandMissingField(Class clazz, String missingField) { + return new ProtocolException( + "Bad command " + clazz.getName() + ", missing field " + missingField, + PROTOCOL_VIOLATION_CODE); + } + + public static ProtocolException inputClosedWhileWaitingEntries() { + return new ProtocolException( + "The input was closed while still waiting to receive all the `known_entries`", + PROTOCOL_VIOLATION_CODE); + } + + public static ProtocolException closedWhileWaitingEntries() { + return new ProtocolException( + "The state machine was closed while still waiting to receive all the `known_entries`", + PROTOCOL_VIOLATION_CODE); + } + static ProtocolException invalidSideEffectCall() { return new ProtocolException( "A syscall was invoked from within a side effect closure.", @@ -84,18 +145,14 @@ static ProtocolException invalidSideEffectCall() { null); } - static ProtocolException unauthorized(Throwable e) { - return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e); + public static ProtocolException idempotencyKeyIsEmpty() { + return new ProtocolException( + "The provided idempotency key is empty.", + TerminalException.INTERNAL_SERVER_ERROR_CODE, + null); } - static ProtocolException unsupportedFeature( - Protocol.ServiceProtocolVersion version, String name) { - return new ProtocolException( - "The feature \"" - + name - + "\" is not supported by the negotiated protocol version \"" - + version.getNumber() - + "\". This might be caused by rolling back a Restate setup to a previous version, while using the experimental context.", - UNSUPPORTED_MEDIA_TYPE_CODE); + public static ProtocolException unauthorized(Throwable e) { + return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ReadyResultStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/ReadyResultStateMachine.java deleted file mode 100644 index 69b3d18df..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ReadyResultStateMachine.java +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.syscalls.Result; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -/** State machine tracking ready results */ -class ReadyResultStateMachine - extends BaseSuspendableCallbackStateMachine { - - private static final Logger LOG = LogManager.getLogger(ReadyResultStateMachine.class); - - interface OnNewReadyResultCallback extends SuspendableCallback { - boolean onNewResult(Map> resultMap); - } - - private final Map completions; - private final Map>> completionParsers; - private final Map> results; - - ReadyResultStateMachine() { - this.completions = new HashMap<>(); - this.completionParsers = new HashMap<>(); - this.results = new HashMap<>(); - } - - void offerCompletion(Protocol.CompletionMessage completionMessage) { - LOG.trace("Offered new completion {}", completionMessage); - - this.completions.put(completionMessage.getEntryIndex(), completionMessage); - this.tryParse(completionMessage.getEntryIndex()); - } - - void offerCompletionParser( - int entryIndex, Function> parser) { - LOG.trace("Offered new completion parser for index {}", entryIndex); - - this.completionParsers.put(entryIndex, parser); - this.tryParse(entryIndex); - } - - void onNewReadyResult(OnNewReadyResultCallback callback) { - this.assertCallbackNotSet("Two concurrent reads were requested."); - - this.tryProgress(callback); - } - - @Override - void abort(Throwable cause) { - super.abort(cause); - this.consumeCallback(this::tryProgress); - } - - private void tryParse(int entryIndex) { - Protocol.CompletionMessage completionMessage = this.completions.get(entryIndex); - if (completionMessage == null) { - return; - } - - Function> parser = - this.completionParsers.remove(entryIndex); - if (parser == null) { - return; - } - - this.completions.remove(entryIndex, completionMessage); - - // Parse to ready result - Result readyResult = parser.apply(completionMessage); - - // Push to the ready result queue - this.results.put(completionMessage.getEntryIndex(), readyResult); - - // We have a new result, let's try to progress - this.consumeCallback(this::tryProgress); - } - - private void tryProgress(OnNewReadyResultCallback cb) { - boolean resolved = cb.onNewResult(this.results); - if (!resolved) { - this.setCallback(cb); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandler.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessor.java similarity index 63% rename from sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandler.java rename to sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessor.java index 1a31bb5fd..8822600a1 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandler.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessor.java @@ -8,10 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import java.nio.ByteBuffer; +import dev.restate.common.Slice; +import dev.restate.sdk.endpoint.Endpoint; import java.util.concurrent.Flow; -/** Resolved handler for an invocation. */ -public interface ResolvedEndpointHandler extends Flow.Processor { +/** Resolved handler for an invocation. See {@link Endpoint} for more details. */ +public interface RequestProcessor extends Flow.Processor { + int statusCode(); + String responseContentType(); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java new file mode 100644 index 000000000..24261f565 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java @@ -0,0 +1,160 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Slice; +import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.types.TerminalException; +import io.opentelemetry.context.Context; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +final class RequestProcessorImpl implements RequestProcessor { + + private static final Logger LOG = LogManager.getLogger(RequestProcessorImpl.class); + + private final String fullyQualifiedHandlerName; + private final StateMachine stateMachine; + private final HandlerDefinition handlerDefinition; + private final Context otelContext; + private final EndpointRequestHandler.LoggingContextSetter loggingContextSetter; + private final Executor syscallsExecutor; + + @SuppressWarnings("unchecked") + public RequestProcessorImpl( + String fullyQualifiedHandlerName, + StateMachine stateMachine, + HandlerDefinition handlerDefinition, + Context otelContext, + EndpointRequestHandler.LoggingContextSetter loggingContextSetter, + Executor syscallExecutor) { + this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; + this.stateMachine = stateMachine; + this.otelContext = otelContext; + this.loggingContextSetter = loggingContextSetter; + this.handlerDefinition = (HandlerDefinition) handlerDefinition; + this.syscallsExecutor = syscallExecutor; + } + + // Flow methods implementation + + @Override + public void subscribe(Flow.Subscriber subscriber) { + LOG.trace("Start processing invocation"); + this.stateMachine.subscribe(subscriber); + stateMachine + .waitForReady() + .thenCompose(v -> this.onReady()) + .whenComplete( + (v, t) -> { + if (t != null) { + this.onError(t); + } + }); + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.stateMachine.onSubscribe(subscription); + } + + @Override + public void onNext(Slice item) { + this.stateMachine.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + this.stateMachine.onError(throwable); + } + + @Override + public void onComplete() { + this.stateMachine.onComplete(); + } + + @Override + public int statusCode() { + return 200; + } + + @Override + public String responseContentType() { + return this.stateMachine.getResponseContentType(); + } + + private CompletableFuture onReady() { + StateMachine.Input input = stateMachine.input(); + + if (input == null) { + return CompletableFuture.failedFuture( + new IllegalStateException("State machine input is empty")); + } + + this.loggingContextSetter.set( + EndpointRequestHandler.LoggingContextSetter.INVOCATION_ID_KEY, + input.invocationId().toString()); + + // Prepare HandlerContext object + HandlerContextInternal contextInternal = + this.syscallsExecutor != null + ? new ExecutorSwitchingHandlerContextImpl( + fullyQualifiedHandlerName, stateMachine, otelContext, input, this.syscallsExecutor) + : new HandlerContextImpl(fullyQualifiedHandlerName, stateMachine, otelContext, input); + + CompletableFuture userCodeFuture = + this.handlerDefinition + .getRunner() + .run( + contextInternal, + handlerDefinition.getRequestSerde(), + handlerDefinition.getResponseSerde()); + + return userCodeFuture.handle( + (slice, t) -> { + if (t != null) { + this.end(contextInternal, t); + } else { + this.writeOutputAndEnd(contextInternal, slice); + } + return null; + }); + } + + private CompletableFuture writeOutputAndEnd( + HandlerContextInternal contextInternal, Slice output) { + return contextInternal.writeOutput(output).thenAccept(v -> this.end(contextInternal, null)); + } + + private CompletableFuture end( + HandlerContextInternal contextInternal, @Nullable Throwable exception) { + if (exception == null || ExceptionUtils.containsSuspendedException(exception)) { + contextInternal.close(); + } else { + LOG.warn("Error when processing the invocation", exception); + if (ExceptionUtils.isTerminalException(exception)) { + return contextInternal + .writeOutput((TerminalException) exception) + .thenAccept( + v -> { + LOG.trace("Closed correctly with non ok exception", exception); + contextInternal.close(); + }); + } else { + contextInternal.fail(exception); + } + } + return CompletableFuture.completedFuture(null); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java deleted file mode 100644 index 1a4d092f4..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.syscalls.*; -import java.nio.ByteBuffer; -import java.util.concurrent.Executor; -import java.util.concurrent.Flow; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -final class ResolvedEndpointHandlerImpl implements ResolvedEndpointHandler { - - private static final Logger LOG = LogManager.getLogger(ResolvedEndpointHandlerImpl.class); - - private final Protocol.ServiceProtocolVersion serviceProtocolVersion; - private final InvocationStateMachine stateMachine; - private final InvocationFlow.InvocationInputSubscriber input; - private final InvocationFlow.InvocationOutputPublisher output; - private final HandlerSpecification spec; - private final HandlerRunner wrappedHandler; - private final @Nullable Object serviceOptions; - private final @Nullable Executor syscallsExecutor; - - @SuppressWarnings("unchecked") - public ResolvedEndpointHandlerImpl( - Protocol.ServiceProtocolVersion serviceProtocolVersion, - InvocationStateMachine stateMachine, - HandlerDefinition handler, - @Nullable Object serviceOptions, - @Nullable Executor syscallExecutor) { - this.serviceProtocolVersion = serviceProtocolVersion; - this.stateMachine = stateMachine; - this.input = new MessageDecoder(new ExceptionCatchingSubscriber<>(stateMachine)); - this.output = new MessageEncoder(stateMachine); - this.spec = (HandlerSpecification) handler.getSpec(); - this.wrappedHandler = - new HandlerRunnerWrapper<>((HandlerRunner) handler.getRunner()); - this.serviceOptions = serviceOptions; - this.syscallsExecutor = syscallExecutor; - } - - // Flow methods implementation - - @Override - public void subscribe(Flow.Subscriber subscriber) { - LOG.trace("Start processing invocation"); - this.output.subscribe(subscriber); - stateMachine.startAndConsumeInput( - SyscallCallback.of( - request -> { - // Prepare Syscalls object - SyscallsInternal syscalls = - this.syscallsExecutor != null - ? new ExecutorSwitchingSyscalls( - new SyscallsImpl(request, stateMachine), this.syscallsExecutor) - : new SyscallsImpl(request, stateMachine); - - // pollInput then invoke the wrappedHandler - wrappedHandler.run( - spec, - syscalls, - serviceOptions, - SyscallCallback.of( - o -> this.writeOutputAndEnd(syscalls, o), t -> this.end(syscalls, t))); - }, - t -> {})); - } - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.input.onSubscribe(subscription); - } - - @Override - public void onNext(ByteBuffer item) { - this.input.onNext(item); - } - - @Override - public void onError(Throwable throwable) { - this.input.onError(throwable); - } - - @Override - public void onComplete() { - this.input.onComplete(); - } - - @Override - public String responseContentType() { - return ServiceProtocol.serviceProtocolVersionToHeaderValue(serviceProtocolVersion); - } - - private void writeOutputAndEnd(SyscallsInternal syscalls, ByteBuffer output) { - syscalls.writeOutput( - output, - SyscallCallback.ofVoid( - () -> { - LOG.trace("Wrote output message"); - this.end(syscalls, null); - }, - syscalls::fail)); - } - - private void end(SyscallsInternal syscalls, @Nullable Throwable exception) { - if (exception == null || Util.containsSuspendedException(exception)) { - syscalls.close(); - } else { - LOG.warn("Error when processing the invocation", exception); - if (Util.isTerminalException(exception)) { - syscalls.writeOutput( - (TerminalException) exception, - SyscallCallback.ofVoid( - () -> { - LOG.trace("Closed correctly with non ok exception", exception); - syscalls.close(); - }, - syscalls::fail)); - } else { - syscalls.fail(exception); - } - } - } - - private static class HandlerRunnerWrapper implements HandlerRunner { - - private final HandlerRunner handler; - - private HandlerRunnerWrapper(HandlerRunner handler) { - this.handler = handler; - } - - @Override - public void run( - HandlerSpecification spec, - Syscalls syscalls, - @Nullable O options, - SyscallCallback callback) { - try { - this.handler.run(spec, syscalls, options, callback); - } catch (Throwable e) { - callback.onCancel(e); - } - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RestateContextDataProvider.java b/sdk-core/src/main/java/dev/restate/sdk/core/RestateContextDataProvider.java index 7a302436d..958117110 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RestateContextDataProvider.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RestateContextDataProvider.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import dev.restate.sdk.common.syscalls.HandlerRunner; +import dev.restate.sdk.endpoint.definition.HandlerRunner; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -16,7 +16,7 @@ /** * Log4j2 {@link ContextDataProvider} inferring context from {@link - * HandlerRunner#SYSCALLS_THREAD_LOCAL}. + * HandlerRunner#HANDLER_CONTEXT_THREAD_LOCAL}. * *

This is used to propagate the context to the user code, such that log statements from the user * will contain the restate logging context variables. @@ -24,8 +24,9 @@ public class RestateContextDataProvider implements ContextDataProvider { @Override public Map supplyContextData() { - SyscallsInternal syscalls = (SyscallsInternal) HandlerRunner.SYSCALLS_THREAD_LOCAL.get(); - if (syscalls == null) { + HandlerContextInternal handlerContextInternal = + (HandlerContextInternal) HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get(); + if (handlerContextInternal == null) { return Collections.emptyMap(); } @@ -33,14 +34,14 @@ public Map supplyContextData() { // https://github.com/apache/logging-log4j2/issues/2098 HashMap m = new HashMap<>(3); m.put( - RestateEndpoint.LoggingContextSetter.INVOCATION_ID_KEY, - syscalls.request().invocationId().toString()); + EndpointRequestHandler.LoggingContextSetter.INVOCATION_ID_KEY, + handlerContextInternal.request().invocationId().toString()); m.put( - RestateEndpoint.LoggingContextSetter.INVOCATION_TARGET_KEY, - syscalls.getFullyQualifiedMethodName()); + EndpointRequestHandler.LoggingContextSetter.INVOCATION_TARGET_KEY, + handlerContextInternal.getFullyQualifiedMethodName()); m.put( - RestateEndpoint.LoggingContextSetter.INVOCATION_STATUS_KEY, - syscalls.getInvocationState().toString()); + EndpointRequestHandler.LoggingContextSetter.INVOCATION_STATUS_KEY, + handlerContextInternal.getInvocationState().toString()); return m; } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java b/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java deleted file mode 100644 index 93cb85ed9..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.generated.service.discovery.Discovery; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.auth.RequestIdentityVerifier; -import dev.restate.sdk.common.syscalls.HandlerDefinition; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.common.syscalls.ServiceDefinitionFactory; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import dev.restate.sdk.core.manifest.Service; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.SpanKind; -import io.opentelemetry.api.trace.Tracer; -import java.util.*; -import java.util.concurrent.Executor; -import java.util.function.Function; -import java.util.stream.Collectors; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.ThreadContext; -import org.jspecify.annotations.Nullable; - -public class RestateEndpoint { - - private static final Logger LOG = LogManager.getLogger(RestateEndpoint.class); - - private final Map> services; - private final Tracer tracer; - private final RequestIdentityVerifier requestIdentityVerifier; - private final EndpointManifest deploymentManifest; - private final boolean experimentalContextEnabled; - - private RestateEndpoint( - EndpointManifestSchema.ProtocolMode protocolMode, - Map> services, - Tracer tracer, - RequestIdentityVerifier requestIdentityVerifier, - boolean experimentalContextEnabled) { - this.services = services; - this.tracer = tracer; - this.requestIdentityVerifier = requestIdentityVerifier; - this.deploymentManifest = - new EndpointManifest( - protocolMode, - services.values().stream().map(c -> c.service), - experimentalContextEnabled); - this.experimentalContextEnabled = experimentalContextEnabled; - - LOG.info("Registered services: {}", this.services.keySet()); - } - - public ResolvedEndpointHandler resolve( - String contentType, - String componentName, - String handlerName, - RequestIdentityVerifier.Headers headers, - io.opentelemetry.context.Context otelContext, - LoggingContextSetter loggingContextSetter, - @Nullable Executor syscallExecutor) - throws ProtocolException { - final Protocol.ServiceProtocolVersion serviceProtocolVersion = - ServiceProtocol.parseServiceProtocolVersion(contentType); - - if (!ServiceProtocol.isSupported(serviceProtocolVersion, this.experimentalContextEnabled)) { - throw new ProtocolException( - String.format( - "Service endpoint does not support the service protocol version '%s'.", contentType), - ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); - } - - // Resolve the service method definition - @SuppressWarnings("unchecked") - ServiceAndOptions svc = (ServiceAndOptions) this.services.get(componentName); - if (svc == null) { - throw ProtocolException.methodNotFound(componentName, handlerName); - } - String fullyQualifiedServiceMethod = componentName + "/" + handlerName; - HandlerDefinition handler = svc.service.getHandler(handlerName); - if (handler == null) { - throw ProtocolException.methodNotFound(componentName, handlerName); - } - - // Verify request - if (requestIdentityVerifier != null) { - try { - requestIdentityVerifier.verifyRequest(headers); - } catch (Exception e) { - throw ProtocolException.unauthorized(e); - } - } - - // Generate the span - Span span = - tracer - .spanBuilder("Invoke method") - .setSpanKind(SpanKind.SERVER) - .setParent(otelContext) - .startSpan(); - - // Setup logging context - loggingContextSetter.set( - LoggingContextSetter.INVOCATION_TARGET_KEY, fullyQualifiedServiceMethod); - - // Instantiate state machine, syscall and grpc bridge - InvocationStateMachine stateMachine = - new InvocationStateMachine( - componentName, - fullyQualifiedServiceMethod, - span, - loggingContextSetter, - serviceProtocolVersion); - - return new ResolvedEndpointHandlerImpl( - serviceProtocolVersion, stateMachine, handler, svc.options, syscallExecutor); - } - - public DiscoveryResponse handleDiscoveryRequest(String acceptContentType) - throws ProtocolException { - Discovery.ServiceDiscoveryProtocolVersion version = - ServiceProtocol.selectSupportedServiceDiscoveryProtocolVersion(acceptContentType); - if (!ServiceProtocol.isSupported(version)) { - throw new ProtocolException( - String.format( - "Unsupported Discovery version in the Accept header '%s'", acceptContentType), - ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); - } - - EndpointManifestSchema response = this.deploymentManifest.manifest(); - LOG.info( - "Replying to discovery request with services [{}]", - response.getServices().stream().map(Service::getName).collect(Collectors.joining(","))); - - return new DiscoveryResponse( - ServiceProtocol.serviceDiscoveryProtocolVersionToHeaderValue(version), - ServiceProtocol.serializeManifest(version, response)); - } - - // -- Builder - - public static Builder newBuilder(EndpointManifestSchema.ProtocolMode protocolMode) { - return new Builder(protocolMode); - } - - public static class Builder { - - private final List> services = new ArrayList<>(); - private final EndpointManifestSchema.ProtocolMode protocolMode; - private RequestIdentityVerifier requestIdentityVerifier; - private Tracer tracer = OpenTelemetry.noop().getTracer("NOOP"); - private boolean experimentalContextEnabled = false; - - public Builder(EndpointManifestSchema.ProtocolMode protocolMode) { - this.protocolMode = protocolMode; - } - - public Builder bind(ServiceDefinition component, @Nullable O options) { - this.services.add(new ServiceAndOptions<>(component, options)); - return this; - } - - public Builder withTracer(Tracer tracer) { - this.tracer = tracer; - return this; - } - - public Builder withRequestIdentityVerifier(RequestIdentityVerifier requestIdentityVerifier) { - this.requestIdentityVerifier = requestIdentityVerifier; - return this; - } - - public Builder enablePreviewContext() { - this.experimentalContextEnabled = true; - return this; - } - - public RestateEndpoint build() { - return new RestateEndpoint( - this.protocolMode, - this.services.stream() - .collect(Collectors.toMap(c -> c.service.getServiceName(), Function.identity())), - tracer, - requestIdentityVerifier, - experimentalContextEnabled); - } - } - - /** - * Interface to abstract setting the logging context variables. - * - *

In classic multithreaded environments, you can just use {@link - * LoggingContextSetter#THREAD_LOCAL_INSTANCE}, though the caller of {@link RestateEndpoint} must - * take care of the cleanup of the thread local map. - */ - @FunctionalInterface - public interface LoggingContextSetter { - - String INVOCATION_ID_KEY = "restateInvocationId"; - String INVOCATION_TARGET_KEY = "restateInvocationTarget"; - String INVOCATION_STATUS_KEY = "restateInvocationStatus"; - - LoggingContextSetter THREAD_LOCAL_INSTANCE = ThreadContext::put; - - void set(String key, String value); - } - - private static class ServiceDefinitionFactorySingleton { - private static final ServiceDefinitionFactoryDiscovery INSTANCE = - new ServiceDefinitionFactoryDiscovery(); - } - - @SuppressWarnings("rawtypes") - private static class ServiceDefinitionFactoryDiscovery { - - private final List factories; - - private ServiceDefinitionFactoryDiscovery() { - this.factories = new ArrayList<>(); - - var serviceLoaderIterator = ServiceLoader.load(ServiceDefinitionFactory.class).iterator(); - while (serviceLoaderIterator.hasNext()) { - try { - this.factories.add(serviceLoaderIterator.next()); - } catch (ServiceConfigurationError | Exception e) { - LOG.debug( - "Found service that cannot be loaded using service provider. " - + "You can ignore this message during development.\n" - + "This might be the result of using a compiler with incremental builds (e.g. IntelliJ IDEA) " - + "that updated a dirty META-INF file after removing/renaming an annotated service.", - e); - } - } - } - - private @Nullable ServiceDefinitionFactory discoverFactory(Object service) { - return this.factories.stream().filter(sa -> sa.supports(service)).findFirst().orElse(null); - } - } - - /** Resolve the code generated {@link ServiceDefinitionFactory} */ - @SuppressWarnings("unchecked") - public static ServiceDefinitionFactory discoverServiceDefinitionFactory( - Object service) { - if (service instanceof ServiceDefinitionFactory) { - // We got this already - return (ServiceDefinitionFactory) service; - } - if (service instanceof ServiceDefinition) { - // We got this already - return new ServiceDefinitionFactory<>() { - @Override - public ServiceDefinition create(Object serviceObject) { - return (ServiceDefinition) serviceObject; - } - - @Override - public boolean supports(Object serviceObject) { - return serviceObject == service; - } - }; - } - return Objects.requireNonNull( - ServiceDefinitionFactorySingleton.INSTANCE.discoverFactory(service), - () -> - "ServiceDefinitionFactory class not found for service " - + service.getClass().getCanonicalName() - + ". " - + "Make sure the annotation processor is correctly configured to generate the ServiceDefinitionFactory, " - + "and it generates the META-INF/services/" - + ServiceDefinitionFactory.class.getCanonicalName() - + " file containing the generated class. " - + "If you're using fat jars, make sure the jar plugin correctly squashes all the META-INF/services files. " - + "Found ServiceAdapter: " - + ServiceDefinitionFactorySingleton.INSTANCE.factories); - } - - private record ServiceAndOptions(ServiceDefinition service, O options) {} - - public static class DiscoveryResponse { - private final String contentType; - private final byte[] serializedManifest; - - private DiscoveryResponse(String contentType, byte[] serializedManifest) { - this.contentType = contentType; - this.serializedManifest = serializedManifest; - } - - public String getContentType() { - return contentType; - } - - public byte[] getSerializedManifest() { - return serializedManifest; - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/StaticResponseRequestProcessor.java b/sdk-core/src/main/java/dev/restate/sdk/core/StaticResponseRequestProcessor.java new file mode 100644 index 000000000..afa07db5e --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/StaticResponseRequestProcessor.java @@ -0,0 +1,67 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import dev.restate.common.Slice; +import java.util.concurrent.Flow; + +class StaticResponseRequestProcessor implements RequestProcessor { + + private final int statusCode; + private final String responseContentType; + private final Slice responseBody; + + StaticResponseRequestProcessor(int statusCode, String responseContentType, Slice responseBody) { + this.statusCode = statusCode; + this.responseContentType = responseContentType; + this.responseBody = responseBody; + } + + @Override + public int statusCode() { + return this.statusCode; + } + + @Override + public String responseContentType() { + return this.responseContentType; + } + + @Override + public void subscribe(Flow.Subscriber subscriber) { + subscriber.onSubscribe( + new Flow.Subscription() { + @Override + public void request(long l) { + if (l <= 0) { + subscriber.onError( + new IllegalStateException("subscription request is negative: " + l)); + return; + } + subscriber.onNext(responseBody); + subscriber.onComplete(); + } + + @Override + public void cancel() {} + }); + } + + @Override + public void onSubscribe(Flow.Subscription subscription) {} + + @Override + public void onNext(Slice slice) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onComplete() {} +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java deleted file mode 100644 index b657225d0..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.Util.nioBufferToProtobufBuffer; - -import com.google.protobuf.ByteString; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.Request; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.Target; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.function.ThrowingRunnable; -import dev.restate.sdk.common.syscalls.*; -import dev.restate.sdk.core.DeferredResults.SingleDeferredInternal; -import dev.restate.sdk.core.Entries.*; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap; -import java.util.Base64; -import java.util.Collection; -import java.util.Map; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -public final class SyscallsImpl implements SyscallsInternal { - - private static final Logger LOG = LogManager.getLogger(SyscallsImpl.class); - - private final Request request; - private final InvocationStateMachine stateMachine; - - SyscallsImpl(Request request, InvocationStateMachine stateMachine) { - this.request = request; - this.stateMachine = stateMachine; - } - - @Override - public String objectKey() { - return this.stateMachine.objectKey(); - } - - @Override - public Request request() { - return this.request; - } - - @Override - public void writeOutput(ByteBuffer value, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("writeOutput success"); - this.writeOutput( - Protocol.OutputEntryMessage.newBuilder() - .setValue(nioBufferToProtobufBuffer(value)) - .build(), - callback); - }, - callback); - } - - @Override - public void writeOutput(TerminalException throwable, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("writeOutput failure"); - this.writeOutput( - Protocol.OutputEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(throwable)) - .build(), - callback); - }, - callback); - } - - private void writeOutput(Protocol.OutputEntryMessage entry, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> this.stateMachine.processJournalEntry(entry, OutputEntry.INSTANCE, callback), - callback); - } - - @Override - public void get(String name, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("get {}", name); - this.stateMachine.processCompletableJournalEntry( - Protocol.GetStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(name)) - .build(), - GetStateEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void getKeys(SyscallCallback>> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("get keys"); - this.stateMachine.processCompletableJournalEntry( - Protocol.GetStateKeysEntryMessage.newBuilder().build(), - GetStateKeysEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void clear(String name, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("clear {}", name); - this.stateMachine.processJournalEntry( - Protocol.ClearStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(name)) - .build(), - ClearStateEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void clearAll(SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("clearAll"); - this.stateMachine.processJournalEntry( - Protocol.ClearAllStateEntryMessage.newBuilder().build(), - ClearAllStateEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void set(String name, ByteBuffer value, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("set {}", name); - this.stateMachine.processJournalEntry( - Protocol.SetStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(name)) - .setValue(nioBufferToProtobufBuffer(value)) - .build(), - SetStateEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void sleep(Duration duration, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("sleep {}", duration); - this.stateMachine.processCompletableJournalEntry( - Protocol.SleepEntryMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli() + duration.toMillis()) - .build(), - SleepEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void call( - Target target, ByteBuffer parameter, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("call {}", target); - - Protocol.CallEntryMessage.Builder builder = - Protocol.CallEntryMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()) - .setParameter(nioBufferToProtobufBuffer(parameter)); - if (target.getKey() != null) { - builder.setKey(target.getKey()); - } - - this.stateMachine.processCompletableJournalEntry( - builder.build(), new InvokeEntry<>(Result::success), callback); - }, - callback); - } - - @Override - public void send( - Target target, - ByteBuffer parameter, - @Nullable Duration delay, - SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("backgroundCall {}", target); - - Protocol.OneWayCallEntryMessage.Builder builder = - Protocol.OneWayCallEntryMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()) - .setParameter(nioBufferToProtobufBuffer(parameter)); - if (target.getKey() != null) { - builder.setKey(target.getKey()); - } - if (delay != null && !delay.isZero()) { - builder.setInvokeTime(Instant.now().toEpochMilli() + delay.toMillis()); - } - - this.stateMachine.processJournalEntry( - builder.build(), OneWayCallEntry.INSTANCE, callback); - }, - callback); - } - - @Override - public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("enterSideEffectBlock"); - this.stateMachine.enterSideEffectBlock(name, callback); - }, - callback); - } - - @Override - public void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("exitSideEffectBlock with success"); - this.stateMachine.exitSideEffectBlock( - Protocol.RunEntryMessage.newBuilder() - .setValue(nioBufferToProtobufBuffer(toWrite)) - .build(), - callback); - }, - callback); - } - - @Override - public void exitSideEffectBlockWithTerminalException( - TerminalException toWrite, ExitSideEffectSyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("exitSideEffectBlock with failure"); - this.stateMachine.exitSideEffectBlock( - Protocol.RunEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(toWrite)) - .build(), - callback); - }, - callback); - } - - @Override - public void exitSideEffectBlockWithException( - Throwable runException, - @Nullable RetryPolicy retryPolicy, - ExitSideEffectSyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("exitSideEffectBlock with exception"); - this.stateMachine.exitSideEffectBlockWithThrowable(runException, retryPolicy, callback); - }, - callback); - } - - @Override - public void awakeable(SyscallCallback>> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("awakeable"); - this.stateMachine.processCompletableJournalEntry( - Protocol.AwakeableEntryMessage.getDefaultInstance(), - AwakeableEntry.INSTANCE, - SyscallCallback.mappingTo( - deferredResult -> { - // Encode awakeable id - ByteString awakeableId = - stateMachine - .id() - .concat( - ByteString.copyFrom( - ByteBuffer.allocate(4) - .putInt( - ((SingleDeferredInternal) deferredResult) - .entryIndex()) - .flip())); - - return new AbstractMap.SimpleImmutableEntry<>( - Entries.AWAKEABLE_IDENTIFIER_PREFIX - + Base64.getUrlEncoder().encodeToString(awakeableId.toByteArray()), - deferredResult); - }, - callback)); - }, - callback); - } - - @Override - public void resolveAwakeable( - String serializedId, ByteBuffer payload, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("resolveAwakeable"); - completeAwakeable( - serializedId, - Protocol.CompleteAwakeableEntryMessage.newBuilder() - .setValue(nioBufferToProtobufBuffer(payload)), - callback); - }, - callback); - } - - @Override - public void rejectAwakeable(String serializedId, String reason, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("rejectAwakeable"); - completeAwakeable( - serializedId, - Protocol.CompleteAwakeableEntryMessage.newBuilder() - .setFailure( - Protocol.Failure.newBuilder() - .setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE) - .setMessage(reason)), - callback); - }, - callback); - } - - private void completeAwakeable( - String serializedId, - Protocol.CompleteAwakeableEntryMessage.Builder builder, - SyscallCallback callback) { - Protocol.CompleteAwakeableEntryMessage expectedEntry = builder.setId(serializedId).build(); - this.stateMachine.processJournalEntry(expectedEntry, CompleteAwakeableEntry.INSTANCE, callback); - } - - @Override - public void promise(String key, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("promise"); - this.stateMachine.processCompletableJournalEntry( - Protocol.GetPromiseEntryMessage.newBuilder().setKey(key).build(), - GetPromiseEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void peekPromise(String key, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("peekPromise"); - this.stateMachine.processCompletableJournalEntry( - Protocol.PeekPromiseEntryMessage.newBuilder().setKey(key).build(), - PeekPromiseEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void resolvePromise( - String key, ByteBuffer payload, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("resolvePromise"); - this.stateMachine.processCompletableJournalEntry( - Protocol.CompletePromiseEntryMessage.newBuilder() - .setKey(key) - .setCompletionValue(nioBufferToProtobufBuffer(payload)) - .build(), - CompletePromiseEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void rejectPromise(String key, String reason, SyscallCallback> callback) { - wrapAndPropagateExceptions( - () -> { - LOG.trace("resolvePromise"); - this.stateMachine.processCompletableJournalEntry( - Protocol.CompletePromiseEntryMessage.newBuilder() - .setKey(key) - .setCompletionFailure( - Protocol.Failure.newBuilder() - .setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE) - .setMessage(reason)) - .build(), - CompletePromiseEntry.INSTANCE, - callback); - }, - callback); - } - - @Override - public void resolveDeferred(Deferred deferredToResolve, SyscallCallback callback) { - wrapAndPropagateExceptions( - () -> this.stateMachine.resolveDeferred(deferredToResolve, callback), callback); - } - - @Override - public String getFullyQualifiedMethodName() { - return this.stateMachine.getFullyQualifiedHandlerName(); - } - - @Override - public InvocationState getInvocationState() { - return this.stateMachine.getInvocationState(); - } - - @Override - public boolean isInsideSideEffect() { - return this.stateMachine.isInsideSideEffect(); - } - - @Override - public void close() { - this.stateMachine.end(); - } - - @Override - public void fail(Throwable cause) { - this.stateMachine.fail(cause); - } - - // -- Wrapper for failure propagation - - private void wrapAndPropagateExceptions(ThrowingRunnable r, SyscallCallback handler) { - try { - r.run(); - } catch (Throwable e) { - this.fail(e); - handler.onCancel(e); - } - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java deleted file mode 100644 index 14db50f70..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsInternal.java +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import dev.restate.sdk.common.syscalls.Deferred; -import dev.restate.sdk.common.syscalls.Syscalls; -import dev.restate.sdk.core.DeferredResults.DeferredInternal; -import java.util.List; -import java.util.stream.Collectors; - -interface SyscallsInternal extends Syscalls { - - @Override - default Deferred createAnyDeferred(List> children) { - return DeferredResults.any( - children.stream().map(dr -> (DeferredInternal) dr).collect(Collectors.toList())); - } - - @Override - default Deferred createAllDeferred(List> children) { - return DeferredResults.all( - children.stream().map(dr -> (DeferredInternal) dr).collect(Collectors.toList())); - } - - // -- Lifecycle methods - - void close(); - - // -- State machine introspection (used by logging propagator) - - /** - * @return fully qualified method name in the form {fullyQualifiedServiceName}/{methodName} - */ - String getFullyQualifiedMethodName(); - - InvocationState getInvocationState(); -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Tracing.java b/sdk-core/src/main/java/dev/restate/sdk/core/Tracing.java deleted file mode 100644 index 5eec9b460..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Tracing.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import io.opentelemetry.api.common.AttributeKey; - -final class Tracing { - - private Tracing() {} - - static final AttributeKey RESTATE_INVOCATION_ID = - AttributeKey.stringKey("restate.invocation.id"); - - static final AttributeKey RESTATE_STATE_KEY = AttributeKey.stringKey("restate.state.key"); - static final AttributeKey RESTATE_SLEEP_WAKE_UP_TIME = - AttributeKey.longKey("restate.sleep.wake_up_time"); - - static final AttributeKey RESTATE_COORDINATION_CALL_SERVICE = - AttributeKey.stringKey("restate.coordination.call.service"); - static final AttributeKey RESTATE_COORDINATION_CALL_METHOD = - AttributeKey.stringKey("restate.coordination.call.method"); -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java deleted file mode 100644 index 44253d620..000000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import com.google.protobuf.UnsafeByteOperations; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.TerminalException; -import java.io.PrintWriter; -import java.io.StringWriter; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Predicate; -import org.jspecify.annotations.Nullable; - -public final class Util { - private Util() {} - - @SuppressWarnings("unchecked") - static void sneakyThrow(Throwable e) throws E { - throw (E) e; - } - - /** - * Finds a throwable fulfilling the condition in the cause chain of the given throwable. If there - * is none, then the method returns an empty optional. - * - * @param throwable to check for the given condition - * @param condition condition that a cause needs to fulfill - * @return Some cause that fulfills the condition; otherwise an empty optional - */ - @SuppressWarnings("unchecked") - static Optional findCause( - Throwable throwable, Predicate condition) { - Throwable currentThrowable = throwable; - - while (currentThrowable != null) { - if (condition.test(currentThrowable)) { - return (Optional) Optional.of(currentThrowable); - } - - if (currentThrowable == currentThrowable.getCause()) { - break; - } else { - currentThrowable = currentThrowable.getCause(); - } - } - - return Optional.empty(); - } - - public static Optional findProtocolException(Throwable throwable) { - return findCause(throwable, t -> t instanceof ProtocolException); - } - - public static boolean containsSuspendedException(Throwable throwable) { - return findCause(throwable, t -> t == AbortedExecutionException.INSTANCE).isPresent(); - } - - static Protocol.Failure toProtocolFailure(int code, String message) { - Protocol.Failure.Builder builder = Protocol.Failure.newBuilder().setCode(code); - if (message != null) { - builder.setMessage(message); - } - return builder.build(); - } - - static Protocol.Failure toProtocolFailure(Throwable throwable) { - if (throwable instanceof TerminalException) { - return toProtocolFailure(((TerminalException) throwable).getCode(), throwable.getMessage()); - } - return toProtocolFailure(TerminalException.INTERNAL_SERVER_ERROR_CODE, throwable.toString()); - } - - static Protocol.ErrorMessage toErrorMessage( - Throwable throwable, - int currentJournalIndex, - @Nullable String currentJournalEntryName, - @Nullable MessageType currentJournalEntryType) { - Protocol.ErrorMessage.Builder msg = - Protocol.ErrorMessage.newBuilder().setMessage(throwable.toString()); - - if (throwable instanceof ProtocolException) { - msg.setCode(((ProtocolException) throwable).getCode()); - } else { - msg.setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE); - } - - // Convert stacktrace to string - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - pw.println("Stacktrace:"); - throwable.printStackTrace(pw); - msg.setDescription(sw.toString()); - - // Add journal entry info - if (currentJournalIndex >= 0) { - msg.setRelatedEntryIndex(currentJournalIndex); - } - if (currentJournalEntryName != null) { - msg.setRelatedEntryName(currentJournalEntryName); - } - if (currentJournalEntryType != null) { - msg.setRelatedEntryType(currentJournalEntryType.encode()); - } - - return msg.build(); - } - - static TerminalException toRestateException(Protocol.Failure failure) { - return new TerminalException(failure.getCode(), failure.getMessage()); - } - - static boolean isTerminalException(Throwable throwable) { - return throwable instanceof TerminalException; - } - - static void assertIsEntry(MessageLite msg) { - if (!isEntry(msg)) { - throw new IllegalStateException("Expected input to be entry: " + msg); - } - } - - static void assertEntryEquals(MessageLite expected, MessageLite actual) { - if (!Objects.equals(expected, actual)) { - throw ProtocolException.entryDoesNotMatch(expected, actual); - } - } - - static void assertEntryClass(Class clazz, MessageLite actual) { - if (!clazz.equals(actual.getClass())) { - throw ProtocolException.unexpectedMessage(clazz, actual); - } - } - - static boolean isEntry(MessageLite msg) { - return msg instanceof Protocol.InputEntryMessage - || msg instanceof Protocol.OutputEntryMessage - || msg instanceof Protocol.GetStateEntryMessage - || msg instanceof Protocol.GetStateKeysEntryMessage - || msg instanceof Protocol.SetStateEntryMessage - || msg instanceof Protocol.ClearStateEntryMessage - || msg instanceof Protocol.ClearAllStateEntryMessage - || msg instanceof Protocol.GetPromiseEntryMessage - || msg instanceof Protocol.PeekPromiseEntryMessage - || msg instanceof Protocol.CompletePromiseEntryMessage - || msg instanceof Protocol.SleepEntryMessage - || msg instanceof Protocol.CallEntryMessage - || msg instanceof Protocol.OneWayCallEntryMessage - || msg instanceof Protocol.AwakeableEntryMessage - || msg instanceof Protocol.CompleteAwakeableEntryMessage - || msg instanceof Java.CombinatorAwaitableEntryMessage - || msg instanceof Protocol.RunEntryMessage; - } - - /** NOTE! This method rewinds the buffer!!! */ - static ByteString nioBufferToProtobufBuffer(ByteBuffer nioBuffer) { - return UnsafeByteOperations.unsafeWrap(nioBuffer); - } - - static Duration durationMin(Duration a, Duration b) { - return (a.compareTo(b) <= 0) ? a : b; - } -} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java new file mode 100644 index 000000000..0594624c4 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java @@ -0,0 +1,131 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.ByteString; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import java.util.*; + +final class AsyncResultsState { + public static final int CANCEL_NOTIFICATION_HANDLE = 1; + + private final Deque> toProcess; + private final Map ready; + private final Map handleMapping; + + private int nextNotificationHandle; + + public AsyncResultsState() { + this.toProcess = new ArrayDeque<>(); + this.ready = new HashMap<>(); + + this.handleMapping = new HashMap<>(); + // Prepare built in signal handles here + this.handleMapping.put(CANCEL_NOTIFICATION_HANDLE, new NotificationId.SignalId(1)); + + // First 15 are reserved for built-in signals! + nextNotificationHandle = 17; + } + + public void enqueue(Protocol.NotificationTemplate notification) { + var notificationId = + switch (notification.getIdCase()) { + case COMPLETION_ID -> new NotificationId.CompletionId(notification.getCompletionId()); + case SIGNAL_ID -> new NotificationId.SignalId(notification.getSignalId()); + case SIGNAL_NAME -> new NotificationId.SignalName(notification.getSignalName()); + case ID_NOT_SET -> throw ProtocolException.badNotificationMessage("id"); + }; + + var notificationValue = + switch (notification.getResultCase()) { + case VOID -> NotificationValue.Empty.INSTANCE; + case VALUE -> + new NotificationValue.Success( + Util.byteStringToSlice(notification.getValue().getContent())); + case FAILURE -> + new NotificationValue.Failure(Util.toRestateException(notification.getFailure())); + case INVOCATION_ID -> new NotificationValue.InvocationId(notification.getInvocationId()); + case STATE_KEYS -> + new NotificationValue.StateKeys( + notification.getStateKeys().getKeysList().stream() + .map(ByteString::toStringUtf8) + .toList()); + case RESULT_NOT_SET -> throw ProtocolException.badNotificationMessage("result"); + }; + + toProcess.addLast(Map.entry(notificationId, notificationValue)); + } + + public void insertReady(NotificationId id, NotificationValue value) { + ready.put(id, value); + } + + public int createHandleMapping(NotificationId notificationId) { + int assignedHandle = nextNotificationHandle; + nextNotificationHandle++; + handleMapping.put(assignedHandle, notificationId); + return assignedHandle; + } + + public boolean processNextUntilAnyFound(Set ids) { + while (!toProcess.isEmpty()) { + Map.Entry notif = toProcess.removeFirst(); + boolean anyFound = ids.contains(notif.getKey()); + ready.put(notif.getKey(), notif.getValue()); + if (anyFound) { + return true; + } + } + return false; + } + + public boolean isHandleCompleted(int handle) { + NotificationId id = handleMapping.get(handle); + return id != null && ready.containsKey(id); + } + + public boolean nonDeterministicFindId(NotificationId id) { + if (ready.containsKey(id)) { + return true; + } + return toProcess.stream().anyMatch(notif -> notif.getKey().equals(id)); + } + + public Set resolveNotificationHandles(List handles) { + Set result = new LinkedHashSet<>(); + for (int handle : handles) { + NotificationId id = handleMapping.get(handle); + if (id != null) { + result.add(id); + } + } + return result; + } + + public NotificationId mustResolveNotificationHandle(int handle) { + NotificationId id = handleMapping.get(handle); + if (id == null) { + throw new IllegalStateException("If there is a handle, there must be a corresponding id"); + } + return id; + } + + public Optional takeHandle(int handle) { + NotificationId id = handleMapping.get(handle); + if (id != null) { + NotificationValue result = ready.remove(id); + if (result != null) { + handleMapping.remove(handle); + return Optional.of(result); + } + } + return Optional.empty(); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java new file mode 100644 index 000000000..91c4a9d96 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ClosedState.java @@ -0,0 +1,31 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import java.time.Duration; +import org.jspecify.annotations.Nullable; + +final class ClosedState implements State { + + @Override + public void hitError( + Throwable throwable, @Nullable Duration nextRetryDelay, StateContext stateContext) { + // Ignore, as we closed already + } + + @Override + public void end(StateContext stateContext) { + // Ignore, as we closed already + } + + @Override + public InvocationState getInvocationState() { + return InvocationState.CLOSED; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java new file mode 100644 index 000000000..1a50cd8dd --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java @@ -0,0 +1,145 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import java.util.Objects; + +interface CommandAccessor { + + String getName(E expected); + + default void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException { + Util.assertEntryEquals(expected, actual); + } + + CommandAccessor INPUT = + new CommandAccessor<>() { + @Override + public String getName(Protocol.InputCommandMessage expected) { + return ""; + } + + @Override + public void checkEntryHeader(Protocol.InputCommandMessage expected, MessageLite actual) + throws ProtocolException { + // Nothing to check + } + }; + CommandAccessor OUTPUT = Protocol.OutputCommandMessage::getName; + CommandAccessor GET_EAGER_STATE = + new CommandAccessor<>() { + @Override + public void checkEntryHeader( + Protocol.GetEagerStateCommandMessage expected, MessageLite actual) + throws ProtocolException { + Util.assertEntryClass(Protocol.GetEagerStateCommandMessage.class, actual); + if (!expected.getKey().equals(((Protocol.GetEagerStateCommandMessage) actual).getKey())) { + throw ProtocolException.commandDoesNotMatch(expected, actual); + } + } + + @Override + public String getName(Protocol.GetEagerStateCommandMessage expected) { + return expected.getName(); + } + }; + CommandAccessor GET_LAZY_STATE = + Protocol.GetLazyStateCommandMessage::getName; + CommandAccessor GET_EAGER_STATE_KEYS = + new CommandAccessor<>() { + @Override + public void checkEntryHeader( + Protocol.GetEagerStateKeysCommandMessage expected, MessageLite actual) + throws ProtocolException { + Util.assertEntryClass(Protocol.GetEagerStateKeysCommandMessage.class, actual); + } + + @Override + public String getName(Protocol.GetEagerStateKeysCommandMessage expected) { + return expected.getName(); + } + }; + CommandAccessor GET_LAZY_STATE_KEYS = + Protocol.GetLazyStateKeysCommandMessage::getName; + CommandAccessor CLEAR_STATE = + Protocol.ClearStateCommandMessage::getName; + CommandAccessor CLEAR_ALL_STATE = + Protocol.ClearAllStateCommandMessage::getName; + CommandAccessor SET_STATE = + Protocol.SetStateCommandMessage::getName; + + CommandAccessor SLEEP = + new CommandAccessor<>() { + @Override + public void checkEntryHeader(Protocol.SleepCommandMessage expected, MessageLite actual) + throws ProtocolException { + Util.assertEntryClass(Protocol.SleepCommandMessage.class, actual); + if (!expected.getName().equals(((Protocol.SleepCommandMessage) actual).getName())) { + throw ProtocolException.commandDoesNotMatch(expected, actual); + } + } + + @Override + public String getName(Protocol.SleepCommandMessage expected) { + return expected.getName(); + } + }; + + CommandAccessor CALL = Protocol.CallCommandMessage::getName; + CommandAccessor ONE_WAY_CALL = + new CommandAccessor<>() { + @Override + public String getName(Protocol.OneWayCallCommandMessage expected) { + return ""; + } + + @Override + public void checkEntryHeader(Protocol.OneWayCallCommandMessage expected, MessageLite actual) + throws ProtocolException { + Util.assertEntryClass(Protocol.OneWayCallCommandMessage.class, actual); + var actualOneWayCall = (Protocol.OneWayCallCommandMessage) actual; + + if (!(Objects.equals(expected.getServiceName(), actualOneWayCall.getServiceName()) + && Objects.equals(expected.getHandlerName(), actualOneWayCall.getHandlerName()) + && Objects.equals(expected.getParameter(), actualOneWayCall.getParameter()) + && Objects.equals(expected.getKey(), actualOneWayCall.getKey()) + && Objects.equals(expected.getHeadersList(), actualOneWayCall.getHeadersList()) + && Objects.equals(expected.getIdempotencyKey(), actualOneWayCall.getIdempotencyKey()) + && Objects.equals(expected.getName(), actualOneWayCall.getName()) + && Objects.equals( + expected.getInvocationIdNotificationIdx(), + actualOneWayCall.getInvocationIdNotificationIdx()))) { + throw ProtocolException.commandDoesNotMatch(expected, actual); + } + } + }; + + CommandAccessor COMPLETE_AWAKEABLE = + Protocol.CompleteAwakeableCommandMessage::getName; + CommandAccessor RUN = Protocol.RunCommandMessage::getName; + + CommandAccessor GET_PROMISE = + Protocol.GetPromiseCommandMessage::getName; + CommandAccessor PEEK_PROMISE = + Protocol.PeekPromiseCommandMessage::getName; + CommandAccessor COMPLETE_PROMISE = + Protocol.CompletePromiseCommandMessage::getName; + + CommandAccessor SEND_SIGNAL = + Protocol.SendSignalCommandMessage::getName; + + CommandAccessor ATTACH_INVOCATION = + Protocol.AttachInvocationCommandMessage::getName; + + CommandAccessor GET_INVOCATION_OUTPUT = + Protocol.GetInvocationOutputCommandMessage::getName; +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java similarity index 50% rename from sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java index 954086334..92865bbd9 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/EagerState.java @@ -6,64 +6,53 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.core.statemachine; import com.google.protobuf.ByteString; -import dev.restate.generated.service.protocol.Protocol; -import java.nio.ByteBuffer; +import dev.restate.common.Slice; +import dev.restate.sdk.core.generated.protocol.Protocol; import java.util.HashMap; import java.util.Set; +import org.jspecify.annotations.Nullable; -final class UserStateStore { +final class EagerState { - interface State {} + sealed interface State {} - static final class Unknown implements State { + record Unknown() implements State { private static final Unknown INSTANCE = new Unknown(); - - private Unknown() {} } - static final class Empty implements State { + record Empty() implements State { private static final Empty INSTANCE = new Empty(); - - private Empty() {} } - static final class Value implements State { - private final ByteBuffer value; - - private Value(ByteBuffer value) { - this.value = value; - } - - public ByteBuffer getValue() { - return value; - } - } + record Value(Slice value) implements State {} private boolean isPartial; - private final HashMap map; + private final HashMap map; - UserStateStore(Protocol.StartMessage startMessage) { + EagerState(Protocol.StartMessage startMessage) { this.isPartial = startMessage.getPartialState(); this.map = new HashMap<>(startMessage.getStateMapCount()); for (int i = 0; i < startMessage.getStateMapCount(); i++) { Protocol.StartMessage.StateEntry entry = startMessage.getStateMap(i); - this.map.put(entry.getKey(), new Value(entry.getValue().asReadOnlyByteBuffer())); + this.map.put( + entry.getKey(), + new NotificationValue.Success(Slice.wrap(entry.getValue().asReadOnlyByteBuffer()))); } } - public State get(ByteString key) { - return this.map.getOrDefault(key, isPartial ? Unknown.INSTANCE : Empty.INSTANCE); + public @Nullable NotificationValue get(ByteString key) { + return this.map.getOrDefault(key, isComplete() ? NotificationValue.Empty.INSTANCE : null); } - public void set(ByteString key, ByteBuffer value) { - this.map.put(key, new Value(value)); + public void set(ByteString key, Slice value) { + this.map.put(key, new NotificationValue.Success(value)); } public void clear(ByteString key) { - this.map.put(key, Empty.INSTANCE); + this.map.put(key, NotificationValue.Empty.INSTANCE); } public void clearAll() { @@ -75,7 +64,10 @@ public boolean isComplete() { return !isPartial; } - public Set keys() { - return this.map.keySet(); + public @Nullable Set keys() { + if (isComplete()) { + return this.map.keySet(); + } + return null; } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java similarity index 95% rename from sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java index c9c19cf77..32321c6c0 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationIdImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationIdImpl.java @@ -6,9 +6,9 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.core.statemachine; -import dev.restate.sdk.common.InvocationId; +import dev.restate.sdk.types.InvocationId; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationInput.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java similarity index 95% rename from sdk-core/src/main/java/dev/restate/sdk/core/InvocationInput.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java index 95de60820..fafab731a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationInput.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationInput.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.core.statemachine; import com.google.protobuf.MessageLite; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java similarity index 90% rename from sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java index 2944d6c3e..3820f41da 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/InvocationState.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.core.statemachine; public enum InvocationState { WAITING_START, diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java new file mode 100644 index 000000000..e8b35581f --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Journal.java @@ -0,0 +1,71 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; + +class Journal { + private int commandIndex; + private int notificationIndex; + private int completionIndex; + private int signalIndex; + private MessageType currentEntryTy; + private String currentEntryName; + + Journal() { + this.commandIndex = -1; + this.notificationIndex = -1; + // Clever trick for protobuf here + this.completionIndex = 1; + // 1 to 16 are reserved! + this.signalIndex = 17; + this.currentEntryTy = MessageType.StartMessage; + this.currentEntryName = ""; + } + + public void commandTransition(String entryName, MessageLite expected) { + this.commandIndex++; + this.currentEntryName = entryName; + this.currentEntryTy = MessageType.fromMessage(expected); + } + + public void notificationTransition(MessageLite expected) { + this.notificationIndex++; + this.currentEntryName = ""; + this.currentEntryTy = null; + } + + public int getCommandIndex() { + return this.commandIndex; + } + + public MessageType getCurrentEntryTy() { + return currentEntryTy; + } + + public String getCurrentEntryName() { + return currentEntryName; + } + + public int getNotificationIndex() { + return this.notificationIndex; + } + + public int nextCompletionNotificationId() { + int next = this.completionIndex; + this.completionIndex++; + return next; + } + + public int nextSignalNotificationId() { + int next = this.signalIndex; + this.signalIndex++; + return next; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageDecoder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java similarity index 61% rename from sdk-core/src/main/java/dev/restate/sdk/core/MessageDecoder.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java index 71a5260bb..abc518fcb 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageDecoder.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageDecoder.java @@ -6,17 +6,17 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; +package dev.restate.sdk.core.statemachine; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.UnsafeByteOperations; -import java.nio.ByteBuffer; +import dev.restate.common.Slice; import java.util.ArrayDeque; import java.util.Queue; -import java.util.concurrent.Flow; +import org.jspecify.annotations.Nullable; -class MessageDecoder implements InvocationFlow.InvocationInputSubscriber { +public final class MessageDecoder { private enum State { WAITING_HEADER, @@ -24,10 +24,6 @@ private enum State { FAILED } - private final Flow.Subscriber inner; - private Flow.Subscription inputSubscription; - private long invocationInputRequests = 0; - private final Queue parsedMessages; private ByteString internalBuffer; @@ -35,8 +31,7 @@ private enum State { private MessageHeader lastParsedMessageHeader; private RuntimeException lastParsingFailure; - MessageDecoder(Flow.Subscriber inner) { - this.inner = inner; + public MessageDecoder() { this.parsedMessages = new ArrayDeque<>(); this.internalBuffer = ByteString.EMPTY; @@ -47,80 +42,19 @@ private enum State { // -- Subscriber methods - @Override - public void onSubscribe(Flow.Subscription byteBufferSubscription) { - this.inputSubscription = byteBufferSubscription; - this.inner.onSubscribe( - new Flow.Subscription() { - @Override - public void request(long n) { - // We ask for MAX VALUE, then we buffer in this class. - // This class could be implemented with more backpressure in mind, but for now this is - // fine. - byteBufferSubscription.request(Long.MAX_VALUE); - handleSubscriptionRequest(n); - } - - @Override - public void cancel() { - byteBufferSubscription.cancel(); - } - }); - } - - @Override - public void onNext(ByteBuffer item) { - this.offer(UnsafeByteOperations.unsafeWrap(item)); - tryProgress(); - } - - @Override - public void onError(Throwable throwable) { - if (this.inputSubscription == null) { - return; - } - this.inner.onError(throwable); - } - - @Override - public void onComplete() { - if (this.inputSubscription == null) { - return; - } - this.inner.onComplete(); + public void offer(Slice item) { + this.offer(UnsafeByteOperations.unsafeWrap(item.asReadOnlyByteBuffer())); } - private void handleSubscriptionRequest(long l) { - if (l == Long.MAX_VALUE) { - this.invocationInputRequests = l; - } else { - this.invocationInputRequests += l; - // Overflow check - if (this.invocationInputRequests < 0) { - this.invocationInputRequests = Long.MAX_VALUE; - } + public @Nullable InvocationInput next() { + if (this.state == State.FAILED) { + throw lastParsingFailure; } - - tryProgress(); + return this.parsedMessages.poll(); } - private void tryProgress() { - if (this.inputSubscription == null) { - return; - } - if (this.state == State.FAILED) { - this.inner.onError(lastParsingFailure); - this.inputSubscription.cancel(); - this.inputSubscription = null; - } - while (this.invocationInputRequests > 0) { - InvocationInput input = this.parsedMessages.poll(); - if (input == null) { - return; - } - this.invocationInputRequests--; - this.inner.onNext(input); - } + public boolean isNextAvailable() { + return !this.parsedMessages.isEmpty(); } // -- Internal methods to handle decoding diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java new file mode 100644 index 000000000..f98b79db9 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageEncoder.java @@ -0,0 +1,61 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import java.nio.ByteBuffer; +import java.util.concurrent.Flow; + +final class MessageEncoder implements Flow.Subscriber { + + private final Flow.Subscriber inner; + + MessageEncoder(Flow.Subscriber inner) { + this.inner = inner; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + inner.onSubscribe(subscription); + } + + @Override + public void onNext(MessageLite item) { + // We could pool those buffers somehow? + ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(item)); + MessageEncoder.encode(buffer, item); + inner.onNext(Slice.wrap(buffer)); + } + + @Override + public void onError(Throwable throwable) { + inner.onError(throwable); + } + + @Override + public void onComplete() { + inner.onComplete(); + } + + static int encodeLength(MessageLite msg) { + return 8 + msg.getSerializedSize(); + } + + static ByteBuffer encode(ByteBuffer buffer, MessageLite msg) { + MessageHeader header = MessageHeader.fromMessage(msg); + + buffer.putLong(header.encode()); + buffer.put(msg.toByteString().asReadOnlyByteBuffer()); + + buffer.flip(); + + return buffer; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java new file mode 100644 index 000000000..c5a9e6123 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageHeader.java @@ -0,0 +1,53 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.ProtocolException; + +public class MessageHeader { + + private final MessageType type; + private final int flags; + private final int length; + + public MessageHeader(MessageType type, int flags, int length) { + this.type = type; + this.flags = flags; + this.length = length; + } + + public MessageType getType() { + return type; + } + + public int getLength() { + return length; + } + + public long encode() { + long res = 0L; + res |= ((long) type.encode() << 48); + res |= ((long) flags << 32); + res |= length; + return res; + } + + public static MessageHeader parse(long encoded) throws ProtocolException { + var ty_code = (short) (encoded >> 48); + var flags = (short) (encoded >> 32); + var len = (int) encoded; + + return new MessageHeader(MessageType.decode(ty_code), flags, len); + } + + public static MessageHeader fromMessage(MessageLite msg) { + return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java new file mode 100644 index 000000000..fdfea066c --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java @@ -0,0 +1,361 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import com.google.protobuf.Parser; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; + +public enum MessageType { + StartMessage, + SuspensionMessage, + ErrorMessage, + EndMessage, + ProposeRunCompletionMessage, + + InputCommandMessage, + OutputCommandMessage, + GetLazyStateCommandMessage, + GetLazyStateCompletionNotificationMessage, + SetStateCommandMessage, + ClearStateCommandMessage, + ClearAllStateCommandMessage, + GetLazyStateKeysCommandMessage, + GetLazyStateKeysCompletionNotificationMessage, + GetEagerStateCommandMessage, + GetEagerStateKeysCommandMessage, + GetPromiseCommandMessage, + GetPromiseCompletionNotificationMessage, + PeekPromiseCommandMessage, + PeekPromiseCompletionNotificationMessage, + CompletePromiseCommandMessage, + CompletePromiseCompletionNotificationMessage, + SleepCommandMessage, + SleepCompletionNotificationMessage, + CallCommandMessage, + CallInvocationIdCompletionNotificationMessage, + CallCompletionNotificationMessage, + OneWayCallCommandMessage, + SendSignalCommandMessage, + RunCommandMessage, + RunCompletionNotificationMessage, + AttachInvocationCommandMessage, + AttachInvocationCompletionNotificationMessage, + GetInvocationOutputCommandMessage, + GetInvocationOutputCompletionNotificationMessage, + CompleteAwakeableCommandMessage, + SignalNotificationMessage; + + public static final short StartMessage_TYPE = (short) 0x0000; + public static final short SuspensionMessage_TYPE = (short) 0x0001; + public static final short ErrorMessage_TYPE = (short) 0x0002; + public static final short EndMessage_TYPE = (short) 0x0003; + public static final short ProposeRunCompletionMessage_TYPE = (short) 0x0005; + public static final short InputCommandMessage_TYPE = (short) 0x0400; + public static final short OutputCommandMessage_TYPE = (short) 0x0401; + public static final short GetLazyStateCommandMessage_TYPE = (short) 0x0402; + public static final short GetLazyStateCompletionNotificationMessage_TYPE = (short) 0x8002; + public static final short SetStateCommandMessage_TYPE = (short) 0x0403; + public static final short ClearStateCommandMessage_TYPE = (short) 0x0404; + public static final short ClearAllStateCommandMessage_TYPE = (short) 0x0405; + public static final short GetLazyStateKeysCommandMessage_TYPE = (short) 0x0406; + public static final short GetLazyStateKeysCompletionNotificationMessage_TYPE = (short) 0x8006; + public static final short GetEagerStateCommandMessage_TYPE = (short) 0x0407; + public static final short GetEagerStateKeysCommandMessage_TYPE = (short) 0x0408; + public static final short GetPromiseCommandMessage_TYPE = (short) 0x0409; + public static final short GetPromiseCompletionNotificationMessage_TYPE = (short) 0x8009; + public static final short PeekPromiseCommandMessage_TYPE = (short) 0x040A; + public static final short PeekPromiseCompletionNotificationMessage_TYPE = (short) 0x800A; + public static final short CompletePromiseCommandMessage_TYPE = (short) 0x040B; + public static final short CompletePromiseCompletionNotificationMessage_TYPE = (short) 0x800B; + public static final short SleepCommandMessage_TYPE = (short) 0x040C; + public static final short SleepCompletionNotificationMessage_TYPE = (short) 0x800C; + public static final short CallCommandMessage_TYPE = (short) 0x040D; + public static final short CallInvocationIdCompletionNotificationMessage_TYPE = (short) 0x800E; + public static final short CallCompletionNotificationMessage_TYPE = (short) 0x800D; + public static final short OneWayCallCommandMessage_TYPE = (short) 0x040E; + public static final short SendSignalCommandMessage_TYPE = (short) 0x0410; + public static final short RunCommandMessage_TYPE = (short) 0x0411; + public static final short RunCompletionNotificationMessage_TYPE = (short) 0x8011; + public static final short AttachInvocationCommandMessage_TYPE = (short) 0x0412; + public static final short AttachInvocationCompletionNotificationMessage_TYPE = (short) 0x8012; + public static final short GetInvocationOutputCommandMessage_TYPE = (short) 0x0413; + public static final short GetInvocationOutputCompletionNotificationMessage_TYPE = (short) 0x8013; + public static final short CompleteAwakeableCommandMessage_TYPE = (short) 0x0414; + public static final short SignalNotificationMessage_TYPE = (short) 0xFBFF; + + public Parser messageParser() { + return switch (this) { + case StartMessage -> Protocol.StartMessage.parser(); + case SuspensionMessage -> Protocol.SuspensionMessage.parser(); + case ErrorMessage -> Protocol.ErrorMessage.parser(); + case EndMessage -> Protocol.EndMessage.parser(); + case ProposeRunCompletionMessage -> Protocol.ProposeRunCompletionMessage.parser(); + case InputCommandMessage -> Protocol.InputCommandMessage.parser(); + case OutputCommandMessage -> Protocol.OutputCommandMessage.parser(); + case GetLazyStateCommandMessage -> Protocol.GetLazyStateCommandMessage.parser(); + case SetStateCommandMessage -> Protocol.SetStateCommandMessage.parser(); + case ClearStateCommandMessage -> Protocol.ClearStateCommandMessage.parser(); + case ClearAllStateCommandMessage -> Protocol.ClearAllStateCommandMessage.parser(); + case GetLazyStateKeysCommandMessage -> Protocol.GetLazyStateKeysCommandMessage.parser(); + case GetEagerStateCommandMessage -> Protocol.GetEagerStateCommandMessage.parser(); + case GetEagerStateKeysCommandMessage -> Protocol.GetEagerStateKeysCommandMessage.parser(); + case GetPromiseCommandMessage -> Protocol.GetPromiseCommandMessage.parser(); + case PeekPromiseCommandMessage -> Protocol.PeekPromiseCommandMessage.parser(); + case CompletePromiseCommandMessage -> Protocol.CompletePromiseCommandMessage.parser(); + case SleepCommandMessage -> Protocol.SleepCommandMessage.parser(); + case CallCommandMessage -> Protocol.CallCommandMessage.parser(); + case OneWayCallCommandMessage -> Protocol.OneWayCallCommandMessage.parser(); + case SendSignalCommandMessage -> Protocol.SendSignalCommandMessage.parser(); + case RunCommandMessage -> Protocol.RunCommandMessage.parser(); + case AttachInvocationCommandMessage -> Protocol.AttachInvocationCommandMessage.parser(); + case GetInvocationOutputCommandMessage -> Protocol.GetInvocationOutputCommandMessage.parser(); + case CompleteAwakeableCommandMessage -> Protocol.CompleteAwakeableCommandMessage.parser(); + case GetLazyStateCompletionNotificationMessage, + SignalNotificationMessage, + GetLazyStateKeysCompletionNotificationMessage, + GetPromiseCompletionNotificationMessage, + PeekPromiseCompletionNotificationMessage, + CompletePromiseCompletionNotificationMessage, + SleepCompletionNotificationMessage, + CallInvocationIdCompletionNotificationMessage, + CallCompletionNotificationMessage, + RunCompletionNotificationMessage, + AttachInvocationCompletionNotificationMessage, + GetInvocationOutputCompletionNotificationMessage -> + Protocol.NotificationTemplate.parser(); + }; + } + + public short encode() { + return switch (this) { + case StartMessage -> StartMessage_TYPE; + case SuspensionMessage -> SuspensionMessage_TYPE; + case ErrorMessage -> ErrorMessage_TYPE; + case EndMessage -> EndMessage_TYPE; + case ProposeRunCompletionMessage -> ProposeRunCompletionMessage_TYPE; + case InputCommandMessage -> InputCommandMessage_TYPE; + case OutputCommandMessage -> OutputCommandMessage_TYPE; + case GetLazyStateCommandMessage -> GetLazyStateCommandMessage_TYPE; + case GetLazyStateCompletionNotificationMessage -> + GetLazyStateCompletionNotificationMessage_TYPE; + case SetStateCommandMessage -> SetStateCommandMessage_TYPE; + case ClearStateCommandMessage -> ClearStateCommandMessage_TYPE; + case ClearAllStateCommandMessage -> ClearAllStateCommandMessage_TYPE; + case GetLazyStateKeysCommandMessage -> GetLazyStateKeysCommandMessage_TYPE; + case GetLazyStateKeysCompletionNotificationMessage -> + GetLazyStateKeysCompletionNotificationMessage_TYPE; + case GetEagerStateCommandMessage -> GetEagerStateCommandMessage_TYPE; + case GetEagerStateKeysCommandMessage -> GetEagerStateKeysCommandMessage_TYPE; + case GetPromiseCommandMessage -> GetPromiseCommandMessage_TYPE; + case GetPromiseCompletionNotificationMessage -> GetPromiseCompletionNotificationMessage_TYPE; + case PeekPromiseCommandMessage -> PeekPromiseCommandMessage_TYPE; + case PeekPromiseCompletionNotificationMessage -> + PeekPromiseCompletionNotificationMessage_TYPE; + case CompletePromiseCommandMessage -> CompletePromiseCommandMessage_TYPE; + case CompletePromiseCompletionNotificationMessage -> + CompletePromiseCompletionNotificationMessage_TYPE; + case SleepCommandMessage -> SleepCommandMessage_TYPE; + case SleepCompletionNotificationMessage -> SleepCompletionNotificationMessage_TYPE; + case CallCommandMessage -> CallCommandMessage_TYPE; + case CallInvocationIdCompletionNotificationMessage -> + CallInvocationIdCompletionNotificationMessage_TYPE; + case CallCompletionNotificationMessage -> CallCompletionNotificationMessage_TYPE; + case OneWayCallCommandMessage -> OneWayCallCommandMessage_TYPE; + case SendSignalCommandMessage -> SendSignalCommandMessage_TYPE; + case RunCommandMessage -> RunCommandMessage_TYPE; + case RunCompletionNotificationMessage -> RunCompletionNotificationMessage_TYPE; + case AttachInvocationCommandMessage -> AttachInvocationCommandMessage_TYPE; + case AttachInvocationCompletionNotificationMessage -> + AttachInvocationCompletionNotificationMessage_TYPE; + case GetInvocationOutputCommandMessage -> GetInvocationOutputCommandMessage_TYPE; + case GetInvocationOutputCompletionNotificationMessage -> + GetInvocationOutputCompletionNotificationMessage_TYPE; + case CompleteAwakeableCommandMessage -> CompleteAwakeableCommandMessage_TYPE; + case SignalNotificationMessage -> SignalNotificationMessage_TYPE; + }; + } + + public boolean isCommand() { + return switch (this) { + case InputCommandMessage, + GetLazyStateCommandMessage, + OutputCommandMessage, + SetStateCommandMessage, + ClearStateCommandMessage, + ClearAllStateCommandMessage, + GetLazyStateKeysCommandMessage, + GetEagerStateCommandMessage, + GetEagerStateKeysCommandMessage, + GetPromiseCommandMessage, + PeekPromiseCommandMessage, + CompletePromiseCommandMessage, + SleepCommandMessage, + CallCommandMessage, + OneWayCallCommandMessage, + SendSignalCommandMessage, + RunCommandMessage, + AttachInvocationCommandMessage, + GetInvocationOutputCommandMessage, + CompleteAwakeableCommandMessage -> + true; + default -> false; + }; + } + + public boolean isNotification() { + return switch (this) { + case GetLazyStateCompletionNotificationMessage, + SignalNotificationMessage, + GetLazyStateKeysCompletionNotificationMessage, + GetPromiseCompletionNotificationMessage, + PeekPromiseCompletionNotificationMessage, + CompletePromiseCompletionNotificationMessage, + SleepCompletionNotificationMessage, + CallInvocationIdCompletionNotificationMessage, + CallCompletionNotificationMessage, + RunCompletionNotificationMessage, + AttachInvocationCompletionNotificationMessage, + GetInvocationOutputCompletionNotificationMessage -> + true; + default -> false; + }; + } + + public static MessageType decode(short value) throws ProtocolException { + return switch (value) { + case StartMessage_TYPE -> StartMessage; + case SuspensionMessage_TYPE -> SuspensionMessage; + case ErrorMessage_TYPE -> ErrorMessage; + case EndMessage_TYPE -> EndMessage; + case ProposeRunCompletionMessage_TYPE -> ProposeRunCompletionMessage; + case InputCommandMessage_TYPE -> InputCommandMessage; + case OutputCommandMessage_TYPE -> OutputCommandMessage; + case GetLazyStateCommandMessage_TYPE -> GetLazyStateCommandMessage; + case GetLazyStateCompletionNotificationMessage_TYPE -> + GetLazyStateCompletionNotificationMessage; + case SetStateCommandMessage_TYPE -> SetStateCommandMessage; + case ClearStateCommandMessage_TYPE -> ClearStateCommandMessage; + case ClearAllStateCommandMessage_TYPE -> ClearAllStateCommandMessage; + case GetLazyStateKeysCommandMessage_TYPE -> GetLazyStateKeysCommandMessage; + case GetLazyStateKeysCompletionNotificationMessage_TYPE -> + GetLazyStateKeysCompletionNotificationMessage; + case GetEagerStateCommandMessage_TYPE -> GetEagerStateCommandMessage; + case GetEagerStateKeysCommandMessage_TYPE -> GetEagerStateKeysCommandMessage; + case GetPromiseCommandMessage_TYPE -> GetPromiseCommandMessage; + case GetPromiseCompletionNotificationMessage_TYPE -> GetPromiseCompletionNotificationMessage; + case PeekPromiseCommandMessage_TYPE -> PeekPromiseCommandMessage; + case PeekPromiseCompletionNotificationMessage_TYPE -> + PeekPromiseCompletionNotificationMessage; + case CompletePromiseCommandMessage_TYPE -> CompletePromiseCommandMessage; + case CompletePromiseCompletionNotificationMessage_TYPE -> + CompletePromiseCompletionNotificationMessage; + case SleepCommandMessage_TYPE -> SleepCommandMessage; + case SleepCompletionNotificationMessage_TYPE -> SleepCompletionNotificationMessage; + case CallCommandMessage_TYPE -> CallCommandMessage; + case CallInvocationIdCompletionNotificationMessage_TYPE -> + CallInvocationIdCompletionNotificationMessage; + case CallCompletionNotificationMessage_TYPE -> CallCompletionNotificationMessage; + case OneWayCallCommandMessage_TYPE -> OneWayCallCommandMessage; + case SendSignalCommandMessage_TYPE -> SendSignalCommandMessage; + case RunCommandMessage_TYPE -> RunCommandMessage; + case RunCompletionNotificationMessage_TYPE -> RunCompletionNotificationMessage; + case AttachInvocationCommandMessage_TYPE -> AttachInvocationCommandMessage; + case AttachInvocationCompletionNotificationMessage_TYPE -> + AttachInvocationCompletionNotificationMessage; + case GetInvocationOutputCommandMessage_TYPE -> GetInvocationOutputCommandMessage; + case GetInvocationOutputCompletionNotificationMessage_TYPE -> + GetInvocationOutputCompletionNotificationMessage; + case CompleteAwakeableCommandMessage_TYPE -> CompleteAwakeableCommandMessage; + case SignalNotificationMessage_TYPE -> SignalNotificationMessage; + default -> throw ProtocolException.unknownMessageType(value); + }; + } + + public static MessageType fromMessage(MessageLite msg) { + if (msg instanceof Protocol.StartMessage) { + return MessageType.StartMessage; + } else if (msg instanceof Protocol.SuspensionMessage) { + return MessageType.SuspensionMessage; + } else if (msg instanceof Protocol.ErrorMessage) { + return MessageType.ErrorMessage; + } else if (msg instanceof Protocol.EndMessage) { + return MessageType.EndMessage; + } else if (msg instanceof Protocol.ProposeRunCompletionMessage) { + return MessageType.ProposeRunCompletionMessage; + } else if (msg instanceof Protocol.InputCommandMessage) { + return MessageType.InputCommandMessage; + } else if (msg instanceof Protocol.OutputCommandMessage) { + return MessageType.OutputCommandMessage; + } else if (msg instanceof Protocol.GetLazyStateCommandMessage) { + return MessageType.GetLazyStateCommandMessage; + } else if (msg instanceof Protocol.GetLazyStateCompletionNotificationMessage) { + return MessageType.GetLazyStateCompletionNotificationMessage; + } else if (msg instanceof Protocol.SetStateCommandMessage) { + return MessageType.SetStateCommandMessage; + } else if (msg instanceof Protocol.ClearStateCommandMessage) { + return MessageType.ClearStateCommandMessage; + } else if (msg instanceof Protocol.ClearAllStateCommandMessage) { + return MessageType.ClearAllStateCommandMessage; + } else if (msg instanceof Protocol.GetLazyStateKeysCommandMessage) { + return MessageType.GetLazyStateKeysCommandMessage; + } else if (msg instanceof Protocol.GetLazyStateKeysCompletionNotificationMessage) { + return MessageType.GetLazyStateKeysCompletionNotificationMessage; + } else if (msg instanceof Protocol.GetEagerStateCommandMessage) { + return MessageType.GetEagerStateCommandMessage; + } else if (msg instanceof Protocol.GetEagerStateKeysCommandMessage) { + return MessageType.GetEagerStateKeysCommandMessage; + } else if (msg instanceof Protocol.GetPromiseCommandMessage) { + return MessageType.GetPromiseCommandMessage; + } else if (msg instanceof Protocol.GetPromiseCompletionNotificationMessage) { + return MessageType.GetPromiseCompletionNotificationMessage; + } else if (msg instanceof Protocol.PeekPromiseCommandMessage) { + return MessageType.PeekPromiseCommandMessage; + } else if (msg instanceof Protocol.PeekPromiseCompletionNotificationMessage) { + return MessageType.PeekPromiseCompletionNotificationMessage; + } else if (msg instanceof Protocol.CompletePromiseCommandMessage) { + return MessageType.CompletePromiseCommandMessage; + } else if (msg instanceof Protocol.CompletePromiseCompletionNotificationMessage) { + return MessageType.CompletePromiseCompletionNotificationMessage; + } else if (msg instanceof Protocol.SleepCommandMessage) { + return MessageType.SleepCommandMessage; + } else if (msg instanceof Protocol.SleepCompletionNotificationMessage) { + return MessageType.SleepCompletionNotificationMessage; + } else if (msg instanceof Protocol.CallCommandMessage) { + return MessageType.CallCommandMessage; + } else if (msg instanceof Protocol.CallInvocationIdCompletionNotificationMessage) { + return MessageType.CallInvocationIdCompletionNotificationMessage; + } else if (msg instanceof Protocol.CallCompletionNotificationMessage) { + return MessageType.CallCompletionNotificationMessage; + } else if (msg instanceof Protocol.OneWayCallCommandMessage) { + return MessageType.OneWayCallCommandMessage; + } else if (msg instanceof Protocol.SendSignalCommandMessage) { + return MessageType.SendSignalCommandMessage; + } else if (msg instanceof Protocol.RunCommandMessage) { + return MessageType.RunCommandMessage; + } else if (msg instanceof Protocol.RunCompletionNotificationMessage) { + return MessageType.RunCompletionNotificationMessage; + } else if (msg instanceof Protocol.AttachInvocationCommandMessage) { + return MessageType.AttachInvocationCommandMessage; + } else if (msg instanceof Protocol.AttachInvocationCompletionNotificationMessage) { + return MessageType.AttachInvocationCompletionNotificationMessage; + } else if (msg instanceof Protocol.GetInvocationOutputCommandMessage) { + return MessageType.GetInvocationOutputCommandMessage; + } else if (msg instanceof Protocol.GetInvocationOutputCompletionNotificationMessage) { + return MessageType.GetInvocationOutputCompletionNotificationMessage; + } else if (msg instanceof Protocol.CompleteAwakeableCommandMessage) { + return MessageType.CompleteAwakeableCommandMessage; + } else if (msg instanceof Protocol.SignalNotificationMessage) { + return MessageType.SignalNotificationMessage; + } + + throw new IllegalStateException("Unexpected protobuf message"); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java new file mode 100644 index 000000000..5b3a0a1d4 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationId.java @@ -0,0 +1,18 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +public sealed interface NotificationId { + + record CompletionId(int id) implements NotificationId {} + + record SignalId(int id) implements NotificationId {} + + record SignalName(String name) implements NotificationId {} +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java new file mode 100644 index 000000000..32ffc524e --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/NotificationValue.java @@ -0,0 +1,28 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import dev.restate.common.Slice; +import dev.restate.sdk.types.TerminalException; +import java.util.List; + +public sealed interface NotificationValue { + + record Empty() implements NotificationValue { + public static Empty INSTANCE = new Empty(); + } + + record Success(Slice slice) implements NotificationValue {} + + record Failure(TerminalException exception) implements NotificationValue {} + + record StateKeys(List stateKeys) implements NotificationValue {} + + record InvocationId(String invocationId) implements NotificationValue {} +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java new file mode 100644 index 000000000..ea1159ca7 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java @@ -0,0 +1,380 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import static dev.restate.sdk.core.statemachine.Util.durationMin; +import static dev.restate.sdk.core.statemachine.Util.sliceToByteString; + +import com.google.protobuf.ByteString; +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import dev.restate.sdk.core.ExceptionUtils; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.RetryPolicy; +import dev.restate.sdk.types.TerminalException; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +final class ProcessingState implements State { + + private static final Logger LOG = LogManager.getLogger(ProcessingState.class); + + private final AsyncResultsState asyncResultsState; + private final RunState runState; + private boolean processingFirstEntry; + + ProcessingState(AsyncResultsState asyncResultsState, RunState runState) { + this.asyncResultsState = asyncResultsState; + this.runState = runState; + this.processingFirstEntry = true; + } + + @Override + public void onNewMessage( + InvocationInput invocationInput, + StateContext stateContext, + CompletableFuture waitForReadyFuture) { + if (invocationInput.header().getType().isNotification()) { + if (!(invocationInput.message() + instanceof Protocol.NotificationTemplate notificationTemplate)) { + throw ProtocolException.unexpectedMessage( + Protocol.NotificationTemplate.class, invocationInput.message()); + } + this.asyncResultsState.enqueue(notificationTemplate); + } else { + throw ProtocolException.unexpectedMessage("notification", invocationInput.message()); + } + } + + @Override + public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { + if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); + if (notificationIds.isEmpty()) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + Integer maybeRunHandle = runState.tryExecuteRun(awaitingOn); + if (maybeRunHandle != null) { + return new DoProgressResponse.ExecuteRun(maybeRunHandle); + } + + if (stateContext.isInputClosed()) { + if (runState.anyExecuting(awaitingOn)) { + return DoProgressResponse.WaitingPendingRun.INSTANCE; + } + + this.hitSuspended(notificationIds, stateContext); + ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); + } + + return DoProgressResponse.ReadFromInput.INSTANCE; + } + + @Override + public boolean isCompleted(int handle) { + return this.asyncResultsState.isHandleCompleted(handle); + } + + @Override + public Optional takeNotification(int handle, StateContext stateContext) { + return this.asyncResultsState.takeHandle(handle); + } + + @Override + public int processRunCommand(String name, StateContext stateContext) { + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var notificationId = new NotificationId.CompletionId(completionId); + + var runCmdBuilder = Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completionId); + if (name != null) { + runCmdBuilder.setName(name); + } + + var notificationHandle = + this.processCompletableCommand( + runCmdBuilder.build(), CommandAccessor.RUN, new int[] {completionId}, stateContext)[0]; + + LOG.trace("Enqueued run notification for {} with id {}.", notificationHandle, notificationId); + runState.insertRunToExecute(notificationHandle); + + return notificationHandle; + } + + @Override + public int processStateGetCommand(String key, StateContext stateContext) { + this.flipFirstProcessingEntry(); + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var handle = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); + + ByteString keyBytes = ByteString.copyFromUtf8(key); + var eagerStateQuery = stateContext.getEagerState().get(keyBytes); + if (eagerStateQuery == null) { + // Lazy state case + var commandMessage = + Protocol.GetLazyStateCommandMessage.newBuilder() + .setKey(keyBytes) + .setResultCompletionId(completionId) + .build(); + stateContext + .getJournal() + .commandTransition( + CommandAccessor.GET_LAZY_STATE.getName(commandMessage), commandMessage); + stateContext.writeMessageOut(commandMessage); + + return handle; + } + + // Eager state case + var commandMessageBuilder = Protocol.GetEagerStateCommandMessage.newBuilder().setKey(keyBytes); + if (eagerStateQuery instanceof NotificationValue.Success) { + commandMessageBuilder.setValue( + Protocol.Value.newBuilder() + .setContent(sliceToByteString(((NotificationValue.Success) eagerStateQuery).slice())) + .build()); + } else { + commandMessageBuilder.setVoid(Protocol.Void.getDefaultInstance()); + } + var commandMessage = commandMessageBuilder.build(); + stateContext + .getJournal() + .commandTransition(CommandAccessor.GET_EAGER_STATE.getName(commandMessage), commandMessage); + + asyncResultsState.insertReady(new NotificationId.CompletionId(completionId), eagerStateQuery); + stateContext.writeMessageOut(commandMessage); + + return handle; + } + + @Override + public int processStateGetKeysCommand(StateContext stateContext) { + this.flipFirstProcessingEntry(); + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var handle = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); + + var eagerStateQuery = stateContext.getEagerState().keys(); + if (eagerStateQuery == null) { + // Lazy state case + var commandMessage = + Protocol.GetLazyStateKeysCommandMessage.newBuilder() + .setResultCompletionId(completionId) + .build(); + stateContext + .getJournal() + .commandTransition( + CommandAccessor.GET_LAZY_STATE_KEYS.getName(commandMessage), commandMessage); + stateContext.writeMessageOut(commandMessage); + + return handle; + } + + // Eager state case + var commandMessage = + Protocol.GetEagerStateKeysCommandMessage.newBuilder() + .setValue(Protocol.StateKeys.newBuilder().addAllKeys(eagerStateQuery).build()) + .build(); + stateContext + .getJournal() + .commandTransition( + CommandAccessor.GET_EAGER_STATE_KEYS.getName(commandMessage), commandMessage); + + asyncResultsState.insertReady( + new NotificationId.CompletionId(completionId), + new NotificationValue.StateKeys( + eagerStateQuery.stream().map(ByteString::toStringUtf8).toList())); + stateContext.writeMessageOut(commandMessage); + + return handle; + } + + @Override + public void processNonCompletableCommand( + E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { + stateContext + .getJournal() + .commandTransition(commandAccessor.getName(commandMessage), commandMessage); + this.flipFirstProcessingEntry(); + + stateContext.writeMessageOut(commandMessage); + } + + @Override + public int[] processCompletableCommand( + E commandMessage, + CommandAccessor commandAccessor, + int[] completionIds, + StateContext stateContext) { + stateContext + .getJournal() + .commandTransition(commandAccessor.getName(commandMessage), commandMessage); + this.flipFirstProcessingEntry(); + + stateContext.writeMessageOut(commandMessage); + + int[] handles = new int[completionIds.length]; + for (int i = 0; i < handles.length; i++) { + handles[i] = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionIds[i])); + } + + return handles; + } + + @Override + public int createSignalHandle(NotificationId notificationId, StateContext stateContext) { + return asyncResultsState.createHandleMapping(notificationId); + } + + @Override + public void proposeRunCompletion(int handle, Slice value, StateContext stateContext) { + var notificationId = asyncResultsState.mustResolveNotificationHandle(handle); + if (!(notificationId instanceof NotificationId.CompletionId)) { + throw ProtocolException.badRunNotificationId(notificationId); + } + + runState.notifyExecuted(handle); + + proposeRunCompletion( + handle, + Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(((NotificationId.CompletionId) notificationId).id()) + .setValue(sliceToByteString(value)), + stateContext); + } + + @Override + public void proposeRunCompletion( + int handle, + Throwable runException, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy, + StateContext stateContext) { + var notificationId = asyncResultsState.mustResolveNotificationHandle(handle); + if (!(notificationId instanceof NotificationId.CompletionId)) { + throw ProtocolException.badRunNotificationId(notificationId); + } + + runState.notifyExecuted(handle); + + TerminalException toWrite; + if (runException instanceof TerminalException) { + LOG.trace("The run completed with a terminal exception"); + toWrite = (TerminalException) runException; + } else { + toWrite = + this.rethrowOrConvertToTerminal(runException, attemptDuration, retryPolicy, stateContext); + } + + proposeRunCompletion( + handle, + Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(((NotificationId.CompletionId) notificationId).id()) + .setFailure(Util.toProtocolFailure(toWrite)), + stateContext); + } + + private void proposeRunCompletion( + int handle, + Protocol.ProposeRunCompletionMessage.Builder messageBuilder, + StateContext stateContext) { + if (!stateContext.maybeWriteMessageOut(messageBuilder.build())) { + LOG.warn( + "Cannot write proposed completion for run with handle {} because the output stream was already closed.", + handle); + } + } + + private Duration getDurationSinceLastStoredEntry(StateContext stateContext) { + // We need to check if this is the first entry we try to commit after replay, and only in this + // case we need to return the info we got from the start message + // + // Moreover, when the retry count is == 0, the durationSinceLastStoredEntry might not be zero. + // In fact, in that case the duration is the interval between the previously stored entry and + // the time to start/resume the invocation. + // For the sake of entry retries though, we're not interested in that time elapsed, so we 0 it + // here for simplicity of the downstream consumer (the retry policy). + return this.processingFirstEntry + && stateContext.getStartInfo().retryCountSinceLastStoredEntry() > 0 + ? stateContext.getStartInfo().durationSinceLastStoredEntry() + : Duration.ZERO; + } + + private int getRetryCountSinceLastStoredEntry(StateContext stateContext) { + // We need to check if this is the first entry we try to commit after replay, and only in this + // case we need to return the info we got from the start message + return this.processingFirstEntry + ? stateContext.getStartInfo().retryCountSinceLastStoredEntry() + : 0; + } + + // This function rethrows the exception if a retry needs to happen. + private TerminalException rethrowOrConvertToTerminal( + Throwable runException, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy, + StateContext stateContext) { + if (retryPolicy == null) { + LOG.trace("The run completed with an exception and no retry policy was provided"); + // Default behavior is always retry + ExceptionUtils.sneakyThrow(runException); + } + + Duration retryLoopDuration = + this.getDurationSinceLastStoredEntry(stateContext).plus(attemptDuration); + int retryCount = this.getRetryCountSinceLastStoredEntry(stateContext) + 1; + + if ((retryPolicy.getMaxAttempts() != null && retryPolicy.getMaxAttempts() <= retryCount) + || (retryPolicy.getMaxDuration() != null + && retryPolicy.getMaxDuration().compareTo(retryLoopDuration) <= 0)) { + LOG.trace("The run completed with a retryable exception, but all attempts were exhausted"); + // We need to convert it to TerminalException + return new TerminalException(runException.toString()); + } + + // Compute next retry delay and throw it! + Duration nextComputedDelay = + retryPolicy + .getInitialDelay() + .multipliedBy((long) Math.pow(retryPolicy.getExponentiationFactor(), retryCount)); + Duration nextRetryDelay = + retryPolicy.getMaxDelay() != null + ? durationMin(retryPolicy.getMaxDelay(), nextComputedDelay) + : nextComputedDelay; + + this.hitError(runException, nextRetryDelay, stateContext); + ExceptionUtils.sneakyThrow(runException); + return null; + } + + private void flipFirstProcessingEntry() { + this.processingFirstEntry = false; + } + + @Override + public InvocationState getInvocationState() { + return InvocationState.PROCESSING; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java new file mode 100644 index 000000000..14d46d3b6 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java @@ -0,0 +1,284 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import static dev.restate.sdk.core.statemachine.Util.byteStringToSlice; + +import com.google.protobuf.ByteString; +import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.ExceptionUtils; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; +import dev.restate.sdk.types.AbortedExecutionException; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +final class ReplayingState implements State { + + private static final Logger LOG = LogManager.getLogger(ReplayingState.class); + + private final Deque commandsToProcess; + private final AsyncResultsState asyncResultsState; + private final RunState runState; + + ReplayingState(Deque commandsToProcess, AsyncResultsState asyncResultsState) { + this.commandsToProcess = commandsToProcess; + this.asyncResultsState = asyncResultsState; + this.runState = new RunState(); + } + + @Override + public void onNewMessage( + InvocationInput invocationInput, + StateContext stateContext, + CompletableFuture waitForReadyFuture) { + if (invocationInput.header().getType().isNotification()) { + if (!(invocationInput.message() + instanceof Protocol.NotificationTemplate notificationTemplate)) { + throw ProtocolException.unexpectedMessage( + Protocol.NotificationTemplate.class, invocationInput.message()); + } + this.asyncResultsState.enqueue(notificationTemplate); + } else { + throw ProtocolException.unexpectedMessage("notification", invocationInput.message()); + } + } + + @Override + public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { + if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); + if (notificationIds.isEmpty()) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { + return DoProgressResponse.AnyCompleted.INSTANCE; + } + + if (stateContext.isInputClosed()) { + this.hitSuspended(notificationIds, stateContext); + ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); + } + + return DoProgressResponse.ReadFromInput.INSTANCE; + } + + @Override + public boolean isCompleted(int handle) { + return this.asyncResultsState.isHandleCompleted(handle); + } + + @Override + public Optional takeNotification(int handle, StateContext stateContext) { + return this.asyncResultsState.takeHandle(handle); + } + + @Override + public StateMachine.Input processInputCommand(StateContext stateContext) { + Protocol.InputCommandMessage inputCommandMessage = + processNonCompletableCommandInner( + Protocol.InputCommandMessage.getDefaultInstance(), CommandAccessor.INPUT, stateContext); + + //noinspection unchecked + return new StateMachine.Input( + new InvocationIdImpl(stateContext.getStartInfo().debugId()), + byteStringToSlice(inputCommandMessage.getValue().getContent()), + Map.ofEntries( + inputCommandMessage.getHeadersList().stream() + .map(h -> Map.entry(h.getKey(), h.getValue())) + .toArray(Map.Entry[]::new)), + stateContext.getStartInfo().objectKey()); + } + + @Override + public int processRunCommand(String name, StateContext stateContext) { + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var notificationId = new NotificationId.CompletionId(completionId); + + var runCmdBuilder = Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completionId); + if (name != null) { + runCmdBuilder.setName(name); + } + + var notificationHandle = + this.processCompletableCommand( + runCmdBuilder.build(), CommandAccessor.RUN, new int[] {completionId}, stateContext)[0]; + + if (asyncResultsState.nonDeterministicFindId(notificationId)) { + LOG.trace( + "Found notification for {} with id {} while replaying, the run closure won't be executed.", + notificationHandle, + notificationId); + } else { + LOG.trace( + "Run notification for {} with id {} not found while replaying, so we enqueue the run to be executed later.", + notificationHandle, + notificationId); + runState.insertRunToExecute(notificationHandle); + } + + return notificationHandle; + } + + @Override + public void processNonCompletableCommand( + E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { + processNonCompletableCommandInner(commandMessage, commandAccessor, stateContext); + } + + private E processNonCompletableCommandInner( + E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { + stateContext + .getJournal() + .commandTransition(commandAccessor.getName(commandMessage), commandMessage); + + MessageLite actual = takeNextCommandToProcess(); + commandAccessor.checkEntryHeader(commandMessage, actual); + + afterProcessingCommand(stateContext); + + // CheckEntryHeader checks that the class type + return (E) actual; + } + + @Override + public int[] processCompletableCommand( + E commandMessage, + CommandAccessor commandAccessor, + int[] completionIds, + StateContext stateContext) { + stateContext + .getJournal() + .commandTransition(commandAccessor.getName(commandMessage), commandMessage); + MessageLite actual = takeNextCommandToProcess(); + commandAccessor.checkEntryHeader(commandMessage, actual); + + int[] handles = new int[completionIds.length]; + for (int i = 0; i < handles.length; i++) { + handles[i] = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionIds[i])); + } + + afterProcessingCommand(stateContext); + + return handles; + } + + @Override + public int processStateGetCommand(String key, StateContext stateContext) { + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var handle = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); + + stateContext + .getJournal() + .commandTransition("", Protocol.GetEagerStateCommandMessage.getDefaultInstance()); + MessageLite actual = takeNextCommandToProcess(); + + if (actual instanceof Protocol.GetEagerStateCommandMessage eagerStateCommandMessage) { + CommandAccessor.GET_EAGER_STATE.checkEntryHeader( + Protocol.GetEagerStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .build(), + actual); + + asyncResultsState.insertReady( + new NotificationId.CompletionId(completionId), + switch (eagerStateCommandMessage.getResultCase()) { + case VOID -> NotificationValue.Empty.INSTANCE; + case VALUE -> + new NotificationValue.Success( + byteStringToSlice(eagerStateCommandMessage.getValue().getContent())); + case RESULT_NOT_SET -> + throw ProtocolException.commandMissingField( + Protocol.GetEagerStateCommandMessage.class, "result"); + }); + + } else if (actual instanceof Protocol.GetLazyStateCommandMessage) { + CommandAccessor.GET_LAZY_STATE.checkEntryHeader( + Protocol.GetLazyStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setResultCompletionId(completionId) + .build(), + actual); + } else { + throw ProtocolException.unexpectedMessage("get state", actual); + } + + afterProcessingCommand(stateContext); + + return handle; + } + + @Override + public int processStateGetKeysCommand(StateContext stateContext) { + var completionId = stateContext.getJournal().nextCompletionNotificationId(); + var handle = + asyncResultsState.createHandleMapping(new NotificationId.CompletionId(completionId)); + + stateContext + .getJournal() + .commandTransition("", Protocol.GetEagerStateKeysCommandMessage.getDefaultInstance()); + MessageLite actual = takeNextCommandToProcess(); + + if (actual instanceof Protocol.GetEagerStateKeysCommandMessage eagerStateCommandMessage) { + CommandAccessor.GET_EAGER_STATE_KEYS.checkEntryHeader( + Protocol.GetEagerStateKeysCommandMessage.getDefaultInstance(), actual); + + asyncResultsState.insertReady( + new NotificationId.CompletionId(completionId), + new NotificationValue.StateKeys( + eagerStateCommandMessage.getValue().getKeysList().stream() + .map(ByteString::toStringUtf8) + .toList())); + } else if (actual instanceof Protocol.GetLazyStateKeysCommandMessage) { + CommandAccessor.GET_LAZY_STATE_KEYS.checkEntryHeader( + Protocol.GetLazyStateKeysCommandMessage.newBuilder() + .setResultCompletionId(completionId) + .build(), + actual); + } else { + throw ProtocolException.unexpectedMessage("get state keys", actual); + } + + afterProcessingCommand(stateContext); + + return handle; + } + + @Override + public int createSignalHandle(NotificationId notificationId, StateContext stateContext) { + return asyncResultsState.createHandleMapping(notificationId); + } + + private void afterProcessingCommand(StateContext stateContext) { + if (commandsToProcess.isEmpty()) { + stateContext.getStateHolder().transition(new ProcessingState(asyncResultsState, runState)); + } + } + + private MessageLite takeNextCommandToProcess() { + if (commandsToProcess.isEmpty()) { + throw ProtocolException.commandsToProcessIsEmpty(); + } + return commandsToProcess.removeFirst(); + } + + @Override + public InvocationState getInvocationState() { + return InvocationState.REPLAYING; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java new file mode 100644 index 000000000..4bf7e0745 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java @@ -0,0 +1,48 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import org.jspecify.annotations.Nullable; + +final class RunState { + private final Set toExecute = new HashSet<>(); + private final Set executing = new HashSet<>(); + + public void insertRunToExecute(int handle) { + toExecute.add(handle); + } + + public @Nullable Integer tryExecuteRun(Collection anyHandle) { + for (int maybeRun : anyHandle) { + if (toExecute.contains(maybeRun)) { + toExecute.remove(maybeRun); + executing.add(maybeRun); + return maybeRun; + } + } + return null; + } + + public boolean anyExecuting(Collection anyHandle) { + for (int handle : anyHandle) { + if (executing.contains(handle)) { + return true; + } + } + return false; + } + + public void notifyExecuted(int executed) { + toExecute.remove(executed); + executing.remove(executed); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java new file mode 100644 index 000000000..ee5563ac5 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java @@ -0,0 +1,64 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import dev.restate.sdk.core.generated.protocol.Protocol; +import java.util.Objects; + +public class ServiceProtocol { + public static final Protocol.ServiceProtocolVersion MIN_SERVICE_PROTOCOL_VERSION = + Protocol.ServiceProtocolVersion.V4; + public static final Protocol.ServiceProtocolVersion MAX_SERVICE_PROTOCOL_VERSION = + Protocol.ServiceProtocolVersion.V4; + + static final String CONTENT_TYPE = "content-type"; + + static Protocol.ServiceProtocolVersion parseServiceProtocolVersion(String version) { + if (version == null) { + return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; + } + version = version.trim(); + + if (version.equals("application/vnd.restate.invocation.v1")) { + return Protocol.ServiceProtocolVersion.V1; + } + if (version.equals("application/vnd.restate.invocation.v2")) { + return Protocol.ServiceProtocolVersion.V2; + } + if (version.equals("application/vnd.restate.invocation.v3")) { + return Protocol.ServiceProtocolVersion.V3; + } + if (version.equals("application/vnd.restate.invocation.v4")) { + return Protocol.ServiceProtocolVersion.V4; + } + return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; + } + + static String serviceProtocolVersionToHeaderValue(Protocol.ServiceProtocolVersion version) { + if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V1) { + return "application/vnd.restate.invocation.v1"; + } + if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V2) { + return "application/vnd.restate.invocation.v2"; + } + if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V3) { + return "application/vnd.restate.invocation.v3"; + } + if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V4) { + return "application/vnd.restate.invocation.v4"; + } + throw new IllegalArgumentException( + String.format("Service protocol version '%s' has no header value", version.getNumber())); + } + + static boolean isSupported(Protocol.ServiceProtocolVersion serviceProtocolVersion) { + return MIN_SERVICE_PROTOCOL_VERSION.getNumber() <= serviceProtocolVersion.getNumber() + && serviceProtocolVersion.getNumber() <= MAX_SERVICE_PROTOCOL_VERSION.getNumber(); + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java similarity index 52% rename from sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java rename to sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java index 43d0ed893..ecd61da22 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StartInfo.java @@ -6,13 +6,15 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.common.syscalls; +package dev.restate.sdk.core.statemachine; -import dev.restate.sdk.common.TerminalException; -import java.nio.ByteBuffer; +import com.google.protobuf.ByteString; +import java.time.Duration; -public interface ExitSideEffectSyscallCallback extends SyscallCallback { - - /** This is user failure. */ - void onFailure(TerminalException t); -} +record StartInfo( + ByteString id, + String debugId, + String objectKey, + int entriesToReplay, + int retryCountSinceLastStoredEntry, + Duration durationSinceLastStoredEntry) {} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java new file mode 100644 index 000000000..cb381388b --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java @@ -0,0 +1,159 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.types.RetryPolicy; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +sealed interface State + permits ClosedState, + ProcessingState, + ReplayingState, + WaitingReplayEntriesState, + WaitingStartState { + + Logger LOG = LogManager.getLogger(State.class); + + default void onNewMessage( + InvocationInput invocationInput, + StateContext stateContext, + CompletableFuture waitForReadyFuture) { + throw ProtocolException.badState(this); + } + + default StateMachine.DoProgressResponse doProgress( + List anyHandle, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default boolean isCompleted(int handle) { + throw ProtocolException.badState(this); + } + + default Optional takeNotification(int handle, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default StateMachine.@Nullable Input processInputCommand(StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default int processStateGetCommand(String key, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default int processStateGetKeysCommand(StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default void processNonCompletableCommand( + E commandMessage, CommandAccessor commandAccessor, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default int[] processCompletableCommand( + E commandMessage, + CommandAccessor commandAccessor, + int[] completionIds, + StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default int createSignalHandle(NotificationId notificationId, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default int processRunCommand(String name, StateContext stateContext) { + throw ProtocolException.badState(this); + } + + default void proposeRunCompletion(int handle, Slice value, StateContext stateContext) { + LOG.warn( + "Going to ignore proposed run completion with handle {} because the state machine is not in processing state.", + handle); + } + + default void proposeRunCompletion( + int handle, + Throwable exception, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy, + StateContext stateContext) { + LOG.warn( + "Going to ignore proposed run completion with handle {} because the state machine is not in processing state.", + handle); + } + + default void hitError( + Throwable throwable, @Nullable Duration nextRetryDelay, StateContext stateContext) { + LOG.warn("Invocation failed", throwable); + + var errorMessage = + Util.toErrorMessage( + throwable, + stateContext.getJournal().getCommandIndex(), + stateContext.getJournal().getCurrentEntryName(), + stateContext.getJournal().getCurrentEntryTy()); + if (nextRetryDelay != null) { + errorMessage = errorMessage.toBuilder().setNextRetryDelay(nextRetryDelay.toMillis()).build(); + } + + stateContext.maybeWriteMessageOut(errorMessage); + stateContext.getStateHolder().transition(new ClosedState()); + + stateContext.closeOutputSubscriber(); + } + + default void hitSuspended(Collection awaitingOn, StateContext stateContext) { + LOG.info("Invocation suspended awaiting on {}", awaitingOn); + + var suspensionMessageBuilder = Protocol.SuspensionMessage.newBuilder(); + for (var notificationId : awaitingOn) { + if (notificationId instanceof NotificationId.CompletionId completionId) { + suspensionMessageBuilder.addWaitingCompletions(completionId.id()); + } else if (notificationId instanceof NotificationId.SignalId signalId) { + suspensionMessageBuilder.addWaitingSignals(signalId.id()); + } else if (notificationId instanceof NotificationId.SignalName signalName) { + suspensionMessageBuilder.addWaitingNamedSignals(signalName.name()); + } + } + + stateContext.maybeWriteMessageOut(suspensionMessageBuilder.build()); + stateContext.getStateHolder().transition(new ClosedState()); + + stateContext.closeOutputSubscriber(); + } + + default void end(StateContext stateContext) { + LOG.info("Invocation ended"); + + stateContext.writeMessageOut(Protocol.EndMessage.getDefaultInstance()); + stateContext.getStateHolder().transition(new ClosedState()); + + stateContext.closeOutputSubscriber(); + } + + default void onInputClosed(StateContext stateContext) { + LOG.trace("Marking input closed"); + stateContext.markInputClosed(); + } + + InvocationState getInvocationState(); +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java new file mode 100644 index 000000000..cbfcac20c --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateContext.java @@ -0,0 +1,94 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.EndpointRequestHandler; +import java.util.Objects; +import java.util.concurrent.Flow; + +final class StateContext { + + private final StateHolder stateHolder; + private final Journal journal; + private EagerState eagerState; + private transient StartInfo startInfo; + private boolean inputClosed; + private Flow.Subscriber outputSubscriber; + + StateContext(EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { + this.stateHolder = new StateHolder(loggingContextSetter); + this.journal = new Journal(); + this.inputClosed = false; + } + + public State getCurrentState() { + return stateHolder.getState(); + } + + public StateHolder getStateHolder() { + return stateHolder; + } + + public Journal getJournal() { + return journal; + } + + public StateContext setEagerState(EagerState eagerState) { + this.eagerState = eagerState; + return this; + } + + public StateContext setStartInfo(StartInfo startInfo) { + this.startInfo = startInfo; + return this; + } + + EagerState getEagerState() { + return Objects.requireNonNull(eagerState, "The state machine should be initialized"); + } + + StartInfo getStartInfo() { + return Objects.requireNonNull(startInfo, "The state machine should be initialized"); + } + + public void markInputClosed() { + this.inputClosed = true; + } + + public boolean isInputClosed() { + return this.inputClosed; + } + + public void writeMessageOut(MessageLite msg) { + Objects.requireNonNull( + this.outputSubscriber, + "Output subscriber should be configured before running the state machine") + .onNext(msg); + } + + public boolean maybeWriteMessageOut(MessageLite msg) { + if (this.outputSubscriber != null) { + this.outputSubscriber.onNext(msg); + return true; + } + return false; + } + + public void closeOutputSubscriber() { + if (this.outputSubscriber != null) { + this.outputSubscriber.onComplete(); + this.outputSubscriber = null; + } + } + + public void registerOutputSubscriber(Flow.Subscriber outputSubscriber) { + this.outputSubscriber = outputSubscriber; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java new file mode 100644 index 000000000..3d58f5a8e --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateHolder.java @@ -0,0 +1,38 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import dev.restate.sdk.core.EndpointRequestHandler; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +final class StateHolder { + + Logger LOG = LogManager.getLogger(StateHolder.class); + + private State state; + private final EndpointRequestHandler.LoggingContextSetter loggingContextSetter; + + StateHolder(EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { + this.loggingContextSetter = loggingContextSetter; + this.state = new WaitingStartState(); + } + + State getState() { + return state; + } + + void transition(State state) { + this.state = state; + LOG.debug("Transitioning state machine to {}", state.getInvocationState()); + this.loggingContextSetter.set( + EndpointRequestHandler.LoggingContextSetter.INVOCATION_STATUS_KEY, + state.getInvocationState().toString()); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java new file mode 100644 index 000000000..a9ee2eb1a --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java @@ -0,0 +1,151 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.EndpointRequestHandler; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.types.*; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import org.jspecify.annotations.Nullable; + +/** + * More or less same as the VM trait + */ +public interface StateMachine extends Flow.Processor { + + static StateMachine init( + HeadersAccessor headersAccessor, + EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { + return new StateMachineImpl(headersAccessor, loggingContextSetter); + } + + // --- Response metadata + + String getResponseContentType(); + + // --- Execution starting point + + CompletableFuture waitForReady(); + + // --- Await next input + + CompletableFuture waitNextInputSignal(); + + // --- Async results + + sealed interface DoProgressResponse { + record AnyCompleted() implements DoProgressResponse { + static AnyCompleted INSTANCE = new AnyCompleted(); + } + + record ReadFromInput() implements DoProgressResponse { + static ReadFromInput INSTANCE = new ReadFromInput(); + } + + record ExecuteRun(int handle) implements DoProgressResponse {} + + record WaitingPendingRun() implements DoProgressResponse { + static WaitingPendingRun INSTANCE = new WaitingPendingRun(); + } + } + + DoProgressResponse doProgress(List anyHandle); + + boolean isCompleted(int handle); + + Optional takeNotification(int handle); + + // --- Commands. The int return value is the handle of the operation. + + record Input( + InvocationId invocationId, Slice body, Map headers, @Nullable String key) {} + + @Nullable Input input(); + + int stateGet(String key); + + int stateGetKeys(); + + void stateSet(String key, Slice bytes); + + void stateClear(String key); + + void stateClearAll(); + + int sleep(Duration duration, String name); + + record CallHandle(int invocationIdHandle, int resultHandle) {} + + CallHandle call( + Target target, + Slice payload, + @Nullable String idempotencyKey, + @Nullable Collection> headers); + + int send( + Target target, + Slice payload, + @Nullable String idempotencyKey, + @Nullable Collection> headers, + @Nullable Duration delay); + + record Awakeable(String awakeableId, int handle) {} + + Awakeable awakeable(); + + void completeAwakeable(String awakeableId, Slice value); + + void completeAwakeable(String awakeableId, TerminalException exception); + + int createSignalHandle(String signalName); + + void completeSignal(String targetInvocationId, String signalName, Slice value); + + void completeSignal(String targetInvocationId, String signalName, TerminalException exception); + + int promiseGet(String key); + + int promisePeek(String key); + + int promiseComplete(String key, Slice value); + + int promiseComplete(String key, TerminalException exception); + + int run(String name); + + void proposeRunCompletion(int handle, Slice value); + + void proposeRunCompletion( + int handle, Throwable exception, Duration attemptDuration, RetryPolicy retryPolicy); + + void cancelInvocation(String targetInvocationId); + + int attachInvocation(String invocationId); + + int getInvocationOutput(String invocationId); + + void writeOutput(Slice value); + + void writeOutput(TerminalException exception); + + void end(); + + // -- Introspection + + InvocationState state(); +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java new file mode 100644 index 000000000..7bc2644a2 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java @@ -0,0 +1,653 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import static dev.restate.sdk.core.statemachine.Util.sliceToByteString; +import static dev.restate.sdk.core.statemachine.Util.toProtocolFailure; + +import com.google.protobuf.ByteString; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.EndpointRequestHandler; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.types.*; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import java.util.function.Consumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +class StateMachineImpl implements StateMachine { + + private static final Logger LOG = LogManager.getLogger(StateMachineImpl.class); + private static final String AWAKEABLE_IDENTIFIER_PREFIX = "sign_1"; + private static final int CANCEL_SIGNAL_ID = 1; + + private final Protocol.ServiceProtocolVersion serviceProtocolVersion; + + // Callbacks + private final CompletableFuture waitForReadyFuture = new CompletableFuture<>(); + private CompletableFuture waitNextProcessedInput; + + // Java Flow and message handling + private final MessageDecoder messageDecoder = new MessageDecoder(); + private Flow.@Nullable Subscription inputSubscription; + + // State machine context + private final StateContext stateContext; + + StateMachineImpl( + HeadersAccessor headersAccessor, + EndpointRequestHandler.LoggingContextSetter loggingContextSetter) { + String contentTypeHeader = headersAccessor.get(ServiceProtocol.CONTENT_TYPE); + + this.serviceProtocolVersion = ServiceProtocol.parseServiceProtocolVersion(contentTypeHeader); + + if (!ServiceProtocol.isSupported(this.serviceProtocolVersion)) { + throw new ProtocolException( + String.format( + "Service endpoint does not support the service protocol version '%s'.", + contentTypeHeader), + ProtocolException.UNSUPPORTED_MEDIA_TYPE_CODE); + } + + this.stateContext = new StateContext(loggingContextSetter); + } + + // -- Few callbacks + + @Override + public CompletableFuture waitForReady() { + return waitForReadyFuture; + } + + @Override + public CompletableFuture waitNextInputSignal() { + if (this.stateContext.isInputClosed()) { + return CompletableFuture.completedFuture(null); + } + if (waitNextProcessedInput == null) { + this.waitNextProcessedInput = new CompletableFuture<>(); + } + return this.waitNextProcessedInput; + } + + private void triggerWaitNextInputSignal() { + if (this.waitNextProcessedInput != null) { + CompletableFuture fut = this.waitNextProcessedInput; + this.waitNextProcessedInput = null; + fut.complete(null); + } + } + + // -- IO + + @Override + public void subscribe(Flow.Subscriber subscriber) { + var outputSubscriber = new MessageEncoder(subscriber); + this.stateContext.registerOutputSubscriber(outputSubscriber); + outputSubscriber.onSubscribe( + new Flow.Subscription() { + @Override + public void request(long l) {} + + @Override + public void cancel() { + end(); + } + }); + } + + // --- Input Subscriber impl + + @Override + public void onSubscribe(Flow.Subscription subscription) { + try { + this.inputSubscription = subscription; + this.inputSubscription.request(Long.MAX_VALUE); + } catch (Throwable e) { + this.onError(e); + } + } + + @Override + public void onNext(Slice slice) { + try { + LOG.trace("Received input slice"); + this.messageDecoder.offer(slice); + + boolean shouldTriggerInputListener = this.messageDecoder.isNextAvailable(); + InvocationInput invocationInput = this.messageDecoder.next(); + while (invocationInput != null) { + LOG.trace( + "Received input message {} {}", + invocationInput.message().getClass(), + invocationInput.message()); + + this.stateContext + .getCurrentState() + .onNewMessage(invocationInput, this.stateContext, this.waitForReadyFuture); + + invocationInput = this.messageDecoder.next(); + } + + if (shouldTriggerInputListener) { + this.triggerWaitNextInputSignal(); + } + + } catch (Throwable e) { + this.onError(e); + } + } + + @Override + public void onError(Throwable throwable) { + LOG.trace("Got failure", throwable); + this.stateContext.getCurrentState().hitError(throwable, null, this.stateContext); + cancelInputSubscription(); + } + + @Override + public void onComplete() { + LOG.trace("Input publisher closed"); + try { + this.stateContext.getCurrentState().onInputClosed(this.stateContext); + } catch (Throwable e) { + this.onError(e); + } + this.triggerWaitNextInputSignal(); + this.cancelInputSubscription(); + } + + // -- State machine + + @Override + public String getResponseContentType() { + return ServiceProtocol.serviceProtocolVersionToHeaderValue(serviceProtocolVersion); + } + + @Override + public DoProgressResponse doProgress(List anyHandle) { + return this.stateContext.getCurrentState().doProgress(anyHandle, this.stateContext); + } + + @Override + public boolean isCompleted(int handle) { + return this.stateContext.getCurrentState().isCompleted(handle); + } + + @Override + public Optional takeNotification(int handle) { + return this.stateContext.getCurrentState().takeNotification(handle, this.stateContext); + } + + @Override + public @Nullable Input input() { + return this.stateContext.getCurrentState().processInputCommand(this.stateContext); + } + + @Override + public int stateGet(String key) { + LOG.debug("Executing 'Get state {}'", key); + return this.stateContext.getCurrentState().processStateGetCommand(key, this.stateContext); + } + + @Override + public int stateGetKeys() { + LOG.debug("Executing 'Get state keys'"); + return this.stateContext.getCurrentState().processStateGetKeysCommand(this.stateContext); + } + + @Override + public void stateSet(String key, Slice bytes) { + LOG.debug("Executing 'Set state {}'", key); + ByteString keyBuffer = ByteString.copyFromUtf8(key); + this.stateContext.getEagerState().set(keyBuffer, bytes); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.SetStateCommandMessage.newBuilder() + .setKey(keyBuffer) + .setValue(Protocol.Value.newBuilder().setContent(sliceToByteString(bytes)).build()) + .build(), + CommandAccessor.SET_STATE, + this.stateContext); + } + + @Override + public void stateClear(String key) { + LOG.debug("Executing 'Clear state {}'", key); + ByteString keyBuffer = ByteString.copyFromUtf8(key); + this.stateContext.getEagerState().clear(keyBuffer); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.ClearStateCommandMessage.newBuilder().setKey(keyBuffer).build(), + CommandAccessor.CLEAR_STATE, + this.stateContext); + } + + @Override + public void stateClearAll() { + LOG.debug("Executing 'Clear all state'"); + this.stateContext.getEagerState().clearAll(); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.ClearAllStateCommandMessage.getDefaultInstance(), + CommandAccessor.CLEAR_ALL_STATE, + this.stateContext); + } + + @Override + public int sleep(Duration duration, @Nullable String name) { + LOG.debug("Executing 'Sleeping for {}'", duration); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + + var sleepCommandBuilder = + Protocol.SleepCommandMessage.newBuilder() + .setWakeUpTime(Instant.now().toEpochMilli() + duration.toMillis()) + .setResultCompletionId(completionId); + if (name != null) { + sleepCommandBuilder.setName(name); + } + + return this.stateContext.getCurrentState() + .processCompletableCommand( + sleepCommandBuilder.build(), + CommandAccessor.SLEEP, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public CallHandle call( + Target target, + Slice payload, + @Nullable String idempotencyKey, + @Nullable Collection> headers) { + LOG.debug("Executing 'Call {}'", target); + if (idempotencyKey != null && idempotencyKey.isBlank()) { + throw ProtocolException.idempotencyKeyIsEmpty(); + } + + var invocationIdCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); + var callCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); + + var callCommandBuilder = + Protocol.CallCommandMessage.newBuilder() + .setServiceName(target.getService()) + .setHandlerName(target.getHandler()) + .setParameter(sliceToByteString(payload)) + .setInvocationIdNotificationIdx(invocationIdCompletionId) + .setResultCompletionId(callCompletionId); + if (target.getKey() != null) { + callCommandBuilder.setKey(target.getKey()); + } + if (idempotencyKey != null) { + callCommandBuilder.setIdempotencyKey(idempotencyKey); + } + if (headers != null) { + for (var header : headers) { + callCommandBuilder.addHeaders( + Protocol.Header.newBuilder() + .setKey(header.getKey()) + .setValue(header.getValue()) + .build()); + } + } + + var notificationHandles = + this.stateContext + .getCurrentState() + .processCompletableCommand( + callCommandBuilder.build(), + CommandAccessor.CALL, + new int[] {invocationIdCompletionId, callCompletionId}, + this.stateContext); + + return new CallHandle(notificationHandles[0], notificationHandles[1]); + } + + @Override + public int send( + Target target, + Slice payload, + @Nullable String idempotencyKey, + @Nullable Collection> headers, + @Nullable Duration delay) { + LOG.debug("Executing 'Send {}'", target); + if (idempotencyKey != null && idempotencyKey.isBlank()) { + throw ProtocolException.idempotencyKeyIsEmpty(); + } + + var invocationIdCompletionId = this.stateContext.getJournal().nextCompletionNotificationId(); + + var sendCommandBuilder = + Protocol.OneWayCallCommandMessage.newBuilder() + .setServiceName(target.getService()) + .setHandlerName(target.getHandler()) + .setParameter(sliceToByteString(payload)) + .setInvocationIdNotificationIdx(invocationIdCompletionId); + if (target.getKey() != null) { + sendCommandBuilder.setKey(target.getKey()); + } + if (idempotencyKey != null) { + sendCommandBuilder.setIdempotencyKey(idempotencyKey); + } + if (headers != null) { + for (var header : headers) { + sendCommandBuilder.addHeaders( + Protocol.Header.newBuilder() + .setKey(header.getKey()) + .setValue(header.getValue()) + .build()); + } + } + if (delay != null && !delay.isZero()) { + sendCommandBuilder.setInvokeTime(Instant.now().toEpochMilli() + delay.toMillis()); + } + + return this.stateContext.getCurrentState() + .processCompletableCommand( + sendCommandBuilder.build(), + CommandAccessor.ONE_WAY_CALL, + new int[] {invocationIdCompletionId}, + this.stateContext)[0]; + } + + @Override + public Awakeable awakeable() { + LOG.debug("Executing 'Create awakeable'"); + + var signalId = this.stateContext.getJournal().nextSignalNotificationId(); + + var signalHandle = + this.stateContext + .getCurrentState() + .createSignalHandle(new NotificationId.SignalId(signalId), this.stateContext); + + // Encode awakeable id + String awakeableId = + AWAKEABLE_IDENTIFIER_PREFIX + + Base64.getUrlEncoder() + .encodeToString( + this.stateContext + .getStartInfo() + .id() + .concat(ByteString.copyFrom(ByteBuffer.allocate(4).putInt(signalId).flip())) + .toByteArray()); + + return new Awakeable(awakeableId, signalHandle); + } + + @Override + public void completeAwakeable(String awakeableId, Slice value) { + LOG.debug("Executing 'Complete awakeable {} with success'", awakeableId); + completeAwakeable( + awakeableId, + builder -> + builder.setValue( + Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); + } + + @Override + public void completeAwakeable(String awakeableId, TerminalException exception) { + LOG.debug("Executing 'Complete awakeable {} with failure'", awakeableId); + completeAwakeable(awakeableId, builder -> builder.setFailure(toProtocolFailure(exception))); + } + + private void completeAwakeable( + String awakeableId, Consumer filler) { + var builder = Protocol.CompleteAwakeableCommandMessage.newBuilder().setAwakeableId(awakeableId); + filler.accept(builder); + + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + builder.build(), CommandAccessor.COMPLETE_AWAKEABLE, this.stateContext); + } + + @Override + public int createSignalHandle(String signalName) { + LOG.debug("Executing 'Create signal handle {}'", signalName); + + return this.stateContext + .getCurrentState() + .createSignalHandle(new NotificationId.SignalName(signalName), this.stateContext); + } + + @Override + public void completeSignal(String targetInvocationId, String signalName, Slice value) { + LOG.debug( + "Executing 'Complete signal {} to invocation {} with success'", + signalName, + targetInvocationId); + this.completeSignal( + targetInvocationId, + signalName, + builder -> + builder.setValue( + Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); + } + + @Override + public void completeSignal( + String targetInvocationId, String signalName, TerminalException exception) { + LOG.debug( + "Executing 'Complete signal {} to invocation {} with failure'", + signalName, + targetInvocationId); + this.completeSignal( + targetInvocationId, + signalName, + builder -> builder.setFailure(toProtocolFailure(exception))); + } + + private void completeSignal( + String targetInvocationId, + String signalName, + Consumer filler) { + var builder = + Protocol.SendSignalCommandMessage.newBuilder() + .setTargetInvocationId(targetInvocationId) + .setName(signalName); + filler.accept(builder); + + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + builder.build(), CommandAccessor.SEND_SIGNAL, this.stateContext); + } + + @Override + public int promiseGet(String key) { + LOG.debug("Executing 'Await promise {}'", key); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.GetPromiseCommandMessage.newBuilder() + .setKey(key) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.GET_PROMISE, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public int promisePeek(String key) { + LOG.debug("Executing 'Peek promise {}'", key); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.PeekPromiseCommandMessage.newBuilder() + .setKey(key) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.PEEK_PROMISE, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public int promiseComplete(String key, Slice value) { + LOG.debug("Executing 'Complete promise {} with success'", key); + return this.promiseComplete( + key, + builder -> + builder.setCompletionValue( + Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build())); + } + + @Override + public int promiseComplete(String key, TerminalException exception) { + LOG.debug("Executing 'Complete promise {} with failure'", key); + return this.promiseComplete( + key, builder -> builder.setCompletionFailure(toProtocolFailure(exception))); + } + + private int promiseComplete( + String key, Consumer filler) { + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + + var builder = + Protocol.CompletePromiseCommandMessage.newBuilder() + .setResultCompletionId(completionId) + .setKey(key); + filler.accept(builder); + + return this.stateContext.getCurrentState() + .processCompletableCommand( + builder.build(), + CommandAccessor.COMPLETE_PROMISE, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public int run(String name) { + LOG.debug("Executing 'Created run {}'", name); + return this.stateContext.getCurrentState().processRunCommand(name, this.stateContext); + } + + @Override + public void proposeRunCompletion(int handle, Slice value) { + LOG.debug("Executing 'Run completed with success'"); + this.stateContext.getCurrentState().proposeRunCompletion(handle, value, this.stateContext); + } + + @Override + public void proposeRunCompletion( + int handle, + Throwable exception, + Duration attemptDuration, + @Nullable RetryPolicy retryPolicy) { + LOG.debug("Executing 'Run completed with failure'"); + this.stateContext + .getCurrentState() + .proposeRunCompletion(handle, exception, attemptDuration, retryPolicy, this.stateContext); + } + + @Override + public void cancelInvocation(String targetInvocationId) { + LOG.debug("Executing 'Cancel invocation {}'", targetInvocationId); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.SendSignalCommandMessage.newBuilder() + .setTargetInvocationId(targetInvocationId) + .setIdx(CANCEL_SIGNAL_ID) + .setVoid(Protocol.Void.getDefaultInstance()) + .build(), + CommandAccessor.SEND_SIGNAL, + this.stateContext); + } + + @Override + public int attachInvocation(String invocationId) { + LOG.debug("Executing 'Attach invocation {}'", invocationId); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.AttachInvocationCommandMessage.newBuilder() + .setInvocationId(invocationId) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.ATTACH_INVOCATION, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public int getInvocationOutput(String invocationId) { + LOG.debug("Executing 'Get invocation output {}'", invocationId); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.GetInvocationOutputCommandMessage.newBuilder() + .setInvocationId(invocationId) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.GET_INVOCATION_OUTPUT, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public void writeOutput(Slice value) { + LOG.debug("Executing 'Write invocation output with success'"); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.OutputCommandMessage.newBuilder() + .setValue(Protocol.Value.newBuilder().setContent(sliceToByteString(value)).build()) + .build(), + CommandAccessor.OUTPUT, + this.stateContext); + } + + @Override + public void writeOutput(TerminalException exception) { + LOG.debug("Executing 'Write invocation output with failure'"); + this.stateContext + .getCurrentState() + .processNonCompletableCommand( + Protocol.OutputCommandMessage.newBuilder() + .setFailure(toProtocolFailure(exception)) + .build(), + CommandAccessor.OUTPUT, + this.stateContext); + } + + @Override + public void end() { + this.stateContext.getCurrentState().end(this.stateContext); + cancelInputSubscription(); + } + + @Override + public InvocationState state() { + return this.stateContext.getCurrentState().getInvocationState(); + } + + private void cancelInputSubscription() { + if (this.inputSubscription != null) { + this.inputSubscription.cancel(); + this.inputSubscription = null; + } + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java new file mode 100644 index 000000000..f27335b0e --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/Util.java @@ -0,0 +1,158 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.ByteString; +import com.google.protobuf.MessageLite; +import com.google.protobuf.UnsafeByteOperations; +import dev.restate.common.Slice; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.types.TerminalException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Objects; +import org.jspecify.annotations.Nullable; + +public class Util { + + static Protocol.Failure toProtocolFailure(int code, String message) { + Protocol.Failure.Builder builder = Protocol.Failure.newBuilder().setCode(code); + if (message != null) { + builder.setMessage(message); + } + return builder.build(); + } + + static Protocol.Failure toProtocolFailure(Throwable throwable) { + if (throwable instanceof TerminalException) { + return toProtocolFailure(((TerminalException) throwable).getCode(), throwable.getMessage()); + } + return toProtocolFailure(TerminalException.INTERNAL_SERVER_ERROR_CODE, throwable.toString()); + } + + static Protocol.ErrorMessage toErrorMessage( + Throwable throwable, + int currentCommandIndex, + @Nullable String currentCommandName, + @Nullable MessageType currentCommandType) { + Protocol.ErrorMessage.Builder msg = Protocol.ErrorMessage.newBuilder(); + + if (throwable.getMessage() == null) { + // This happens only with few common exceptions, but anyway + msg.setMessage(throwable.toString()); + } else { + msg.setMessage(throwable.getMessage()); + } + + if (throwable instanceof ProtocolException) { + msg.setCode(((ProtocolException) throwable).getCode()); + } else { + msg.setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE); + } + + // Convert stacktrace to string + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + throwable.printStackTrace(pw); + msg.setDescription(sw.toString()); + + // Add journal entry info + if (currentCommandIndex >= 0) { + msg.setRelatedCommandIndex(currentCommandIndex); + } + if (currentCommandName != null) { + msg.setRelatedCommandName(currentCommandName); + } + if (currentCommandType != null) { + msg.setRelatedCommandType(currentCommandType.encode()); + } + + return msg.build(); + } + + static TerminalException toRestateException(Protocol.Failure failure) { + return new TerminalException(failure.getCode(), failure.getMessage()); + } + + static void assertEntryEquals(MessageLite expected, MessageLite actual) { + if (!Objects.equals(expected, actual)) { + throw ProtocolException.commandDoesNotMatch(expected, actual); + } + } + + static void assertEntryClass(Class clazz, MessageLite actual) { + if (!clazz.equals(actual.getClass())) { + throw ProtocolException.commandClassDoesNotMatch(clazz, actual); + } + } + + /** NOTE! This method rewinds the buffer!!! */ + static ByteString nioBufferToProtobufBuffer(ByteBuffer nioBuffer) { + return UnsafeByteOperations.unsafeWrap(nioBuffer); + } + + /** NOTE! This method rewinds the buffer!!! */ + static ByteString sliceToByteString(Slice slice) { + return nioBufferToProtobufBuffer(slice.asReadOnlyByteBuffer()); + } + + static Slice byteStringToSlice(ByteString byteString) { + return new ByteStringSlice(byteString); + } + + static Duration durationMin(Duration a, Duration b) { + return (a.compareTo(b) <= 0) ? a : b; + } + + private static final class ByteStringSlice implements Slice { + private final ByteString byteString; + + public ByteStringSlice(ByteString bytes) { + this.byteString = Objects.requireNonNull(bytes); + } + + @Override + public ByteBuffer asReadOnlyByteBuffer() { + return byteString.asReadOnlyByteBuffer(); + } + + @Override + public int readableBytes() { + return byteString.size(); + } + + @Override + public void copyTo(byte[] target) { + copyTo(target, 0); + } + + @Override + public void copyTo(byte[] target, int targetOffset) { + byteString.copyTo(target, targetOffset); + } + + @Override + public byte byteAt(int position) { + return byteString.byteAt(position); + } + + @Override + public void copyTo(ByteBuffer buffer) { + byteString.copyTo(buffer); + } + + @Override + public byte[] toByteArray() { + return byteString.toByteArray(); + } + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java new file mode 100644 index 000000000..8be9c36f6 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingReplayEntriesState.java @@ -0,0 +1,68 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.MessageLite; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.CompletableFuture; + +final class WaitingReplayEntriesState implements State { + + private int receivedEntries = 0; + private final Deque commandsToProcess = new ArrayDeque<>(); + private final AsyncResultsState asyncResultsState = new AsyncResultsState(); + + @Override + public void onNewMessage( + InvocationInput invocationInput, + StateContext stateContext, + CompletableFuture waitForReadyFuture) { + if (invocationInput.header().getType().isNotification()) { + if (!(invocationInput.message() + instanceof Protocol.NotificationTemplate notificationTemplate)) { + throw ProtocolException.unexpectedMessage( + Protocol.NotificationTemplate.class, invocationInput.message()); + } + + this.asyncResultsState.enqueue(notificationTemplate); + } else if (invocationInput.header().getType().isCommand()) { + this.commandsToProcess.add(invocationInput.message()); + } else { + throw ProtocolException.unexpectedMessage( + "command or notification", invocationInput.message()); + } + + this.receivedEntries++; + + if (stateContext.getStartInfo().entriesToReplay() == this.receivedEntries) { + stateContext + .getStateHolder() + .transition(new ReplayingState(commandsToProcess, asyncResultsState)); + waitForReadyFuture.complete(null); + } + } + + @Override + public void onInputClosed(StateContext stateContext) { + throw ProtocolException.inputClosedWhileWaitingEntries(); + } + + @Override + public void end(StateContext stateContext) { + throw ProtocolException.closedWhileWaitingEntries(); + } + + @Override + public InvocationState getInvocationState() { + return InvocationState.WAITING_START; + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java new file mode 100644 index 000000000..9dac9d541 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/WaitingStartState.java @@ -0,0 +1,68 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.types.TerminalException; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +final class WaitingStartState implements State { + + @Override + public void onNewMessage( + InvocationInput invocationInput, + StateContext stateContext, + CompletableFuture waitForReadyFuture) { + if (!(invocationInput.message() instanceof Protocol.StartMessage startMessage)) { + throw ProtocolException.unexpectedMessage( + Protocol.StartMessage.class, invocationInput.message()); + } + + // Sanity checks + if (startMessage.getKnownEntries() == 0) { + throw new ProtocolException( + "Expected at least one entry with Input, got 0 entries", + TerminalException.INTERNAL_SERVER_ERROR_CODE); + } + + // Register start info and eager state + stateContext.setStartInfo( + new StartInfo( + startMessage.getId(), + startMessage.getDebugId(), + startMessage.getKey(), + startMessage.getKnownEntries(), + startMessage.getRetryCountSinceLastStoredEntry(), + Duration.ofMillis(startMessage.getDurationSinceLastStoredEntry()))); + stateContext.setEagerState(new EagerState(startMessage)); + + // Tracing and logging setup + LOG.info("Start invocation"); + + // Execute state transition + stateContext.getStateHolder().transition(new WaitingReplayEntriesState()); + } + + @Override + public void onInputClosed(StateContext stateContext) { + throw ProtocolException.inputClosedWhileWaitingEntries(); + } + + @Override + public void end(StateContext stateContext) { + throw ProtocolException.closedWhileWaitingEntries(); + } + + @Override + public InvocationState getInvocationState() { + return InvocationState.WAITING_START; + } +} diff --git a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto b/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto deleted file mode 100644 index e790a8229..000000000 --- a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -syntax = "proto3"; - -package dev.restate.sdk.java; - -import "google/protobuf/any.proto"; -import "dev/restate/service/protocol.proto"; - -option java_package = "dev.restate.generated.sdk.java"; - -// Type: 0xFC00 + 0 -message CombinatorAwaitableEntryMessage { - repeated uint32 entry_index = 1; - - // Entry name - string name = 12; -} diff --git a/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto b/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto index 01a0e1bef..bede0be33 100644 --- a/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto +++ b/sdk-core/src/main/service-protocol/dev/restate/service/discovery.proto @@ -11,8 +11,7 @@ syntax = "proto3"; package dev.restate.service.discovery; -option java_package = "dev.restate.generated.service.discovery"; -option go_package = "restate.dev/sdk-go/pb/service/discovery"; +option java_package = "dev.restate.sdk.core.generated.discovery"; // Service discovery protocol version. enum ServiceDiscoveryProtocolVersion { diff --git a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto b/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto index 45aa41730..cf905808d 100644 --- a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto +++ b/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto @@ -11,8 +11,7 @@ syntax = "proto3"; package dev.restate.service.protocol; -option java_package = "dev.restate.generated.service.protocol"; -option go_package = "restate.dev/sdk-go/pb/service/protocol"; +option java_package = "dev.restate.sdk.core.generated.protocol"; // Service protocol version. enum ServiceProtocolVersion { @@ -26,7 +25,11 @@ enum ServiceProtocolVersion { // * New entry to cancel invocations: CancelInvocationEntryMessage // * New entry to retrieve the invocation id: GetCallInvocationIdEntryMessage // * New field to set idempotency key for Call entries + // * New entry to attach to existing invocation: AttachInvocationEntryMessage + // * New entry to get output of existing invocation: GetInvocationOutputEntryMessage V3 = 3; + // Immutable journal. + V4 = 4; } // --- Core frames --- @@ -47,6 +50,7 @@ message StartMessage { // The user can use this id to address this invocation in admin and status introspection apis. string debug_id = 2; + // This is the sum of known commands + notifications uint32 known_entries = 3; // protolint:disable:next REPEATED_FIELD_NAMES_PLURALIZED @@ -70,28 +74,18 @@ message StartMessage { } // Type: 0x0000 + 1 -message CompletionMessage { - uint32 entry_index = 1; - - oneof result { - Empty empty = 13; - bytes value = 14; - Failure failure = 15; - }; -} - -// Type: 0x0000 + 2 // Implementations MUST send this message when suspending an invocation. +// +// These lists represent any of the notification_idx and/or notification_name the invocation is waiting on to progress. +// The runtime will resume the invocation as soon as either one of the given notification_idx or notification_name is completed. +// Between the two lists there MUST be at least one element. message SuspensionMessage { - // This list represents any of the entry_index the invocation is waiting on to progress. - // The runtime will resume the invocation as soon as one of the given entry_index is completed. - // This list MUST not be empty. - // False positive, entry_indexes is a valid plural of entry_indices. - // https://learn.microsoft.com/en-us/style-guide/a-z-word-list-term-collections/i/index-indexes-indices - repeated uint32 entry_indexes = 1; // protolint:disable:this REPEATED_FIELD_NAMES_PLURALIZED + repeated uint32 waiting_completions = 1; + repeated uint32 waiting_signals = 2; + repeated string waiting_named_signals = 3; } -// Type: 0x0000 + 3 +// Type: 0x0000 + 2 message ErrorMessage { // The code can be any HTTP status code, as described https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml. // In addition, we define the following error codes that MAY be used by the SDK for better error reporting: @@ -103,52 +97,82 @@ message ErrorMessage { // Contains a verbose error description, e.g. the exception stacktrace. string description = 3; - // Entry that caused the failure. This may be outside the current stored journal size. + // Command that caused the failure. This may be outside the current stored journal size. // If no specific entry caused the failure, the current replayed/processed entry can be used. - optional uint32 related_entry_index = 4; + optional uint32 related_command_index = 4; // Name of the entry that caused the failure. - optional string related_entry_name = 5; - // Entry type. - optional uint32 related_entry_type = 6; + optional string related_command_name = 5; + // Command type. + optional uint32 related_command_type = 6; // Delay before executing the next retry, specified as duration in milliseconds. // If provided, it will override the default retry policy used by Restate's invoker ONLY for the next retry attempt. optional uint64 next_retry_delay = 8; } +// Type: 0x0000 + 3 +// Implementations MUST send this message when the invocation lifecycle ends. +message EndMessage { +} + // Type: 0x0000 + 4 -message EntryAckMessage { - uint32 entry_index = 1; +message CommandAckMessage { + uint32 command_index = 1; } +// This is a special control message to propose ctx.run completions to the runtime. +// This won't be written to the journal immediately, but will appear later as a new notification (meaning the result was stored). +// // Type: 0x0000 + 5 -// Implementations MUST send this message when the invocation lifecycle ends. -message EndMessage { +message ProposeRunCompletionMessage { + uint32 result_completion_id = 1; + oneof result { + bytes value = 14; + Failure failure = 15; + }; } -// --- Journal Entries --- +// --- Commands and Notifications --- -// Every Completable JournalEntry has a result field, filled only and only if the entry is in DONE state. +// The Journal is modelled as commands and notifications. +// Commands define the operations executed, while notifications can be: +// * Completions to commands +// * Unnamed signals +// * Named signals // -// For every journal entry, fields 12, 13, 14 and 15 are reserved. -// -// The field 12 is used for name. The name is used by introspection/observability tools. -// -// Depending on the semantics of the corresponding syscall, the entry can represent the completion result field with any of these three types: +// An individual command can produce 0 or more completions, where the respective completion id(s) are defined in the command message. + +// A notification message follows the following duck-type: // -// * google.protobuf.Empty empty = 13 for cases when we need to propagate to user code the distinction between default value or no value. -// * bytes value = 14 for carrying the result value -// * Failure failure = 15 for carrying a failure +message NotificationTemplate { + reserved 12; + + oneof id { + uint32 completion_id = 1; + uint32 signal_id = 2; + string signal_name = 3; + } + + oneof result { + Void void = 4; + Value value = 5; + Failure failure = 6; + + // Used by specific commands + string invocation_id = 16; + StateKeys state_keys = 17; + }; +} // ------ Input and output ------ // Completable: No // Fallible: No // Type: 0x0400 + 0 -message InputEntryMessage { +message InputCommandMessage { repeated Header headers = 1; - bytes value = 14; + Value value = 14; // Entry name string name = 12; @@ -157,9 +181,9 @@ message InputEntryMessage { // Completable: No // Fallible: No // Type: 0x0400 + 1 -message OutputEntryMessage { +message OutputCommandMessage { oneof result { - bytes value = 14; + Value value = 14; Failure failure = 15; }; @@ -171,26 +195,34 @@ message OutputEntryMessage { // Completable: Yes // Fallible: No -// Type: 0x0800 + 0 -message GetStateEntryMessage { +// Type: 0x0400 + 2 +message GetLazyStateCommandMessage { bytes key = 1; + uint32 result_completion_id = 11; + string name = 12; +} + +// Notification for GetLazyStateCommandMessage +// Type: 0x8000 + 2 +message GetLazyStateCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 6, 7, 8, 12; + + uint32 completion_id = 1; + oneof result { - Empty empty = 13; - bytes value = 14; - Failure failure = 15; + Void void = 4; + Value value = 5; }; - - // Entry name - string name = 12; } // Completable: No // Fallible: No -// Type: 0x0800 + 1 -message SetStateEntryMessage { +// Type: 0x0400 + 3 +message SetStateCommandMessage { bytes key = 1; - bytes value = 3; + Value value = 3; // Entry name string name = 12; @@ -198,8 +230,8 @@ message SetStateEntryMessage { // Completable: No // Fallible: No -// Type: 0x0800 + 2 -message ClearStateEntryMessage { +// Type: 0x0400 + 4 +message ClearStateCommandMessage { bytes key = 1; // Entry name @@ -208,39 +240,50 @@ message ClearStateEntryMessage { // Completable: No // Fallible: No -// Type: 0x0800 + 3 -message ClearAllStateEntryMessage { +// Type: 0x0400 + 5 +message ClearAllStateCommandMessage { // Entry name string name = 12; } // Completable: Yes // Fallible: No -// Type: 0x0800 + 4 -message GetStateKeysEntryMessage { - message StateKeys { - repeated bytes keys = 1; - } +// Type: 0x0400 + 6 +message GetLazyStateKeysCommandMessage { + uint32 result_completion_id = 11; + string name = 12; +} + +// Notification for GetLazyStateKeysCommandMessage +// Type: 0x8000 + 6 +message GetLazyStateKeysCompletionNotificationMessage { + // See NotificationMessage above + reserved 2 to 8, 12, 16; + + uint32 completion_id = 1; + StateKeys state_keys = 17; +} + +// Completable: No +// Fallible: No +// Type: 0x0400 + 7 +message GetEagerStateCommandMessage { + bytes key = 1; oneof result { - StateKeys value = 14; - Failure failure = 15; + Void void = 13; + Value value = 14; }; // Entry name string name = 12; } -// Completable: Yes +// Completable: No // Fallible: No -// Type: 0x0800 + 8 -message GetPromiseEntryMessage { - string key = 1; - - oneof result { - bytes value = 14; - Failure failure = 15; - }; +// Type: 0x0400 + 8 +message GetEagerStateKeysCommandMessage { + StateKeys value = 14; // Entry name string name = 12; @@ -248,66 +291,111 @@ message GetPromiseEntryMessage { // Completable: Yes // Fallible: No -// Type: 0x0800 + 9 -message PeekPromiseEntryMessage { +// Type: 0x0400 + 9 +message GetPromiseCommandMessage { string key = 1; + uint32 result_completion_id = 11; + string name = 12; +} + +// Notification for GetPromiseCommandMessage +// Type: 0x8000 + 9 +message GetPromiseCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 4, 12, 16, 17; + + uint32 completion_id = 1; + oneof result { - Empty empty = 13; - bytes value = 14; - Failure failure = 15; + Value value = 5; + Failure failure = 6; }; +} - // Entry name +// Completable: Yes +// Fallible: No +// Type: 0x0400 + A +message PeekPromiseCommandMessage { + string key = 1; + + uint32 result_completion_id = 11; string name = 12; } +// Notification for PeekPromiseCommandMessage +// Type: 0x8000 + A +message PeekPromiseCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 12, 16, 17; + + uint32 completion_id = 1; + + oneof result { + Void void = 4; + Value value = 5; + Failure failure = 6; + }; +} + // Completable: Yes // Fallible: No -// Type: 0x0800 + A -message CompletePromiseEntryMessage { +// Type: 0x0400 + B +message CompletePromiseCommandMessage { string key = 1; // The value to use to complete the promise oneof completion { - bytes completion_value = 2; + Value completion_value = 2; Failure completion_failure = 3; }; - oneof result { - // Returns empty if value was set successfully - Empty empty = 13; - // Returns a failure if the promise was already completed - Failure failure = 15; - } - - // Entry name + uint32 result_completion_id = 11; string name = 12; } +// Notification for CompletePromiseCommandMessage +// Type: 0x8000 + B +message CompletePromiseCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 5, 7, 8, 12, 16, 17; + + uint32 completion_id = 1; + + oneof result { + Void void = 4; + Failure failure = 6; + }; +} + // ------ Syscalls ------ // Completable: Yes // Fallible: No -// Type: 0x0C00 + 0 -message SleepEntryMessage { +// Type: 0x0400 + C +message SleepCommandMessage { // Wake up time. // The time is set as duration since UNIX Epoch. uint64 wake_up_time = 1; - oneof result { - Empty empty = 13; - Failure failure = 15; - } - - // Entry name + uint32 result_completion_id = 11; string name = 12; } -// Completable: Yes +// Notification for SleepCommandMessage +// Type: 0x8000 + C +message SleepCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 5, 6, 7, 8, 12, 16, 17; + + uint32 completion_id = 1; + Void void = 4; +} + +// Completable: Yes (two notifications: one with invocation id, then one with the actual result) // Fallible: Yes -// Type: 0x0C00 + 1 -message CallEntryMessage { +// Type: 0x0400 + D +message CallCommandMessage { string service_name = 1; string handler_name = 2; @@ -321,19 +409,39 @@ message CallEntryMessage { // If present, it must be non empty. optional string idempotency_key = 6; + uint32 invocation_id_notification_idx = 10; + uint32 result_completion_id = 11; + string name = 12; +} + +// Notification for CallCommandMessage and OneWayCallCommandMessage +// Type: 0x8000 + E +message CallInvocationIdCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 4, 5, 6, 7, 8, 12, 17; + + uint32 completion_id = 1; + string invocation_id = 16; +} + +// Notification for CallCommandMessage +// Type: 0x8000 + D +message CallCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 4, 12, 16, 17; + + uint32 completion_id = 1; + oneof result { - bytes value = 14; - Failure failure = 15; + Value value = 5; + Failure failure = 6; }; - - // Entry name - string name = 12; } -// Completable: No +// Completable: Yes (only one notification with invocation id) // Fallible: Yes -// Type: 0x0C00 + 2 -message OneWayCallEntryMessage { +// Type: 0x0400 + E +message OneWayCallCommandMessage { string service_name = 1; string handler_name = 2; @@ -353,88 +461,164 @@ message OneWayCallEntryMessage { // If present, it must be non empty. optional string idempotency_key = 7; - // Entry name + uint32 invocation_id_notification_idx = 10; string name = 12; } -// Completable: Yes -// Fallible: No -// Type: 0x0C00 + 3 -// Awakeables are addressed by an identifier exposed to the user. See the spec for more details. -message AwakeableEntryMessage { +// Completable: No +// Fallible: Yes +// Type: 0x04000 + 10 +message SendSignalCommandMessage { + string target_invocation_id = 1; + + oneof signal_id { + uint32 idx = 2; + string name = 3; + } + oneof result { - bytes value = 14; - Failure failure = 15; + Void void = 4; + Value value = 5; + Failure failure = 6; }; - // Entry name + // Cannot use the field 'name' here because used above + string entry_name = 12; +} + +// Proposals for Run completions are sent through ProposeRunCompletionMessage +// +// Completable: Yes +// Fallible: No +// Type: 0x0400 + 11 +message RunCommandMessage { + uint32 result_completion_id = 11; string name = 12; } -// Completable: No -// Fallible: Yes -// Type: 0x0C00 + 4 -message CompleteAwakeableEntryMessage { - // Identifier of the awakeable. See the spec for more details. - string id = 1; +// Notification for RunCommandMessage +// Type: 0x8000 + 11 +message RunCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 4, 12, 16, 17; + + uint32 completion_id = 1; oneof result { - bytes value = 14; - Failure failure = 15; + Value value = 5; + Failure failure = 6; }; +} - // Entry name +// Completable: Yes +// Fallible: Yes +// Type: 0x0400 + 12 +message AttachInvocationCommandMessage { + oneof target { + // Target invocation id + string invocation_id = 1; + // Target idempotent request + IdempotentRequestTarget idempotent_request_target = 3; + // Target workflow target + WorkflowTarget workflow_target = 4; + } + + uint32 result_completion_id = 11; string name = 12; } -// Completable: No -// Fallible: No -// Type: 0x0C00 + 5 -// Flag: RequiresRuntimeAck -message RunEntryMessage { +// Notification for AttachInvocationCommandMessage +// Type: 0x8000 + 12 +message AttachInvocationCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 4, 12, 16, 17; + + uint32 completion_id = 1; + oneof result { - bytes value = 14; - Failure failure = 15; + Value value = 5; + Failure failure = 6; }; - - // Entry name - string name = 12; } -// Completable: No +// Completable: Yes // Fallible: Yes -// Type: 0x0C00 + 6 -message CancelInvocationEntryMessage { +// Type: 0x0400 + 13 +message GetInvocationOutputCommandMessage { oneof target { - // Target invocation id to cancel + // Target invocation id string invocation_id = 1; - // Target index of the call/one way call journal entry in this journal. - uint32 call_entry_index = 2; + // Target idempotent request + IdempotentRequestTarget idempotent_request_target = 3; + // Target workflow target + WorkflowTarget workflow_target = 4; } - // Entry name + uint32 result_completion_id = 11; string name = 12; } -// Completable: Yes +// Notification for GetInvocationOutputCommandMessage +// Type: 0x8000 + 13 +message GetInvocationOutputCompletionNotificationMessage { + // See NotificationMessage above + reserved 2, 3, 12, 16, 17; + + uint32 completion_id = 1; + + oneof result { + Void void = 4; + Value value = 5; + Failure failure = 6; + }; +} + +// We have this for backward compatibility, because we need to parse both old and new awakeable id. +// Completable: No // Fallible: Yes -// Type: 0x0C00 + 7 -message GetCallInvocationIdEntryMessage { - // Index of the call/one way call journal entry in this journal. - uint32 call_entry_index = 1; +// Type: 0x0400 + 14 +message CompleteAwakeableCommandMessage { + string awakeable_id = 1; oneof result { - string value = 14; - Failure failure = 15; + Value value = 2; + Failure failure = 3; }; + // Cannot use the field 'name' here because used above string name = 12; } +// Notification message for signals +// Type: 0xFBFF +message SignalNotificationMessage { + // See NotificationMessage above + reserved 1, 12, 16, 17; + + oneof signal_id { + uint32 idx = 2; + string name = 3; + } + + oneof result { + Void void = 4; + Value value = 5; + Failure failure = 6; + }; +} + // --- Nested messages +message StateKeys { + repeated bytes keys = 1; +} + +message Value { + bytes content = 1; +} + // This failure object carries user visible errors, -// e.g. invocation failure return value or failure result of an InvokeEntryMessage. +// e.g. invocation failure return value or failure result of an InvokeCommandMessage. message Failure { // The code can be any HTTP status code, as described https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml. uint32 code = 1; @@ -447,5 +631,23 @@ message Header { string value = 2; } -message Empty { +message WorkflowTarget { + string workflow_name = 1; + string workflow_key = 2; } + +message IdempotentRequestTarget { + string service_name = 1; + optional string service_key = 2; + string handler_name = 3; + string idempotency_key = 4; +} + +message Void { +} + +enum BuiltInSignal { + UNKNOWN = 0; + CANCEL = 1; + reserved 2 to 15; +} \ No newline at end of file diff --git a/sdk-core/src/main/service-protocol/service-invocation-protocol.md b/sdk-core/src/main/service-protocol/service-invocation-protocol.md index 89e23cbdf..45fcc1179 100644 --- a/sdk-core/src/main/service-protocol/service-invocation-protocol.md +++ b/sdk-core/src/main/service-protocol/service-invocation-protocol.md @@ -9,7 +9,7 @@ The system is composed of two actors: - Restate Runtime - Service deployment, which is split into: - SDK, which contains the implementation of the Restate Protocol - - User business logic, which interacts with the SDK to access Restate system calls (or syscalls) + - User business logic, which interacts with the SDK to access Restate system calls (or handlerContext) Each invocation is modeled by the protocol as a state machine, where state transitions can be caused either by user code or by _Runtime events_. @@ -37,7 +37,7 @@ sequenceDiagram Note over Runtime,SDK: Replaying Runtime->>SDK: [...]EntryMessage(s) Note over Runtime,SDK: Processing - SDK->>Runtime: HTTP Response headers + SDK->>Runtime: HTTP Response headersAccessor loop SDK->>Runtime: [...]EntryMessage Runtime->>SDK: CompletionMessage and/or EntryAckMessage @@ -78,7 +78,7 @@ There are a couple of properties that we enforce through the design of the proto ### Syscalls Most Restate features, such as interaction with other services, accessing service instance state, and so on, are defined -as _Restate syscalls_ and exposed through the service protocol. The user interacts with these syscalls using the SDK +as _Restate syscalls_ and exposed through the service protocol. The user interacts with these handlerContext using the SDK APIs, which generate _Journal Entry_ messages that will be handled by the invocation state machine. Depending on the specific syscall, the Restate runtime generates as response either: diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java index e144b8bcf..461fbbbee 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java @@ -14,18 +14,23 @@ import static org.assertj.core.api.InstanceOfAssertFactories.type; import com.google.protobuf.MessageLite; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import dev.restate.sdk.core.manifest.Handler; -import dev.restate.sdk.core.manifest.Service; -import java.util.Arrays; +import dev.restate.common.Slice; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.Handler; +import dev.restate.sdk.core.generated.manifest.Service; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.InvocationInput; +import dev.restate.sdk.core.statemachine.MessageDecoder; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.types.TerminalException; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.assertj.core.api.AbstractObjectAssert; +import org.assertj.core.api.ListAssert; import org.assertj.core.api.ObjectAssert; public class AssertUtils { @@ -48,15 +53,19 @@ public static Consumer exactErrorMessage(Throwable e) { return errorMessage( msg -> assertThat(msg) - .returns(e.toString(), Protocol.ErrorMessage::getMessage) + .returns(e.getMessage(), Protocol.ErrorMessage::getMessage) .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, Protocol.ErrorMessage::getCode)); + TerminalException.INTERNAL_SERVER_ERROR_CODE, Protocol.ErrorMessage::getCode) + .extracting(Protocol.ErrorMessage::getDescription, STRING) + .startsWith(e.getClass().getName())); } - public static Consumer errorMessageStartingWith(String str) { + public static Consumer errorDescriptionStartingWith(String str) { return errorMessage( msg -> - assertThat(msg).extracting(Protocol.ErrorMessage::getMessage, STRING).startsWith(str)); + assertThat(msg) + .extracting(Protocol.ErrorMessage::getDescription, STRING) + .startsWith(str)); } public static Consumer protocolExceptionErrorMessage(int code) { @@ -64,28 +73,36 @@ public static Consumer protocolExceptionErrorMessage(int co msg -> assertThat(msg) .returns(code, Protocol.ErrorMessage::getCode) - .extracting(Protocol.ErrorMessage::getMessage, STRING) + .extracting(Protocol.ErrorMessage::getDescription, STRING) .startsWith(ProtocolException.class.getCanonicalName())); } public static EndpointManifestSchemaAssert assertThatDiscovery(Object... services) { + Endpoint.Builder builder = Endpoint.builder(); + for (var svc : services) { + builder.bind(svc); + } + return new EndpointManifestSchemaAssert( new EndpointManifest( EndpointManifestSchema.ProtocolMode.BIDI_STREAM, - Arrays.stream(services) - .map( - svc -> { - if (svc instanceof ServiceDefinition) { - return (ServiceDefinition) svc; - } - - return RestateEndpoint.discoverServiceDefinitionFactory(svc).create(svc); - }), - false) + builder.build().getServiceDefinitions(), + true) .manifest(), EndpointManifestSchemaAssert.class); } + public static ListAssert assertThatDecodingMessages(Slice... slices) { + var messageDecoder = new MessageDecoder(); + Stream.of(slices).forEach(messageDecoder::offer); + + var outputList = new ArrayList(); + while (messageDecoder.isNextAvailable()) { + outputList.add(messageDecoder.next()); + } + return assertThat(outputList); + } + public static class EndpointManifestSchemaAssert extends AbstractObjectAssert { public EndpointManifestSchemaAssert( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java new file mode 100644 index 000000000..a2e7a4904 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AsyncResultTestSuite.java @@ -0,0 +1,323 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import static dev.restate.sdk.core.TestDefinitions.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static org.assertj.core.api.Assertions.assertThat; + +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.types.TerminalException; +import java.util.function.Supplier; +import java.util.stream.Stream; + +public abstract class AsyncResultTestSuite implements TestSuite { + + protected abstract TestInvocationBuilder reverseAwaitOrder(); + + protected abstract TestInvocationBuilder awaitTwiceTheSameAwaitable(); + + protected abstract TestInvocationBuilder awaitAll(); + + protected abstract TestInvocationBuilder awaitAny(); + + protected abstract TestInvocationBuilder combineAnyWithAll(); + + protected abstract TestInvocationBuilder awaitAnyIndex(); + + protected abstract TestInvocationBuilder awaitOnAlreadyResolvedAwaitables(); + + protected abstract TestInvocationBuilder awaitWithTimeout(); + + protected Stream anyTestDefinitions( + Supplier testInvocation) { + return Stream.of( + testInvocation + .get() + .withInput(startMessage(1), inputCmd()) + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + suspensionMessage(2, 4)) + .named("No completions will suspend"), + testInvocation + .get() + .withInput( + startMessage(4), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + callCompletion(4, "TILL")) + .expectingOutput(outputCmd("TILL"), END_MESSAGE) + .named("Only one completion completes any combinator"), + testInvocation + .get() + .withInput( + startMessage(4), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + callCompletion(4, new TerminalException("My error"))) + .expectingOutput(outputCmd(new TerminalException("My error")), END_MESSAGE) + .named("Only one failure completes any combinator"), + testInvocation + .get() + .withInput( + startMessage(5), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCompletion(2, "FRANCESCO"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + callCompletion(4, "TILL")) + .assertingOutput( + msgs -> { + assertThat(msgs).hasSize(2); + + assertThat(msgs).element(0).isIn(outputCmd("FRANCESCO"), outputCmd("TILL")); + assertThat(msgs).element(1).isEqualTo(END_MESSAGE); + }) + .named("Everything completed completes the any combinator"), + testInvocation + .get() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + outputCmd("FRANCESCO"), + END_MESSAGE) + .named("Complete any asynchronously")); + } + + @Override + public Stream definitions() { + return Stream.concat( + // --- Any combinator + anyTestDefinitions(this::awaitAny), + Stream.of( + // --- Reverse await order + this.reverseAwaitOrder() + .withInput(startMessage(1), inputCmd()) + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + suspensionMessage(4)) + .named("None completed"), + this.reverseAwaitOrder() + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, "FRANCESCO"), + callCompletion(4, "TILL")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + setStateCmd("A2", "TILL"), + outputCmd("FRANCESCO-TILL"), + END_MESSAGE) + .named("A1 and A2 completed later"), + this.reverseAwaitOrder() + .withInput( + startMessage(1), + inputCmd(), + callCompletion(4, "TILL"), + callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + setStateCmd("A2", "TILL"), + outputCmd("FRANCESCO-TILL"), + END_MESSAGE) + .named("A2 and A1 completed later in reverse order"), + this.reverseAwaitOrder() + .withInput(startMessage(1), inputCmd(), callCompletion(4, "TILL")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + setStateCmd("A2", "TILL"), + suspensionMessage(2)) + .named("Only A2 completed"), + this.reverseAwaitOrder() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + suspensionMessage(4)) + .named("Only A1 completed"), + + // --- Await twice the same executable + this.awaitTwiceTheSameAwaitable() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + outputCmd("FRANCESCO-FRANCESCO"), + END_MESSAGE), + + // --- All combinator + this.awaitAll() + .withInput(startMessage(1), inputCmd()) + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + suspensionMessage(2, 4)) + .named("No completions will suspend"), + this.awaitAll() + .withInput( + startMessage(4), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + callCompletion(4, "TILL")) + .expectingOutput(suspensionMessage(2)) + .named("Only one completion will suspend"), + this.awaitAll() + .withInput( + startMessage(3), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + callCompletion(2, "FRANCESCO"), + callCompletion(4, "TILL")) + .expectingOutput(outputCmd("FRANCESCO-TILL"), END_MESSAGE) + .named("Everything completed completes the all combinator"), + this.awaitAll() + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, "FRANCESCO"), + callCompletion(4, "TILL")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + outputCmd("FRANCESCO-TILL"), + END_MESSAGE) + .named("Complete all asynchronously"), + this.awaitAll() + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, new IllegalStateException("My error"))) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + outputCmd(new IllegalStateException("My error")), + END_MESSAGE) + .named("All fails on first failure"), + this.awaitAll() + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, "FRANCESCO"), + callCompletion(4, new IllegalStateException("My error"))) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCmd(3, 4, GREETER_SERVICE_TARGET, "Till"), + outputCmd(new IllegalStateException("My error")), + END_MESSAGE) + .named("All fails on second failure"), + + // --- Compose any with all + this.combineAnyWithAll() + .withInput( + startMessage(5), + inputCmd(), + signalNotification(17, "1"), + signalNotification(18, "2"), + signalNotification(19, "3"), + signalNotification(20, "4")) + .expectingOutput(outputCmd("123"), END_MESSAGE), + this.combineAnyWithAll() + .withInput( + startMessage(5), + inputCmd(), + signalNotification(18, "2"), + signalNotification(17, "1"), + signalNotification(20, "4"), + signalNotification(19, "3")) + .expectingOutput(outputCmd("224"), END_MESSAGE) + .named("Inverted order"), + + // --- Await Any with index + this.awaitAnyIndex() + .withInput( + startMessage(5), + inputCmd(), + signalNotification(17, "1"), + signalNotification(18, "2"), + signalNotification(19, "3"), + signalNotification(20, "4")) + .expectingOutput(outputCmd("0"), END_MESSAGE), + this.awaitAnyIndex() + .withInput( + startMessage(5), + inputCmd(), + signalNotification(19, "3"), + signalNotification(18, "2"), + signalNotification(17, "1"), + signalNotification(20, "4")) + .expectingOutput(outputCmd("1"), END_MESSAGE) + .named("Complete all"), + + // --- Compose nested and resolved all should work + this.awaitOnAlreadyResolvedAwaitables() + .withInput( + startMessage(3), + inputCmd(), + signalNotification(17, "1"), + signalNotification(18, "2")) + .expectingOutput(outputCmd("12"), END_MESSAGE), + + // --- Await with timeout + this.awaitWithTimeout() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .assertingOutput( + messages -> { + assertThat(messages).hasSize(4); + assertThat(messages) + .element(0) + .isEqualTo(callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco").build()); + assertThat(messages) + .element(1) + .isInstanceOf(Protocol.SleepCommandMessage.class); + assertThat(messages).element(2).isEqualTo(outputCmd("FRANCESCO")); + assertThat(messages).element(3).isEqualTo(END_MESSAGE); + }), + this.awaitWithTimeout() + .withInput( + startMessage(1), + inputCmd(), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(3) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) + .onlyBidiStream() + .assertingOutput( + messages -> { + assertThat(messages).hasSize(4); + assertThat(messages) + .element(0) + .isEqualTo(callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco").build()); + assertThat(messages) + .element(1) + .isInstanceOf(Protocol.SleepCommandMessage.class); + assertThat(messages).element(2).isEqualTo(outputCmd("timeout")); + assertThat(messages).element(3).isEqualTo(END_MESSAGE); + }) + .named("Fires timeout"))); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java index 3ea051ae9..6f9408423 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java @@ -8,14 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.inputMessage; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import com.google.protobuf.ByteString; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.generated.service.protocol.Protocol.AwakeableEntryMessage; -import dev.restate.generated.service.protocol.Protocol.StartMessage; import dev.restate.sdk.core.TestDefinitions.TestDefinition; import dev.restate.sdk.core.TestDefinitions.TestSuite; import java.nio.ByteBuffer; @@ -35,31 +30,17 @@ public Stream definitions() { ByteBuffer expectedAwakeableId = ByteBuffer.allocate(serializedId.length + 4); expectedAwakeableId.put(serializedId); - expectedAwakeableId.putInt(1); + expectedAwakeableId.putInt(17); expectedAwakeableId.flip(); String base64ExpectedAwakeableId = - Entries.AWAKEABLE_IDENTIFIER_PREFIX - + Base64.getUrlEncoder().encodeToString(expectedAwakeableId.array()); + "sign_1" + Base64.getUrlEncoder().encodeToString(expectedAwakeableId.array()); return Stream.of( returnAwakeableId() .withInput( - StartMessage.newBuilder() - .setDebugId(debugId) - .setId(ByteString.copyFrom(serializedId)) - .setKnownEntries(1), - inputMessage()) - .assertingOutput( - messages -> { - assertThat(messages).element(0).isInstanceOf(AwakeableEntryMessage.class); - assertThat(messages) - .element(1) - .asInstanceOf(type(Protocol.OutputEntryMessage.class)) - .extracting( - out -> - TestSerdes.STRING.deserialize(out.getValue().asReadOnlyByteBuffer())) - .isEqualTo(base64ExpectedAwakeableId); - })); + startMessage(1).setDebugId(debugId).setId(ByteString.copyFrom(serializedId)), + inputCmd()) + .expectingOutput(outputCmd(base64ExpectedAwakeableId), END_MESSAGE)); } private byte[] serializeUUID(UUID uuid) { diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java new file mode 100644 index 000000000..46d4046a9 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java @@ -0,0 +1,80 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; + +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.TestDefinitions.TestDefinition; +import dev.restate.sdk.core.TestDefinitions.TestSuite; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.types.TerminalException; +import java.util.Map; +import java.util.stream.Stream; + +public abstract class CallTestSuite implements TestSuite { + + protected abstract TestDefinitions.TestInvocationBuilder oneWayCall( + Target target, String idempotencyKey, Map headers, Slice body); + + protected abstract TestDefinitions.TestInvocationBuilder implicitCancellation( + Target target, Slice body); + + private static String IDEMPOTENCY_KEY = "my-idempotency-key"; + private static Map HEADERS = Map.of("abc", "123", "fge", "456"); + private static Slice BODY = Slice.wrap("bla"); + + @Override + public Stream definitions() { + return Stream.of( + oneWayCall(GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) + .withInput(startMessage(1), inputCmd()) + .expectingOutput( + oneWayCallCmd(1, GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), + outputCmd(), + END_MESSAGE), + oneWayCall(GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) + .withInput( + startMessage(3), + inputCmd(), + oneWayCallCmd(1, GREETER_SERVICE_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), + callInvocationIdCompletion(1, "abc")) + .expectingOutput(outputCmd(), END_MESSAGE) + .named("With invocation ID completion"), + oneWayCall(GREETER_VIRTUAL_OBJECT_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY) + .withInput(startMessage(1), inputCmd()) + .expectingOutput( + oneWayCallCmd(1, GREETER_VIRTUAL_OBJECT_TARGET, IDEMPOTENCY_KEY, HEADERS, BODY), + outputCmd(), + END_MESSAGE), + implicitCancellation(GREETER_SERVICE_TARGET, BODY) + .withInput( + startMessage(3), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, BODY.toByteArray()), + CANCELLATION_SIGNAL) + .onlyBidiStream() + .expectingOutput(Protocol.SuspensionMessage.newBuilder().addWaitingCompletions(1)) + .named("Suspends on waiting the invocation id"), + implicitCancellation(GREETER_SERVICE_TARGET, BODY) + .withInput( + startMessage(4), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, BODY.toByteArray()), + CANCELLATION_SIGNAL, + callInvocationIdCompletion(1, "my-id")) + .onlyBidiStream() + .expectingOutput( + sendCancelSignal("my-id"), + outputCmd(new TerminalException(TerminalException.CANCELLED_CODE)), + END_MESSAGE) + .named("Surfaces cancellation")); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java index 5b4ba80ca..8e2f5e799 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java @@ -10,14 +10,13 @@ import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.common.HandlerType; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.ServiceType; -import dev.restate.sdk.common.syscalls.HandlerDefinition; -import dev.restate.sdk.common.syscalls.HandlerSpecification; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import dev.restate.sdk.core.manifest.Service; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.Service; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; +import dev.restate.serde.Serde; import java.util.List; import java.util.stream.Stream; import org.junit.jupiter.api.Test; @@ -35,9 +34,7 @@ void handleWithMultipleServices() { ServiceType.SERVICE, List.of( HandlerDefinition.of( - HandlerSpecification.of( - "greet", HandlerType.EXCLUSIVE, Serde.VOID, Serde.VOID), - null)))), + "greet", HandlerType.EXCLUSIVE, Serde.VOID, Serde.VOID, null)))), false); EndpointManifestSchema manifest = deploymentManifest.manifest(); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java deleted file mode 100644 index f012a4f37..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.ProtoUtils.*; -import static dev.restate.sdk.core.TestDefinitions.*; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.list; -import static org.assertj.core.api.InstanceOfAssertFactories.type; - -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.generated.service.protocol.Protocol.Empty; -import java.util.function.Supplier; -import java.util.stream.Stream; - -public abstract class DeferredTestSuite implements TestSuite { - - protected abstract TestInvocationBuilder reverseAwaitOrder(); - - protected abstract TestInvocationBuilder awaitTwiceTheSameAwaitable(); - - protected abstract TestInvocationBuilder awaitAll(); - - protected abstract TestInvocationBuilder awaitAny(); - - protected abstract TestInvocationBuilder combineAnyWithAll(); - - protected abstract TestInvocationBuilder awaitAnyIndex(); - - protected abstract TestInvocationBuilder awaitOnAlreadyResolvedAwaitables(); - - protected abstract TestInvocationBuilder awaitWithTimeout(); - - protected Stream anyTestDefinitions( - Supplier testInvocation) { - return Stream.of( - testInvocation - .get() - .withInput(startMessage(1), inputMessage()) - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(1, 2)) - .named("No completions will suspend"), - testInvocation - .get() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), - ackMessage(3)) - .expectingOutput(combinatorsMessage(2), outputMessage("TILL"), END_MESSAGE) - .named("Only one completion will generate the combinators message"), - testInvocation - .get() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) - .expectingOutput(combinatorsMessage(2), suspensionMessage(3)) - .named("Completed without ack will suspend"), - testInvocation - .get() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till") - .setFailure(Util.toProtocolFailure(new IllegalStateException("My error"))), - ackMessage(3)) - .expectingOutput( - combinatorsMessage(2), - outputMessage(new IllegalStateException("My error")), - END_MESSAGE) - .named("Only one failure will generate the combinators message"), - testInvocation - .get() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), - ackMessage(3)) - .assertingOutput( - msgs -> { - assertThat(msgs).hasSize(3); - - assertThat(msgs) - .element(0, type(Java.CombinatorAwaitableEntryMessage.class)) - .extracting( - Java.CombinatorAwaitableEntryMessage::getEntryIndexList, - list(Integer.class)) - .hasSize(1) - .element(0) - .isIn(1, 2); - - assertThat(msgs) - .element(1) - .isIn(outputMessage("FRANCESCO"), outputMessage("TILL")); - assertThat(msgs).element(2).isEqualTo(END_MESSAGE); - }) - .named("Everything completed will generate the combinators message"), - testInvocation - .get() - .withInput( - startMessage(4), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), - combinatorsMessage(2)) - .expectingOutput(outputMessage("TILL"), END_MESSAGE) - .named("Replay the combinator"), - testInvocation - .get() - .withInput( - startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO"), ackMessage(3)) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - combinatorsMessage(1), - outputMessage("FRANCESCO"), - END_MESSAGE) - .named("Complete any asynchronously")); - } - - @Override - public Stream definitions() { - return Stream.concat( - // --- Any combinator - anyTestDefinitions(this::awaitAny), - Stream.of( - // --- Reverse await order - this.reverseAwaitOrder() - .withInput(startMessage(1), inputMessage()) - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(2)) - .named("None completed"), - this.reverseAwaitOrder() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, "FRANCESCO"), - completionMessage(2, "TILL")) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - setStateMessage("A2", "TILL"), - outputMessage("FRANCESCO-TILL"), - END_MESSAGE) - .named("A1 and A2 completed later"), - this.reverseAwaitOrder() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(2, "TILL"), - completionMessage(1, "FRANCESCO")) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - setStateMessage("A2", "TILL"), - outputMessage("FRANCESCO-TILL"), - END_MESSAGE) - .named("A2 and A1 completed later"), - this.reverseAwaitOrder() - .withInput(startMessage(1), inputMessage(), completionMessage(2, "TILL")) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - setStateMessage("A2", "TILL"), - suspensionMessage(1)) - .named("Only A2 completed"), - this.reverseAwaitOrder() - .withInput(startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO")) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(2)) - .named("Only A1 completed"), - - // --- Await twice the same executable - this.awaitTwiceTheSameAwaitable() - .withInput(startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO")) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - outputMessage("FRANCESCO-FRANCESCO"), - END_MESSAGE), - - // --- All combinator - this.awaitAll() - .withInput(startMessage(1), inputMessage()) - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - suspensionMessage(1, 2)) - .named("No completions will suspend"), - this.awaitAll() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) - .expectingOutput(suspensionMessage(1)) - .named("Only one completion will suspend"), - this.awaitAll() - .withInput( - startMessage(3), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), - ackMessage(3)) - .assertingOutput( - msgs -> { - assertThat(msgs).hasSize(3); - - assertThat(msgs) - .element(0, type(Java.CombinatorAwaitableEntryMessage.class)) - .extracting( - Java.CombinatorAwaitableEntryMessage::getEntryIndexList, - list(Integer.class)) - .containsExactlyInAnyOrder(1, 2); - - assertThat(msgs).element(1).isEqualTo(outputMessage("FRANCESCO-TILL")); - assertThat(msgs).element(2).isEqualTo(END_MESSAGE); - }) - .named("Everything completed will generate the combinators message"), - this.awaitAll() - .withInput( - startMessage(4), - inputMessage(), - invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), - combinatorsMessage(1, 2)) - .expectingOutput(outputMessage("FRANCESCO-TILL"), END_MESSAGE) - .named("Replay the combinator"), - this.awaitAll() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, "FRANCESCO"), - completionMessage(2, "TILL"), - ackMessage(3)) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - combinatorsMessage(1, 2), - outputMessage("FRANCESCO-TILL"), - END_MESSAGE) - .named("Complete all asynchronously"), - this.awaitAll() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, new IllegalStateException("My error")), - ackMessage(3)) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - combinatorsMessage(1), - outputMessage(new IllegalStateException("My error")), - END_MESSAGE) - .named("All fails on first failure"), - this.awaitAll() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, "FRANCESCO"), - completionMessage(2, new IllegalStateException("My error")), - ackMessage(3)) - .onlyUnbuffered() - .expectingOutput( - invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till"), - combinatorsMessage(1, 2), - outputMessage(new IllegalStateException("My error")), - END_MESSAGE) - .named("All fails on second failure"), - - // --- Compose any with all - this.combineAnyWithAll() - .withInput( - startMessage(6), - inputMessage(), - awakeable("1"), - awakeable("2"), - awakeable("3"), - awakeable("4"), - combinatorsMessage(2, 3)) - .expectingOutput(outputMessage("223"), END_MESSAGE), - this.combineAnyWithAll() - .withInput( - startMessage(6), - inputMessage(), - awakeable("1"), - awakeable("2"), - awakeable("3"), - awakeable("4"), - combinatorsMessage(3, 2)) - .expectingOutput(outputMessage("233"), END_MESSAGE) - .named("Inverted order"), - - // --- Await Any with index - this.awaitAnyIndex() - .withInput( - startMessage(6), - inputMessage(), - awakeable("1"), - awakeable("2"), - awakeable("3"), - awakeable("4"), - combinatorsMessage(1)) - .expectingOutput(outputMessage("0"), END_MESSAGE), - this.awaitAnyIndex() - .withInput( - startMessage(6), - inputMessage(), - awakeable("1"), - awakeable("2"), - awakeable("3"), - awakeable("4"), - combinatorsMessage(3, 2)) - .expectingOutput(outputMessage("1"), END_MESSAGE) - .named("Complete all"), - - // --- Compose nested and resolved all should work - this.awaitOnAlreadyResolvedAwaitables() - .withInput( - startMessage(3), - inputMessage(), - awakeable("1"), - awakeable("2"), - ackMessage(3), - ackMessage(4)) - .assertingOutput( - msgs -> { - assertThat(msgs).hasSize(4); - - assertThat(msgs) - .element(0, type(Java.CombinatorAwaitableEntryMessage.class)) - .extracting( - Java.CombinatorAwaitableEntryMessage::getEntryIndexList, - list(Integer.class)) - .containsExactlyInAnyOrder(1, 2); - - assertThat(msgs).element(1).isEqualTo(combinatorsMessage()); - assertThat(msgs).element(2).isEqualTo(outputMessage("12")); - assertThat(msgs).element(3).isEqualTo(END_MESSAGE); - }), - - // --- Await with timeout - this.awaitWithTimeout() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, "FRANCESCO"), - ackMessage(3)) - .onlyUnbuffered() - .assertingOutput( - messages -> { - assertThat(messages).hasSize(5); - assertThat(messages) - .element(0) - .isEqualTo(invokeMessage(GREETER_SERVICE_TARGET, "Francesco").build()); - assertThat(messages) - .element(1) - .isInstanceOf(Protocol.SleepEntryMessage.class); - assertThat(messages).element(2).isEqualTo(combinatorsMessage(1)); - assertThat(messages).element(3).isEqualTo(outputMessage("FRANCESCO")); - assertThat(messages).element(4).isEqualTo(END_MESSAGE); - }), - this.awaitWithTimeout() - .withInput( - startMessage(1), - inputMessage(), - completionMessage(2).setEmpty(Empty.getDefaultInstance()), - ackMessage(3)) - .onlyUnbuffered() - .assertingOutput( - messages -> { - assertThat(messages).hasSize(5); - assertThat(messages) - .element(0) - .isEqualTo(invokeMessage(GREETER_SERVICE_TARGET, "Francesco").build()); - assertThat(messages) - .element(1) - .isInstanceOf(Protocol.SleepEntryMessage.class); - assertThat(messages).element(2).isEqualTo(combinatorsMessage(2)); - assertThat(messages).element(3).isEqualTo(outputMessage("timeout")); - assertThat(messages).element(4).isEqualTo(END_MESSAGE); - }) - .named("Fires timeout"))); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java index 2925db24d..0f3c34ebc 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/EagerStateTestSuite.java @@ -8,14 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.*; +import static dev.restate.sdk.core.generated.protocol.Protocol.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.AssertionsForClassTypes.entry; import com.google.protobuf.MessageLite; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.generated.service.protocol.Protocol.ClearAllStateEntryMessage; -import dev.restate.generated.service.protocol.Protocol.Empty; import java.util.Map; import java.util.stream.Stream; @@ -38,30 +36,29 @@ public abstract class EagerStateTestSuite implements TestSuite { private static final Map.Entry STATE_FRANCESCO = entry("STATE", "Francesco"); private static final Map.Entry ANOTHER_STATE_FRANCESCO = entry("ANOTHER_STATE", "Francesco"); - private static final MessageLite INPUT_TILL = inputMessage("Till"); - private static final MessageLite GET_STATE_FRANCESCO = getStateMessage("STATE", "Francesco"); + private static final MessageLite INPUT_TILL = inputCmd("Till"); + private static final MessageLite GET_STATE_FRANCESCO = getEagerStateCmd("STATE", "Francesco"); private static final MessageLite GET_STATE_FRANCESCO_TILL = - getStateMessage("STATE", "FrancescoTill"); - private static final MessageLite SET_STATE_FRANCESCO_TILL = - setStateMessage("STATE", "FrancescoTill"); - private static final MessageLite OUTPUT_FRANCESCO = outputMessage("Francesco"); - private static final MessageLite OUTPUT_FRANCESCO_TILL = outputMessage("FrancescoTill"); + getEagerStateCmd("STATE", "FrancescoTill"); + private static final MessageLite SET_STATE_FRANCESCO_TILL = setStateCmd("STATE", "FrancescoTill"); + private static final MessageLite OUTPUT_FRANCESCO = outputCmd("Francesco"); + private static final MessageLite OUTPUT_FRANCESCO_TILL = outputCmd("FrancescoTill"); @Override public Stream definitions() { return Stream.of( this.getEmpty() .withInput(startMessage(1).setPartialState(false), INPUT_TILL) - .expectingOutput(getStateEmptyMessage("STATE"), outputMessage("true"), END_MESSAGE) + .expectingOutput(getEagerStateEmptyCmd("STATE"), outputCmd("true"), END_MESSAGE) .named("With complete state"), this.getEmpty() .withInput(startMessage(1).setPartialState(true), INPUT_TILL) - .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) + .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) .named("With partial state"), this.getEmpty() .withInput( - startMessage(2).setPartialState(true), INPUT_TILL, getStateEmptyMessage("STATE")) - .expectingOutput(outputMessage("true"), END_MESSAGE) + startMessage(2).setPartialState(true), INPUT_TILL, getEagerStateEmptyCmd("STATE")) + .expectingOutput(outputCmd("true"), END_MESSAGE) .named("Resume with partial state"), this.get() .withInput( @@ -75,7 +72,7 @@ public Stream definitions() { .named("With partial state"), this.get() .withInput(startMessage(1).setPartialState(true), INPUT_TILL) - .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) + .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) .named("With partial state without the state entry"), this.getAppendAndGet() .withInput(startMessage(1, "my-greeter", STATE_FRANCESCO), INPUT_TILL) @@ -90,9 +87,10 @@ public Stream definitions() { .withInput( startMessage(1).setPartialState(true), INPUT_TILL, - completionMessage(1, "Francesco")) + getLazyStateCompletion(1, "Francesco")) + .onlyBidiStream() .expectingOutput( - getStateMessage("STATE"), + getLazyStateCmd(1, "STATE"), SET_STATE_FRANCESCO_TILL, GET_STATE_FRANCESCO_TILL, OUTPUT_FRANCESCO_TILL, @@ -102,8 +100,8 @@ public Stream definitions() { .withInput(startMessage(1, "my-greeter", STATE_FRANCESCO), INPUT_TILL) .expectingOutput( GET_STATE_FRANCESCO, - clearStateMessage("STATE"), - getStateEmptyMessage("STATE"), + clearStateCmd("STATE"), + getEagerStateEmptyCmd("STATE"), OUTPUT_FRANCESCO, END_MESSAGE) .named("With state in the state_map"), @@ -111,11 +109,12 @@ public Stream definitions() { .withInput( startMessage(1).setPartialState(true), INPUT_TILL, - completionMessage(1, "Francesco")) + getLazyStateCompletion(1, "Francesco")) + .onlyBidiStream() .expectingOutput( - getStateMessage("STATE"), - clearStateMessage("STATE"), - getStateEmptyMessage("STATE"), + getLazyStateCmd(1, "STATE"), + clearStateCmd("STATE"), + getEagerStateEmptyCmd("STATE"), OUTPUT_FRANCESCO, END_MESSAGE) .named("With partial state on the first get"), @@ -124,9 +123,9 @@ public Stream definitions() { startMessage(1, "my-greeter", STATE_FRANCESCO, ANOTHER_STATE_FRANCESCO), INPUT_TILL) .expectingOutput( GET_STATE_FRANCESCO, - ClearAllStateEntryMessage.getDefaultInstance(), - getStateEmptyMessage("STATE"), - getStateEmptyMessage("ANOTHER_STATE"), + ClearAllStateCommandMessage.getDefaultInstance(), + getEagerStateEmptyCmd("STATE"), + getEagerStateEmptyCmd("ANOTHER_STATE"), OUTPUT_FRANCESCO, END_MESSAGE) .named("With state in the state_map"), @@ -134,12 +133,13 @@ public Stream definitions() { .withInput( startMessage(1).setPartialState(true), INPUT_TILL, - completionMessage(1, STATE_FRANCESCO.getValue())) + getLazyStateCompletion(1, STATE_FRANCESCO.getValue())) + .onlyBidiStream() .expectingOutput( - getStateMessage("STATE"), - ClearAllStateEntryMessage.getDefaultInstance(), - getStateEmptyMessage("STATE"), - getStateEmptyMessage("ANOTHER_STATE"), + getLazyStateCmd(1, "STATE"), + ClearAllStateCommandMessage.getDefaultInstance(), + getEagerStateEmptyCmd("STATE"), + getEagerStateEmptyCmd("ANOTHER_STATE"), OUTPUT_FRANCESCO, END_MESSAGE) .named("With partial state on the first get"), @@ -147,44 +147,42 @@ public Stream definitions() { .withInput( startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(true), INPUT_TILL, - completionMessage(1, stateKeys("a", "b"))) + GetLazyStateKeysCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setStateKeys(stateKeys("a", "b"))) + .onlyBidiStream() .expectingOutput( - Protocol.GetStateKeysEntryMessage.getDefaultInstance(), - outputMessage("a,b"), + GetLazyStateKeysCommandMessage.newBuilder().setResultCompletionId(1), + outputCmd("a,b"), END_MESSAGE) .named("With partial state"), this.listKeys() .withInput( startMessage(1, "my-greeter", STATE_FRANCESCO).setPartialState(false), INPUT_TILL) .expectingOutput( - Protocol.GetStateKeysEntryMessage.newBuilder() + GetEagerStateKeysCommandMessage.newBuilder() .setValue(stateKeys(STATE_FRANCESCO.getKey())), - outputMessage(STATE_FRANCESCO.getKey()), + outputCmd(STATE_FRANCESCO.getKey()), END_MESSAGE) .named("With complete state"), this.listKeys() .withInput( startMessage(2).setPartialState(true), INPUT_TILL, - Protocol.GetStateKeysEntryMessage.newBuilder().setValue(stateKeys("3", "2", "1"))) - .expectingOutput(outputMessage("3,2,1"), END_MESSAGE) + GetEagerStateKeysCommandMessage.newBuilder().setValue(stateKeys("3", "2", "1"))) + .expectingOutput(outputCmd("3,2,1"), END_MESSAGE) .named("With replayed list"), this.consecutiveGetWithEmpty() - .withInput(startMessage(1).setPartialState(false), inputMessage()) + .withInput(startMessage(1).setPartialState(false), inputCmd()) .expectingOutput( - getStateMessage("key-0").setEmpty(Empty.getDefaultInstance()), - getStateMessage("key-0").setEmpty(Empty.getDefaultInstance()), - outputMessage(), + getEagerStateEmptyCmd("key-0"), + getEagerStateEmptyCmd("key-0"), + outputCmd(), END_MESSAGE), this.consecutiveGetWithEmpty() .withInput( - startMessage(2).setPartialState(false), - inputMessage(), - getStateMessage("key-0").setEmpty(Empty.getDefaultInstance())) - .expectingOutput( - getStateMessage("key-0").setEmpty(Empty.getDefaultInstance()), - outputMessage(), - END_MESSAGE) + startMessage(2).setPartialState(false), inputCmd(), getEagerStateEmptyCmd("key-0")) + .expectingOutput(getEagerStateEmptyCmd("key-0"), outputCmd(), END_MESSAGE) .named("With replay of the first get")); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java index ccb32ca5e..00bf88d9b 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/InvocationIdTestSuite.java @@ -8,11 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.END_MESSAGE; -import static dev.restate.sdk.core.ProtoUtils.outputMessage; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import com.google.protobuf.ByteString; -import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.TestDefinitions.TestDefinition; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import dev.restate.sdk.core.TestDefinitions.TestSuite; @@ -29,10 +27,8 @@ public Stream definitions() { return Stream.of( returnInvocationId() - .withInput( - Protocol.StartMessage.newBuilder().setDebugId(debugId).setId(id).setKnownEntries(1), - ProtoUtils.inputMessage()) - .onlyUnbuffered() - .expectingOutput(outputMessage(debugId), END_MESSAGE)); + .withInput(startMessage(1).setDebugId(debugId).setId(id), inputCmd()) + .onlyBidiStream() + .expectingOutput(outputCmd(debugId), END_MESSAGE)); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MessageDecoderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/MessageDecoderTest.java deleted file mode 100644 index ed7526db1..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MessageDecoderTest.java +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static dev.restate.sdk.core.ProtoUtils.inputMessage; -import static dev.restate.sdk.core.ProtoUtils.startMessage; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.entry; - -import com.google.protobuf.MessageLite; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import java.nio.ByteBuffer; -import java.util.List; -import org.junit.jupiter.api.Test; - -public class MessageDecoderTest { - - @Test - void oneMessage() { - AssertSubscriber assertSubscriber = AssertSubscriber.create(1); - - Multi.createFrom() - .item(ProtoUtils.messageToByteString(startMessage(1, "my-key", entry("key", "value")))) - .subscribe(new MessageDecoder(assertSubscriber)); - - assertThat(assertSubscriber.getLastItem().message()) - .isEqualTo(startMessage(1, "my-key", entry("key", "value")).build()); - } - - @Test - void multiMessage() { - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - - Multi.createFrom() - .items( - ProtoUtils.messageToByteString(startMessage(1, "my-key", entry("key", "value"))), - ProtoUtils.messageToByteString(inputMessage("my-value"))) - .subscribe(new MessageDecoder(assertSubscriber)); - - assertThat(assertSubscriber.getItems()) - .map(InvocationInput::message) - .containsExactly( - startMessage(1, "my-key", entry("key", "value")).build(), inputMessage("my-value")); - } - - @Test - void multiMessageInSingleBuffer() { - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - - List messages = - List.of(startMessage(1, "my-key", entry("key", "value")).build(), inputMessage("my-value")); - ByteBuffer byteBuffer = - ByteBuffer.allocate(messages.stream().mapToInt(MessageEncoder::encodeLength).sum()); - messages.stream().map(ProtoUtils::messageToByteString).forEach(byteBuffer::put); - byteBuffer.flip(); - - Multi.createFrom().item(byteBuffer).subscribe(new MessageDecoder(assertSubscriber)); - - assertThat(assertSubscriber.getItems()) - .map(InvocationInput::message) - .containsExactly( - startMessage(1, "my-key", entry("key", "value")).build(), inputMessage("my-value")); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java deleted file mode 100644 index f46cef6f5..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static org.assertj.core.api.Assertions.assertThat; - -import org.junit.jupiter.api.Test; - -public class MessageHeaderTest { - - @Test - void requiresAckFlag() { - assertThat( - new MessageHeader( - MessageType.CallEntryMessage, - MessageHeader.DONE_FLAG | MessageHeader.REQUIRES_ACK_FLAG, - 2) - .encode()) - .isEqualTo(0x0C01_8001_0000_0002L); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java new file mode 100644 index 000000000..50c540238 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MockBidiStream.java @@ -0,0 +1,110 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; + +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.InvocationInput; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.helpers.test.AssertSubscriber; +import io.smallrye.mutiny.subscription.DemandPacer; +import io.smallrye.mutiny.subscription.FixedDemandPacer; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.apache.logging.log4j.ThreadContext; + +public final class MockBidiStream implements TestDefinitions.TestExecutor { + + public static final MockBidiStream INSTANCE = new MockBidiStream(); + + private MockBidiStream() {} + + @Override + public boolean buffered() { + return false; + } + + @Override + public void executeTest(TestDefinitions.TestDefinition definition) { + Executor coreExecutor = Executors.newSingleThreadExecutor(); + + // This test infra supports only services returning one service definition + ServiceDefinition serviceDefinition = definition.getServiceDefinition(); + + // Prepare server + Endpoint.Builder builder = + Endpoint.builder().bind(serviceDefinition, definition.getServiceOptions()); + if (definition.isEnablePreviewContext()) { + builder.enablePreviewContext(); + } + EndpointRequestHandler server = EndpointRequestHandler.forBidiStream(builder.build()); + + // Start invocation + RequestProcessor handler = + server.processorForRequest( + "/" + serviceDefinition.getServiceName() + "/" + definition.getMethod(), + HeadersAccessor.wrap( + Map.of("content-type", ProtoUtils.serviceProtocolContentTypeHeader())), + EndpointRequestHandler.LoggingContextSetter.THREAD_LOCAL_INSTANCE, + coreExecutor); + + // Wire invocation + AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); + + // Wire invocation and start it + Multi.createFrom() + .iterable(definition.getInput()) + .runSubscriptionOn(coreExecutor) + .map(ProtoUtils::invocationInputToByteString) + .map(Slice::wrap) + .paceDemand() + .using(inputPacer(definition.getInput())) + .emitOn(coreExecutor) + .subscribe(handler); + Multi.createFrom() + .publisher(handler) + .runSubscriptionOn(coreExecutor) + .subscribe(assertSubscriber); + + // Check completed + assertSubscriber.awaitCompletion(Duration.ofSeconds(10)); + + // Unwrap messages and decode them + //noinspection unchecked + assertThatDecodingMessages(assertSubscriber.getItems().toArray(Slice[]::new)) + .map(InvocationInput::message) + .satisfies(l -> definition.getOutputAssert().accept((List) l)); + + // Clean logging + ThreadContext.clearAll(); + } + + private DemandPacer inputPacer(List input) { + if (input.get(0).message() instanceof Protocol.StartMessage startMessage) { + int knownEntries = startMessage.getKnownEntries(); + if (knownEntries != input.size() - 1) { + // We're sending a journal to replay plus more stuff, let's pace after the replay ends + return new FixedDemandPacer(knownEntries + 1, Duration.ofMillis(200)); + } + } + // We're only sending a journal to replay, or we're not sending start message, let's just pace + // right in the middle + return new FixedDemandPacer(Math.min(1, input.size() / 2), Duration.ofMillis(100)); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockMultiThreaded.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockMultiThreaded.java deleted file mode 100644 index 9698b963b..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockMultiThreaded.java +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import org.apache.logging.log4j.ThreadContext; - -public final class MockMultiThreaded implements TestDefinitions.TestExecutor { - - public static final MockMultiThreaded INSTANCE = new MockMultiThreaded(); - - private MockMultiThreaded() {} - - @Override - public boolean buffered() { - return false; - } - - @Override - public void executeTest(TestDefinitions.TestDefinition definition) { - Executor syscallsExecutor = Executors.newSingleThreadExecutor(); - - // This test infra supports only services returning one service definition - ServiceDefinition serviceDefinition = definition.getServiceDefinition(); - - // Prepare server - @SuppressWarnings("unchecked") - RestateEndpoint.Builder builder = - RestateEndpoint.newBuilder(EndpointManifestSchema.ProtocolMode.BIDI_STREAM) - .bind( - (ServiceDefinition) serviceDefinition, - definition.getServiceOptions()); - if (definition.isEnablePreviewContext()) { - builder.enablePreviewContext(); - } - RestateEndpoint server = builder.build(); - - // Start invocation - ResolvedEndpointHandler handler = - server.resolve( - ServiceProtocol.serviceProtocolVersionToHeaderValue( - ServiceProtocol.maxServiceProtocolVersion(definition.isEnablePreviewContext())), - serviceDefinition.getServiceName(), - definition.getMethod(), - k -> null, - io.opentelemetry.context.Context.current(), - RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE, - syscallsExecutor); - - // Wire invocation - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - - // Wire invocation and start it - Multi.createFrom() - .iterable(definition.getInput()) - .runSubscriptionOn(syscallsExecutor) - .map(ProtoUtils::invocationInputToByteString) - .subscribe(handler); - Multi.createFrom() - .publisher(handler) - .runSubscriptionOn(syscallsExecutor) - .subscribe(new MessageDecoder(assertSubscriber)); - - // Check completed - assertSubscriber.awaitCompletion(Duration.ofSeconds(1)); - // Unwrap messages - //noinspection unchecked - assertThat(assertSubscriber.getItems()) - .map(InvocationInput::message) - .satisfies(l -> definition.getOutputAssert().accept((List) l)); - - // Clean logging - ThreadContext.clearAll(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java new file mode 100644 index 000000000..e55e2c83d --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/MockRequestResponse.java @@ -0,0 +1,89 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; + +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import dev.restate.sdk.core.TestDefinitions.TestDefinition; +import dev.restate.sdk.core.TestDefinitions.TestExecutor; +import dev.restate.sdk.core.statemachine.InvocationInput; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.helpers.test.AssertSubscriber; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.apache.logging.log4j.ThreadContext; + +public final class MockRequestResponse implements TestExecutor { + + public static final MockRequestResponse INSTANCE = new MockRequestResponse(); + + private MockRequestResponse() {} + + @Override + public boolean buffered() { + return true; + } + + @Override + public void executeTest(TestDefinition definition) { + Executor syscallsExecutor = Executors.newSingleThreadExecutor(); + + ServiceDefinition serviceDefinition = definition.getServiceDefinition(); + + // Prepare server + Endpoint.Builder builder = + Endpoint.builder().bind(serviceDefinition, definition.getServiceOptions()); + if (definition.isEnablePreviewContext()) { + builder.enablePreviewContext(); + } + EndpointRequestHandler server = EndpointRequestHandler.forRequestResponse(builder.build()); + + // Start invocation + RequestProcessor handler = + server.processorForRequest( + "/" + serviceDefinition.getServiceName() + "/" + definition.getMethod(), + HeadersAccessor.wrap( + Map.of("content-type", ProtoUtils.serviceProtocolContentTypeHeader())), + EndpointRequestHandler.LoggingContextSetter.THREAD_LOCAL_INSTANCE, + syscallsExecutor); + + // Wire invocation + AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); + Multi.createFrom() + .iterable(definition.getInput()) + .runSubscriptionOn(syscallsExecutor) + .map(ProtoUtils::invocationInputToByteString) + .map(Slice::wrap) + .subscribe(handler); + Multi.createFrom() + .publisher(handler) + .runSubscriptionOn(syscallsExecutor) + .subscribe(assertSubscriber); + + // Check completed + assertSubscriber.awaitCompletion(Duration.ofSeconds(10000)); + // Unwrap messages and decode them + //noinspection unchecked + assertThatDecodingMessages(assertSubscriber.getItems().toArray(Slice[]::new)) + .map(InvocationInput::message) + .satisfies(l -> definition.getOutputAssert().accept((List) l)); + + // Clean logging + ThreadContext.clearAll(); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/MockSingleThread.java b/sdk-core/src/test/java/dev/restate/sdk/core/MockSingleThread.java deleted file mode 100644 index 55d2f3241..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/MockSingleThread.java +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.TestDefinitions.TestDefinition; -import dev.restate.sdk.core.TestDefinitions.TestExecutor; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import java.time.Duration; -import java.util.List; -import org.apache.logging.log4j.ThreadContext; - -public final class MockSingleThread implements TestExecutor { - - public static final MockSingleThread INSTANCE = new MockSingleThread(); - - private MockSingleThread() {} - - @Override - public boolean buffered() { - return true; - } - - @Override - public void executeTest(TestDefinition definition) { - ServiceDefinition serviceDefinition = definition.getServiceDefinition(); - - // Prepare server - @SuppressWarnings("unchecked") - RestateEndpoint.Builder builder = - RestateEndpoint.newBuilder(EndpointManifestSchema.ProtocolMode.BIDI_STREAM) - .bind( - (ServiceDefinition) serviceDefinition, - definition.getServiceOptions()); - if (definition.isEnablePreviewContext()) { - builder.enablePreviewContext(); - } - RestateEndpoint server = builder.build(); - - // Start invocation - ResolvedEndpointHandler handler = - server.resolve( - ServiceProtocol.serviceProtocolVersionToHeaderValue( - ServiceProtocol.maxServiceProtocolVersion(definition.isEnablePreviewContext())), - serviceDefinition.getServiceName(), - definition.getMethod(), - k -> null, - io.opentelemetry.context.Context.current(), - RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE, - null); - - // Wire invocation - AssertSubscriber assertSubscriber = AssertSubscriber.create(Long.MAX_VALUE); - Multi.createFrom() - .iterable(definition.getInput()) - .map(ProtoUtils::invocationInputToByteString) - .subscribe(handler); - Multi.createFrom().publisher(handler).subscribe(new MessageDecoder(assertSubscriber)); - - // Check completed - assertSubscriber.awaitCompletion(Duration.ofSeconds(1)); - // Unwrap messages - //noinspection unchecked - assertThat(assertSubscriber.getItems()) - .map(InvocationInput::message) - .satisfies(l -> definition.getOutputAssert().accept((List) l)); - - // Clean logging - ThreadContext.clearAll(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java index 9f6c988d3..dfcb89072 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java @@ -8,8 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestDefinition; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import dev.restate.sdk.core.TestDefinitions.TestSuite; @@ -23,7 +23,7 @@ public abstract class OnlyInputAndOutputTestSuite implements TestSuite { public Stream definitions() { return Stream.of( this.noSyscallsGreeter() - .withInput(startMessage(1), inputMessage("Francesco")) - .expectingOutput(outputMessage("Hello Francesco"), END_MESSAGE)); + .withInput(startMessage(1), inputCmd("Francesco")) + .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE)); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java index 533c5d47b..6f5b20d32 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/PromiseTestSuite.java @@ -8,11 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.*; +import static dev.restate.sdk.core.generated.protocol.Protocol.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.generated.protocol.Protocol.GetPromiseCompletionNotificationMessage; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.types.TerminalException; import java.util.stream.Stream; public abstract class PromiseTestSuite implements TestSuite { @@ -37,95 +40,131 @@ public Stream definitions() { return Stream.of( // --- Await promise this.awaitPromise(PROMISE_KEY) - .withInput(startMessage(1), inputMessage(), completionMessage(1, "my value")) - .expectingOutput(getPromise(PROMISE_KEY), outputMessage("my value"), END_MESSAGE) + .withInput( + startMessage(1), + inputCmd(), + GetPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setValue(value("my value"))) + .expectingOutput(getPromiseCmd(1, PROMISE_KEY), outputCmd("my value"), END_MESSAGE) .named("Completed with success"), this.awaitPromise(PROMISE_KEY) .withInput( startMessage(1), - inputMessage(), - completionMessage(1, new TerminalException("myerror"))) + inputCmd(), + GetPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setFailure(ProtoUtils.failure(new TerminalException("myerror")))) .expectingOutput( - getPromise(PROMISE_KEY), - outputMessage(new TerminalException("myerror")), + getPromiseCmd(1, PROMISE_KEY), + outputCmd(new TerminalException("myerror")), END_MESSAGE) .named("Completed with failure"), // --- Peek promise this.awaitPeekPromise(PROMISE_KEY, "null") - .withInput(startMessage(1), inputMessage(), completionMessage(1, "my value")) - .expectingOutput(peekPromise(PROMISE_KEY), outputMessage("my value"), END_MESSAGE) + .withInput( + startMessage(1), + inputCmd(), + PeekPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setValue(value("my value"))) + .expectingOutput(peekPromiseCmd(1, PROMISE_KEY), outputCmd("my value"), END_MESSAGE) .named("Completed with success"), this.awaitPeekPromise(PROMISE_KEY, "null") .withInput( startMessage(1), - inputMessage(), - completionMessage(1, new TerminalException("myerror"))) + inputCmd(), + PeekPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setFailure(ProtoUtils.failure(new TerminalException("myerror")))) .expectingOutput( - peekPromise(PROMISE_KEY), - outputMessage(new TerminalException("myerror")), + peekPromiseCmd(1, PROMISE_KEY), + outputCmd(new TerminalException("myerror")), END_MESSAGE) .named("Completed with failure"), this.awaitPeekPromise(PROMISE_KEY, "null") .withInput( startMessage(1), - inputMessage(), - completionMessage(1).setEmpty(Protocol.Empty.getDefaultInstance())) - .expectingOutput(peekPromise(PROMISE_KEY), outputMessage("null"), END_MESSAGE) + inputCmd(), + PeekPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance())) + .expectingOutput(peekPromiseCmd(1, PROMISE_KEY), outputCmd("null"), END_MESSAGE) .named("Completed with null"), // --- Promise is completed this.awaitIsPromiseCompleted(PROMISE_KEY) - .withInput(startMessage(1), inputMessage(), completionMessage(1, "my value")) + .withInput( + startMessage(1), + inputCmd(), + PeekPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setValue(value("my value"))) + .onlyBidiStream() .expectingOutput( - peekPromise(PROMISE_KEY), outputMessage(TestSerdes.BOOLEAN, true), END_MESSAGE) + peekPromiseCmd(1, PROMISE_KEY), outputCmd(TestSerdes.BOOLEAN, true), END_MESSAGE) .named("Completed with success"), this.awaitIsPromiseCompleted(PROMISE_KEY) .withInput( startMessage(1), - inputMessage(), - completionMessage(1).setEmpty(Protocol.Empty.getDefaultInstance())) + inputCmd(), + PeekPromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance())) .expectingOutput( - peekPromise(PROMISE_KEY), outputMessage(TestSerdes.BOOLEAN, false), END_MESSAGE) + peekPromiseCmd(1, PROMISE_KEY), outputCmd(TestSerdes.BOOLEAN, false), END_MESSAGE) .named("Not completed"), // --- Promise resolve this.awaitResolvePromise(PROMISE_KEY, "my val") .withInput( startMessage(1), - inputMessage(), - completionMessage(1).setEmpty(Protocol.Empty.getDefaultInstance())) + inputCmd(), + CompletePromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) .expectingOutput( - completePromise(PROMISE_KEY, "my val"), - outputMessage(TestSerdes.BOOLEAN, true), + completePromiseCmd(1, PROMISE_KEY, "my val"), + outputCmd(TestSerdes.BOOLEAN, true), END_MESSAGE) .named("resolve succeeds"), this.awaitResolvePromise(PROMISE_KEY, "my val") .withInput( startMessage(1), - inputMessage(), - completionMessage(1, new TerminalException("cannot write promise"))) + inputCmd(), + CompletePromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setFailure(failure(new TerminalException("cannot write promise"))) + .build()) .expectingOutput( - completePromise(PROMISE_KEY, "my val"), - outputMessage(TestSerdes.BOOLEAN, false), + completePromiseCmd(1, PROMISE_KEY, "my val"), + outputCmd(TestSerdes.BOOLEAN, false), END_MESSAGE) .named("resolve fails"), // --- Promise reject this.awaitRejectPromise(PROMISE_KEY, "my failure") .withInput( startMessage(1), - inputMessage(), - completionMessage(1).setEmpty(Protocol.Empty.getDefaultInstance())) + inputCmd(), + CompletePromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance()) + .build()) .expectingOutput( - completePromise(PROMISE_KEY, new TerminalException("my failure")), - outputMessage(TestSerdes.BOOLEAN, true), + completePromiseCmd(1, PROMISE_KEY, new TerminalException("my failure")), + outputCmd(TestSerdes.BOOLEAN, true), END_MESSAGE) .named("resolve succeeds"), this.awaitRejectPromise(PROMISE_KEY, "my failure") .withInput( startMessage(1), - inputMessage(), - completionMessage(1, new TerminalException("cannot write promise"))) + inputCmd(), + CompletePromiseCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setFailure(failure(new TerminalException("cannot write promise"))) + .build()) .expectingOutput( - completePromise(PROMISE_KEY, new TerminalException("my failure")), - outputMessage(TestSerdes.BOOLEAN, false), + completePromiseCmd(1, PROMISE_KEY, new TerminalException("my failure")), + outputCmd(TestSerdes.BOOLEAN, false), END_MESSAGE) .named("resolve fails")); } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java deleted file mode 100644 index cd9ad425c..000000000 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -import com.google.protobuf.ByteString; -import com.google.protobuf.MessageLite; -import com.google.protobuf.MessageLiteOrBuilder; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.discovery.Discovery; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.generated.service.protocol.Protocol.StartMessage.StateEntry; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.Target; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.helpers.test.AssertSubscriber; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -public class ProtoUtils { - - public static String serviceProtocolContentTypeHeader() { - return ServiceProtocol.serviceProtocolVersionToHeaderValue( - ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION); - } - - public static String serviceProtocolContentTypeHeader(boolean enableContextPreview) { - return ServiceProtocol.serviceProtocolVersionToHeaderValue( - ServiceProtocol.maxServiceProtocolVersion(enableContextPreview)); - } - - public static String serviceProtocolDiscoveryContentTypeHeader() { - return ServiceProtocol.serviceDiscoveryProtocolVersionToHeaderValue( - Discovery.ServiceDiscoveryProtocolVersion.V1); - } - - /** - * Variant of {@link MessageHeader#fromMessage(MessageLite)} supporting StartMessage and - * CompletionMessage. - */ - public static MessageHeader headerFromMessage(MessageLite msg) { - if (msg instanceof Protocol.StartMessage) { - return new MessageHeader(MessageType.StartMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompletionMessage) { - return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize()); - } - return MessageHeader.fromMessage(msg); - } - - public static ByteBuffer invocationInputToByteString(InvocationInput invocationInput) { - ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(invocationInput.message())); - - buffer.putLong(invocationInput.header().encode()); - buffer.put(invocationInput.message().toByteString().asReadOnlyByteBuffer()); - - buffer.flip(); - return buffer; - } - - public static ByteBuffer messageToByteString(MessageLiteOrBuilder msgOrBuilder) { - var msg = build(msgOrBuilder); - return invocationInputToByteString(InvocationInput.of(headerFromMessage(msg), msg)); - } - - public static List bufferToMessages(List byteBuffers) { - AssertSubscriber subscriber = AssertSubscriber.create(Long.MAX_VALUE); - Multi.createFrom().iterable(byteBuffers).subscribe(new MessageDecoder(subscriber)); - subscriber.awaitCompletion(); - return subscriber.getItems().stream() - .map(InvocationInput::message) - .collect(Collectors.toList()); - } - - public static Protocol.StartMessage.Builder startMessage(int entries) { - return Protocol.StartMessage.newBuilder() - .setId(ByteString.copyFromUtf8("abc")) - .setDebugId("abc") - .setKnownEntries(entries) - .setPartialState(true); - } - - public static Protocol.StartMessage.Builder startMessage(int entries, String key) { - return Protocol.StartMessage.newBuilder() - .setId(ByteString.copyFromUtf8("abc")) - .setDebugId("abc") - .setKnownEntries(entries) - .setKey(key) - .setPartialState(true); - } - - @SafeVarargs - public static Protocol.StartMessage.Builder startMessage( - int entries, String key, Map.Entry... stateEntries) { - return startMessage(entries, key) - .addAllStateMap( - Arrays.stream(stateEntries) - .map( - e -> - StateEntry.newBuilder() - .setKey(ByteString.copyFromUtf8(e.getKey())) - .setValue( - ByteString.copyFrom(TestSerdes.STRING.serialize(e.getValue()))) - .build()) - .collect(Collectors.toList())); - } - - public static Protocol.CompletionMessage.Builder completionMessage(int index) { - return Protocol.CompletionMessage.newBuilder().setEntryIndex(index); - } - - public static Protocol.CompletionMessage completionMessage( - int index, Serde serde, T value) { - return completionMessage(index).setValue(ByteString.copyFrom(serde.serialize(value))).build(); - } - - public static Protocol.CompletionMessage completionMessage(int index, String value) { - return completionMessage(index, TestSerdes.STRING, value); - } - - public static Protocol.CompletionMessage completionMessage( - int index, MessageLiteOrBuilder value) { - return completionMessage(index).setValue(build(value).toByteString()).build(); - } - - public static Protocol.CompletionMessage completionMessage(int index, Throwable e) { - return completionMessage(index).setFailure(Util.toProtocolFailure(e)).build(); - } - - public static Protocol.EntryAckMessage ackMessage(int index) { - return Protocol.EntryAckMessage.newBuilder().setEntryIndex(index).build(); - } - - public static Protocol.SuspensionMessage suspensionMessage(Integer... indexes) { - return Protocol.SuspensionMessage.newBuilder().addAllEntryIndexes(List.of(indexes)).build(); - } - - public static Protocol.InputEntryMessage inputMessage() { - return Protocol.InputEntryMessage.newBuilder().setValue(ByteString.EMPTY).build(); - } - - public static Protocol.InputEntryMessage inputMessage(byte[] value) { - return Protocol.InputEntryMessage.newBuilder().setValue(ByteString.copyFrom(value)).build(); - } - - public static Protocol.InputEntryMessage inputMessage(Serde serde, T value) { - return Protocol.InputEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(serde.serialize(value))) - .build(); - } - - public static Protocol.InputEntryMessage inputMessage(String value) { - return inputMessage(TestSerdes.STRING, value); - } - - public static Protocol.InputEntryMessage inputMessage(int value) { - return inputMessage(TestSerdes.INT, value); - } - - public static Protocol.OutputEntryMessage outputMessage(Serde serde, T value) { - return Protocol.OutputEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(serde.serialize(value))) - .build(); - } - - public static Protocol.OutputEntryMessage outputMessage(String value) { - return outputMessage(TestSerdes.STRING, value); - } - - public static Protocol.OutputEntryMessage outputMessage(int value) { - return outputMessage(TestSerdes.INT, value); - } - - public static Protocol.OutputEntryMessage outputMessage(byte[] b) { - return outputMessage(Serde.RAW, b); - } - - public static Protocol.OutputEntryMessage outputMessage() { - return Protocol.OutputEntryMessage.newBuilder().setValue(ByteString.EMPTY).build(); - } - - public static Protocol.OutputEntryMessage outputMessage(int code, String message) { - return Protocol.OutputEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(code, message)) - .build(); - } - - public static Protocol.OutputEntryMessage outputMessage(Throwable e) { - return Protocol.OutputEntryMessage.newBuilder().setFailure(Util.toProtocolFailure(e)).build(); - } - - public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key) { - return Protocol.GetStateEntryMessage.newBuilder().setKey(ByteString.copyFromUtf8(key)); - } - - public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key, Throwable error) { - return getStateMessage(key).setFailure(Util.toProtocolFailure(error)); - } - - public static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) { - return Protocol.GetStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setEmpty(Protocol.Empty.getDefaultInstance()) - .build(); - } - - public static Protocol.GetStateEntryMessage getStateMessage( - String key, Serde serde, T value) { - return getStateMessage(key).setValue(ByteString.copyFrom(serde.serialize(value))).build(); - } - - public static Protocol.GetStateEntryMessage getStateMessage(String key, String value) { - return getStateMessage(key, TestSerdes.STRING, value); - } - - public static Protocol.SetStateEntryMessage setStateMessage( - String key, Serde serde, T value) { - return Protocol.SetStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .setValue(ByteString.copyFrom(serde.serialize(value))) - .build(); - } - - public static Protocol.SetStateEntryMessage setStateMessage(String key, String value) { - return setStateMessage(key, TestSerdes.STRING, value); - } - - public static Protocol.ClearStateEntryMessage clearStateMessage(String key) { - return Protocol.ClearStateEntryMessage.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) - .build(); - } - - public static Protocol.CallEntryMessage.Builder invokeMessage(Target target) { - Protocol.CallEntryMessage.Builder builder = - Protocol.CallEntryMessage.newBuilder() - .setServiceName(target.getService()) - .setHandlerName(target.getHandler()); - if (target.getKey() != null) { - builder.setKey(target.getKey()); - } - - return builder; - } - - public static Protocol.CallEntryMessage.Builder invokeMessage(Target target, byte[] parameter) { - return invokeMessage(target, Serde.RAW, parameter); - } - - public static Protocol.CallEntryMessage.Builder invokeMessage( - Target target, Serde reqSerde, T parameter) { - return invokeMessage(target).setParameter(ByteString.copyFrom(reqSerde.serialize(parameter))); - } - - public static Protocol.CallEntryMessage invokeMessage( - Target target, Serde reqSerde, T parameter, Serde resSerde, R result) { - return invokeMessage(target, reqSerde, parameter) - .setValue(ByteString.copyFrom(resSerde.serialize(result))) - .build(); - } - - public static Protocol.CallEntryMessage.Builder invokeMessage(Target target, String parameter) { - return invokeMessage(target, TestSerdes.STRING, parameter); - } - - public static Protocol.CallEntryMessage invokeMessage( - Target target, String parameter, String result) { - return invokeMessage(target, TestSerdes.STRING, parameter, TestSerdes.STRING, result); - } - - public static Protocol.AwakeableEntryMessage.Builder awakeable() { - return Protocol.AwakeableEntryMessage.newBuilder(); - } - - public static Protocol.AwakeableEntryMessage awakeable(String value) { - return awakeable().setValue(ByteString.copyFrom(TestSerdes.STRING.serialize(value))).build(); - } - - public static Protocol.GetPromiseEntryMessage.Builder getPromise(String key) { - return Protocol.GetPromiseEntryMessage.newBuilder().setKey(key); - } - - public static Protocol.PeekPromiseEntryMessage.Builder peekPromise(String key) { - return Protocol.PeekPromiseEntryMessage.newBuilder().setKey(key); - } - - public static Protocol.CompletePromiseEntryMessage.Builder completePromise( - String key, String value) { - return Protocol.CompletePromiseEntryMessage.newBuilder() - .setKey(key) - .setCompletionValue(ByteString.copyFrom(TestSerdes.STRING.serialize(value))); - } - - public static Protocol.CompletePromiseEntryMessage.Builder completePromise( - String key, Throwable e) { - return Protocol.CompletePromiseEntryMessage.newBuilder() - .setKey(key) - .setCompletionFailure(Util.toProtocolFailure(e)); - } - - public static Java.CombinatorAwaitableEntryMessage combinatorsMessage(Integer... order) { - return Java.CombinatorAwaitableEntryMessage.newBuilder() - .addAllEntryIndex(Arrays.asList(order)) - .build(); - } - - public static final Protocol.EndMessage END_MESSAGE = Protocol.EndMessage.getDefaultInstance(); - - public static final Target GREETER_SERVICE_TARGET = Target.service("Greeter", "greeter"); - public static Target GREETER_VIRTUAL_OBJECT_TARGET = - Target.virtualObject("Greeter", "Francesco", "greeter"); - - public static Protocol.GetStateKeysEntryMessage.StateKeys.Builder stateKeys(String... keys) { - return Protocol.GetStateKeysEntryMessage.StateKeys.newBuilder() - .addAllKeys(Arrays.stream(keys).map(ByteString::copyFromUtf8).collect(Collectors.toList())); - } - - static MessageLite build(MessageLiteOrBuilder value) { - if (value instanceof MessageLite) { - return (MessageLite) value; - } else { - return ((MessageLite.Builder) value).build(); - } - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java index 1ce400fde..7f65d7459 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/RandomTestSuite.java @@ -9,19 +9,18 @@ package dev.restate.sdk.core; import static dev.restate.sdk.core.AssertUtils.*; -import static dev.restate.sdk.core.ProtoUtils.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import dev.restate.sdk.core.TestDefinitions.TestDefinition; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import dev.restate.sdk.core.TestDefinitions.TestSuite; +import dev.restate.sdk.core.statemachine.ProtoUtils; import java.util.stream.Stream; public abstract class RandomTestSuite implements TestSuite { protected abstract TestInvocationBuilder randomShouldBeDeterministic(); - protected abstract TestInvocationBuilder randomInsideSideEffect(); - protected abstract int getExpectedInt(long seed); @Override @@ -30,14 +29,9 @@ public Stream definitions() { return Stream.of( this.randomShouldBeDeterministic() - .withInput(startMessage(1).setDebugId(debugId), ProtoUtils.inputMessage()) + .withInput(startMessage(1).setDebugId(debugId), ProtoUtils.inputCmd()) .expectingOutput( - outputMessage(getExpectedInt(new InvocationIdImpl(debugId).toRandomSeed())), - END_MESSAGE), - this.randomInsideSideEffect() - .withInput(startMessage(1).setDebugId(debugId), ProtoUtils.inputMessage()) - .assertingOutput( - containsOnly( - errorMessageStartingWith(IllegalStateException.class.getCanonicalName())))); + outputCmd(getExpectedInt(ProtoUtils.invocationIdToRandomSeed(debugId))), + END_MESSAGE)); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java index 10048b30d..8a6b24cfa 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java @@ -9,16 +9,15 @@ package dev.restate.sdk.core; import static dev.restate.sdk.core.AssertUtils.*; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.STRING; -import static org.assertj.core.api.InstanceOfAssertFactories.type; -import com.google.protobuf.ByteString; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.MessageType; +import dev.restate.sdk.types.RetryPolicy; +import dev.restate.sdk.types.TerminalException; import java.time.Duration; import java.util.stream.Stream; @@ -32,8 +31,6 @@ public abstract class SideEffectTestSuite implements TestDefinitions.TestSuite { protected abstract TestInvocationBuilder checkContextSwitching(); - protected abstract TestInvocationBuilder sideEffectGuard(); - protected abstract TestInvocationBuilder failingSideEffect(String name, String reason); protected abstract TestInvocationBuilder failingSideEffectWithRetryPolicy( @@ -43,58 +40,59 @@ protected abstract TestInvocationBuilder failingSideEffectWithRetryPolicy( public Stream definitions() { return Stream.of( this.sideEffect("Francesco") - .withInput(startMessage(1), inputMessage("Till")) - .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), - suspensionMessage(1)) - .named("Without optimization suspends"), + .withInput(startMessage(1), inputCmd("Till")) + .expectingOutput(runCmd(1), proposeRunCompletion(1, "Francesco"), suspensionMessage(1)) + .named("Run and propose completion"), this.sideEffect("Francesco") - .withInput(startMessage(1), inputMessage("Till"), ackMessage(1)) - .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), - outputMessage("Hello Francesco"), - END_MESSAGE) - .named("Without optimization and with acks returns"), + .withInput(startMessage(3), inputCmd("Till"), runCmd(1), runCompletion(1, "Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) + .named("Replay from completion"), this.namedSideEffect("get-my-name", "Francesco") - .withInput(startMessage(1), inputMessage("Till")) + .withInput(startMessage(1), inputCmd("Till")) .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setName("get-my-name") - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), + runCmd(1, "get-my-name"), + proposeRunCompletion(1, "Francesco"), suspensionMessage(1)), this.consecutiveSideEffect("Francesco") - .withInput(startMessage(1), inputMessage("Till")) - .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), - suspensionMessage(1)) - .named("With optimization and without ack on first side effect will suspend"), - this.consecutiveSideEffect("Francesco") - .withInput(startMessage(1), inputMessage("Till"), ackMessage(1)) - .onlyUnbuffered() - .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("FRANCESCO"))), - suspensionMessage(2)) - .named("With optimization and ack on first side effect will suspend"), + .withInput(startMessage(3), inputCmd("Till"), runCmd(1), runCompletion(1, "Francesco")) + .expectingOutput(runCmd(2), proposeRunCompletion(2, "FRANCESCO"), suspensionMessage(2)) + .named("Suspends on second run"), this.consecutiveSideEffect("Francesco") - .withInput(startMessage(1), inputMessage("Till"), ackMessage(1), ackMessage(2)) - .onlyUnbuffered() - .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("Francesco"))), - Protocol.RunEntryMessage.newBuilder() - .setValue(ByteString.copyFrom(TestSerdes.STRING.serialize("FRANCESCO"))), - outputMessage("Hello FRANCESCO"), - END_MESSAGE) + .withInput( + startMessage(5), + inputCmd("Till"), + runCmd(1), + runCmd(2), + runCompletion(1, "Francesco"), + runCompletion(2, "FRANCESCO")) + .expectingOutput(outputCmd("Hello FRANCESCO"), END_MESSAGE) .named("With optimization and ack on first and second side effect will resume"), this.failingSideEffect("my-side-effect", "some failure") - .withInput(startMessage(1), inputMessage()) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd()) + .assertingOutput( + msgs -> + assertThat(msgs) + .satisfiesExactly( + msg -> assertThat(msg).isEqualTo(runCmd(1, "my-side-effect")), + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) + .returns( + (int) MessageType.RunCommandMessage.encode(), + Protocol.ErrorMessage::getRelatedCommandType) + .returns( + "my-side-effect", + Protocol.ErrorMessage::getRelatedCommandName) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("some failure")))) + .named("Fail on first attempt"), + this.failingSideEffect("my-side-effect", "some failure") + .withInput(startMessage(2), inputCmd(), runCmd(1, "my-side-effect")) .assertingOutput( containsOnly( errorMessage( @@ -103,68 +101,53 @@ public Stream definitions() { .returns( TerminalException.INTERNAL_SERVER_ERROR_CODE, Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedEntryIndex) + .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) .returns( - (int) MessageType.RunEntryMessage.encode(), - Protocol.ErrorMessage::getRelatedEntryType) + (int) MessageType.RunCommandMessage.encode(), + Protocol.ErrorMessage::getRelatedCommandType) .returns( - "my-side-effect", Protocol.ErrorMessage::getRelatedEntryName) + "my-side-effect", Protocol.ErrorMessage::getRelatedCommandName) .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("some failure")))), + .contains("some failure")))) + .named("Fail on second attempt"), this.failingSideEffectWithRetryPolicy( "some failure", RetryPolicy.exponential(Duration.ofMillis(100), 1.0f).setMaxAttempts(2)) - .withInput(startMessage(1).setRetryCountSinceLastStoredEntry(0), inputMessage()) - .enablePreviewContext() - .onlyUnbuffered() + .withInput(startMessage(1).setRetryCountSinceLastStoredEntry(0), inputCmd()) + .onlyBidiStream() .assertingOutput( - containsOnly( - errorMessage( - errorMessage -> - assertThat(errorMessage) - .returns( - TerminalException.INTERNAL_SERVER_ERROR_CODE, - Protocol.ErrorMessage::getCode) - .returns(1, Protocol.ErrorMessage::getRelatedEntryIndex) - .returns( - (int) MessageType.RunEntryMessage.encode(), - Protocol.ErrorMessage::getRelatedEntryType) - .returns(100L, Protocol.ErrorMessage::getNextRetryDelay) - .extracting(Protocol.ErrorMessage::getMessage, STRING) - .contains("java.lang.IllegalStateException: some failure")))) + msgs -> + assertThat(msgs) + .satisfiesExactly( + msg -> assertThat(msg).isEqualTo(runCmd(1)), + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(1, Protocol.ErrorMessage::getRelatedCommandIndex) + .returns( + (int) MessageType.RunCommandMessage.encode(), + Protocol.ErrorMessage::getRelatedCommandType) + .returns(100L, Protocol.ErrorMessage::getNextRetryDelay) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("some failure")))) .named("Should fail as retryable error with the attached next retry delay"), this.failingSideEffectWithRetryPolicy( "some failure", RetryPolicy.exponential(Duration.ofMillis(100), 1.0f).setMaxAttempts(2)) - .withInput( - startMessage(1).setRetryCountSinceLastStoredEntry(1), inputMessage(), ackMessage(1)) - .enablePreviewContext() - .onlyUnbuffered() + .withInput(startMessage(2).setRetryCountSinceLastStoredEntry(1), inputCmd(), runCmd(1)) .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setFailure( - Util.toProtocolFailure( - 500, "java.lang.IllegalStateException: some failure")), - outputMessage(500, "java.lang.IllegalStateException: some failure"), - END_MESSAGE) + proposeRunCompletion(1, 500, "java.lang.IllegalStateException: some failure"), + suspensionMessage(1)) .named("Should convert retryable error to terminal"), // --- Other tests this.checkContextSwitching() - .withInput(startMessage(1), inputMessage(), ackMessage(1)) - .onlyUnbuffered() - .assertingOutput( - actualOutputMessages -> { - assertThat(actualOutputMessages).hasSize(3); - assertThat(actualOutputMessages) - .element(0) - .asInstanceOf(type(Protocol.RunEntryMessage.class)) - .returns(true, Protocol.RunEntryMessage::hasValue); - assertThat(actualOutputMessages).element(1).isEqualTo(outputMessage("Hello")); - assertThat(actualOutputMessages).element(2).isEqualTo(END_MESSAGE); - }), - this.sideEffectGuard() - .withInput(startMessage(1), inputMessage("Till")) + .withInput(startMessage(1), inputCmd()) + .onlyBidiStream() .assertingOutput( - containsOnlyExactErrorMessage(ProtocolException.invalidSideEffectCall()))); + actualOutputMessages -> + assertThat(actualOutputMessages).element(2).isEqualTo(suspensionMessage(1)))); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java index efd6b0892..796889cc2 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java @@ -8,16 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.LONG; import static org.assertj.core.api.InstanceOfAssertFactories.type; import com.google.protobuf.MessageLiteOrBuilder; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.generated.protocol.Protocol; import java.time.Instant; +import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -33,13 +33,13 @@ public abstract class SleepTestSuite implements TestDefinitions.TestSuite { public Stream definitions() { return Stream.of( this.sleepGreeter() - .withInput(startMessage(1), inputMessage("Till")) + .withInput(startMessage(1), inputCmd("Till")) .assertingOutput( messageLites -> { assertThat(messageLites) .element(0) - .asInstanceOf(type(Protocol.SleepEntryMessage.class)) - .extracting(Protocol.SleepEntryMessage::getWakeUpTime, LONG) + .asInstanceOf(type(Protocol.SleepCommandMessage.class)) + .extracting(Protocol.SleepCommandMessage::getWakeUpTime, LONG) .isGreaterThanOrEqualTo(startTime + 1000) .isLessThanOrEqualTo(Instant.now().toEpochMilli() + 1000); @@ -51,18 +51,20 @@ public Stream definitions() { this.sleepGreeter() .withInput( startMessage(2), - inputMessage("Till"), - Protocol.SleepEntryMessage.newBuilder() + inputCmd("Till"), + Protocol.SleepCommandMessage.newBuilder() .setWakeUpTime(Instant.now().toEpochMilli()) - .setEmpty(Protocol.Empty.getDefaultInstance()) - .build()) - .expectingOutput(outputMessage("Hello"), END_MESSAGE) + .setResultCompletionId(1), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance())) + .expectingOutput(outputCmd("Hello"), END_MESSAGE) .named("Sleep 1000 ms sleep completed"), this.sleepGreeter() .withInput( startMessage(2), - inputMessage("Till"), - Protocol.SleepEntryMessage.newBuilder() + inputCmd("Till"), + Protocol.SleepCommandMessage.newBuilder() .setWakeUpTime(Instant.now().toEpochMilli()) .build()) .expectingOutput(suspensionMessage(1)) @@ -70,42 +72,39 @@ public Stream definitions() { this.manySleeps() .withInput( Stream.concat( - Stream.of(startMessage(11), inputMessage("Till")), - IntStream.rangeClosed(1, 10) - .mapToObj( - i -> - (i % 3 == 0) - ? Protocol.SleepEntryMessage.newBuilder() + Stream.of(startMessage(14), inputCmd("Till")), + IntStream.rangeClosed(1, 10) + .mapToObj( + i -> + (i % 3 == 0) + ? Stream.of( + Protocol.SleepCommandMessage.newBuilder() .setWakeUpTime(Instant.now().toEpochMilli()) - .setEmpty(Protocol.Empty.getDefaultInstance()) - .build() - : Protocol.SleepEntryMessage.newBuilder() + .setResultCompletionId(i), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(i) + .setVoid(Protocol.Void.getDefaultInstance())) + : Stream.of( + Protocol.SleepCommandMessage.newBuilder() .setWakeUpTime(Instant.now().toEpochMilli()) - .build())) - .toArray(MessageLiteOrBuilder[]::new)) + .setResultCompletionId(i))) + .flatMap(Function.identity()))) .expectingOutput(suspensionMessage(1, 2, 4, 5, 7, 8, 10)) .named("Sleep 1000 ms sleep completed"), - this.sleepGreeter() - .withInput( - startMessage(2), - inputMessage("Till"), - Protocol.SleepEntryMessage.newBuilder() - .setWakeUpTime(Instant.now().toEpochMilli()) - .setFailure(Util.toProtocolFailure(409, "canceled")) - .build()) - .expectingOutput(outputMessage(409, "canceled"), END_MESSAGE) - .named("Failed sleep"), this.sleepGreeter() .withInput( startMessage(1), - inputMessage("Till"), - completionMessage(1, new TerminalException(409, "canceled"))) + inputCmd("Till"), + Protocol.SleepCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setVoid(Protocol.Void.getDefaultInstance())) + .onlyBidiStream() .assertingOutput( messageLites -> { assertThat(messageLites) .element(0) - .isInstanceOf(Protocol.SleepEntryMessage.class); - assertThat(messageLites).element(1).isEqualTo(outputMessage(409, "canceled")); + .isInstanceOf(Protocol.SleepCommandMessage.class); + assertThat(messageLites).element(1).isEqualTo(outputCmd("Hello")); assertThat(messageLites).element(2).isEqualTo(END_MESSAGE); }) .named("Failing sleep")); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java index cbf1beaee..c35806785 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/StateMachineFailuresTestSuite.java @@ -8,18 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.AssertUtils.errorMessageStartingWith; -import static dev.restate.sdk.core.AssertUtils.protocolExceptionErrorMessage; -import static dev.restate.sdk.core.ProtoUtils.*; +import static dev.restate.sdk.core.AssertUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.Serde; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.serde.Serde; import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; -import org.assertj.core.api.Assertions; public abstract class StateMachineFailuresTestSuite implements TestDefinitions.TestSuite { @@ -48,10 +46,10 @@ public Stream definitions() { return Stream.of( this.getState(nonTerminalExceptionsSeenTest1) - .withInput(startMessage(2), inputMessage("Till"), getStateMessage("Something")) + .withInput(startMessage(2), inputCmd("Till"), getLazyStateCmd(1, "Something")) .assertingOutput( msgs -> { - Assertions.assertThat(msgs) + assertThat(msgs) .satisfiesExactly( protocolExceptionErrorMessage(ProtocolException.JOURNAL_MISMATCH_CODE)); assertThat(nonTerminalExceptionsSeenTest1).hasValue(0); @@ -60,27 +58,39 @@ public Stream definitions() { this.getState(nonTerminalExceptionsSeenTest2) .withInput( startMessage(2), - inputMessage("Till"), - getStateMessage("STATE", "This is not an integer")) + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletion(1, "This is not an integer")) .assertingOutput( msgs -> { - Assertions.assertThat(msgs) + assertThat(msgs) .satisfiesExactly( - errorMessageStartingWith(NumberFormatException.class.getCanonicalName())); + errorDescriptionStartingWith( + NumberFormatException.class.getCanonicalName())); assertThat(nonTerminalExceptionsSeenTest2).hasValue(0); }) .named("Serde error"), this.sideEffectFailure(FAILING_SERIALIZATION_INTEGER_TYPE_TAG) - .withInput(startMessage(1), inputMessage("Till")) + .withInput(startMessage(1), inputCmd("Till")) .assertingOutput( - AssertUtils.containsOnly( - errorMessageStartingWith(IllegalStateException.class.getCanonicalName()))) + msgs -> + assertThat(msgs.get(1)) + .satisfies( + errorDescriptionStartingWith( + IllegalStateException.class.getCanonicalName()))) .named("Serde serialization error"), this.sideEffectFailure(FAILING_DESERIALIZATION_INTEGER_TYPE_TAG) - .withInput(startMessage(2), inputMessage("Till"), Protocol.RunEntryMessage.newBuilder()) + .withInput( + startMessage(3), + inputCmd("Till"), + runCmd(1), + Protocol.RunCompletionNotificationMessage.newBuilder() + .setCompletionId(1) + .setValue(Protocol.Value.getDefaultInstance()) + .build()) .assertingOutput( - AssertUtils.containsOnly( - errorMessageStartingWith(IllegalStateException.class.getCanonicalName()))) + containsOnly( + errorDescriptionStartingWith(IllegalStateException.class.getCanonicalName()))) .named("Serde deserialization error")); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java index bdafc9075..8387941d9 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java @@ -8,13 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.AssertUtils.containsOnlyExactErrorMessage; -import static dev.restate.sdk.core.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; +import static dev.restate.sdk.core.AssertUtils.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.generated.protocol.Protocol; import java.util.stream.Stream; public abstract class StateTestSuite implements TestDefinitions.TestSuite { @@ -29,86 +29,79 @@ public abstract class StateTestSuite implements TestDefinitions.TestSuite { public Stream definitions() { return Stream.of( this.getState() - .withInput(startMessage(2), inputMessage("Till"), getStateMessage("STATE", "Francesco")) - .expectingOutput(outputMessage("Hello Francesco"), END_MESSAGE) + .withInput( + startMessage(3), + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletion(1, "Francesco")) + .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) .named("With GetStateEntry already completed"), this.getState() .withInput( - startMessage(2), - inputMessage("Till"), - getStateMessage("STATE").setEmpty(Protocol.Empty.getDefaultInstance())) - .expectingOutput(outputMessage("Hello Unknown"), END_MESSAGE) + startMessage(3), + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletionEmpty(1)) + .expectingOutput(outputCmd("Hello Unknown"), END_MESSAGE) .named("With GetStateEntry already completed empty"), this.getState() - .withInput(startMessage(1), inputMessage("Till")) - .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) + .withInput(startMessage(1), inputCmd("Till")) + .expectingOutput(getLazyStateCmd(1, "STATE"), suspensionMessage(1)) .named("Without GetStateEntry"), this.getState() - .withInput(startMessage(2), inputMessage("Till"), getStateMessage("STATE")) + .withInput(startMessage(2), inputCmd("Till"), getLazyStateCmd(1, "STATE")) .expectingOutput(suspensionMessage(1)) .named("With GetStateEntry not completed"), this.getState() .withInput( startMessage(2), - inputMessage("Till"), - getStateMessage("STATE"), - completionMessage(1, "Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Hello Francesco"), END_MESSAGE) + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletion(1, "Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) .named("With GetStateEntry and completed with later CompletionFrame"), this.getState() - .withInput(startMessage(1), inputMessage("Till"), completionMessage(1, "Francesco")) - .onlyUnbuffered() - .expectingOutput( - getStateMessage("STATE"), outputMessage("Hello Francesco"), END_MESSAGE) + .withInput(startMessage(1), inputCmd("Till"), getLazyStateCompletion(1, "Francesco")) + .onlyBidiStream() + .expectingOutput(getLazyStateCmd(1, "STATE"), outputCmd("Hello Francesco"), END_MESSAGE) .named("Without GetStateEntry and completed with later CompletionFrame"), - this.getState() - .withInput( - startMessage(2), - inputMessage("Till"), - getStateMessage("STATE", new TerminalException(409))) - .expectingOutput(outputMessage(new TerminalException(409)), END_MESSAGE) - .named("Failed GetStateEntry"), - this.getState() - .withInput( - startMessage(1), - inputMessage("Till"), - completionMessage(1, new TerminalException(409))) - .assertingOutput( - messageLites -> { - assertThat(messageLites) - .element(0) - .isInstanceOf(Protocol.GetStateEntryMessage.class); - assertThat(messageLites) - .element(1) - .isEqualTo(outputMessage(new TerminalException(409))); - assertThat(messageLites).element(2).isEqualTo(END_MESSAGE); - }) - .named("Failing GetStateEntry"), this.getAndSetState() .withInput( - startMessage(3), - inputMessage("Till"), - getStateMessage("STATE", "Francesco"), - setStateMessage("STATE", "Till")) - .expectingOutput(outputMessage("Hello Francesco"), END_MESSAGE) + startMessage(4), + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletion(1, "Francesco"), + setStateCmd("STATE", "Till")) + .expectingOutput(outputCmd("Hello Francesco"), END_MESSAGE) .named("With GetState and SetState"), this.getAndSetState() - .withInput(startMessage(2), inputMessage("Till"), getStateMessage("STATE", "Francesco")) + .withInput( + startMessage(3), + inputCmd("Till"), + getLazyStateCmd(1, "STATE"), + getLazyStateCompletion(1, "Francesco")) .expectingOutput( - setStateMessage("STATE", "Till"), outputMessage("Hello Francesco"), END_MESSAGE) + setStateCmd("STATE", "Till"), outputCmd("Hello Francesco"), END_MESSAGE) .named("With GetState already completed"), this.getAndSetState() - .withInput(startMessage(1), inputMessage("Till"), completionMessage(1, "Francesco")) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd("Till"), getLazyStateCompletion(1, "Francesco")) + .onlyBidiStream() .expectingOutput( - getStateMessage("STATE"), - setStateMessage("STATE", "Till"), - outputMessage("Hello Francesco"), + getLazyStateCmd(1, "STATE"), + setStateCmd("STATE", "Till"), + outputCmd("Hello Francesco"), END_MESSAGE) .named("With GetState completed later"), this.setNullState() - .withInput(startMessage(1), inputMessage("Till")) - .assertingOutput(containsOnlyExactErrorMessage(new NullPointerException()))); + .withInput(startMessage(1), inputCmd("Till")) + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .extracting(Protocol.ErrorMessage::getDescription, STRING) + .startsWith(NullPointerException.class.getName())))) + .named("Set null state")); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java index 131d57dd5..b10fdc265 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java @@ -8,13 +8,17 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.ProtoUtils.headerFromMessage; import static org.assertj.core.api.Assertions.assertThat; import com.google.protobuf.MessageLite; import com.google.protobuf.MessageLiteOrBuilder; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.syscalls.ServiceDefinition; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.InvocationInput; +import dev.restate.sdk.core.statemachine.MessageHeader; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.endpoint.definition.HandlerRunner; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactories; import java.util.*; import java.util.function.Consumer; import java.util.function.Supplier; @@ -28,9 +32,9 @@ public final class TestDefinitions { private TestDefinitions() {} public interface TestDefinition { - ServiceDefinition getServiceDefinition(); + ServiceDefinition getServiceDefinition(); - Object getServiceOptions(); + HandlerRunner.Options getServiceOptions(); String getMethod(); @@ -73,17 +77,17 @@ public static TestInvocationBuilder testInvocation(Supplier svcSupplier, public static TestInvocationBuilder testInvocation(Object service, String handler) { if (service instanceof ServiceDefinition) { - return new TestInvocationBuilder((ServiceDefinition) service, null, handler); + return new TestInvocationBuilder((ServiceDefinition) service, null, handler); } // In case it's code generated, discover the adapter - ServiceDefinition serviceDefinition = - RestateEndpoint.discoverServiceDefinitionFactory(service).create(service); + ServiceDefinition serviceDefinition = + ServiceDefinitionFactories.discover(service).create(service, null); return new TestInvocationBuilder(serviceDefinition, null, handler); } - public static TestInvocationBuilder testInvocation( - ServiceDefinition service, O options, String handler) { + public static TestInvocationBuilder testInvocation( + ServiceDefinition service, HandlerRunner.Options options, String handler) { return new TestInvocationBuilder(service, options, handler); } @@ -92,12 +96,13 @@ public static TestInvocationBuilder unsupported(String reason) { } public static class TestInvocationBuilder { - protected final @Nullable ServiceDefinition service; - protected final @Nullable Object options; + protected final @Nullable ServiceDefinition service; + protected final HandlerRunner.@Nullable Options options; protected final @Nullable String handler; protected final @Nullable String invalidReason; - TestInvocationBuilder(ServiceDefinition service, @Nullable Object options, String handler) { + TestInvocationBuilder( + ServiceDefinition service, HandlerRunner.@Nullable Options options, String handler) { this.service = service; this.options = options; this.handler = handler; @@ -113,7 +118,7 @@ public static class TestInvocationBuilder { this.invalidReason = invalidReason; } - public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { + public WithInputBuilder withInput(Stream messages) { if (invalidReason != null) { return new WithInputBuilder(invalidReason); } @@ -122,14 +127,18 @@ public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { service, options, handler, - Arrays.stream(messages) + messages .map( msgOrBuilder -> { MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(headerFromMessage(msg), msg); + return InvocationInput.of(MessageHeader.fromMessage(msg), msg); }) .collect(Collectors.toList())); } + + public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { + return withInput(Arrays.stream(messages)); + } } public static class WithInputBuilder extends TestInvocationBuilder { @@ -143,8 +152,8 @@ public static class WithInputBuilder extends TestInvocationBuilder { } WithInputBuilder( - ServiceDefinition service, - @Nullable Object options, + ServiceDefinition service, + HandlerRunner.@Nullable Options options, String method, List input) { super(service, options, method); @@ -159,14 +168,14 @@ public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { .map( msgOrBuilder -> { MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(headerFromMessage(msg), msg); + return InvocationInput.of(MessageHeader.fromMessage(msg), msg); }) - .collect(Collectors.toList())); + .toList()); } return this; } - public WithInputBuilder onlyUnbuffered() { + public WithInputBuilder onlyBidiStream() { this.onlyUnbuffered = true; return this; } @@ -183,7 +192,7 @@ public ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) actual -> assertThat(actual) .asInstanceOf(InstanceOfAssertFactories.LIST) - .isEqualTo(builtMessages)); + .containsExactlyElementsOf(builtMessages)); } public ExpectingOutputMessages assertingOutput(Consumer> messages) { @@ -200,8 +209,8 @@ public ExpectingOutputMessages assertingOutput(Consumer> messa } public abstract static class BaseTestDefinition implements TestDefinition { - protected final @Nullable ServiceDefinition service; - protected final @Nullable Object options; + protected final @Nullable ServiceDefinition service; + protected final HandlerRunner.@Nullable Options options; protected final @Nullable String invalidReason; protected final String method; protected final List input; @@ -210,8 +219,8 @@ public abstract static class BaseTestDefinition implements TestDefinition { protected final String named; private BaseTestDefinition( - @Nullable ServiceDefinition service, - @Nullable Object options, + @Nullable ServiceDefinition service, + HandlerRunner.@Nullable Options options, @Nullable String invalidReason, String method, List input, @@ -229,12 +238,12 @@ private BaseTestDefinition( } @Override - public ServiceDefinition getServiceDefinition() { + public ServiceDefinition getServiceDefinition() { return Objects.requireNonNull(service); } @Override - public Object getServiceOptions() { + public HandlerRunner.Options getServiceOptions() { return options; } @@ -274,8 +283,8 @@ public static class ExpectingOutputMessages extends BaseTestDefinition { private final Consumer> messagesAssert; private ExpectingOutputMessages( - @Nullable ServiceDefinition service, - @Nullable Object options, + @Nullable ServiceDefinition service, + HandlerRunner.@Nullable Options options, @Nullable String invalidReason, String method, List input, @@ -295,8 +304,8 @@ private ExpectingOutputMessages( } ExpectingOutputMessages( - @Nullable ServiceDefinition service, - @Nullable Object options, + @Nullable ServiceDefinition service, + HandlerRunner.@Nullable Options options, @Nullable String invalidReason, String method, List input, diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java index ea5278158..902002a1f 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/TestRunner.java @@ -12,13 +12,17 @@ import static org.assertj.core.api.Assertions.entry; import static org.junit.jupiter.params.provider.Arguments.arguments; +import com.google.protobuf.InvalidProtocolBufferException; import dev.restate.sdk.core.TestDefinitions.TestDefinition; import dev.restate.sdk.core.TestDefinitions.TestExecutor; import dev.restate.sdk.core.TestDefinitions.TestSuite; +import dev.restate.sdk.core.statemachine.MessageType; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; +import java.util.Objects; import java.util.stream.Stream; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.*; import org.junit.jupiter.api.parallel.Execution; @@ -36,7 +40,7 @@ public abstract class TestRunner { protected abstract Stream definitions(); final Stream source() { - List executors = executors().collect(Collectors.toList()); + List executors = executors().toList(); return definitions() .flatMap(ts -> ts.definitions().map(def -> entry(ts.getClass().getName(), def))) @@ -92,6 +96,27 @@ public void interceptTestTemplateMethod( } } + static { + registerMessageFormatters(); + } + + private static void registerMessageFormatters() { + Arrays.stream(MessageType.values()) + .map( + mt -> { + try { + return mt.messageParser().parseFrom(new byte[] {}).getClass(); + } catch (InvalidProtocolBufferException e) { + return null; + } + }) + .filter(Objects::nonNull) + .forEach( + messageClazz -> + Assertions.registerFormatterForType( + messageClazz, ml -> ml.getClass().getSimpleName() + " { " + ml + "}")); + } + @ExtendWith(DisableInvalidTestDefinition.class) @ParameterizedTest(name = "{index}: {0}") @MethodSource("source") diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestSerdes.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestSerdes.java index a5623be64..55e187fb2 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestSerdes.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/TestSerdes.java @@ -12,9 +12,10 @@ import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.function.ThrowingBiConsumer; -import dev.restate.sdk.common.function.ThrowingFunction; +import dev.restate.common.Slice; +import dev.restate.common.function.ThrowingBiConsumer; +import dev.restate.common.function.ThrowingFunction; +import dev.restate.serde.Serde; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -112,19 +113,19 @@ private static Serde usingJackson( ThrowingFunction deserializer) { return new Serde<>() { @Override - public byte[] serialize(@Nullable T value) { + public Slice serialize(@Nullable T value) { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); try (JsonGenerator gen = JSON_FACTORY.createGenerator(outputStream)) { serializer.asBiConsumer().accept(gen, value); } catch (IOException e) { throw new RuntimeException("Cannot create JsonGenerator", e); } - return outputStream.toByteArray(); + return Slice.wrap(outputStream.toByteArray()); } @Override - public T deserialize(byte[] value) { - ByteArrayInputStream inputStream = new ByteArrayInputStream(value); + public T deserialize(Slice value) { + ByteArrayInputStream inputStream = new ByteArrayInputStream(value.toByteArray()); try (JsonParser parser = JSON_FACTORY.createParser(inputStream)) { return deserializer.asFunction().apply(parser); } catch (IOException e) { diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java index beab7bce8..7766fc893 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/UserFailuresTestSuite.java @@ -10,12 +10,13 @@ import static dev.restate.sdk.core.AssertUtils.containsOnlyExactErrorMessage; import static dev.restate.sdk.core.AssertUtils.exactErrorMessage; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.types.TerminalException; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; @@ -42,14 +43,14 @@ public Stream definitions() { return Stream.of( // Cases returning ErrorMessage this.throwIllegalStateException() - .withInput(startMessage(1), inputMessage()) + .withInput(startMessage(1), inputCmd()) .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), this.sideEffectThrowIllegalStateException(nonTerminalExceptionsSeen) - .withInput(startMessage(1), inputMessage()) + .withInput(startMessage(1), inputCmd()) .assertingOutput( msgs -> { - assertThat(msgs) - .satisfiesExactly(exactErrorMessage(new IllegalStateException("Whatever"))); + assertThat(msgs.get(1)) + .satisfies(exactErrorMessage(new IllegalStateException("Whatever"))); // Check the counter has not been incremented assertThat(nonTerminalExceptionsSeen).hasValue(0); @@ -57,32 +58,47 @@ public Stream definitions() { // Cases completing the invocation with OutputStreamEntry.failure this.throwTerminalException(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) - .withInput(startMessage(1), inputMessage()) + .withInput(startMessage(1), inputCmd()) .expectingOutput( - outputMessage(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), END_MESSAGE) + outputCmd(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), END_MESSAGE) .named("With internal error"), this.throwTerminalException(501, WHATEVER) - .withInput(startMessage(1), inputMessage()) - .expectingOutput(outputMessage(501, WHATEVER), END_MESSAGE) + .withInput(startMessage(1), inputCmd()) + .expectingOutput(outputCmd(501, WHATEVER), END_MESSAGE) .named("With unknown error"), this.sideEffectThrowTerminalException( TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) - .withInput(startMessage(1), inputMessage(), ackMessage(1)) + .withInput(startMessage(1), inputCmd()) .expectingOutput( - Protocol.RunEntryMessage.newBuilder() + Protocol.RunCommandMessage.newBuilder().setResultCompletionId(1), + Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(1) .setFailure( - Util.toProtocolFailure( - TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR)), - outputMessage(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), - END_MESSAGE) + ProtoUtils.failure(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR)), + suspensionMessage(1)) .named("With internal error"), this.sideEffectThrowTerminalException(501, WHATEVER) - .withInput(startMessage(1), inputMessage(), ackMessage(1)) + .withInput(startMessage(1), inputCmd()) .expectingOutput( - Protocol.RunEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(501, WHATEVER)), - outputMessage(501, WHATEVER), - END_MESSAGE) - .named("With unknown error")); + Protocol.RunCommandMessage.newBuilder().setResultCompletionId(1), + Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(1) + .setFailure(ProtoUtils.failure(501, WHATEVER)), + suspensionMessage(1)) + .named("With unknown error"), + this.sideEffectThrowTerminalException( + TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR) + .withInput( + startMessage(3), + inputCmd(), + runCmd(1), + runCompletion(1, TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR)) + .expectingOutput( + outputCmd(TerminalException.INTERNAL_SERVER_ERROR_CODE, MY_ERROR), END_MESSAGE) + .named("With internal error during replay"), + this.sideEffectThrowTerminalException(501, WHATEVER) + .withInput(startMessage(3), inputCmd(), runCmd(1), runCompletion(1, 501, WHATEVER)) + .expectingOutput(outputCmd(501, WHATEVER), END_MESSAGE) + .named("With unknown error during replay")); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java similarity index 51% rename from sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java index 1cacafd71..12069086f 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AsyncResultTest.java @@ -6,31 +6,38 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.*; +import static dev.restate.sdk.core.javaapi.JavaAPITests.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.StateKey; -import dev.restate.sdk.core.DeferredTestSuite; +import dev.restate.sdk.Awaitable; +import dev.restate.sdk.Select; +import dev.restate.sdk.core.AsyncResultTestSuite; +import dev.restate.sdk.core.TestDefinitions; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.StateKey; +import dev.restate.sdk.types.TimeoutException; +import dev.restate.serde.Serde; import java.time.Duration; -import java.util.concurrent.TimeoutException; +import java.util.stream.Stream; -public class DeferredTest extends DeferredTestSuite { +public class AsyncResultTest extends AsyncResultTestSuite { @Override protected TestInvocationBuilder reverseAwaitOrder() { return testDefinitionForVirtualObject( "ReverseAwaitOrder", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> { Awaitable a1 = callGreeterGreetService(context, "Francesco"); Awaitable a2 = callGreeterGreetService(context, "Till"); String a2Res = a2.await(); - context.set(StateKey.of("A2", JsonSerdes.STRING), a2Res); + context.set(StateKey.of("A2", TestSerdes.STRING), a2Res); String a1Res = a1.await(); @@ -43,7 +50,7 @@ protected TestInvocationBuilder awaitTwiceTheSameAwaitable() { return testDefinitionForService( "AwaitTwiceTheSameAwaitable", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> { Awaitable a = callGreeterGreetService(context, "Francesco"); @@ -56,7 +63,7 @@ protected TestInvocationBuilder awaitAll() { return testDefinitionForService( "AwaitAll", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> { Awaitable a1 = callGreeterGreetService(context, "Francesco"); Awaitable a2 = callGreeterGreetService(context, "Till"); @@ -72,12 +79,12 @@ protected TestInvocationBuilder awaitAny() { return testDefinitionForService( "AwaitAny", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> { Awaitable a1 = callGreeterGreetService(context, "Francesco"); Awaitable a2 = callGreeterGreetService(context, "Till"); - return (String) Awaitable.any(a1, a2).await(); + return Select.select().or(a1).or(a2).await(); }); } @@ -86,19 +93,20 @@ protected TestInvocationBuilder combineAnyWithAll() { return testDefinitionForService( "CombineAnyWithAll", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - Awaitable a1 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a2 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a3 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a4 = ctx.awakeable(JsonSerdes.STRING); - - Awaitable a12 = Awaitable.any(a1, a2); - Awaitable a23 = Awaitable.any(a2, a3); - Awaitable a34 = Awaitable.any(a3, a4); - Awaitable.all(a12, a23, a34).await(); - - return a12.await() + (String) a23.await() + a34.await(); + Awaitable a1 = ctx.awakeable(String.class); + Awaitable a2 = ctx.awakeable(String.class); + Awaitable a3 = ctx.awakeable(String.class); + Awaitable a4 = ctx.awakeable(String.class); + + Awaitable a12 = Select.select().or(a1).or(a2); + Awaitable a23 = Select.select().or(a2).or(a3); + Awaitable a34 = Select.select().or(a3).or(a4); + Awaitable result = + Awaitable.all(a12, a23, a34).map(v -> a12.await() + a23.await() + a34.await()); + + return result.await(); }); } @@ -107,14 +115,14 @@ protected TestInvocationBuilder awaitAnyIndex() { return testDefinitionForService( "AwaitAnyIndex", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - Awaitable a1 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a2 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a3 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a4 = ctx.awakeable(JsonSerdes.STRING); + Awaitable a1 = ctx.awakeable(String.class); + Awaitable a2 = ctx.awakeable(String.class); + Awaitable a3 = ctx.awakeable(String.class); + Awaitable a4 = ctx.awakeable(String.class); - return String.valueOf(Awaitable.any(a1, Awaitable.all(a2, a3), a4).awaitIndex()); + return String.valueOf(Awaitable.any(a1, Awaitable.all(a2, a3), a4).await()); }); } @@ -123,10 +131,10 @@ protected TestInvocationBuilder awaitOnAlreadyResolvedAwaitables() { return testDefinitionForService( "AwaitOnAlreadyResolvedAwaitables", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - Awaitable a1 = ctx.awakeable(JsonSerdes.STRING); - Awaitable a2 = ctx.awakeable(JsonSerdes.STRING); + Awaitable a1 = ctx.awakeable(String.class); + Awaitable a2 = ctx.awakeable(String.class); Awaitable a12 = Awaitable.all(a1, a2); Awaitable a12and1 = Awaitable.all(a12, a1); @@ -144,7 +152,7 @@ protected TestInvocationBuilder awaitWithTimeout() { return testDefinitionForService( "AwaitWithTimeout", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { Awaitable call = callGreeterGreetService(ctx, "Francesco"); @@ -158,4 +166,41 @@ protected TestInvocationBuilder awaitWithTimeout() { return result; }); } + + private TestInvocationBuilder checkAwaitableMapThread() { + return testDefinitionForService( + "CheckAwaitableThread", + Serde.VOID, + Serde.VOID, + (ctx, unused) -> { + var currentThreadName = Thread.currentThread().getName().split("-"); + var currentThreadPool = currentThreadName[0] + "-" + currentThreadName[1]; + + callGreeterGreetService(ctx, "Francesco") + .map( + u -> { + assertThat(Thread.currentThread().getName()).startsWith(currentThreadPool); + return null; + }) + .await(); + + return null; + }); + } + + @Override + public Stream definitions() { + return Stream.concat( + super.definitions(), + Stream.of( + this.checkAwaitableMapThread() + .withInput( + startMessage(3), + inputCmd(), + callCmd(1, 2, GREETER_SERVICE_TARGET, "Francesco"), + callCompletion(2, "FRANCESCO")) + .onlyBidiStream() + .expectingOutput(outputCmd(), END_MESSAGE) + .named("Check map constraints"))); + } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java similarity index 70% rename from sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java index 567b3cdc7..3deafeaf4 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/AwakeableIdTest.java @@ -6,13 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.Serde; import dev.restate.sdk.core.AwakeableIdTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.serde.Serde; public class AwakeableIdTest extends AwakeableIdTestSuite { @@ -21,7 +22,7 @@ protected TestInvocationBuilder returnAwakeableId() { return testDefinitionForService( "ReturnAwakeableId", Serde.VOID, - JsonSerdes.STRING, - (context, unused) -> context.awakeable(JsonSerdes.STRING).id()); + TestSerdes.STRING, + (context, unused) -> context.awakeable(TestSerdes.STRING).id()); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java new file mode 100644 index 000000000..6669a6179 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java @@ -0,0 +1,48 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi; + +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; + +import dev.restate.common.Request; +import dev.restate.common.SendRequest; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.CallTestSuite; +import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.serde.Serde; +import java.util.Map; + +public class CallTest extends CallTestSuite { + + @Override + protected TestInvocationBuilder oneWayCall( + Target target, String idempotencyKey, Map headers, Slice body) { + return testDefinitionForService( + "OneWayCall", + Serde.VOID, + Serde.VOID, + (context, unused) -> { + context.send( + SendRequest.ofRaw(target, body.toByteArray()) + .headers(headers) + .idempotencyKey(idempotencyKey)); + return null; + }); + } + + @Override + protected TestInvocationBuilder implicitCancellation(Target target, Slice body) { + return testDefinitionForService( + "ImplicitCancellation", + Serde.VOID, + Serde.RAW, + (context, unused) -> context.call(Request.ofRaw(target, body.toByteArray())).await()); + } +} diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java similarity index 61% rename from sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java index 3290d962d..4dc70b252 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenDiscoveryTest.java @@ -6,41 +6,25 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; import static dev.restate.sdk.core.AssertUtils.assertThatDiscovery; import static org.assertj.core.api.InstanceOfAssertFactories.type; -import dev.restate.sdk.core.MockMultiThreaded; -import dev.restate.sdk.core.MockSingleThread; -import dev.restate.sdk.core.TestDefinitions.TestExecutor; -import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.TestRunner; -import dev.restate.sdk.core.manifest.Handler; -import dev.restate.sdk.core.manifest.Input; -import dev.restate.sdk.core.manifest.Output; -import dev.restate.sdk.core.manifest.Service; -import java.util.stream.Stream; +import dev.restate.sdk.core.generated.manifest.Handler; +import dev.restate.sdk.core.generated.manifest.Input; +import dev.restate.sdk.core.generated.manifest.Output; +import dev.restate.sdk.core.generated.manifest.Service; import org.junit.jupiter.api.Test; -public class JavaCodegenTests extends TestRunner { - - @Override - protected Stream executors() { - return Stream.of(MockSingleThread.INSTANCE, MockMultiThreaded.INSTANCE); - } - - @Override - public Stream definitions() { - return Stream.of(new CodegenTest()); - } +public class CodegenDiscoveryTest { @Test void checkCustomInputContentType() { assertThatDiscovery(new CodegenTest.RawInputOutput()) .extractingService("RawInputOutput") .extractingHandler("rawInputWithCustomCt") - .extracting(dev.restate.sdk.core.manifest.Handler::getInput, type(Input.class)) + .extracting(Handler::getInput, type(Input.class)) .extracting(Input::getContentType) .isEqualTo("application/vnd.my.custom"); } @@ -50,7 +34,7 @@ void checkCustomInputAcceptContentType() { assertThatDiscovery(new CodegenTest.RawInputOutput()) .extractingService("RawInputOutput") .extractingHandler("rawInputWithCustomAccept") - .extracting(dev.restate.sdk.core.manifest.Handler::getInput, type(Input.class)) + .extracting(Handler::getInput, type(Input.class)) .extracting(Input::getContentType) .isEqualTo("application/*"); } @@ -60,7 +44,7 @@ void checkCustomOutputContentType() { assertThatDiscovery(new CodegenTest.RawInputOutput()) .extractingService("RawInputOutput") .extractingHandler("rawOutputWithCustomCT") - .extracting(dev.restate.sdk.core.manifest.Handler::getOutput, type(Output.class)) + .extracting(Handler::getOutput, type(Output.class)) .extracting(Output::getContentType) .isEqualTo("application/vnd.my.custom"); } diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java similarity index 67% rename from sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java index 128b2f142..2cf316eaa 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CodegenTest.java @@ -6,19 +6,19 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.testInvocation; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; +import static org.assertj.core.api.Assertions.assertThat; -import com.google.protobuf.ByteString; +import dev.restate.common.Target; +import dev.restate.sdk.*; import dev.restate.sdk.annotation.*; -import dev.restate.sdk.annotation.Service; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.Target; -import dev.restate.sdk.core.ProtoUtils; import dev.restate.sdk.core.TestDefinitions; import dev.restate.sdk.core.TestDefinitions.TestSuite; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.serde.Serde; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.stream.Stream; @@ -124,6 +124,7 @@ public String submit(SharedWorkflowContext context, String request) { return CodegenTestWorkflowCornerCasesClient.connect("invalid", request) .workflowHandle() .getOutput() + .response() .getValue(); } } @@ -191,118 +192,132 @@ String greet(Context context, String request) throws IOException { } } + @Service + @CustomSerdeFactory(MySerdeFactory.class) + static class CustomSerde { + @Handler + String greet(Context context, String request) { + assertThat(request).isEqualTo("INPUT"); + return "output"; + } + } + @Override public Stream definitions() { return Stream.of( testInvocation(ServiceGreeter::new, "greet") - .withInput(startMessage(1), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation(ObjectGreeter::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation(ObjectGreeter::new, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation(ObjectGreeterImplementedFromInterface::new, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation(Empty::new, "emptyInput") - .withInput(startMessage(1), inputMessage(), completionMessage(1, "Till")) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyInput")), - outputMessage("Till"), + callCmd(1, 2, Target.service("Empty", "emptyInput")), + outputCmd("Till"), END_MESSAGE) .named("empty output"), testInvocation(Empty::new, "emptyOutput") - .withInput( - startMessage(1), - inputMessage("Francesco"), - completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyOutput"), "Francesco"), - ProtoUtils.outputMessage(), + callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), + outputCmd(), END_MESSAGE) .named("empty output"), testInvocation(Empty::new, "emptyInputOutput") - .withInput( - startMessage(1), - inputMessage("Francesco"), - completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyInputOutput")), - ProtoUtils.outputMessage(), + callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), + outputCmd(), END_MESSAGE) .named("empty input and empty output"), testInvocation(PrimitiveTypes::new, "primitiveOutput") - .withInput(startMessage(1), inputMessage(), completionMessage(1, JsonSerdes.INT, 10)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), - outputMessage(JsonSerdes.INT, 10), + callCmd( + 1, 2, Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), + outputCmd(TestSerdes.INT, 10), END_MESSAGE) .named("primitive output"), testInvocation(PrimitiveTypes::new, "primitiveInput") - .withInput( - startMessage(1), inputMessage(10), completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("PrimitiveTypes", "primitiveInput"), JsonSerdes.INT, 10), - outputMessage(), + callCmd( + 1, 2, Target.service("PrimitiveTypes", "primitiveInput"), TestSerdes.INT, 10), + outputCmd(), END_MESSAGE) .named("primitive input"), testInvocation(RawInputOutput::new, "rawInput") .withInput( startMessage(1), - inputMessage("{{".getBytes(StandardCharsets.UTF_8)), - completionMessage(1, Serde.VOID, null)) - .onlyUnbuffered() + inputCmd("{{".getBytes(StandardCharsets.UTF_8)), + callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage( + callCmd( + 1, + 2, Target.service("RawInputOutput", "rawInput"), "{{".getBytes(StandardCharsets.UTF_8)), - outputMessage(), + outputCmd(), END_MESSAGE), testInvocation(RawInputOutput::new, "rawInputWithCustomCt") .withInput( startMessage(1), - inputMessage("{{".getBytes(StandardCharsets.UTF_8)), - completionMessage(1, Serde.VOID, null)) - .onlyUnbuffered() + inputCmd("{{".getBytes(StandardCharsets.UTF_8)), + callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage( + callCmd( + 1, + 2, Target.service("RawInputOutput", "rawInputWithCustomCt"), "{{".getBytes(StandardCharsets.UTF_8)), - outputMessage(), + outputCmd(), END_MESSAGE), testInvocation(RawInputOutput::new, "rawOutput") .withInput( startMessage(1), - inputMessage(), - completionMessage(1, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyUnbuffered() + inputCmd(), + callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("RawInputOutput", "rawOutput"), Serde.VOID, null), - outputMessage("{{".getBytes(StandardCharsets.UTF_8)), + callCmd(1, 2, Target.service("RawInputOutput", "rawOutput"), Serde.VOID, null), + outputCmd("{{".getBytes(StandardCharsets.UTF_8)), END_MESSAGE), testInvocation(RawInputOutput::new, "rawOutputWithCustomCT") .withInput( startMessage(1), - inputMessage(), - completionMessage(1, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) - .onlyUnbuffered() + inputCmd(), + callCompletion(2, Serde.RAW, "{{".getBytes(StandardCharsets.UTF_8))) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("RawInputOutput", "rawOutputWithCustomCT"), Serde.VOID, null), - outputMessage("{{".getBytes(StandardCharsets.UTF_8)), - END_MESSAGE)); + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + Serde.VOID, + null), + outputCmd("{{".getBytes(StandardCharsets.UTF_8)), + END_MESSAGE), + testInvocation(CustomSerde::new, "greet") + .withInput(startMessage(1), inputCmd(MySerdeFactory.SERDE, "input")) + .expectingOutput(outputCmd(MySerdeFactory.SERDE, "OUTPUT"), END_MESSAGE)); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java similarity index 64% rename from sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java index b7dd19722..53c6609b8 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/EagerStateTest.java @@ -6,15 +6,16 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForVirtualObject; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.StateKey; import dev.restate.sdk.core.EagerStateTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; public class EagerStateTest extends EagerStateTestSuite { @@ -23,9 +24,9 @@ protected TestInvocationBuilder getEmpty() { return testDefinitionForVirtualObject( "GetEmpty", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> - String.valueOf(ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).isEmpty())); + String.valueOf(ctx.get(StateKey.of("STATE", TestSerdes.STRING)).isEmpty())); } @Override @@ -33,21 +34,21 @@ protected TestInvocationBuilder get() { return testDefinitionForVirtualObject( "GetEmpty", Serde.VOID, - JsonSerdes.STRING, - (ctx, unused) -> ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get()); + TestSerdes.STRING, + (ctx, unused) -> ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get()); } @Override protected TestInvocationBuilder getAppendAndGet() { return testDefinitionForVirtualObject( "GetAppendAndGet", - JsonSerdes.STRING, - JsonSerdes.STRING, + TestSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get(); - ctx.set(StateKey.of("STATE", JsonSerdes.STRING), oldState + input); + String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); + ctx.set(StateKey.of("STATE", TestSerdes.STRING), oldState + input); - return ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get(); + return ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); }); } @@ -56,12 +57,12 @@ protected TestInvocationBuilder getClearAndGet() { return testDefinitionForVirtualObject( "GetClearAndGet", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get(); + String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); - ctx.clear(StateKey.of("STATE", JsonSerdes.STRING)); - assertThat(ctx.get(StateKey.of("STATE", JsonSerdes.STRING))).isEmpty(); + ctx.clear(StateKey.of("STATE", TestSerdes.STRING)); + assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isEmpty(); return oldState; }); } @@ -71,13 +72,13 @@ protected TestInvocationBuilder getClearAllAndGet() { return testDefinitionForVirtualObject( "GetClearAllAndGet", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> { - String oldState = ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get(); + String oldState = ctx.get(StateKey.of("STATE", TestSerdes.STRING)).get(); ctx.clearAll(); - assertThat(ctx.get(StateKey.of("STATE", JsonSerdes.STRING))).isEmpty(); - assertThat(ctx.get(StateKey.of("ANOTHER_STATE", JsonSerdes.STRING))).isEmpty(); + assertThat(ctx.get(StateKey.of("STATE", TestSerdes.STRING))).isEmpty(); + assertThat(ctx.get(StateKey.of("ANOTHER_STATE", TestSerdes.STRING))).isEmpty(); return oldState; }); @@ -88,7 +89,7 @@ protected TestInvocationBuilder listKeys() { return testDefinitionForVirtualObject( "ListKeys", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> String.join(",", ctx.stateKeys())); } @@ -99,8 +100,8 @@ protected TestInvocationBuilder consecutiveGetWithEmpty() { Serde.VOID, Serde.VOID, (ctx, input) -> { - assertThat(ctx.get(StateKey.of("key-0", JsonSerdes.STRING))).isEmpty(); - assertThat(ctx.get(StateKey.of("key-0", JsonSerdes.STRING))).isEmpty(); + assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isEmpty(); + assertThat(ctx.get(StateKey.of("key-0", TestSerdes.STRING))).isEmpty(); return null; }); } diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithExplicitName.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java similarity index 88% rename from sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithExplicitName.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java index b3fda9b23..879cd82a5 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithExplicitName.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithExplicitName.java @@ -6,8 +6,9 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; +import dev.restate.sdk.Context; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Service; diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithoutExplicitName.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java similarity index 88% rename from sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithoutExplicitName.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java index 56d6b9c0a..77571d77d 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/GreeterWithoutExplicitName.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/GreeterWithoutExplicitName.java @@ -6,8 +6,9 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; +import dev.restate.sdk.Context; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Service; diff --git a/sdk-api/src/test/java/dev/restate/sdk/InvocationIdTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java similarity index 77% rename from sdk-api/src/test/java/dev/restate/sdk/InvocationIdTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java index 6e3213ca6..97e10d934 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/InvocationIdTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/InvocationIdTest.java @@ -6,13 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.Serde; import dev.restate.sdk.core.InvocationIdTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.serde.Serde; public class InvocationIdTest extends InvocationIdTestSuite { @@ -21,7 +22,7 @@ protected TestInvocationBuilder returnInvocationId() { return testDefinitionForService( "ReturnInvocationId", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> ctx.request().invocationId().toString()); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java similarity index 59% rename from sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java index cf1ec6b19..95bbd9475 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java @@ -6,39 +6,39 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.core.ProtoUtils.GREETER_SERVICE_TARGET; +import static dev.restate.sdk.core.statemachine.ProtoUtils.GREETER_SERVICE_TARGET; -import dev.restate.sdk.common.HandlerType; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.ServiceType; -import dev.restate.sdk.common.function.ThrowingBiFunction; -import dev.restate.sdk.common.syscalls.HandlerDefinition; -import dev.restate.sdk.common.syscalls.HandlerSpecification; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.MockMultiThreaded; -import dev.restate.sdk.core.MockSingleThread; -import dev.restate.sdk.core.TestDefinitions; +import dev.restate.common.Request; +import dev.restate.common.function.ThrowingBiFunction; +import dev.restate.sdk.*; +import dev.restate.sdk.core.*; import dev.restate.sdk.core.TestDefinitions.TestExecutor; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import dev.restate.sdk.core.TestDefinitions.TestSuite; -import dev.restate.sdk.core.TestRunner; +import dev.restate.sdk.endpoint.definition.HandlerDefinition; +import dev.restate.sdk.endpoint.definition.HandlerType; +import dev.restate.sdk.endpoint.definition.ServiceDefinition; +import dev.restate.sdk.endpoint.definition.ServiceType; +import dev.restate.sdk.serde.jackson.JacksonSerdeFactory; +import dev.restate.serde.Serde; import java.util.List; import java.util.stream.Stream; -public class JavaBlockingTests extends TestRunner { +public class JavaAPITests extends TestRunner { @Override protected Stream executors() { - return Stream.of(MockSingleThread.INSTANCE, MockMultiThreaded.INSTANCE); + return Stream.of(MockRequestResponse.INSTANCE, MockBidiStream.INSTANCE); } @Override public Stream definitions() { return Stream.of( new AwakeableIdTest(), - new DeferredTest(), + new AsyncResultTest(), + new CallTest(), new EagerStateTest(), new StateTest(), new InvocationIdTest(), @@ -48,7 +48,8 @@ public Stream definitions() { new SleepTest(), new StateMachineFailuresTest(), new UserFailuresTest(), - new RandomTest()); + new RandomTest(), + new CodegenTest()); } public static TestInvocationBuilder testDefinitionForService( @@ -59,8 +60,11 @@ public static TestInvocationBuilder testDefinitionForService( ServiceType.SERVICE, List.of( HandlerDefinition.of( - HandlerSpecification.of("run", HandlerType.SHARED, reqSerde, resSerde), - HandlerRunner.of(runner)))), + "run", + HandlerType.SHARED, + reqSerde, + resSerde, + HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), "run"); } @@ -75,8 +79,11 @@ public static TestInvocationBuilder testDefinitionForVirtualObject( ServiceType.VIRTUAL_OBJECT, List.of( HandlerDefinition.of( - HandlerSpecification.of("run", HandlerType.EXCLUSIVE, reqSerde, resSerde), - HandlerRunner.of(runner)))), + "run", + HandlerType.EXCLUSIVE, + reqSerde, + resSerde, + HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), "run"); } @@ -91,12 +98,16 @@ public static TestInvocationBuilder testDefinitionForWorkflow( ServiceType.WORKFLOW, List.of( HandlerDefinition.of( - HandlerSpecification.of("run", HandlerType.WORKFLOW, reqSerde, resSerde), - HandlerRunner.of(runner)))), + "run", + HandlerType.WORKFLOW, + reqSerde, + resSerde, + HandlerRunner.of(runner, new JacksonSerdeFactory(), null)))), "run"); } public static Awaitable callGreeterGreetService(Context ctx, String parameter) { - return ctx.call(GREETER_SERVICE_TARGET, JsonSerdes.STRING, JsonSerdes.STRING, parameter); + return ctx.call( + Request.of(GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java new file mode 100644 index 000000000..9bdef03ac --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/MySerdeFactory.java @@ -0,0 +1,38 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi; + +import static org.assertj.core.api.Assertions.assertThat; + +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.TypeRef; +import java.nio.charset.StandardCharsets; + +@SuppressWarnings("unchecked") +public class MySerdeFactory implements SerdeFactory { + + static Serde SERDE = + Serde.using( + "mycontent/type", + s -> s.toUpperCase().getBytes(), + b -> new String(b, StandardCharsets.UTF_8).toUpperCase()); + + @Override + public Serde create(TypeRef typeRef) { + assertThat(typeRef.getType()).isEqualTo(String.class); + return (Serde) SERDE; + } + + @Override + public Serde create(Class clazz) { + assertThat(clazz).isEqualTo(String.class); + return (Serde) SERDE; + } +} diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java similarity index 70% rename from sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java index 7fbdf76e7..6c17ac783 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/NameInferenceTest.java @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; import static org.assertj.core.api.Assertions.assertThat; @@ -16,10 +16,10 @@ public class NameInferenceTest { @Test void expectedName() { - assertThat(CodegenTestServiceGreeterDefinitions.SERVICE_NAME) + assertThat(CodegenTestServiceGreeterMetadata.SERVICE_NAME) .isEqualTo("CodegenTestServiceGreeter"); - assertThat(GreeterWithoutExplicitNameDefinitions.SERVICE_NAME) + assertThat(GreeterWithoutExplicitNameMetadata.SERVICE_NAME) .isEqualTo("GreeterWithoutExplicitName"); - assertThat(MyExplicitNameDefinitions.SERVICE_NAME).isEqualTo("MyExplicitName"); + assertThat(MyExplicitNameMetadata.SERVICE_NAME).isEqualTo("MyExplicitName"); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/OnlyInputAndOutputTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java similarity index 76% rename from sdk-api/src/test/java/dev/restate/sdk/OnlyInputAndOutputTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java index ac0317676..506b31114 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/OnlyInputAndOutputTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/OnlyInputAndOutputTest.java @@ -6,12 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; import dev.restate.sdk.core.OnlyInputAndOutputTestSuite; import dev.restate.sdk.core.TestDefinitions; +import dev.restate.sdk.core.TestSerdes; public class OnlyInputAndOutputTest extends OnlyInputAndOutputTestSuite { @@ -19,8 +20,8 @@ public class OnlyInputAndOutputTest extends OnlyInputAndOutputTestSuite { protected TestDefinitions.TestInvocationBuilder noSyscallsGreeter() { return testDefinitionForService( "NoSyscallsGreeter", - JsonSerdes.STRING, - JsonSerdes.STRING, + TestSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> "Hello " + input); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/PromiseTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java similarity index 76% rename from sdk-api/src/test/java/dev/restate/sdk/PromiseTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java index 8baf63e81..e6af1b194 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/PromiseTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/PromiseTest.java @@ -6,16 +6,16 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.*; +import static dev.restate.sdk.core.javaapi.JavaAPITests.*; -import dev.restate.sdk.common.DurablePromiseKey; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.core.PromiseTestSuite; import dev.restate.sdk.core.TestDefinitions; import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.DurablePromiseKey; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; public class PromiseTest extends PromiseTestSuite { @Override @@ -23,12 +23,9 @@ protected TestDefinitions.TestInvocationBuilder awaitPromise(String promiseKey) return testDefinitionForWorkflow( "AwaitPromise", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> - context - .promise(DurablePromiseKey.of(promiseKey, TestSerdes.STRING)) - .awaitable() - .await()); + context.promise(DurablePromiseKey.of(promiseKey, String.class)).awaitable().await()); } @Override @@ -37,10 +34,10 @@ protected TestDefinitions.TestInvocationBuilder awaitPeekPromise( return testDefinitionForWorkflow( "PeekPromise", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (context, unused) -> context - .promise(DurablePromiseKey.of(promiseKey, TestSerdes.STRING)) + .promise(DurablePromiseKey.of(promiseKey, String.class)) .peek() .orElse(emptyCaseReturnValue)); } @@ -50,9 +47,9 @@ protected TestDefinitions.TestInvocationBuilder awaitIsPromiseCompleted(String p return testDefinitionForWorkflow( "IsCompletedPromise", Serde.VOID, - JsonSerdes.BOOLEAN, + TestSerdes.BOOLEAN, (context, unused) -> - context.promise(DurablePromiseKey.of(promiseKey, TestSerdes.STRING)).peek().isReady()); + context.promise(DurablePromiseKey.of(promiseKey, String.class)).peek().isReady()); } @Override @@ -61,11 +58,11 @@ protected TestDefinitions.TestInvocationBuilder awaitResolvePromise( return testDefinitionForWorkflow( "ResolvePromise", Serde.VOID, - JsonSerdes.BOOLEAN, + TestSerdes.BOOLEAN, (context, unused) -> { try { context - .promiseHandle(DurablePromiseKey.of(promiseKey, TestSerdes.STRING)) + .promiseHandle(DurablePromiseKey.of(promiseKey, String.class)) .resolve(completionValue); return true; } catch (TerminalException e) { @@ -80,11 +77,11 @@ protected TestDefinitions.TestInvocationBuilder awaitRejectPromise( return testDefinitionForWorkflow( "RejectPromise", Serde.VOID, - JsonSerdes.BOOLEAN, + TestSerdes.BOOLEAN, (context, unused) -> { try { context - .promiseHandle(DurablePromiseKey.of(promiseKey, TestSerdes.STRING)) + .promiseHandle(DurablePromiseKey.of(promiseKey, String.class)) .reject(rejectReason); return true; } catch (TerminalException e) { diff --git a/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java similarity index 62% rename from sdk-api/src/test/java/dev/restate/sdk/RandomTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java index f96e59e15..7c9815a08 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/RandomTest.java @@ -6,13 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.Serde; import dev.restate.sdk.core.RandomTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.serde.Serde; import java.util.Random; public class RandomTest extends RandomTestSuite { @@ -22,22 +23,10 @@ protected TestInvocationBuilder randomShouldBeDeterministic() { return testDefinitionForService( "RandomShouldBeDeterministic", Serde.VOID, - JsonSerdes.INT, + TestSerdes.INT, (ctx, unused) -> ctx.random().nextInt()); } - @Override - protected TestInvocationBuilder randomInsideSideEffect() { - return testDefinitionForService( - "RandomInsideSideEffect", - Serde.VOID, - JsonSerdes.INT, - (ctx, unused) -> { - ctx.run(() -> ctx.random().nextInt()); - throw new IllegalStateException("This should not unreachable"); - }); - } - @Override protected int getExpectedInt(long seed) { return new Random(seed).nextInt(); diff --git a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java similarity index 68% rename from sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java index 1c6596d5c..77dd0385f 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SideEffectTest.java @@ -6,15 +6,15 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; -import static dev.restate.sdk.core.ProtoUtils.GREETER_SERVICE_TARGET; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.RetryPolicy; -import dev.restate.sdk.common.Serde; import dev.restate.sdk.core.SideEffectTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.RetryPolicy; +import dev.restate.serde.Serde; import java.util.Objects; public class SideEffectTest extends SideEffectTestSuite { @@ -24,9 +24,9 @@ protected TestInvocationBuilder sideEffect(String sideEffectOutput) { return testDefinitionForService( "SideEffect", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - String result = ctx.run(JsonSerdes.STRING, () -> sideEffectOutput); + String result = ctx.run(String.class, () -> sideEffectOutput); return "Hello " + result; }); } @@ -34,11 +34,11 @@ protected TestInvocationBuilder sideEffect(String sideEffectOutput) { @Override protected TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput) { return testDefinitionForService( - "SideEffect", + "NamedSideEffect", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - String result = ctx.run(name, JsonSerdes.STRING, () -> sideEffectOutput); + String result = ctx.run(name, String.class, () -> sideEffectOutput); return "Hello " + result; }); } @@ -48,10 +48,10 @@ protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { return testDefinitionForService( "ConsecutiveSideEffect", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - String firstResult = ctx.run(JsonSerdes.STRING, () -> sideEffectOutput); - String secondResult = ctx.run(JsonSerdes.STRING, firstResult::toUpperCase); + String firstResult = ctx.run(String.class, () -> sideEffectOutput); + String secondResult = ctx.run(String.class, firstResult::toUpperCase); return "Hello " + secondResult; }); @@ -62,12 +62,11 @@ protected TestInvocationBuilder checkContextSwitching() { return testDefinitionForService( "CheckContextSwitching", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { String currentThread = Thread.currentThread().getName(); - String sideEffectThread = - ctx.run(JsonSerdes.STRING, () -> Thread.currentThread().getName()); + String sideEffectThread = ctx.run(String.class, () -> Thread.currentThread().getName()); if (!Objects.equals(currentThread, sideEffectThread)) { throw new IllegalStateException( @@ -81,24 +80,12 @@ protected TestInvocationBuilder checkContextSwitching() { }); } - @Override - protected TestInvocationBuilder sideEffectGuard() { - return testDefinitionForService( - "SideEffectGuard", - Serde.VOID, - JsonSerdes.STRING, - (ctx, unused) -> { - ctx.run(() -> ctx.send(GREETER_SERVICE_TARGET, new byte[] {})); - throw new IllegalStateException("This point should not be reached"); - }); - } - @Override protected TestInvocationBuilder failingSideEffect(String name, String reason) { return testDefinitionForService( "FailingSideEffect", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { ctx.run( name, @@ -115,9 +102,10 @@ protected TestInvocationBuilder failingSideEffectWithRetryPolicy( return testDefinitionForService( "FailingSideEffectWithRetryPolicy", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { ctx.run( + null, retryPolicy, () -> { throw new IllegalStateException(reason); diff --git a/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java similarity index 85% rename from sdk-api/src/test/java/dev/restate/sdk/SleepTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java index 3ad5778fb..8db0adef6 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/SleepTest.java @@ -6,13 +6,15 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.Serde; +import dev.restate.sdk.Awaitable; import dev.restate.sdk.core.SleepTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.serde.Serde; import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -24,7 +26,7 @@ protected TestInvocationBuilder sleepGreeter() { return testDefinitionForService( "SleepGreeter", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { ctx.sleep(Duration.ofSeconds(1)); return "Hello"; diff --git a/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java similarity index 82% rename from sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java index fcd435667..c6fec61ff 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateMachineFailuresTest.java @@ -6,13 +6,17 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForVirtualObject; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; -import dev.restate.sdk.common.*; import dev.restate.sdk.core.StateMachineFailuresTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.StateKey; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicInteger; @@ -30,7 +34,7 @@ protected TestInvocationBuilder getState(AtomicInteger nonTerminalExceptionsSeen return testDefinitionForVirtualObject( "GetState", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { try { ctx.get(STATE); @@ -55,7 +59,7 @@ protected TestInvocationBuilder sideEffectFailure(Serde serde) { return testDefinitionForVirtualObject( "SideEffectFailure", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { ctx.run(serde, () -> 0); return "Francesco"; diff --git a/sdk-api/src/test/java/dev/restate/sdk/StateTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java similarity index 73% rename from sdk-api/src/test/java/dev/restate/sdk/StateTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java index 7c4a6dc3d..622be3d73 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/StateTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/StateTest.java @@ -6,14 +6,15 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForVirtualObject; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForVirtualObject; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.StateKey; import dev.restate.sdk.core.StateTestSuite; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; public class StateTest extends StateTestSuite { @@ -22,9 +23,9 @@ protected TestInvocationBuilder getState() { return testDefinitionForVirtualObject( "GetState", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { - String state = ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).orElse("Unknown"); + String state = ctx.get(StateKey.of("STATE", String.class)).orElse("Unknown"); return "Hello " + state; }); @@ -34,12 +35,12 @@ protected TestInvocationBuilder getState() { protected TestInvocationBuilder getAndSetState() { return testDefinitionForVirtualObject( "GetState", - JsonSerdes.STRING, - JsonSerdes.STRING, + TestSerdes.STRING, + TestSerdes.STRING, (ctx, input) -> { - String state = ctx.get(StateKey.of("STATE", JsonSerdes.STRING)).get(); + String state = ctx.get(StateKey.of("STATE", String.class)).get(); - ctx.set(StateKey.of("STATE", JsonSerdes.STRING), input); + ctx.set(StateKey.of("STATE", String.class), input); return "Hello " + state; }); @@ -50,7 +51,7 @@ protected TestInvocationBuilder setNullState() { return testDefinitionForVirtualObject( "GetState", Serde.VOID, - JsonSerdes.STRING, + TestSerdes.STRING, (ctx, unused) -> { ctx.set( StateKey.of( diff --git a/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java similarity index 90% rename from sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java index 735ddbd62..90e14f307 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/UserFailuresTest.java @@ -6,15 +6,15 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; +package dev.restate.sdk.core.javaapi; -import static dev.restate.sdk.JavaBlockingTests.testDefinitionForService; +import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.sdk.common.AbortedExecutionException; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import dev.restate.sdk.core.UserFailuresTestSuite; +import dev.restate.sdk.types.AbortedExecutionException; +import dev.restate.sdk.types.TerminalException; +import dev.restate.serde.Serde; import java.util.concurrent.atomic.AtomicInteger; public class UserFailuresTest extends UserFailuresTestSuite { diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java similarity index 81% rename from sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java rename to sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java index f5eba44a9..bef0d9189 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/LambdaHandlerTest.java @@ -6,9 +6,9 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda; +package dev.restate.sdk.core.lambda; -import static dev.restate.sdk.core.ProtoUtils.*; +import static dev.restate.sdk.core.statemachine.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; import com.amazonaws.services.lambda.runtime.ClientContext; @@ -20,31 +20,31 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.ByteString; import com.google.protobuf.MessageLite; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.core.ProtoUtils; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import dev.restate.sdk.core.manifest.Service; -import dev.restate.sdk.lambda.testservices.JavaCounterDefinitions; -import dev.restate.sdk.lambda.testservices.MyServicesHandler; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.Service; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.lambda.testservices.JavaCounterMetadata; +import dev.restate.sdk.core.lambda.testservices.MyServicesHandler; +import dev.restate.sdk.core.statemachine.MessageHeader; +import dev.restate.sdk.core.statemachine.ProtoUtils; +import dev.restate.sdk.lambda.BaseRestateLambdaHandler; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.*; +import java.util.Base64; +import java.util.Map; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; class LambdaHandlerTest { - @ValueSource(strings = {JavaCounterDefinitions.SERVICE_NAME, "KtCounter"}) - @ParameterizedTest - public void testInvoke(String serviceName) throws IOException { + @Test + public void testInvoke() throws IOException { MyServicesHandler handler = new MyServicesHandler(); // Mock request APIGatewayProxyRequestEvent request = new APIGatewayProxyRequestEvent(); request.setHeaders(Map.of("content-type", ProtoUtils.serviceProtocolContentTypeHeader())); - request.setPath("/a/path/prefix/invoke/" + serviceName + "/get"); + request.setPath("/a/path/prefix/invoke/" + JavaCounterMetadata.SERVICE_NAME + "/get"); request.setHttpMethod("POST"); request.setIsBase64Encoded(true); request.setBody( @@ -57,7 +57,7 @@ public void testInvoke(String serviceName) throws IOException { .setKnownEntries(1) .setPartialState(true) .build(), - inputMessage()))); + inputCmd()))); // Send request APIGatewayProxyResponseEvent response = handler.handleRequest(request, mockContext()); @@ -69,7 +69,7 @@ public void testInvoke(String serviceName) throws IOException { assertThat(response.getIsBase64Encoded()).isTrue(); assertThat(response.getBody()) .asBase64Decoded() - .isEqualTo(serializeEntries(getStateMessage("counter").build(), suspensionMessage(1))); + .isEqualTo(serializeEntries(getLazyStateCmd(1, "counter").build(), suspensionMessage(1))); } @Test @@ -96,14 +96,14 @@ public void testDiscovery() throws IOException { assertThat(discoveryResponse.getServices()) .map(Service::getName) - .containsOnly(JavaCounterDefinitions.SERVICE_NAME, "KtCounter"); + .containsOnly(JavaCounterMetadata.SERVICE_NAME); } private static byte[] serializeEntries(MessageLite... msgs) throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); for (MessageLite msg : msgs) { ByteBuffer headerBuf = ByteBuffer.allocate(8); - headerBuf.putLong(ProtoUtils.headerFromMessage(msg).encode()); + headerBuf.putLong(MessageHeader.fromMessage(msg).encode()); outputStream.write(headerBuf.array()); msg.writeTo(outputStream); } diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/JavaCounterService.java similarity index 88% rename from sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java rename to sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/JavaCounterService.java index eb8c973b3..d71d66e15 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/JavaCounterService.java @@ -6,13 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda.testservices; +package dev.restate.sdk.core.lambda.testservices; import dev.restate.sdk.ObjectContext; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.VirtualObject; -import dev.restate.sdk.common.Serde; -import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.types.StateKey; +import dev.restate.serde.Serde; import java.nio.charset.StandardCharsets; @VirtualObject(name = "JavaCounter") diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/MyServicesHandler.java b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/MyServicesHandler.java similarity index 65% rename from sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/MyServicesHandler.java rename to sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/MyServicesHandler.java index 4d11be280..d6b9b9afa 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/MyServicesHandler.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/lambda/testservices/MyServicesHandler.java @@ -6,14 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda.testservices; +package dev.restate.sdk.core.lambda.testservices; +import dev.restate.sdk.endpoint.Endpoint; import dev.restate.sdk.lambda.BaseRestateLambdaHandler; -import dev.restate.sdk.lambda.RestateLambdaEndpointBuilder; public class MyServicesHandler extends BaseRestateLambdaHandler { @Override - public void register(RestateLambdaEndpointBuilder builder) { - builder.bind(new JavaCounterService()).bind(KotlinCounterServiceKt.counter()); + public void register(Endpoint.Builder builder) { + builder.bind(new JavaCounterService()); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java new file mode 100644 index 000000000..04038dab0 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/MessageDecoderTest.java @@ -0,0 +1,59 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import static dev.restate.sdk.core.AssertUtils.assertThatDecodingMessages; +import static dev.restate.sdk.core.statemachine.ProtoUtils.startMessage; +import static org.assertj.core.api.Assertions.entry; + +import com.google.protobuf.MessageLite; +import dev.restate.common.Slice; +import java.nio.ByteBuffer; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class MessageDecoderTest { + + @Test + void oneMessage() { + assertThatDecodingMessages( + ProtoUtils.encodeMessageToSlice(startMessage(1, "my-key", entry("key", "value")))) + .map(InvocationInput::message) + .containsExactly(startMessage(1, "my-key", entry("key", "value")).build()); + } + + @Test + void multiMessage() { + assertThatDecodingMessages( + ProtoUtils.encodeMessageToSlice(startMessage(1, "my-key", entry("key", "value"))), + ProtoUtils.encodeMessageToSlice(ProtoUtils.inputCmd("my-value"))) + .map(InvocationInput::message) + .containsExactly( + startMessage(1, "my-key", entry("key", "value")).build(), + ProtoUtils.inputCmd("my-value")); + } + + @Test + void multiMessageInSingleBuffer() { + List messages = + List.of( + startMessage(1, "my-key", entry("key", "value")).build(), + ProtoUtils.inputCmd("my-value")); + ByteBuffer byteBuffer = + ByteBuffer.allocate(messages.stream().mapToInt(MessageEncoder::encodeLength).sum()); + messages.stream().map(ProtoUtils::encodeMessageToByteBuffer).forEach(byteBuffer::put); + byteBuffer.flip(); + + assertThatDecodingMessages(Slice.wrap(byteBuffer)) + .map(InvocationInput::message) + .containsExactly( + startMessage(1, "my-key", entry("key", "value")).build(), + ProtoUtils.inputCmd("my-value")); + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java new file mode 100644 index 000000000..f02561b00 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/statemachine/ProtoUtils.java @@ -0,0 +1,502 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import com.google.protobuf.ByteString; +import com.google.protobuf.MessageLite; +import com.google.protobuf.MessageLiteOrBuilder; +import com.google.protobuf.UnsafeByteOperations; +import dev.restate.common.Slice; +import dev.restate.common.Target; +import dev.restate.sdk.core.TestSerdes; +import dev.restate.sdk.core.generated.protocol.Protocol; +import dev.restate.sdk.core.generated.protocol.Protocol.StartMessage.StateEntry; +import dev.restate.serde.Serde; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.jspecify.annotations.Nullable; + +public class ProtoUtils { + + public static long invocationIdToRandomSeed(String invocationId) { + return new InvocationIdImpl(invocationId).toRandomSeed(); + } + + public static String serviceProtocolContentTypeHeader() { + return ServiceProtocol.serviceProtocolVersionToHeaderValue( + ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION); + } + + public static String serviceProtocolContentTypeHeader(boolean enableContextPreview) { + return ServiceProtocol.serviceProtocolVersionToHeaderValue( + ServiceProtocol.MAX_SERVICE_PROTOCOL_VERSION); + } + + public static String serviceProtocolDiscoveryContentTypeHeader() { + return "application/vnd.restate.endpointmanifest.v2+json"; + } + + public static ByteBuffer invocationInputToByteString(InvocationInput invocationInput) { + ByteBuffer buffer = ByteBuffer.allocate(MessageEncoder.encodeLength(invocationInput.message())); + + buffer.putLong(invocationInput.header().encode()); + buffer.put(invocationInput.message().toByteString().asReadOnlyByteBuffer()); + + buffer.flip(); + return buffer; + } + + public static ByteBuffer encodeMessageToByteBuffer(MessageLiteOrBuilder msgOrBuilder) { + var msg = build(msgOrBuilder); + return invocationInputToByteString(InvocationInput.of(MessageHeader.fromMessage(msg), msg)); + } + + public static Slice encodeMessageToSlice(MessageLiteOrBuilder msgOrBuilder) { + return Slice.wrap(encodeMessageToByteBuffer(msgOrBuilder)); + } + + public static List bufferToMessages(List byteBuffers) { + var messageDecoder = new MessageDecoder(); + byteBuffers.stream().map(Slice::wrap).forEach(messageDecoder::offer); + + var outputList = new ArrayList(); + while (messageDecoder.isNextAvailable()) { + outputList.add(messageDecoder.next()); + } + return outputList.stream().map(InvocationInput::message).collect(Collectors.toList()); + } + + public static Protocol.StartMessage.Builder startMessage(int entries) { + return Protocol.StartMessage.newBuilder() + .setId(ByteString.copyFromUtf8("abc")) + .setDebugId("abc") + .setKnownEntries(entries) + .setPartialState(true); + } + + public static Protocol.StartMessage.Builder startMessage(int entries, String key) { + return Protocol.StartMessage.newBuilder() + .setId(ByteString.copyFromUtf8("abc")) + .setDebugId("abc") + .setKnownEntries(entries) + .setKey(key) + .setPartialState(true); + } + + @SafeVarargs + public static Protocol.StartMessage.Builder startMessage( + int entries, String key, Map.Entry... stateEntries) { + return startMessage(entries, key) + .addAllStateMap( + Arrays.stream(stateEntries) + .map( + e -> + StateEntry.newBuilder() + .setKey(ByteString.copyFromUtf8(e.getKey())) + .setValue( + ByteString.copyFrom( + TestSerdes.STRING.serialize(e.getValue()).toByteArray())) + .build()) + .collect(Collectors.toList())); + } + + public static Protocol.SuspensionMessage suspensionMessage(Integer... completionIds) { + return Protocol.SuspensionMessage.newBuilder() + .addAllWaitingCompletions(List.of(completionIds)) + .addWaitingSignals(1) + .build(); + } + + public static Protocol.InputCommandMessage inputCmd() { + return Protocol.InputCommandMessage.newBuilder() + .setValue(Protocol.Value.newBuilder().setContent(ByteString.EMPTY)) + .build(); + } + + public static Protocol.InputCommandMessage inputCmd(byte[] value) { + return Protocol.InputCommandMessage.newBuilder() + .setValue(Protocol.Value.newBuilder().setContent(ByteString.copyFrom(value))) + .build(); + } + + public static Protocol.InputCommandMessage inputCmd(Serde serde, T value) { + return Protocol.InputCommandMessage.newBuilder().setValue(value(serde, value)).build(); + } + + public static Protocol.InputCommandMessage inputCmd(String value) { + return inputCmd(TestSerdes.STRING, value); + } + + public static Protocol.InputCommandMessage inputCmd(int value) { + return inputCmd(TestSerdes.INT, value); + } + + public static Protocol.OutputCommandMessage outputCmd(Serde serde, T value) { + return Protocol.OutputCommandMessage.newBuilder().setValue(value(serde, value)).build(); + } + + public static Protocol.OutputCommandMessage outputCmd(String value) { + return outputCmd(TestSerdes.STRING, value); + } + + public static Protocol.OutputCommandMessage outputCmd(int value) { + return outputCmd(TestSerdes.INT, value); + } + + public static Protocol.OutputCommandMessage outputCmd(byte[] b) { + return outputCmd(Serde.RAW, b); + } + + public static Protocol.OutputCommandMessage outputCmd() { + return Protocol.OutputCommandMessage.newBuilder() + .setValue(Protocol.Value.newBuilder().setContent(ByteString.empty()).build()) + .build(); + } + + public static Protocol.OutputCommandMessage outputCmd(int code, String message) { + return Protocol.OutputCommandMessage.newBuilder().setFailure(failure(code, message)).build(); + } + + public static Protocol.OutputCommandMessage outputCmd(Throwable e) { + return Protocol.OutputCommandMessage.newBuilder().setFailure(failure(e)).build(); + } + + public static Protocol.GetLazyStateCommandMessage.Builder getLazyStateCmd( + int completionId, String key) { + return Protocol.GetLazyStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setResultCompletionId(completionId); + } + + public static Protocol.GetEagerStateCommandMessage getEagerStateEmptyCmd(String key) { + return Protocol.GetEagerStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setVoid(Protocol.Void.getDefaultInstance()) + .build(); + } + + public static Protocol.GetEagerStateCommandMessage getEagerStateCmd( + String key, Serde serde, T value) { + return Protocol.GetEagerStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setValue(value(serde, value)) + .build(); + } + + public static Protocol.GetEagerStateCommandMessage getEagerStateCmd(String key, String value) { + return getEagerStateCmd(key, TestSerdes.STRING, value); + } + + public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletion( + int completionId, Serde serde, T value) { + return Protocol.GetLazyStateCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setValue( + Protocol.Value.newBuilder() + .setContent( + UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) + .build(); + } + + public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletion( + int completionId, String value) { + return getLazyStateCompletion(completionId, TestSerdes.STRING, value); + } + + public static Protocol.GetLazyStateCompletionNotificationMessage getLazyStateCompletionEmpty( + int completionId) { + return Protocol.GetLazyStateCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setVoid(Protocol.Void.getDefaultInstance()) + .build(); + } + + public static Protocol.SetStateCommandMessage setStateCmd( + String key, Serde serde, T value) { + return Protocol.SetStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setValue( + Protocol.Value.newBuilder() + .setContent(ByteString.copyFrom(serde.serialize(value).toByteArray()))) + .build(); + } + + public static Protocol.SetStateCommandMessage setStateCmd(String key, String value) { + return setStateCmd(key, TestSerdes.STRING, value); + } + + public static Protocol.ClearStateCommandMessage clearStateCmd(String key) { + return Protocol.ClearStateCommandMessage.newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .build(); + } + + public static Protocol.CallCommandMessage.Builder callCmd( + int invocationIdCompletionId, int resultCompletionId, Target target) { + Protocol.CallCommandMessage.Builder builder = + Protocol.CallCommandMessage.newBuilder() + .setServiceName(target.getService()) + .setHandlerName(target.getHandler()); + if (target.getKey() != null) { + builder.setKey(target.getKey()); + } + builder + .setInvocationIdNotificationIdx(invocationIdCompletionId) + .setResultCompletionId(resultCompletionId); + + return builder; + } + + public static Protocol.CallCommandMessage.Builder callCmd( + int invocationIdCompletionId, int resultCompletionId, Target target, byte[] parameter) { + return callCmd(invocationIdCompletionId, resultCompletionId, target, Serde.RAW, parameter); + } + + public static Protocol.CallCommandMessage.Builder callCmd( + int invocationIdCompletionId, + int resultCompletionId, + Target target, + Serde reqSerde, + T parameter) { + return callCmd(invocationIdCompletionId, resultCompletionId, target) + .setParameter(ByteString.copyFrom(reqSerde.serialize(parameter).toByteArray())); + } + + public static Protocol.CallCommandMessage.Builder callCmd( + int invocationIdCompletionId, int resultCompletionId, Target target, String parameter) { + return callCmd( + invocationIdCompletionId, resultCompletionId, target, TestSerdes.STRING, parameter); + } + + public static Protocol.OneWayCallCommandMessage.Builder oneWayCallCmd( + int invocationIdCompletionId, + Target target, + @Nullable String idempotencyKey, + @Nullable Map headers, + Slice input) { + Protocol.OneWayCallCommandMessage.Builder builder = + Protocol.OneWayCallCommandMessage.newBuilder() + .setServiceName(target.getService()) + .setHandlerName(target.getHandler()); + if (target.getKey() != null) { + builder.setKey(target.getKey()); + } + if (idempotencyKey != null) { + builder.setIdempotencyKey(idempotencyKey); + } + if (headers != null) { + builder.addAllHeaders( + headers.entrySet().stream() + .map( + e -> + Protocol.Header.newBuilder() + .setKey(e.getKey()) + .setValue(e.getValue()) + .build()) + .toList()); + } + + builder + .setParameter(UnsafeByteOperations.unsafeWrap(input.asReadOnlyByteBuffer())) + .setInvocationIdNotificationIdx(invocationIdCompletionId); + + return builder; + } + + public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( + int completionId, Serde reqSerde, T parameter) { + return Protocol.CallCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setValue(value(reqSerde, parameter)); + } + + public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( + int completionId, String result) { + return callCompletion(completionId, TestSerdes.STRING, result); + } + + public static Protocol.CallCompletionNotificationMessage.Builder callCompletion( + int completionId, Throwable failure) { + return Protocol.CallCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setFailure(failure(failure)); + } + + public static + Protocol.CallInvocationIdCompletionNotificationMessage.Builder callInvocationIdCompletion( + int completionId, String invocationId) { + return Protocol.CallInvocationIdCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setInvocationId(invocationId); + } + + public static Protocol.GetPromiseCommandMessage.Builder getPromiseCmd( + int completionId, String key) { + return Protocol.GetPromiseCommandMessage.newBuilder() + .setResultCompletionId(completionId) + .setKey(key); + } + + public static Protocol.PeekPromiseCommandMessage.Builder peekPromiseCmd( + int completionId, String key) { + return Protocol.PeekPromiseCommandMessage.newBuilder() + .setResultCompletionId(completionId) + .setKey(key); + } + + public static Protocol.CompletePromiseCommandMessage.Builder completePromiseCmd( + int completionId, String key, String value) { + return Protocol.CompletePromiseCommandMessage.newBuilder() + .setKey(key) + .setResultCompletionId(completionId) + .setCompletionValue(value(value)); + } + + public static Protocol.CompletePromiseCommandMessage.Builder completePromiseCmd( + int completionId, String key, Throwable e) { + return Protocol.CompletePromiseCommandMessage.newBuilder() + .setKey(key) + .setResultCompletionId(completionId) + .setCompletionFailure(failure(e)); + } + + public static Protocol.SignalNotificationMessage signalNotification( + int signalId, Serde serde, T value) { + return Protocol.SignalNotificationMessage.newBuilder() + .setIdx(signalId) + .setValue( + Protocol.Value.newBuilder() + .setContent( + UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) + .build(); + } + + public static Protocol.SignalNotificationMessage signalNotification(int signalId, String value) { + return signalNotification(signalId, TestSerdes.STRING, value); + } + + public static Protocol.SignalNotificationMessage signalNotification( + String signalName, Serde serde, T value) { + return Protocol.SignalNotificationMessage.newBuilder() + .setName(signalName) + .setValue( + Protocol.Value.newBuilder() + .setContent( + UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer()))) + .build(); + } + + public static Protocol.SignalNotificationMessage signalNotification( + String signalName, String value) { + return signalNotification(signalName, TestSerdes.STRING, value); + } + + public static Protocol.RunCommandMessage runCmd(int completion) { + return Protocol.RunCommandMessage.newBuilder().setResultCompletionId(completion).build(); + } + + public static Protocol.RunCommandMessage runCmd(int completion, String name) { + return Protocol.RunCommandMessage.newBuilder() + .setResultCompletionId(completion) + .setName(name) + .build(); + } + + public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( + int completionId, Serde reqSerde, T parameter) { + return Protocol.RunCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setValue(value(reqSerde, parameter)); + } + + public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( + int completionId, String result) { + return runCompletion(completionId, TestSerdes.STRING, result); + } + + public static Protocol.RunCompletionNotificationMessage.Builder runCompletion( + int completionId, int code, String message) { + return Protocol.RunCompletionNotificationMessage.newBuilder() + .setCompletionId(completionId) + .setFailure(failure(code, message)); + } + + public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( + int completionId, Serde reqSerde, T parameter) { + return Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(completionId) + .setValue(value(reqSerde, parameter).getContent()); + } + + public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( + int completionId, String result) { + return proposeRunCompletion(completionId, TestSerdes.STRING, result); + } + + public static Protocol.ProposeRunCompletionMessage.Builder proposeRunCompletion( + int completionId, int code, String message) { + return Protocol.ProposeRunCompletionMessage.newBuilder() + .setResultCompletionId(completionId) + .setFailure(failure(code, message)); + } + + public static Protocol.SendSignalCommandMessage sendCancelSignal(String targetInvocationId) { + return Protocol.SendSignalCommandMessage.newBuilder() + .setTargetInvocationId(targetInvocationId) + .setIdx(1) + .setVoid(Protocol.Void.getDefaultInstance()) + .build(); + } + + public static Protocol.Failure failure(int code, String message) { + return Util.toProtocolFailure(code, message); + } + + public static Protocol.Failure failure(Throwable throwable) { + return Util.toProtocolFailure(throwable); + } + + public static Protocol.Value value(String jsonStringContent) { + return value(TestSerdes.STRING, jsonStringContent); + } + + public static Protocol.Value value(Serde serde, T value) { + return Protocol.Value.newBuilder() + .setContent(UnsafeByteOperations.unsafeWrap(serde.serialize(value).asReadOnlyByteBuffer())) + .build(); + } + + public static final Protocol.EndMessage END_MESSAGE = Protocol.EndMessage.getDefaultInstance(); + public static final Protocol.SignalNotificationMessage CANCELLATION_SIGNAL = + Protocol.SignalNotificationMessage.newBuilder() + .setVoid(Protocol.Void.getDefaultInstance()) + .setIdx(1) + .build(); + + public static final Target GREETER_SERVICE_TARGET = Target.service("Greeter", "greeter"); + public static Target GREETER_VIRTUAL_OBJECT_TARGET = + Target.virtualObject("Greeter", "Francesco", "greeter"); + + public static Protocol.StateKeys.Builder stateKeys(String... keys) { + return Protocol.StateKeys.newBuilder() + .addAllKeys(Arrays.stream(keys).map(ByteString::copyFromUtf8).collect(Collectors.toList())); + } + + public static MessageLite build(MessageLiteOrBuilder value) { + if (value instanceof MessageLite) { + return (MessageLite) value; + } else { + return ((MessageLite.Builder) value).build(); + } + } +} diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt similarity index 62% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt index 9a3b94cba..80e17d927 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AsyncResultTest.kt @@ -6,17 +6,20 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.core.DeferredTestSuite +import dev.restate.sdk.core.AsyncResultTestSuite import dev.restate.sdk.core.TestDefinitions.* import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.callGreeterGreetService -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.callGreeterGreetService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TimeoutException import java.util.stream.Stream +import kotlin.time.Duration.Companion.days -class DeferredTest : DeferredTestSuite() { +class AsyncResultTest : AsyncResultTestSuite() { override fun reverseAwaitOrder(): TestInvocationBuilder = testDefinitionForVirtualObject("ReverseAwaitOrder") { ctx, _: Unit -> val a1: Awaitable = callGreeterGreetService(ctx, "Francesco") @@ -26,13 +29,13 @@ class DeferredTest : DeferredTestSuite() { ctx.set(StateKey.of("A2", TestSerdes.STRING), a2Res) val a1Res: String = a1.await() - "$a1Res-$a2Res" + return@testDefinitionForVirtualObject "$a1Res-$a2Res" } override fun awaitTwiceTheSameAwaitable(): TestInvocationBuilder = testDefinitionForVirtualObject("AwaitTwiceTheSameAwaitable") { ctx, _: Unit -> val a = callGreeterGreetService(ctx, "Francesco") - "${a.await()}-${a.await()}" + return@testDefinitionForVirtualObject "${a.await()}-${a.await()}" } override fun awaitAll(): TestInvocationBuilder = @@ -40,24 +43,30 @@ class DeferredTest : DeferredTestSuite() { val a1 = callGreeterGreetService(ctx, "Francesco") val a2 = callGreeterGreetService(ctx, "Till") - listOf(a1, a2).awaitAll().joinToString(separator = "-") + return@testDefinitionForVirtualObject listOf(a1, a2) + .awaitAll() + .joinToString(separator = "-") } override fun awaitAny(): TestInvocationBuilder = testDefinitionForVirtualObject("AwaitAny") { ctx, _: Unit -> val a1 = callGreeterGreetService(ctx, "Francesco") val a2 = callGreeterGreetService(ctx, "Till") - Awaitable.any(a1, a2).await() as String + + return@testDefinitionForVirtualObject Awaitable.any(a1, a2) + .map { it -> if (it == 0) a1.await() else a2.await() } + .await() } private fun awaitSelect(): TestInvocationBuilder = testDefinitionForVirtualObject("AwaitSelect") { ctx, _: Unit -> val a1 = callGreeterGreetService(ctx, "Francesco") val a2 = callGreeterGreetService(ctx, "Till") - select { - a1.onAwait { it } - a2.onAwait { it } - } + return@testDefinitionForVirtualObject select { + a1.onAwait { it } + a2.onAwait { it } + } + .await() } override fun combineAnyWithAll(): TestInvocationBuilder = @@ -67,12 +76,12 @@ class DeferredTest : DeferredTestSuite() { val a3 = ctx.awakeable(TestSerdes.STRING) val a4 = ctx.awakeable(TestSerdes.STRING) - val a12 = Awaitable.any(a1, a2) - val a23 = Awaitable.any(a2, a3) - val a34 = Awaitable.any(a3, a4) + val a12 = Awaitable.any(a1, a2).map { if (it == 0) a1.await() else a2.await() } + val a23 = Awaitable.any(a2, a3).map { if (it == 0) a2.await() else a3.await() } + val a34 = Awaitable.any(a3, a4).map { if (it == 0) a3.await() else a4.await() } Awaitable.all(a12, a23, a34).await() - a12.await().toString() + a23.await() as String + a34.await() + return@testDefinitionForVirtualObject a12.await() + a23.await() + a34.await() } override fun awaitAnyIndex(): TestInvocationBuilder = @@ -82,7 +91,9 @@ class DeferredTest : DeferredTestSuite() { val a3 = ctx.awakeable(TestSerdes.STRING) val a4 = ctx.awakeable(TestSerdes.STRING) - Awaitable.any(a1, Awaitable.all(a2, a3), a4).awaitIndex().toString() + return@testDefinitionForVirtualObject Awaitable.any(a1, Awaitable.all(a2, a3), a4) + .await() + .toString() } override fun awaitOnAlreadyResolvedAwaitables(): TestInvocationBuilder = @@ -95,12 +106,18 @@ class DeferredTest : DeferredTestSuite() { a12and1.await() a121and12.await() - a1.await() + a2.await() + return@testDefinitionForVirtualObject a1.await() + a2.await() } - override fun awaitWithTimeout(): TestInvocationBuilder { - return unsupported("This is a feature not available in sdk-api-kotlin") - } + override fun awaitWithTimeout(): TestInvocationBuilder = + testDefinitionForVirtualObject("AwaitWithTimeout") { ctx, _: Unit -> + val a1 = callGreeterGreetService(ctx, "Francesco") + return@testDefinitionForVirtualObject try { + a1.await(1.days) + } catch (_: TimeoutException) { + "timeout" + } + } override fun definitions(): Stream = Stream.concat(super.definitions(), super.anyTestDefinitions { awaitSelect() }) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt similarity index 81% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt index b2076c6bf..2ab5398ec 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/AwakeableIdTest.kt @@ -6,11 +6,12 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.AwakeableIdTestSuite import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* class AwakeableIdTest : AwakeableIdTestSuite() { diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt new file mode 100644 index 000000000..6a4251d08 --- /dev/null +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt @@ -0,0 +1,39 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi + +import dev.restate.common.Request +import dev.restate.common.Slice +import dev.restate.common.Target +import dev.restate.sdk.core.CallTestSuite +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.serde.Serde + +class CallTest : CallTestSuite() { + + override fun oneWayCall( + target: Target, + idempotencyKey: String, + headers: Map, + body: Slice + ) = + testDefinitionForService("OneWayCall") { ctx, _: Unit -> + val ignored = + ctx.send( + Request.of(target, Serde.SLICE, Serde.RAW, body) + .headers(headers) + .idempotencyKey(idempotencyKey)) + } + + override fun implicitCancellation(target: Target, body: Slice) = + testDefinitionForService("ImplicitCancellation") { ctx, _: Unit -> + val ignored = + ctx.call(Request.of(target, Serde.SLICE, Serde.RAW, body)).await() + } +} diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt similarity index 70% rename from sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt index 5dccd660b..29f3f51d8 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenDiscoveryTest.kt @@ -6,30 +6,17 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.AssertUtils.assertThatDiscovery -import dev.restate.sdk.core.MockMultiThreaded -import dev.restate.sdk.core.MockSingleThread -import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.core.TestDefinitions.TestExecutor -import dev.restate.sdk.core.TestRunner -import dev.restate.sdk.core.manifest.Handler -import dev.restate.sdk.core.manifest.Input -import dev.restate.sdk.core.manifest.Output -import dev.restate.sdk.core.manifest.Service -import java.util.stream.Stream +import dev.restate.sdk.core.generated.manifest.Handler +import dev.restate.sdk.core.generated.manifest.Input +import dev.restate.sdk.core.generated.manifest.Output +import dev.restate.sdk.core.generated.manifest.Service import org.assertj.core.api.InstanceOfAssertFactories.type import org.junit.jupiter.api.Test -class KtCodegenTests : TestRunner() { - override fun executors(): Stream { - return Stream.of(MockSingleThread.INSTANCE, MockMultiThreaded.INSTANCE) - } - - public override fun definitions(): Stream { - return Stream.of(CodegenTest()) - } +class CodegenDiscoveryTest { @Test fun checkCustomInputContentType() { diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt similarity index 67% rename from sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt index 4072d46ee..9f3e65ca9 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CodegenTest.kt @@ -6,17 +6,18 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import com.google.protobuf.ByteString +import dev.restate.common.Target import dev.restate.sdk.annotation.* -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.Target -import dev.restate.sdk.core.ProtoUtils.* import dev.restate.sdk.core.TestDefinitions import dev.restate.sdk.core.TestDefinitions.TestDefinition import dev.restate.sdk.core.TestDefinitions.testInvocation import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.serialization.* +import dev.restate.serde.Serde import java.util.stream.Stream import kotlinx.serialization.Serializable @@ -121,7 +122,7 @@ class CodegenTest : TestDefinitions.TestSuite { @Exclusive suspend fun returnNull(context: ObjectContext, request: String?): String? { return CodegenTestCornerCasesClient.fromContext(context, context.key()) - .returnNull(request) + .returnNull(request) {} .await() } } @@ -142,6 +143,7 @@ class CodegenTest : TestDefinitions.TestSuite { return CodegenTestWorkflowCornerCasesClient.connect("invalid", request) .workflowHandle() .output + .response() .value } } @@ -202,143 +204,138 @@ class CodegenTest : TestDefinitions.TestSuite { } } - // Just needs to compile - @MyMetaServiceAnnotation(name = "MetaAnnotatedGreeter") - class MetaAnnotatedGreeter { - @Handler - suspend fun greet(context: Context, request: String): String { - return MetaAnnotatedGreeterClient.fromContext(context).greet(request).await() - } - } - override fun definitions(): Stream { return Stream.of( testInvocation({ ServiceGreeter() }, "greet") - .withInput(startMessage(1), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation({ ObjectGreeter() }, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation({ ObjectGreeter() }, "sharedGreet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation({ NestedDataClass() }, "greet") .withInput( startMessage(1, "slinkydeveloper"), - inputMessage(KtSerdes.json(), NestedDataClass.Input("123"))) - .onlyUnbuffered() + inputCmd(jsonSerde(), NestedDataClass.Input("123"))) + .onlyBidiStream() .expectingOutput( - outputMessage(KtSerdes.json(), NestedDataClass.Output("123")), END_MESSAGE), + outputCmd(jsonSerde(), NestedDataClass.Output("123")), + END_MESSAGE), testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") - .withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco")) - .onlyUnbuffered() - .expectingOutput(outputMessage("Francesco"), END_MESSAGE), + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), testInvocation({ Empty() }, "emptyInput") - .withInput(startMessage(1), inputMessage(), completionMessage(1, "Till")) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyInput")), - outputMessage("Till"), + callCmd(1, 2, Target.service("Empty", "emptyInput")), + outputCmd("Till"), END_MESSAGE) .named("empty output"), testInvocation({ Empty() }, "emptyOutput") - .withInput( - startMessage(1), - inputMessage("Francesco"), - completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyOutput"), "Francesco"), - outputMessage(), + callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), + outputCmd(), END_MESSAGE) .named("empty output"), testInvocation({ Empty() }, "emptyInputOutput") - .withInput( - startMessage(1), - inputMessage("Francesco"), - completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("Empty", "emptyInputOutput")), - outputMessage(), + callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), + outputCmd(), END_MESSAGE) .named("empty input and empty output"), testInvocation({ PrimitiveTypes() }, "primitiveOutput") - .withInput(startMessage(1), inputMessage(), completionMessage(1, TestSerdes.INT, 10)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), - outputMessage(TestSerdes.INT, 10), + callCmd( + 1, 2, Target.service("PrimitiveTypes", "primitiveOutput"), Serde.VOID, null), + outputCmd(TestSerdes.INT, 10), END_MESSAGE) .named("primitive output"), testInvocation({ PrimitiveTypes() }, "primitiveInput") - .withInput( - startMessage(1), inputMessage(10), completionMessage(1).setValue(ByteString.EMPTY)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("PrimitiveTypes", "primitiveInput"), TestSerdes.INT, 10), - outputMessage(), + callCmd( + 1, 2, Target.service("PrimitiveTypes", "primitiveInput"), TestSerdes.INT, 10), + outputCmd(), END_MESSAGE) .named("primitive input"), testInvocation({ RawInputOutput() }, "rawInput") .withInput( startMessage(1), - inputMessage("{{".toByteArray()), - completionMessage(1, KtSerdes.UNIT, null)) - .onlyUnbuffered() + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit)) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), - outputMessage(), + callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), + outputCmd(), END_MESSAGE), testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") .withInput( startMessage(1), - inputMessage("{{".toByteArray()), - completionMessage(1, KtSerdes.UNIT, null)) - .onlyUnbuffered() + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit)) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("RawInputOutput", "rawInputWithCustomCt"), "{{".toByteArray()), - outputMessage(), + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawInputWithCustomCt"), + "{{".toByteArray()), + outputCmd(), END_MESSAGE), testInvocation({ RawInputOutput() }, "rawOutput") .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, Serde.RAW, "{{".toByteArray())) - .onlyUnbuffered() + startMessage(1), inputCmd(), callCompletion(2, Serde.RAW, "{{".toByteArray())) + .onlyBidiStream() .expectingOutput( - invokeMessage(Target.service("RawInputOutput", "rawOutput"), KtSerdes.UNIT, null), - outputMessage("{{".toByteArray()), + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutput"), + KotlinSerializationSerdeFactory.UNIT, + Unit), + outputCmd("{{".toByteArray()), END_MESSAGE), testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") .withInput( - startMessage(1), - inputMessage(), - completionMessage(1, Serde.RAW, "{{".toByteArray())) - .onlyUnbuffered() + startMessage(1), inputCmd(), callCompletion(2, Serde.RAW, "{{".toByteArray())) + .onlyBidiStream() .expectingOutput( - invokeMessage( - Target.service("RawInputOutput", "rawOutputWithCustomCT"), KtSerdes.UNIT, null), - outputMessage("{{".toByteArray()), + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + KotlinSerializationSerdeFactory.UNIT, + Unit), + outputCmd("{{".toByteArray()), END_MESSAGE), testInvocation({ CornerCases() }, "returnNull") .withInput( startMessage(1, "mykey"), - inputMessage(KtSerdes.json().serialize(null)), - completionMessage(1, KtSerdes.json(), null)) - .onlyUnbuffered() + inputCmd(jsonSerde(), null), + callCompletion(2, jsonSerde(), null)) + .onlyBidiStream() .expectingOutput( - invokeMessage( + callCmd( + 1, + 2, Target.virtualObject("CodegenTestCornerCases", "mykey", "returnNull"), - KtSerdes.json(), + jsonSerde(), null), - outputMessage(KtSerdes.json(), null), + outputCmd(jsonSerde(), null), END_MESSAGE), ) } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt similarity index 93% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt index 5e9fb2f21..934a0c88a 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/EagerStateTest.kt @@ -6,13 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.StateKey import dev.restate.sdk.core.EagerStateTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.types.StateKey import org.assertj.core.api.AssertionsForClassTypes.assertThat class EagerStateTest : EagerStateTestSuite() { diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt similarity index 84% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt index ab80d4a63..a4e504f52 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/InvocationIdTest.kt @@ -6,11 +6,11 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.InvocationIdTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService class InvocationIdTest : InvocationIdTestSuite() { diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/KotlinCoroutinesTests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt similarity index 55% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/KotlinCoroutinesTests.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt index 3a3d9b65e..a08008729 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/KotlinCoroutinesTests.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt @@ -6,28 +6,32 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.HandlerType -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.syscalls.HandlerDefinition -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.ServiceDefinition +import dev.restate.common.Request import dev.restate.sdk.core.* import dev.restate.sdk.core.TestDefinitions.TestExecutor import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.statemachine.ProtoUtils +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.serialization.* import java.util.stream.Stream import kotlinx.coroutines.Dispatchers -class KotlinCoroutinesTests : TestRunner() { +class KotlinAPITests : TestRunner() { override fun executors(): Stream { - return Stream.of(MockSingleThread.INSTANCE, MockMultiThreaded.INSTANCE) + return Stream.of(MockRequestResponse.INSTANCE, MockBidiStream.INSTANCE) } public override fun definitions(): Stream { return Stream.of( AwakeableIdTest(), - DeferredTest(), + AsyncResultTest(), + CallTest(), EagerStateTest(), StateTest(), InvocationIdTest(), @@ -37,7 +41,8 @@ class KotlinCoroutinesTests : TestRunner() { SleepTest(), StateMachineFailuresTest(), UserFailuresTest(), - RandomTest()) + RandomTest(), + CodegenTest()) } companion object { @@ -51,10 +56,14 @@ class KotlinCoroutinesTests : TestRunner() { ServiceType.SERVICE, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "run", HandlerType.SHARED, KtSerdes.json(), KtSerdes.json()), - HandlerRunner.of(runner)))), - HandlerRunner.Options(Dispatchers.Unconfined), + "run", + HandlerType.SHARED, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner)))), "run") } @@ -68,10 +77,14 @@ class KotlinCoroutinesTests : TestRunner() { ServiceType.VIRTUAL_OBJECT, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "run", HandlerType.EXCLUSIVE, KtSerdes.json(), KtSerdes.json()), - HandlerRunner.of(runner)))), - HandlerRunner.Options(Dispatchers.Unconfined), + "run", + HandlerType.EXCLUSIVE, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner)))), "run") } @@ -85,16 +98,21 @@ class KotlinCoroutinesTests : TestRunner() { ServiceType.WORKFLOW, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "run", HandlerType.WORKFLOW, KtSerdes.json(), KtSerdes.json()), - HandlerRunner.of(runner)))), - HandlerRunner.Options(Dispatchers.Unconfined), + "run", + HandlerType.WORKFLOW, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options(Dispatchers.Unconfined), + runner)))), "run") } suspend fun callGreeterGreetService(ctx: Context, parameter: String): Awaitable { - return ctx.callAsync( - ProtoUtils.GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter) + return ctx.call( + Request.of( + ProtoUtils.GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)) } } } diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/MyMetaServiceAnnotation.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt similarity index 92% rename from sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/MyMetaServiceAnnotation.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt index f63875a0b..e8c6606a2 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/MyMetaServiceAnnotation.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/MyMetaServiceAnnotation.kt @@ -6,7 +6,7 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.annotation.Service diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt similarity index 84% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt index 321c36901..d8bf351ab 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/OnlyInputAndOutputTest.kt @@ -6,11 +6,11 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.OnlyInputAndOutputTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService class OnlyInputAndOutputTest : OnlyInputAndOutputTestSuite() { diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/PromiseTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt similarity index 73% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/PromiseTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt index b4f0cceb5..0470f6618 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/PromiseTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/PromiseTest.kt @@ -6,17 +6,18 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.TerminalException import dev.restate.sdk.core.PromiseTestSuite import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForWorkflow +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForWorkflow +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.types.TerminalException class PromiseTest : PromiseTestSuite() { override fun awaitPromise(promiseKey: String): TestDefinitions.TestInvocationBuilder = testDefinitionForWorkflow("AwaitPromise") { ctx, _: Unit -> - ctx.promise(KtDurablePromiseKey.json(promiseKey)).awaitable().await() + ctx.promise(durablePromiseKey(promiseKey)).awaitable().await() } override fun awaitPeekPromise( @@ -24,14 +25,12 @@ class PromiseTest : PromiseTestSuite() { emptyCaseReturnValue: String ): TestDefinitions.TestInvocationBuilder = testDefinitionForWorkflow("AwaitPeekPromise") { ctx, _: Unit -> - ctx.promise(KtDurablePromiseKey.json(promiseKey)) - .peek() - .orElse(emptyCaseReturnValue) + ctx.promise(durablePromiseKey(promiseKey)).peek().orElse(emptyCaseReturnValue) } override fun awaitIsPromiseCompleted(promiseKey: String): TestDefinitions.TestInvocationBuilder = testDefinitionForWorkflow("IsCompletedPromise") { ctx, _: Unit -> - ctx.promise(KtDurablePromiseKey.json(promiseKey)).peek().isReady + ctx.promise(durablePromiseKey(promiseKey)).peek().isReady } override fun awaitResolvePromise( @@ -40,7 +39,7 @@ class PromiseTest : PromiseTestSuite() { ): TestDefinitions.TestInvocationBuilder = testDefinitionForWorkflow("ResolvePromise") { ctx, _: Unit -> try { - ctx.promiseHandle(KtDurablePromiseKey.json(promiseKey)).resolve(completionValue) + ctx.promiseHandle(durablePromiseKey(promiseKey)).resolve(completionValue) return@testDefinitionForWorkflow true } catch (e: TerminalException) { return@testDefinitionForWorkflow false @@ -53,7 +52,7 @@ class PromiseTest : PromiseTestSuite() { ): TestDefinitions.TestInvocationBuilder = testDefinitionForWorkflow("RejectPromise") { ctx, _: Unit -> try { - ctx.promiseHandle(KtDurablePromiseKey.json(promiseKey)).reject(rejectReason) + ctx.promiseHandle(durablePromiseKey(promiseKey)).reject(rejectReason) return@testDefinitionForWorkflow true } catch (e: TerminalException) { return@testDefinitionForWorkflow false diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt similarity index 66% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt index 18a36caa1..99b7de8d8 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/RandomTest.kt @@ -6,11 +6,11 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.RandomTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService import kotlin.random.Random class RandomTest : RandomTestSuite() { @@ -19,12 +19,6 @@ class RandomTest : RandomTestSuite() { ctx.random().nextInt() } - override fun randomInsideSideEffect(): TestInvocationBuilder = - testDefinitionForService("RandomInsideSideEffect") { ctx, _: Unit -> - ctx.runBlock { ctx.random().nextInt() } - throw IllegalStateException("This should not unreachable") - } - override fun getExpectedInt(seed: Long): Int { return Random(seed).nextInt() } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt similarity index 62% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt index 27a2eae01..3511f8035 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SideEffectTest.kt @@ -6,18 +6,19 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.HandlerType -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.syscalls.HandlerDefinition -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.ServiceDefinition -import dev.restate.sdk.core.ProtoUtils.GREETER_SERVICE_TARGET import dev.restate.sdk.core.SideEffectTestSuite import dev.restate.sdk.core.TestDefinitions import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.serialization.* +import dev.restate.sdk.types.RetryPolicy import java.util.* import kotlin.coroutines.coroutineContext import kotlin.time.toKotlinDuration @@ -52,35 +53,32 @@ class SideEffectTest : SideEffectTestSuite() { ServiceType.SERVICE, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "run", HandlerType.SHARED, KtSerdes.UNIT, KtSerdes.json()), - HandlerRunner.of { ctx: Context, _: Unit -> - val sideEffectCoroutine = - ctx.runBlock { coroutineContext[CoroutineName]!!.name } - check(sideEffectCoroutine == "CheckContextSwitchingTestCoroutine") { - "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectCoroutine" - } - "Hello" - }))), - HandlerRunner.Options( - Dispatchers.Unconfined + CoroutineName("CheckContextSwitchingTestCoroutine")), + "run", + HandlerType.SHARED, + jsonSerde(), + jsonSerde(), + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options( + Dispatchers.Unconfined + + CoroutineName("CheckContextSwitchingTestCoroutine"))) { + ctx: Context, + _: Unit -> + val sideEffectCoroutine = + ctx.runBlock { coroutineContext[CoroutineName]!!.name } + check(sideEffectCoroutine == "CheckContextSwitchingTestCoroutine") { + "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectCoroutine" + } + "Hello" + }))), "run") - override fun sideEffectGuard(): TestInvocationBuilder = - testDefinitionForService("SideEffectGuard") { ctx, _: Unit -> - ctx.runBlock { ctx.send(GREETER_SERVICE_TARGET, KtSerdes.json(), "something") } - throw IllegalStateException("This point should not be reached") - } - override fun failingSideEffect(name: String, reason: String): TestInvocationBuilder = testDefinitionForService("FailingSideEffect") { ctx, _: Unit -> ctx.runBlock(name) { throw IllegalStateException(reason) } } - override fun failingSideEffectWithRetryPolicy( - reason: String, - retryPolicy: dev.restate.sdk.common.RetryPolicy? - ) = + override fun failingSideEffectWithRetryPolicy(reason: String, retryPolicy: RetryPolicy?) = testDefinitionForService("FailingSideEffectWithRetryPolicy") { ctx, _: Unit -> ctx.runBlock( retryPolicy = diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt similarity index 86% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt index 4881b1b51..1bc1b8bbb 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/SleepTest.kt @@ -6,11 +6,12 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi import dev.restate.sdk.core.SleepTestSuite import dev.restate.sdk.core.TestDefinitions -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* import kotlin.time.Duration.Companion.milliseconds class SleepTest : SleepTestSuite() { diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt similarity index 75% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt index b32e459cb..c324abe6c 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateMachineFailuresTest.kt @@ -6,15 +6,16 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException import dev.restate.sdk.core.StateMachineFailuresTestSuite import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.types.AbortedExecutionException +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException +import dev.restate.serde.Serde import java.nio.charset.StandardCharsets import java.util.concurrent.atomic.AtomicInteger import kotlinx.coroutines.CancellationException @@ -35,6 +36,10 @@ class StateMachineFailuresTest : StateMachineFailuresTestSuite() { try { ctx.get(STATE) } catch (e: Throwable) { + // A user should never catch Throwable!!! + if (AbortedExecutionException.INSTANCE == e) { + throw e + } // A user should never catch Throwable!!! if (e !is CancellationException && e !is TerminalException) { nonTerminalExceptionsSeen.addAndGet(1) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt similarity index 71% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt index 1defffd09..2f26a993f 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/StateTest.kt @@ -6,14 +6,16 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.core.ProtoUtils.* import dev.restate.sdk.core.StateTestSuite import dev.restate.sdk.core.TestDefinitions.* import dev.restate.sdk.core.TestSerdes -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForVirtualObject +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.kotlin.serialization.jsonSerde +import dev.restate.sdk.types.StateKey import java.util.stream.Stream import kotlinx.serialization.Serializable @@ -41,7 +43,7 @@ class StateTest : StateTestSuite() { @Serializable data class Data(var a: Int, val b: String) private companion object { - val DATA: StateKey = StateKey.of("STATE", KtSerdes.json()) + val DATA = stateKey("STATE") } private fun getAndSetStateUsingKtSerdes(): TestInvocationBuilder = @@ -60,19 +62,19 @@ class StateTest : StateTestSuite() { getAndSetStateUsingKtSerdes() .withInput( startMessage(3), - inputMessage(), - getStateMessage("STATE", KtSerdes.json(), Data(1, "Till")), - setStateMessage("STATE", KtSerdes.json(), Data(2, "Till"))) - .expectingOutput(outputMessage("Hello " + Data(2, "Till")), END_MESSAGE) + inputCmd(), + getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till")), + setStateCmd("STATE", jsonSerde(), Data(2, "Till"))) + .expectingOutput(outputCmd("Hello " + Data(2, "Till")), END_MESSAGE) .named("With GetState and SetState"), getAndSetStateUsingKtSerdes() .withInput( startMessage(2), - inputMessage(), - getStateMessage("STATE", KtSerdes.json(), Data(1, "Till"))) + inputCmd(), + getEagerStateCmd("STATE", jsonSerde(), Data(1, "Till"))) .expectingOutput( - setStateMessage("STATE", KtSerdes.json(), Data(2, "Till")), - outputMessage("Hello " + Data(2, "Till")), + setStateCmd("STATE", jsonSerde(), Data(2, "Till")), + outputCmd("Hello " + Data(2, "Till")), END_MESSAGE) .named("With GetState already completed"), )) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt similarity index 90% rename from sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt index 88c18eb56..d740d08dd 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/UserFailuresTest.kt @@ -6,12 +6,13 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.kotlin +package dev.restate.sdk.core.kotlinapi -import dev.restate.sdk.common.TerminalException import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder import dev.restate.sdk.core.UserFailuresTestSuite -import dev.restate.sdk.kotlin.KotlinCoroutinesTests.Companion.testDefinitionForService +import dev.restate.sdk.core.kotlinapi.KotlinAPITests.Companion.testDefinitionForService +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.types.TerminalException import java.util.concurrent.atomic.AtomicInteger import kotlin.coroutines.cancellation.CancellationException diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt new file mode 100644 index 000000000..11597860b --- /dev/null +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTest.kt @@ -0,0 +1,156 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.vertx + +import com.fasterxml.jackson.databind.ObjectMapper +import com.google.protobuf.MessageLite +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.HandlerRunner +import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.endpoint.endpoint +import dev.restate.sdk.kotlin.serialization.KotlinSerializationSerdeFactory +import dev.restate.sdk.kotlin.serialization.jsonSerde +import dev.restate.sdk.kotlin.stateKey +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.HttpResponseStatus +import io.vertx.core.Vertx +import io.vertx.core.buffer.Buffer +import io.vertx.core.http.* +import io.vertx.junit5.VertxExtension +import io.vertx.kotlin.coroutines.coAwait +import io.vertx.kotlin.coroutines.dispatcher +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.runBlocking +import org.apache.logging.log4j.LogManager +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.parallel.Isolated + +@Isolated +@ExtendWith(VertxExtension::class) +internal class RestateHttpServerTest { + + companion object { + val HTTP_CLIENT_OPTIONS: HttpClientOptions = + HttpClientOptions() + // Set prior knowledge + .setProtocolVersion(HttpVersion.HTTP_2) + .setHttp2ClearTextUpgrade(false) + + private val LOG = LogManager.getLogger() + private val COUNTER = stateKey("counter") + + const val GREETER_NAME = "Greeter" + + fun greeter(): ServiceDefinition = + ServiceDefinition.of( + GREETER_NAME, + ServiceType.VIRTUAL_OBJECT, + listOf( + HandlerDefinition.of( + "greet", + HandlerType.EXCLUSIVE, + jsonSerde(), + jsonSerde(), + HandlerRunner.of(KotlinSerializationSerdeFactory()) { + ctx: ObjectContext, + request: String -> + LOG.info("Greet invoked!") + + val count = (ctx.get(COUNTER) ?: 0) + 1 + ctx.set(COUNTER, count) + + ctx.sleep(1.seconds) + + "Hello $request. Count: $count" + }))) + } + + @Test + fun return404(vertx: Vertx): Unit = + runBlocking(vertx.dispatcher()) { + val endpointPort: Int = + RestateHttpServer.fromEndpoint( + vertx, endpoint { bind(greeter()) }, HttpServerOptions().setPort(0)) + .listen() + .coAwait() + .actualPort() + + val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) + + val request = + client + .request( + HttpMethod.POST, + endpointPort, + "localhost", + "/invoke/$GREETER_NAME/unknownMethod") + .coAwait() + + // Prepare request header + request + .setChunked(true) + .putHeader(HttpHeaders.CONTENT_TYPE, serviceProtocolContentTypeHeader()) + .putHeader(HttpHeaders.ACCEPT, serviceProtocolContentTypeHeader()) + request.write(encode(startMessage(0).build())) + + val response = request.response().coAwait() + + // Response status should be 404 + assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.NOT_FOUND.code()) + + response.end().coAwait() + } + + @Test + fun serviceDiscovery(vertx: Vertx): Unit = + runBlocking(vertx.dispatcher()) { + val endpointPort: Int = + RestateHttpServer.fromEndpoint( + vertx, endpoint { bind(greeter()) }, HttpServerOptions().setPort(0)) + .listen() + .coAwait() + .actualPort() + + val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) + + // Send request + val request = + client.request(HttpMethod.GET, endpointPort, "localhost", "/discover").coAwait() + request.putHeader(HttpHeaders.ACCEPT, serviceProtocolDiscoveryContentTypeHeader()) + request.end().coAwait() + + // Assert response + val response = request.response().coAwait() + + // Response status and content type header + assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code()) + assertThat(response.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(serviceProtocolDiscoveryContentTypeHeader()) + + // Parse response + val responseBody = response.body().coAwait() + // Compute response and write it back + val discoveryResponse: EndpointManifestSchema = + ObjectMapper().readValue(responseBody.bytes, EndpointManifestSchema::class.java) + + assertThat(discoveryResponse.services).map { it.name }.containsOnly(GREETER_NAME) + } + + private fun encode(msg: MessageLite): Buffer { + return Buffer.buffer(Unpooled.wrappedBuffer(encodeMessageToByteBuffer(msg))) + } +} diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTestExecutor.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt similarity index 79% rename from sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTestExecutor.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt index 08203fa9f..875856bb3 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTestExecutor.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTestExecutor.kt @@ -6,12 +6,14 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx +package dev.restate.sdk.core.vertx -import dev.restate.sdk.common.syscalls.ServiceDefinition -import dev.restate.sdk.core.ProtoUtils import dev.restate.sdk.core.TestDefinitions.TestDefinition import dev.restate.sdk.core.TestDefinitions.TestExecutor +import dev.restate.sdk.core.statemachine.ProtoUtils +import dev.restate.sdk.endpoint.Endpoint +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.http.vertx.RestateHttpServer import io.netty.buffer.Unpooled import io.vertx.core.Vertx import io.vertx.core.buffer.Buffer @@ -28,7 +30,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.yield -class HttpVertxTestExecutor(private val vertx: Vertx) : TestExecutor { +class RestateHttpServerTestExecutor(private val vertx: Vertx) : TestExecutor { override fun buffered(): Boolean { return false } @@ -36,20 +38,20 @@ class HttpVertxTestExecutor(private val vertx: Vertx) : TestExecutor { override fun executeTest(definition: TestDefinition) { runBlocking(vertx.dispatcher()) { // Build server - val serverBuilder = - RestateHttpEndpointBuilder.builder(vertx) - .withOptions(HttpServerOptions().setPort(0)) - .bind( - definition.serviceDefinition as ServiceDefinition, definition.serviceOptions) + val endpointBuilder = + Endpoint.builder() + .bind(definition.serviceDefinition as ServiceDefinition, definition.serviceOptions) if (definition.isEnablePreviewContext()) { - serverBuilder.enablePreviewContext() + endpointBuilder.enablePreviewContext() } // Start server - val server = serverBuilder.build() + val server = + RestateHttpServer.fromEndpoint( + vertx, endpointBuilder.build(), HttpServerOptions().setPort(0)) server.listen().coAwait() - val client = vertx.createHttpClient(RestateHttpEndpointTest.HTTP_CLIENT_OPTIONS) + val client = vertx.createHttpClient(RestateHttpServerTest.Companion.HTTP_CLIENT_OPTIONS) val request = client diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt similarity index 61% rename from sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt index ebdf29332..c030596cc 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/RestateHttpServerTests.kt @@ -6,20 +6,19 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx +package dev.restate.sdk.core.vertx -import dev.restate.sdk.JavaBlockingTests -import dev.restate.sdk.JavaCodegenTests import dev.restate.sdk.core.TestDefinitions.TestExecutor import dev.restate.sdk.core.TestDefinitions.TestSuite -import dev.restate.sdk.kotlin.KotlinCoroutinesTests -import dev.restate.sdk.kotlin.KtCodegenTests +import dev.restate.sdk.core.TestRunner +import dev.restate.sdk.core.javaapi.JavaAPITests +import dev.restate.sdk.core.kotlinapi.KotlinAPITests import io.vertx.core.Vertx import java.util.stream.Stream import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll -class HttpVertxTests : dev.restate.sdk.core.TestRunner() { +class RestateHttpServerTests : TestRunner() { lateinit var vertx: Vertx @@ -34,14 +33,12 @@ class HttpVertxTests : dev.restate.sdk.core.TestRunner() { } override fun executors(): Stream { - return Stream.of(HttpVertxTestExecutor(vertx)) + return Stream.of(RestateHttpServerTestExecutor(vertx)) } override fun definitions(): Stream { return Stream.concat( - Stream.concat( - Stream.concat(JavaBlockingTests().definitions(), JavaCodegenTests().definitions()), - Stream.concat(KotlinCoroutinesTests().definitions(), KtCodegenTests().definitions())), - Stream.of(VertxExecutorsTest())) + Stream.concat(JavaAPITests().definitions(), KotlinAPITests().definitions()), + Stream.of(ThreadTrampoliningTestSuite())) } } diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt similarity index 54% rename from sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt rename to sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt index 7924e933a..bdb7f0c49 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/vertx/ThreadTrampoliningTestSuite.kt @@ -6,23 +6,21 @@ // You can find a copy of the license in file LICENSE in the root // directory of this repository or package, or at // https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx +package dev.restate.sdk.core.vertx -import com.google.protobuf.ByteString -import dev.restate.generated.service.protocol.Protocol -import dev.restate.sdk.HandlerRunner -import dev.restate.sdk.common.HandlerType -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.syscalls.HandlerDefinition -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.ServiceDefinition -import dev.restate.sdk.core.ProtoUtils.* import dev.restate.sdk.core.TestDefinitions import dev.restate.sdk.core.TestDefinitions.testInvocation +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceType import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.kotlin.KtSerdes +import dev.restate.sdk.kotlin.HandlerRunner import dev.restate.sdk.kotlin.runBlock +import dev.restate.sdk.kotlin.serialization.KotlinSerializationSerdeFactory +import dev.restate.sdk.serde.jackson.JacksonSerdeFactory +import dev.restate.serde.Serde import io.vertx.core.Vertx import java.util.stream.Stream import kotlin.coroutines.coroutineContext @@ -30,7 +28,7 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.Dispatchers import org.apache.logging.log4j.LogManager -class VertxExecutorsTest : TestDefinitions.TestSuite { +class ThreadTrampoliningTestSuite : TestDefinitions.TestSuite { private val nonBlockingCoroutineName = CoroutineName("CheckContextSwitchingTestCoroutine") @@ -38,15 +36,12 @@ class VertxExecutorsTest : TestDefinitions.TestSuite { private val LOG = LogManager.getLogger() } - private suspend fun checkNonBlockingComponentTrampolineExecutor( - ctx: dev.restate.sdk.kotlin.Context - ) { + private suspend fun checkNonBlockingComponentTrampolineExecutor(ctx: Context) { LOG.info("I am on the thread I am before executing side effect") check(Vertx.currentContext() == null) check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) ctx.runBlock { LOG.info("I am on the thread I am when executing side effect") - check(coroutineContext[CoroutineName] == nonBlockingCoroutineName) check(Vertx.currentContext() == null) } LOG.info("I am on the thread I am after executing side effect") @@ -60,10 +55,7 @@ class VertxExecutorsTest : TestDefinitions.TestSuite { ): Void? { val id = Thread.currentThread().id check(Vertx.currentContext() == null) - ctx.run { - check(Thread.currentThread().id == id) - check(Vertx.currentContext() == null) - } + ctx.run { check(Vertx.currentContext() == null) } check(Thread.currentThread().id == id) check(Vertx.currentContext() == null) return null @@ -77,41 +69,41 @@ class VertxExecutorsTest : TestDefinitions.TestSuite { ServiceType.SERVICE, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "do", HandlerType.SHARED, KtSerdes.UNIT, KtSerdes.UNIT), - dev.restate.sdk.kotlin.HandlerRunner.of { ctx: Context, _: Unit -> - checkNonBlockingComponentTrampolineExecutor(ctx) - }))), - dev.restate.sdk.kotlin.HandlerRunner.Options( - Dispatchers.Default + nonBlockingCoroutineName), + "do", + HandlerType.SHARED, + KotlinSerializationSerdeFactory.UNIT, + KotlinSerializationSerdeFactory.UNIT, + HandlerRunner.of( + KotlinSerializationSerdeFactory(), + HandlerRunner.Options( + Dispatchers.Default + nonBlockingCoroutineName)) { + ctx: Context, + _: Unit -> + checkNonBlockingComponentTrampolineExecutor(ctx) + }))), "do") - .withInput(startMessage(1), inputMessage(), ackMessage(1)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd()) + .onlyBidiStream() .expectingOutput( - Protocol.RunEntryMessage.newBuilder().setValue(ByteString.EMPTY), - outputMessage(), - END_MESSAGE), + runCmd(1), proposeRunCompletion(1, Serde.VOID, null), suspensionMessage(1)), testInvocation( ServiceDefinition.of( "CheckBlockingComponentTrampolineExecutor", ServiceType.SERVICE, listOf( HandlerDefinition.of( - HandlerSpecification.of( - "do", - HandlerType.SHARED, - Serde.VOID, - Serde.VOID, - ), + "do", + HandlerType.SHARED, + Serde.VOID, + Serde.VOID, dev.restate.sdk.HandlerRunner.of( - this::checkBlockingComponentTrampolineExecutor)))), - HandlerRunner.Options.DEFAULT, + this::checkBlockingComponentTrampolineExecutor, + JacksonSerdeFactory(), + null)))), "do") - .withInput(startMessage(1), inputMessage(), ackMessage(1)) - .onlyUnbuffered() + .withInput(startMessage(1), inputCmd()) + .onlyBidiStream() .expectingOutput( - Protocol.RunEntryMessage.newBuilder().setValue(ByteString.EMPTY), - outputMessage(), - END_MESSAGE)) + runCmd(1), proposeRunCompletion(1, Serde.VOID, null), suspensionMessage(1))) } } diff --git a/sdk-http-vertx/build.gradle.kts b/sdk-http-vertx/build.gradle.kts index 2bc54b3dd..c02fe1e7b 100644 --- a/sdk-http-vertx/build.gradle.kts +++ b/sdk-http-vertx/build.gradle.kts @@ -20,26 +20,4 @@ dependencies { implementation(libs.opentelemetry.api) implementation(libs.log4j.api) implementation(libs.reactiverse.contextual.logging) - - // Testing - testImplementation(project(":sdk-api")) - testImplementation(project(":sdk-serde-jackson")) - testAnnotationProcessor(project(":sdk-api-gen")) - testImplementation(project(":sdk-api-kotlin")) - testImplementation(project(":sdk-core", "testArchive")) - testImplementation(project(":sdk-api", "testArchive")) - testImplementation(project(":sdk-api-gen", "testArchive")) - testImplementation(project(":sdk-api-kotlin", "testArchive")) - testImplementation(project(":sdk-api-kotlin-gen", "testArchive")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - testImplementation(libs.vertx.junit5) - testImplementation(libs.mutiny) - - testImplementation(libs.protobuf.java) - testImplementation(libs.protobuf.kotlin) - testImplementation(libs.log4j.core) - - testImplementation(libs.kotlinx.coroutines.core) - testImplementation(libs.vertx.kotlin.coroutines) } diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpEndpointRequestHandler.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpEndpointRequestHandler.java new file mode 100644 index 000000000..b6dc5b695 --- /dev/null +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpEndpointRequestHandler.java @@ -0,0 +1,105 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.http.vertx; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpResponseStatus.*; + +import dev.restate.sdk.core.EndpointRequestHandler; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.RequestProcessor; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.version.Version; +import io.netty.util.AsciiString; +import io.reactiverse.contextual.logging.ContextualData; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.http.HttpServerRequest; +import io.vertx.core.http.HttpServerResponse; +import io.vertx.core.http.impl.HttpServerRequestInternal; +import java.net.URI; +import java.util.concurrent.Executor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.jspecify.annotations.Nullable; + +public class HttpEndpointRequestHandler implements Handler { + + private static final Logger LOG = LogManager.getLogger(HttpEndpointRequestHandler.class); + + private static final AsciiString X_RESTATE_SERVER_KEY = AsciiString.cached("x-restate-server"); + private static final AsciiString X_RESTATE_SERVER_VALUE = + AsciiString.cached(Version.X_RESTATE_SERVER); + + private final EndpointRequestHandler endpoint; + + private HttpEndpointRequestHandler(Endpoint endpoint) { + this.endpoint = EndpointRequestHandler.forBidiStream(endpoint); + } + + @Override + public void handle(HttpServerRequest request) { + URI uri = URI.create(request.uri()); + Context vertxCurrentContext = ((HttpServerRequestInternal) request).context(); + + RequestProcessor requestProcessor; + try { + requestProcessor = + this.endpoint.processorForRequest( + uri.getPath(), + new HeadersAccessor() { + @Override + public Iterable keys() { + return request.headers().names(); + } + + @Override + public @Nullable String get(String key) { + return request.getHeader(key); + } + }, + ContextualData::put, + currentContextExecutor(vertxCurrentContext)); + } catch (ProtocolException e) { + LOG.warn("Error when handling the request", e); + request + .response() + .setStatusCode(e.getCode()) + .putHeader(CONTENT_TYPE, "text/plain") + .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE) + .end(e.getMessage()); + return; + } + + // Prepare the header frame to send in the response. + // Vert.x will send them as soon as we send the first write + HttpServerResponse response = request.response(); + response.setStatusCode(requestProcessor.statusCode()); + response + .putHeader(CONTENT_TYPE, requestProcessor.responseContentType()) + .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE); + // This is No-op for HTTP2 + response.setChunked(true); + + HttpRequestFlowAdapter requestFlowAdapter = new HttpRequestFlowAdapter(request); + HttpResponseFlowAdapter responseFlowAdapter = new HttpResponseFlowAdapter(response); + + requestFlowAdapter.subscribe(requestProcessor); + requestProcessor.subscribe(responseFlowAdapter); + } + + private Executor currentContextExecutor(Context currentContext) { + return runnable -> currentContext.runOnContext(v -> runnable.run()); + } + + public static HttpEndpointRequestHandler fromEndpoint(Endpoint endpoint) { + return new HttpEndpointRequestHandler(endpoint); + } +} diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpRequestFlowAdapter.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpRequestFlowAdapter.java index 852ffb298..bed7f9613 100644 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpRequestFlowAdapter.java +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpRequestFlowAdapter.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.http.vertx; -import dev.restate.sdk.core.InvocationFlow; +import dev.restate.common.Slice; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerRequest; import java.nio.ByteBuffer; @@ -18,13 +18,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -class HttpRequestFlowAdapter implements InvocationFlow.InvocationInputPublisher { +class HttpRequestFlowAdapter implements Flow.Publisher { private static final Logger LOG = LogManager.getLogger(HttpRequestFlowAdapter.class); private final HttpServerRequest httpServerRequest; - private Flow.Subscriber inputMessagesSubscriber; + private Flow.Subscriber inputMessagesSubscriber; private long subscriberRequest = 0; private final Queue buffers; @@ -34,7 +34,7 @@ class HttpRequestFlowAdapter implements InvocationFlow.InvocationInputPublisher } @Override - public void subscribe(Flow.Subscriber subscriber) { + public void subscribe(Flow.Subscriber subscriber) { this.inputMessagesSubscriber = subscriber; this.inputMessagesSubscriber.onSubscribe( new Flow.Subscription() { @@ -78,7 +78,7 @@ private void handleSubscriptionRequest(long l) { private void handleIncomingBuffer(Buffer buffer) { // Fast path if (this.buffers.isEmpty() && this.subscriberRequest > 0) { - this.inputMessagesSubscriber.onNext(buffer.getByteBuf().nioBuffer()); + this.inputMessagesSubscriber.onNext(Slice.wrap(buffer.getByteBuf().nioBuffer())); this.subscriberRequest--; return; } @@ -105,7 +105,7 @@ private void tryProgress() { return; } this.subscriberRequest--; - inputMessagesSubscriber.onNext(input); + inputMessagesSubscriber.onNext(Slice.wrap(input)); } } } diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpResponseFlowAdapter.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpResponseFlowAdapter.java index d69bd0ec1..d22a774a2 100644 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpResponseFlowAdapter.java +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/HttpResponseFlowAdapter.java @@ -8,17 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.http.vertx; -import dev.restate.sdk.core.InvocationFlow; -import dev.restate.sdk.core.Util; +import dev.restate.common.Slice; +import dev.restate.sdk.core.ExceptionUtils; import io.netty.buffer.Unpooled; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerResponse; -import java.nio.ByteBuffer; import java.util.concurrent.Flow; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -class HttpResponseFlowAdapter implements InvocationFlow.InvocationOutputSubscriber { +class HttpResponseFlowAdapter implements Flow.Subscriber { private static final Logger LOG = LogManager.getLogger(HttpResponseFlowAdapter.class); @@ -39,14 +38,15 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(ByteBuffer byteBuffer) { + public void onNext(Slice slice) { if (this.httpServerResponse.ended()) { cancelSubscription(); return; } // If HTTP HEADERS frame have not been sent, Vert.x will send them - this.httpServerResponse.write(Buffer.buffer(Unpooled.wrappedBuffer(byteBuffer))); + this.httpServerResponse.write( + Buffer.buffer(Unpooled.wrappedBuffer(slice.asReadOnlyByteBuffer()))); } @Override @@ -69,7 +69,7 @@ private void propagateWireFailure(Throwable e) { private void propagatePublisherFailure(Throwable e) { if (!httpServerResponse.headWritten()) { // Try to write the failure in the head - Util.findProtocolException(e) + ExceptionUtils.findProtocolException(e) .ifPresentOrElse( pe -> httpServerResponse.setStatusCode(pe.getCode()), () -> httpServerResponse.setStatusCode(500)); diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java deleted file mode 100644 index bc282a4e2..000000000 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx; - -import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT; -import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; -import static io.netty.handler.codec.http.HttpResponseStatus.*; - -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.ResolvedEndpointHandler; -import dev.restate.sdk.core.RestateEndpoint; -import dev.restate.sdk.version.Version; -import io.netty.util.AsciiString; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.context.propagation.TextMapGetter; -import io.reactiverse.contextual.logging.ContextualData; -import io.vertx.core.Context; -import io.vertx.core.Handler; -import io.vertx.core.MultiMap; -import io.vertx.core.buffer.Buffer; -import io.vertx.core.http.HttpServerRequest; -import io.vertx.core.http.HttpServerResponse; -import io.vertx.core.http.impl.HttpServerRequestInternal; -import java.net.URI; -import java.util.concurrent.Executor; -import java.util.regex.Pattern; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.jspecify.annotations.Nullable; - -class RequestHttpServerHandler implements Handler { - - private static final Logger LOG = LogManager.getLogger(RequestHttpServerHandler.class); - - private static final AsciiString X_RESTATE_SERVER_KEY = AsciiString.cached("x-restate-server"); - private static final AsciiString X_RESTATE_SERVER_VALUE = - AsciiString.cached(Version.X_RESTATE_SERVER); - - private static final Pattern SLASH = Pattern.compile(Pattern.quote("/")); - - private static final String DISCOVER_PATH = "/discover"; - private static final String HEALTH_PATH = "/health"; - - static final TextMapGetter OTEL_TEXT_MAP_GETTER = - new TextMapGetter<>() { - @Override - public Iterable keys(MultiMap carrier) { - return carrier.names(); - } - - @Nullable - @Override - public String get(@Nullable MultiMap carrier, String key) { - if (carrier == null) { - return null; - } - return carrier.get(key); - } - }; - - private final RestateEndpoint restateEndpoint; - private final OpenTelemetry openTelemetry; - - RequestHttpServerHandler(RestateEndpoint restateEndpoint, OpenTelemetry openTelemetry) { - this.restateEndpoint = restateEndpoint; - this.openTelemetry = openTelemetry; - } - - @Override - public void handle(HttpServerRequest request) { - URI uri = URI.create(request.uri()); - - // health check - if (HEALTH_PATH.equalsIgnoreCase(uri.getPath())) { - this.handleHealthRequest(request); - return; - } - - // Discovery request - if (DISCOVER_PATH.equalsIgnoreCase(uri.getPath())) { - this.handleDiscoveryRequest(request); - return; - } - - // Parse request - String[] pathSegments = SLASH.split(uri.getPath()); - if (pathSegments.length < 3) { - LOG.warn( - "Path doesn't match the pattern /invoke/ServiceName/HandlerName nor /discover nor /health: '{}'", - request.path()); - request.response().setStatusCode(NOT_FOUND.code()).end(); - return; - } - String serviceName = pathSegments[pathSegments.length - 2]; - String handlerName = pathSegments[pathSegments.length - 1]; - - // Parse OTEL context and generate span - final io.opentelemetry.context.Context otelContext = - openTelemetry - .getPropagators() - .getTextMapPropagator() - .extract( - io.opentelemetry.context.Context.current(), - request.headers(), - OTEL_TEXT_MAP_GETTER); - - Context vertxCurrentContext = ((HttpServerRequestInternal) request).context(); - - ResolvedEndpointHandler handler; - try { - handler = - restateEndpoint.resolve( - request.getHeader(CONTENT_TYPE), - serviceName, - handlerName, - request::getHeader, - otelContext, - ContextualData::put, - currentContextExecutor(vertxCurrentContext)); - } catch (ProtocolException e) { - LOG.warn("Error when handling the request", e); - request - .response() - .setStatusCode(e.getCode()) - .putHeader(CONTENT_TYPE, "text/plain") - .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE) - .end(e.getMessage()); - return; - } - - LOG.debug("Handling request to {}/{}", serviceName, handlerName); - - // Prepare the header frame to send in the response. - // Vert.x will send them as soon as we send the first write - HttpServerResponse response = request.response(); - response.setStatusCode(OK.code()); - response - .putHeader(CONTENT_TYPE, handler.responseContentType()) - .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE); - // This is No-op for HTTP2 - response.setChunked(true); - - HttpRequestFlowAdapter requestFlowAdapter = new HttpRequestFlowAdapter(request); - HttpResponseFlowAdapter responseFlowAdapter = new HttpResponseFlowAdapter(response); - - requestFlowAdapter.subscribe(handler); - handler.subscribe(responseFlowAdapter); - } - - private Executor currentContextExecutor(Context currentContext) { - return runnable -> currentContext.runOnContext(v -> runnable.run()); - } - - private void handleDiscoveryRequest(HttpServerRequest request) { - RestateEndpoint.DiscoveryResponse discoveryResponse; - try { - discoveryResponse = restateEndpoint.handleDiscoveryRequest(request.getHeader(ACCEPT)); - } catch (ProtocolException e) { - LOG.warn("Error when handling the discovery request", e); - request - .response() - .setStatusCode(e.getCode()) - .putHeader(CONTENT_TYPE, "text/plain") - .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE) - .end(e.getMessage()); - return; - } - - request - .response() - .setStatusCode(OK.code()) - .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE) - .putHeader(CONTENT_TYPE, discoveryResponse.getContentType()) - .end(Buffer.buffer(discoveryResponse.getSerializedManifest())); - } - - private void handleHealthRequest(HttpServerRequest request) { - request - .response() - .setStatusCode(OK.code()) - .putHeader(X_RESTATE_SERVER_KEY, X_RESTATE_SERVER_VALUE) - .end(); - } -} diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java deleted file mode 100644 index 88942fc66..000000000 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx; - -import dev.restate.sdk.auth.RequestIdentityVerifier; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.RestateEndpoint; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import io.opentelemetry.api.OpenTelemetry; -import io.vertx.core.Future; -import io.vertx.core.Vertx; -import io.vertx.core.http.Http2Settings; -import io.vertx.core.http.HttpServer; -import io.vertx.core.http.HttpServerOptions; -import java.util.*; -import java.util.concurrent.CompletionException; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -/** - * Endpoint builder for a Restate HTTP Endpoint using Vert.x, to serve Restate services. - * - *

This endpoint supports the Restate HTTP/2 Streaming component Protocol. - * - *

Example usage: - * - *

- * public static void main(String[] args) {
- *   RestateHttpEndpointBuilder.builder()
- *           .bind(new Counter())
- *           .buildAndListen();
- * }
- * 
- */ -public class RestateHttpEndpointBuilder { - - private static final Logger LOG = LogManager.getLogger(RestateHttpEndpointBuilder.class); - - private final Vertx vertx; - private final RestateEndpoint.Builder endpointBuilder = - RestateEndpoint.newBuilder(EndpointManifestSchema.ProtocolMode.BIDI_STREAM); - private OpenTelemetry openTelemetry = OpenTelemetry.noop(); - private HttpServerOptions options = - new HttpServerOptions() - .setPort(Optional.ofNullable(System.getenv("PORT")).map(Integer::parseInt).orElse(9080)) - .setInitialSettings(new Http2Settings().setMaxConcurrentStreams(Integer.MAX_VALUE)); - - private RestateHttpEndpointBuilder(Vertx vertx) { - this.vertx = vertx; - } - - /** Create a new builder. */ - public static RestateHttpEndpointBuilder builder() { - return new RestateHttpEndpointBuilder(Vertx.vertx()); - } - - /** Create a new builder. */ - public static RestateHttpEndpointBuilder builder(Vertx vertx) { - return new RestateHttpEndpointBuilder(vertx); - } - - /** Add custom {@link HttpServerOptions} to the server used by the endpoint. */ - public RestateHttpEndpointBuilder withOptions(HttpServerOptions options) { - this.options = Objects.requireNonNull(options); - return this; - } - - /** - * Add a Restate service to the endpoint. This will automatically discover the generated factory - * based on the class name. - * - *

You can also manually instantiate the {@link ServiceDefinition} using {@link - * #bind(ServiceDefinition)}. - */ - public RestateHttpEndpointBuilder bind(Object service) { - return this.bind(RestateEndpoint.discoverServiceDefinitionFactory(service).create(service)); - } - - /** - * Add a Restate service to the endpoint. - * - *

To set the options, use {@link #bind(ServiceDefinition, Object)}. - */ - public RestateHttpEndpointBuilder bind(ServiceDefinition serviceDefinition) { - //noinspection unchecked - this.endpointBuilder.bind((ServiceDefinition) serviceDefinition, null); - return this; - } - - /** Add a Restate service to the endpoint, setting the options. */ - public RestateHttpEndpointBuilder bind(ServiceDefinition serviceDefinition, O options) { - this.endpointBuilder.bind(serviceDefinition, options); - return this; - } - - /** - * Set the {@link OpenTelemetry} implementation for tracing and metrics. - * - * @see OpenTelemetry - */ - public RestateHttpEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetry) { - this.openTelemetry = openTelemetry; - return this; - } - - /** - * Set the request identity verifier for this endpoint. - * - *

For the Restate implementation to use with Restate Cloud, check the module {@code - * sdk-request-identity}. - */ - public RestateHttpEndpointBuilder withRequestIdentityVerifier( - RequestIdentityVerifier requestIdentityVerifier) { - this.endpointBuilder.withRequestIdentityVerifier(requestIdentityVerifier); - return this; - } - - public RestateHttpEndpointBuilder enablePreviewContext() { - this.endpointBuilder.enablePreviewContext(); - return this; - } - - /** - * Build and listen on the specified port. - * - *

NOTE: this method will block for opening the socket and reserving the port. If you need a - * non-blocking variant, manually {@link #build()} the server and start listening it. - * - * @return The listening port - */ - public int buildAndListen(int port) { - return handleStart(build().listen(port)); - } - - /** - * Build and listen on the port specified by the environment variable {@code PORT}, or - * alternatively on the default {@code 9080} port. - * - *

NOTE: this method will block for opening the socket and reserving the port. If you need a - * non-blocking variant, manually {@link #build()} the server and start listening it. - * - * @return The listening port - */ - public int buildAndListen() { - return handleStart(build().listen()); - } - - /** Build the {@link HttpServer} serving the Restate service endpoint. */ - public HttpServer build() { - HttpServer server = vertx.createHttpServer(options); - - this.endpointBuilder.withTracer(this.openTelemetry.getTracer("restate-java-sdk-vertx")); - - server.requestHandler( - new RequestHttpServerHandler(this.endpointBuilder.build(), openTelemetry)); - - return server; - } - - private static int handleStart(Future fut) { - try { - HttpServer server = fut.toCompletionStage().toCompletableFuture().join(); - LOG.info("Restate HTTP Endpoint server started on port {}", server.actualPort()); - return server.actualPort(); - } catch (CompletionException e) { - LOG.error("Restate HTTP Endpoint server start failed", e.getCause()); - sneakyThrow(e.getCause()); - // This is never reached - return -1; - } - } - - @SuppressWarnings("unchecked") - private static void sneakyThrow(Throwable e) throws E { - throw (E) e; - } -} diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpServer.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpServer.java new file mode 100644 index 000000000..5b4ee28f9 --- /dev/null +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpServer.java @@ -0,0 +1,153 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.http.vertx; + +import dev.restate.sdk.endpoint.Endpoint; +import io.vertx.core.Future; +import io.vertx.core.Vertx; +import io.vertx.core.http.Http2Settings; +import io.vertx.core.http.HttpServer; +import io.vertx.core.http.HttpServerOptions; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Endpoint builder for a Restate HTTP Endpoint using Vert.x, to serve Restate services. + * + *

This endpoint supports the Restate HTTP/2 Streaming component Protocol. + * + *

Example usage: + * + *

+ * public static void main(String[] args) {
+ *   Endpoint endpoint = Endpoint.builder()
+ *     .bind(new Counter())
+ *     .build();
+ *
+ *   RestateHttpServer.listen(endpoint);
+ * }
+ * 
+ */ +public class RestateHttpServer { + + private static final Logger LOG = LogManager.getLogger(RestateHttpServer.class); + + private static final int DEFAULT_PORT = + Optional.ofNullable(System.getenv("PORT")).map(Integer::parseInt).orElse(9080); + private static final HttpServerOptions DEFAULT_OPTIONS = + new HttpServerOptions() + .setInitialSettings(new Http2Settings().setMaxConcurrentStreams(Integer.MAX_VALUE)); + + /** + * Start serving the provided {@code endpoint} on the port specified by the environment variable + * {@code PORT}, or alternatively on the default {@code 9080} port. + * + *

NOTE: this method will block for opening the socket and reserving the port. If you need a + * non-blocking variant, manually create the server with {@link #fromEndpoint(Endpoint)} and start + * listening it. + * + * @return The listening port + */ + public static int listen(Endpoint endpoint) { + return handleStart(fromEndpoint(endpoint).listen(DEFAULT_PORT)); + } + + /** Like {@link #listen(Endpoint)} */ + public static int listen(Endpoint.Builder endpointBuilder) { + return listen(endpointBuilder.build()); + } + + /** + * Start serving the provided {@code endpoint} on the specified port. + * + *

NOTE: this method will block for opening the socket and reserving the port. If you need a + * non-blocking variant, manually create the server with {@link #fromEndpoint(Endpoint)} and start + * listening it. + * + * @return The listening port + */ + public static int listen(Endpoint endpoint, int port) { + return handleStart(fromEndpoint(endpoint).listen(port)); + } + + /** Like {@link #listen(Endpoint, int)} */ + public static int listen(Endpoint.Builder endpointBuilder, int port) { + return listen(endpointBuilder.build(), port); + } + + /** Create a Vert.x {@link HttpServer} from the provided endpoint. */ + public static HttpServer fromEndpoint(Endpoint endpoint) { + return fromEndpoint(endpoint, DEFAULT_OPTIONS); + } + + /** Like {@link #fromEndpoint(Endpoint)} */ + public static HttpServer fromEndpoint(Endpoint.Builder endpointBuilder) { + return fromEndpoint(endpointBuilder.build()); + } + + /** + * Create a Vert.x {@link HttpServer} from the provided endpoint, with the given {@link + * HttpServerOptions}. + */ + public static HttpServer fromEndpoint(Endpoint endpoint, HttpServerOptions options) { + return fromEndpoint(Vertx.vertx(), endpoint, options); + } + + /** Like {@link #fromEndpoint(Endpoint, HttpServerOptions)} */ + public static HttpServer fromEndpoint( + Endpoint.Builder endpointBuilder, HttpServerOptions options) { + return fromEndpoint(endpointBuilder.build(), options); + } + + /** Create a Vert.x {@link HttpServer} from the provided endpoint. */ + public static HttpServer fromEndpoint(Vertx vertx, Endpoint endpoint) { + return fromEndpoint(vertx, endpoint, DEFAULT_OPTIONS); + } + + /** Like {@link #fromEndpoint(Vertx, Endpoint)} */ + public static HttpServer fromEndpoint(Vertx vertx, Endpoint.Builder endpointBuilder) { + return fromEndpoint(vertx, endpointBuilder.build()); + } + + /** + * Create a Vert.x {@link HttpServer} from the provided endpoint, with the given {@link + * HttpServerOptions}. + */ + public static HttpServer fromEndpoint(Vertx vertx, Endpoint endpoint, HttpServerOptions options) { + HttpServer server = vertx.createHttpServer(options); + server.requestHandler(HttpEndpointRequestHandler.fromEndpoint(endpoint)); + return server; + } + + /** Like {@link #fromEndpoint(Vertx, Endpoint, HttpServerOptions)} */ + public static HttpServer fromEndpoint( + Vertx vertx, Endpoint.Builder endpointBuilder, HttpServerOptions options) { + return fromEndpoint(vertx, endpointBuilder.build(), options); + } + + private static int handleStart(Future fut) { + try { + HttpServer server = fut.toCompletionStage().toCompletableFuture().join(); + LOG.info("Restate HTTP Endpoint server started on port {}", server.actualPort()); + return server.actualPort(); + } catch (CompletionException e) { + LOG.error("Restate HTTP Endpoint server start failed", e.getCause()); + sneakyThrow(e.getCause()); + // This is never reached + return -1; + } + } + + @SuppressWarnings("unchecked") + private static void sneakyThrow(Throwable e) throws E { + throw (E) e; + } +} diff --git a/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeter.java b/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeter.java deleted file mode 100644 index cd8280bdd..000000000 --- a/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeter.java +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx.testservices; - -import dev.restate.sdk.JsonSerdes; -import dev.restate.sdk.ObjectContext; -import dev.restate.sdk.annotation.Handler; -import dev.restate.sdk.annotation.VirtualObject; -import dev.restate.sdk.common.StateKey; -import java.time.Duration; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -@VirtualObject -public class BlockingGreeter { - - private static final Logger LOG = LogManager.getLogger(BlockingGreeter.class); - public static final StateKey COUNTER = StateKey.of("counter", JsonSerdes.LONG); - - @Handler - public String greet(ObjectContext context, String request) { - LOG.info("Greet invoked!"); - - var count = context.get(COUNTER).orElse(0L) + 1; - context.set(COUNTER, count); - - context.sleep(Duration.ofSeconds(1)); - - return "Hello " + request + ". Count: " + count; - } -} diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt deleted file mode 100644 index 42fd3dde6..000000000 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx - -import com.fasterxml.jackson.databind.ObjectMapper -import com.google.protobuf.ByteString -import com.google.protobuf.MessageLite -import dev.restate.generated.service.protocol.Protocol.* -import dev.restate.sdk.JsonSerdes -import dev.restate.sdk.core.ProtoUtils.* -import dev.restate.sdk.core.manifest.EndpointManifestSchema -import dev.restate.sdk.http.vertx.testservices.BlockingGreeter -import dev.restate.sdk.http.vertx.testservices.greeter -import io.netty.buffer.Unpooled -import io.netty.handler.codec.http.HttpResponseStatus -import io.vertx.core.Vertx -import io.vertx.core.buffer.Buffer -import io.vertx.core.http.* -import io.vertx.junit5.Timeout -import io.vertx.junit5.VertxExtension -import io.vertx.kotlin.coroutines.coAwait -import io.vertx.kotlin.coroutines.dispatcher -import io.vertx.kotlin.coroutines.receiveChannelHandler -import java.nio.ByteBuffer -import java.util.concurrent.TimeUnit -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.delay -import kotlinx.coroutines.runBlocking -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.extension.ExtendWith -import org.junit.jupiter.api.parallel.Isolated - -@Isolated -@ExtendWith(VertxExtension::class) -internal class RestateHttpEndpointTest { - - companion object { - val HTTP_CLIENT_OPTIONS: HttpClientOptions = - HttpClientOptions() - // Set prior knowledge - .setProtocolVersion(HttpVersion.HTTP_2) - .setHttp2ClearTextUpgrade(false) - } - - @Timeout(value = 1, timeUnit = TimeUnit.SECONDS) - @Test - fun endpointWithNonBlockingService(vertx: Vertx): Unit = - greetTest(vertx, "KtGreeter") { it.bind(greeter()) } - - @Timeout(value = 1, timeUnit = TimeUnit.SECONDS) - @Test - fun endpointWithBlockingService(vertx: Vertx): Unit = - greetTest(vertx, BlockingGreeter::class.simpleName!!) { it.bind(BlockingGreeter()) } - - private fun greetTest( - vertx: Vertx, - componentName: String, - consumeBuilderFn: (RestateHttpEndpointBuilder) -> RestateHttpEndpointBuilder - ): Unit = - runBlocking(vertx.dispatcher()) { - val endpointBuilder = RestateHttpEndpointBuilder.builder(vertx) - consumeBuilderFn(endpointBuilder) - - val endpointPort: Int = - endpointBuilder - .withOptions(HttpServerOptions().setPort(0)) - .build() - .listen() - .coAwait() - .actualPort() - - val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) - - val request = - client - .request(HttpMethod.POST, endpointPort, "localhost", "/invoke/$componentName/greet") - .coAwait() - - // Prepare request header - request - .setChunked(true) - .putHeader(HttpHeaders.CONTENT_TYPE, serviceProtocolContentTypeHeader()) - - // Send start message and PollInputStreamEntry - request.write(encode(startMessage(1).build())) - request.write(encode(inputMessage("Francesco"))) - - val response = request.response().coAwait() - - // Start the input decoder - val inputChannel = vertx.receiveChannelHandler() - response.handler { - bufferToMessages(listOf(ByteBuffer.wrap(it.bytes))).forEach(inputChannel::handle) - } - response.resume() - - // Wait for Get State Entry - val getStateEntry = inputChannel.receive() - - assertThat(getStateEntry).isInstanceOf(GetStateEntryMessage::class.java) - assertThat(getStateEntry as GetStateEntryMessage) - .returns(ByteString.copyFromUtf8("counter"), GetStateEntryMessage::getKey) - - // Send completion - request.write( - encode( - completionMessage(1) - .setValue(ByteString.copyFrom(JsonSerdes.LONG.serialize(2))) - .build())) - - // Wait for Set State Entry - val setStateEntry = inputChannel.receive() - - assertThat(setStateEntry).isInstanceOf(SetStateEntryMessage::class.java) - assertThat(setStateEntry as SetStateEntryMessage) - .returns(ByteString.copyFromUtf8("counter"), SetStateEntryMessage::getKey) - .returns(ByteString.copyFromUtf8("3"), SetStateEntryMessage::getValue) - - // Wait for the sleep and complete it - val sleepEntry = inputChannel.receive() - - assertThat(sleepEntry).isInstanceOf(SleepEntryMessage::class.java) - - // Wait a bit, then send the completion - delay(1.seconds) - request.write( - encode( - CompletionMessage.newBuilder() - .setEntryIndex(3) - .setEmpty(Empty.getDefaultInstance()) - .build())) - - // Now wait for response - val outputEntry = inputChannel.receive() - - assertThat(outputEntry).isInstanceOf(OutputEntryMessage::class.java) - assertThat(outputEntry).isEqualTo(outputMessage("Hello Francesco. Count: 3")) - - // Wait for closing request and response - request.end().coAwait() - } - - @Test - fun return404(vertx: Vertx): Unit = - runBlocking(vertx.dispatcher()) { - val endpointPort: Int = - RestateHttpEndpointBuilder.builder(vertx) - .bind(BlockingGreeter()) - .withOptions(HttpServerOptions().setPort(0)) - .build() - .listen() - .coAwait() - .actualPort() - - val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) - - val request = - client - .request( - HttpMethod.POST, - endpointPort, - "localhost", - "/invoke/" + BlockingGreeter::class.java.simpleName + "/unknownMethod") - .coAwait() - - // Prepare request header - request - .setChunked(true) - .putHeader(HttpHeaders.CONTENT_TYPE, serviceProtocolContentTypeHeader()) - .putHeader(HttpHeaders.ACCEPT, serviceProtocolContentTypeHeader()) - request.write(encode(startMessage(0).build())) - - val response = request.response().coAwait() - - // Response status should be 404 - assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.NOT_FOUND.code()) - - response.end().coAwait() - } - - @Test - fun serviceDiscovery(vertx: Vertx): Unit = - runBlocking(vertx.dispatcher()) { - val endpointPort: Int = - RestateHttpEndpointBuilder.builder(vertx) - .bind(BlockingGreeter()) - .withOptions(HttpServerOptions().setPort(0)) - .build() - .listen() - .coAwait() - .actualPort() - - val client = vertx.createHttpClient(HTTP_CLIENT_OPTIONS) - - // Send request - val request = - client.request(HttpMethod.GET, endpointPort, "localhost", "/discover").coAwait() - request.putHeader(HttpHeaders.ACCEPT, serviceProtocolDiscoveryContentTypeHeader()) - request.end().coAwait() - - // Assert response - val response = request.response().coAwait() - - // Response status and content type header - assertThat(response.statusCode()).isEqualTo(HttpResponseStatus.OK.code()) - assertThat(response.getHeader(HttpHeaders.CONTENT_TYPE)) - .isEqualTo(serviceProtocolDiscoveryContentTypeHeader()) - - // Parse response - val responseBody = response.body().coAwait() - // Compute response and write it back - val discoveryResponse: EndpointManifestSchema = - ObjectMapper().readValue(responseBody.bytes, EndpointManifestSchema::class.java) - - assertThat(discoveryResponse.services) - .map { it.name } - .containsOnly(BlockingGreeter::class.java.simpleName) - } - - private fun encode(msg: MessageLite): Buffer { - return Buffer.buffer(Unpooled.wrappedBuffer(messageToByteString(msg))) - } -} diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtComponent.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtComponent.kt deleted file mode 100644 index f2bc69c16..000000000 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtComponent.kt +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.http.vertx.testservices - -import dev.restate.sdk.common.HandlerType -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.syscalls.HandlerDefinition -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.ServiceDefinition -import dev.restate.sdk.kotlin.HandlerRunner -import dev.restate.sdk.kotlin.KtSerdes -import dev.restate.sdk.kotlin.ObjectContext -import kotlin.time.Duration.Companion.seconds -import org.apache.logging.log4j.LogManager - -private val LOG = LogManager.getLogger() -private val COUNTER: StateKey = BlockingGreeter.COUNTER - -fun greeter(): ServiceDefinition<*> = - ServiceDefinition.of( - "KtGreeter", - ServiceType.VIRTUAL_OBJECT, - listOf( - HandlerDefinition.of( - HandlerSpecification.of( - "greet", HandlerType.EXCLUSIVE, KtSerdes.json(), KtSerdes.json()), - HandlerRunner.of { ctx: ObjectContext, request: String -> - LOG.info("Greet invoked!") - - val count = (ctx.get(COUNTER) ?: 0) + 1 - ctx.set(COUNTER, count) - - ctx.sleep(1.seconds) - - "Hello $request. Count: $count" - }))) diff --git a/sdk-http-vertx/src/test/resources/junit-platform.properties b/sdk-http-vertx/src/test/resources/junit-platform.properties deleted file mode 100644 index 3e799af08..000000000 --- a/sdk-http-vertx/src/test/resources/junit-platform.properties +++ /dev/null @@ -1,3 +0,0 @@ -junit.jupiter.execution.parallel.enabled = true -junit.jupiter.execution.parallel.config.strategy = dynamic -junit.jupiter.execution.parallel.mode.default = same_thread \ No newline at end of file diff --git a/sdk-java-http/build.gradle.kts b/sdk-java-http/build.gradle.kts new file mode 100644 index 000000000..cb8f1e80a --- /dev/null +++ b/sdk-java-http/build.gradle.kts @@ -0,0 +1,14 @@ +plugins { + `java-conventions` + `java-library` + `library-publishing-conventions` +} + +description = "Restate SDK Java HTTP starter" + +dependencies { + api(project(":sdk-api")) + api(project(":sdk-http-vertx")) + api(project(":client")) + implementation(libs.log4j.core) +} diff --git a/sdk-java-lambda/build.gradle.kts b/sdk-java-lambda/build.gradle.kts new file mode 100644 index 000000000..f21bf782b --- /dev/null +++ b/sdk-java-lambda/build.gradle.kts @@ -0,0 +1,14 @@ +plugins { + `java-conventions` + `java-library` + `library-publishing-conventions` +} + +description = "Restate SDK Java Lambda starter" + +dependencies { + api(project(":sdk-api")) + api(project(":sdk-lambda")) + api(project(":client")) + implementation(libs.log4j.core) +} diff --git a/sdk-kotlin-http/build.gradle.kts b/sdk-kotlin-http/build.gradle.kts new file mode 100644 index 000000000..b260671e9 --- /dev/null +++ b/sdk-kotlin-http/build.gradle.kts @@ -0,0 +1,13 @@ +plugins { + `kotlin-conventions` + `library-publishing-conventions` +} + +description = "Restate SDK Kotlin HTTP starter" + +dependencies { + api(project(":sdk-api-kotlin")) + api(project(":sdk-http-vertx")) + api(project(":client-kotlin")) + implementation(libs.log4j.core) +} diff --git a/sdk-kotlin-lambda/build.gradle.kts b/sdk-kotlin-lambda/build.gradle.kts new file mode 100644 index 000000000..880424265 --- /dev/null +++ b/sdk-kotlin-lambda/build.gradle.kts @@ -0,0 +1,13 @@ +plugins { + `kotlin-conventions` + `library-publishing-conventions` +} + +description = "Restate SDK Kotlin Lambda starter" + +dependencies { + api(project(":sdk-api-kotlin")) + api(project(":sdk-lambda")) + api(project(":client-kotlin")) + implementation(libs.log4j.core) +} diff --git a/sdk-lambda/build.gradle.kts b/sdk-lambda/build.gradle.kts index 71cdfece0..798eef391 100644 --- a/sdk-lambda/build.gradle.kts +++ b/sdk-lambda/build.gradle.kts @@ -17,18 +17,4 @@ dependencies { implementation(libs.opentelemetry.api) implementation(libs.log4j.api) - - testAnnotationProcessor(project(":sdk-api-gen")) - testImplementation(project(":sdk-api")) - testImplementation(project(":sdk-api-kotlin")) - testImplementation(project(":sdk-core", "testArchive")) - testImplementation(project(":sdk-serde-jackson")) - testImplementation(libs.junit.jupiter) - testImplementation(libs.assertj) - - testImplementation(libs.protobuf.java) - testImplementation(libs.protobuf.kotlin) - testImplementation(libs.log4j.core) - - testImplementation(libs.kotlinx.coroutines.core) } diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/BaseRestateLambdaHandler.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/BaseRestateLambdaHandler.java index 2fd0c8d74..9309f3075 100644 --- a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/BaseRestateLambdaHandler.java +++ b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/BaseRestateLambdaHandler.java @@ -12,6 +12,7 @@ import com.amazonaws.services.lambda.runtime.RequestHandler; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; +import dev.restate.sdk.endpoint.Endpoint; import org.apache.logging.log4j.CloseableThreadContext; /** @@ -30,22 +31,22 @@ public abstract class BaseRestateLambdaHandler private static final String AWS_REQUEST_ID = "AWSRequestId"; - private final RestateLambdaEndpoint restateLambdaEndpoint; + private final LambdaEndpointRequestHandler lambdaEndpointRequestHandler; protected BaseRestateLambdaHandler() { - RestateLambdaEndpointBuilder builder = RestateLambdaEndpoint.builder(); - register(builder); - this.restateLambdaEndpoint = builder.build(); + Endpoint.Builder endpointBuilder = Endpoint.builder(); + register(endpointBuilder); + this.lambdaEndpointRequestHandler = new LambdaEndpointRequestHandler(endpointBuilder.build()); } /** Configure your services in this method. */ - public abstract void register(RestateLambdaEndpointBuilder builder); + public abstract void register(Endpoint.Builder builder); @Override public APIGatewayProxyResponseEvent handleRequest( APIGatewayProxyRequestEvent input, Context context) { try (var requestId = CloseableThreadContext.put(AWS_REQUEST_ID, context.getAwsRequestId())) { - return restateLambdaEndpoint.handleRequest(input, context); + return lambdaEndpointRequestHandler.handleRequest(input, context); } } } diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaEndpointRequestHandler.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaEndpointRequestHandler.java new file mode 100644 index 000000000..b3540577f --- /dev/null +++ b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaEndpointRequestHandler.java @@ -0,0 +1,117 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.lambda; + +import static dev.restate.sdk.lambda.LambdaFlowAdapters.*; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; +import dev.restate.common.Slice; +import dev.restate.sdk.core.EndpointRequestHandler; +import dev.restate.sdk.core.ProtocolException; +import dev.restate.sdk.core.RequestProcessor; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.version.Version; +import java.util.*; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.ThreadContext; + +/** Restate Lambda Endpoint. */ +public final class LambdaEndpointRequestHandler { + + private static final Logger LOG = LogManager.getLogger(LambdaEndpointRequestHandler.class); + + private final EndpointRequestHandler endpoint; + + LambdaEndpointRequestHandler(Endpoint endpoint) { + this.endpoint = EndpointRequestHandler.forRequestResponse(endpoint); + } + + /** Handle a Lambda request as Restate Lambda endpoint. */ + public APIGatewayProxyResponseEvent handleRequest( + APIGatewayProxyRequestEvent input, Context context) { + // Remove trailing path separator + String path = + input.getPath().endsWith("/") + ? input.getPath().substring(0, input.getPath().length() - 1) + : input.getPath(); + + // Parse request body + final Slice requestBody = parseInputBody(input); + final Executor coreExecutor = Executors.newSingleThreadExecutor(); + + RequestProcessor requestProcessor; + try { + requestProcessor = + this.endpoint.processorForRequest( + path, + HeadersAccessor.wrap(input.getHeaders()), + EndpointRequestHandler.LoggingContextSetter.THREAD_LOCAL_INSTANCE, + coreExecutor); + } catch (ProtocolException e) { + // We can handle protocol exceptions by returning back the correct response + LOG.warn("Error when handling the request", e); + return new APIGatewayProxyResponseEvent() + .withStatusCode(e.getCode()) + .withHeaders( + Map.of("content-type", "text/plain", "x-restate-server", Version.X_RESTATE_SERVER)) + .withBody(e.getMessage()); + } + + BufferedPublisher publisher = new BufferedPublisher(requestBody); + ResultSubscriber subscriber = new ResultSubscriber(); + + // Wire handler + coreExecutor.execute(() -> publisher.subscribe(requestProcessor)); + requestProcessor.subscribe(subscriber); + + // Await the result + byte[] responseBody; + try { + responseBody = subscriber.getResult(); + } catch (Error | RuntimeException e) { + throw e; + } catch (Throwable e) { + throw new RuntimeException(e); + } + + // Clear logging + ThreadContext.clearAll(); + + final APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent(); + response.setHeaders( + Map.of( + "content-type", + requestProcessor.responseContentType(), + "x-restate-server", + Version.X_RESTATE_SERVER)); + response.setIsBase64Encoded(true); + response.setStatusCode(requestProcessor.statusCode()); + response.setBody(Base64.getEncoder().encodeToString(responseBody)); + return response; + } + + // --- Utils + + private static Slice parseInputBody(APIGatewayProxyRequestEvent input) { + if (input.getBody() == null) { + return Slice.EMPTY; + } + if (!input.getIsBase64Encoded()) { + throw new IllegalArgumentException( + "Input is not Base64 encoded. This is most likely an SDK bug, please contact the developers."); + } + return Slice.wrap(Base64.getDecoder().decode(input.getBody())); + } +} diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaFlowAdapters.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaFlowAdapters.java index 669a3927b..a35375102 100644 --- a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaFlowAdapters.java +++ b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaFlowAdapters.java @@ -8,10 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.lambda; -import dev.restate.sdk.core.InvocationFlow; +import dev.restate.common.Slice; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.concurrent.CompletableFuture; @@ -20,7 +19,7 @@ class LambdaFlowAdapters { - static class ResultSubscriber implements InvocationFlow.InvocationOutputSubscriber { + static class ResultSubscriber implements Flow.Subscriber { private final CompletableFuture completionFuture; private final ByteArrayOutputStream outputStream; @@ -38,9 +37,9 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(ByteBuffer item) { + public void onNext(Slice item) { try { - this.channel.write(item); + this.channel.write(item.asReadOnlyByteBuffer()); } catch (IOException e) { this.completionFuture.completeExceptionally(e); } @@ -66,24 +65,24 @@ public byte[] getResult() throws Throwable { } } - static class BufferedPublisher implements InvocationFlow.InvocationInputPublisher { + static class BufferedPublisher implements Flow.Publisher { - private ByteBuffer buffer; + private Slice slice; - BufferedPublisher(ByteBuffer buffer) { - this.buffer = buffer.asReadOnlyBuffer(); + BufferedPublisher(Slice slice) { + this.slice = slice; } @Override - public void subscribe(Flow.Subscriber subscriber) { + public void subscribe(Flow.Subscriber subscriber) { subscriber.onSubscribe( new Flow.Subscription() { @Override public void request(long l) { - if (buffer != null) { - subscriber.onNext(buffer); + if (slice != null) { + subscriber.onNext(slice); subscriber.onComplete(); - buffer = null; + slice = null; } } diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpoint.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpoint.java deleted file mode 100644 index 12c6ef2b6..000000000 --- a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpoint.java +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda; - -import static dev.restate.sdk.lambda.LambdaFlowAdapters.*; - -import com.amazonaws.services.lambda.runtime.Context; -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; -import dev.restate.sdk.core.ProtocolException; -import dev.restate.sdk.core.ResolvedEndpointHandler; -import dev.restate.sdk.core.RestateEndpoint; -import dev.restate.sdk.version.Version; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.context.propagation.TextMapGetter; -import java.nio.ByteBuffer; -import java.util.*; -import java.util.regex.Pattern; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.ThreadContext; - -/** Restate Lambda Endpoint. */ -public final class RestateLambdaEndpoint { - - private static final Logger LOG = LogManager.getLogger(RestateLambdaEndpoint.class); - - private static final Pattern SLASH = Pattern.compile(Pattern.quote("/")); - private static final String INVOKE_PATH_SEGMENT = "invoke"; - private static final String DISCOVER_PATH = "/discover"; - - private static final TextMapGetter> OTEL_HEADERS_GETTER = - new TextMapGetter<>() { - @Override - public Iterable keys(Map carrier) { - return carrier.keySet(); - } - - @Override - public String get(Map carrier, String key) { - if (carrier == null) { - return null; - } - return carrier.get(key); - } - }; - - private final RestateEndpoint restateEndpoint; - private final OpenTelemetry openTelemetry; - - RestateLambdaEndpoint(RestateEndpoint restateEndpoint, OpenTelemetry openTelemetry) { - this.restateEndpoint = restateEndpoint; - this.openTelemetry = openTelemetry; - } - - /** Create a new builder. */ - public static RestateLambdaEndpointBuilder builder() { - return new RestateLambdaEndpointBuilder(); - } - - /** Handle a Lambda request as Restate Lambda endpoint. */ - public APIGatewayProxyResponseEvent handleRequest( - APIGatewayProxyRequestEvent input, Context context) { - // Remove trailing path separator - String path = - input.getPath().endsWith("/") - ? input.getPath().substring(0, input.getPath().length() - 1) - : input.getPath(); - - try { - if (path.endsWith(DISCOVER_PATH)) { - return this.handleDiscovery(input.getHeaders().get("accept")); - } - return this.handleInvoke(input); - } catch (ProtocolException e) { - // We can handle protocol exceptions by returning back the correct response - LOG.warn("Error when handling the request", e); - return new APIGatewayProxyResponseEvent() - .withStatusCode(e.getCode()) - .withHeaders( - Map.of("content-type", "text/plain", "x-restate-server", Version.X_RESTATE_SERVER)) - .withBody(e.getMessage()); - } - } - - // --- Invoke request - - private APIGatewayProxyResponseEvent handleInvoke(APIGatewayProxyRequestEvent input) { - // Parse request - String[] pathSegments = SLASH.split(input.getPath()); - if (pathSegments.length < 3 - || !INVOKE_PATH_SEGMENT.equalsIgnoreCase(pathSegments[pathSegments.length - 3])) { - LOG.warn("Path doesn't match the pattern /invoke/SvcName/MethodName: '{}'", input.getPath()); - return new APIGatewayProxyResponseEvent().withStatusCode(404); - } - String serviceName = pathSegments[pathSegments.length - 2]; - String handlerName = pathSegments[pathSegments.length - 1]; - - // Parse OTEL context and generate span - final io.opentelemetry.context.Context otelContext = - openTelemetry - .getPropagators() - .getTextMapPropagator() - .extract( - io.opentelemetry.context.Context.current(), - input.getHeaders(), - OTEL_HEADERS_GETTER); - - // Parse request body - final ByteBuffer requestBody = parseInputBody(input); - - // Resolve handler - ResolvedEndpointHandler handler; - try { - handler = - this.restateEndpoint.resolve( - input.getHeaders().get("content-type"), - serviceName, - handlerName, - input.getHeaders()::get, - otelContext, - RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE, - null); - } catch (ProtocolException e) { - LOG.warn("Error when resolving the grpc handler", e); - return new APIGatewayProxyResponseEvent().withStatusCode(e.getCode()); - } - - BufferedPublisher publisher = new BufferedPublisher(requestBody); - ResultSubscriber subscriber = new ResultSubscriber(); - - // Wire handler - publisher.subscribe(handler); - handler.subscribe(subscriber); - - // Await the result - byte[] responseBody; - try { - responseBody = subscriber.getResult(); - } catch (Error | RuntimeException e) { - throw e; - } catch (Throwable e) { - throw new RuntimeException(e); - } - - // Clear logging - ThreadContext.clearAll(); - - final APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent(); - response.setHeaders( - Map.of( - "content-type", - handler.responseContentType(), - "x-restate-server", - Version.X_RESTATE_SERVER)); - response.setIsBase64Encoded(true); - response.setStatusCode(200); - response.setBody(Base64.getEncoder().encodeToString(responseBody)); - return response; - } - - // --- Service discovery - - private APIGatewayProxyResponseEvent handleDiscovery(String acceptVersionsString) { - RestateEndpoint.DiscoveryResponse discoveryResponse = - this.restateEndpoint.handleDiscoveryRequest(acceptVersionsString); - - final APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent(); - response.setHeaders( - Map.of( - "content-type", - discoveryResponse.getContentType(), - "x-restate-server", - Version.X_RESTATE_SERVER)); - response.setIsBase64Encoded(true); - response.setStatusCode(200); - response.setBody(Base64.getEncoder().encodeToString(discoveryResponse.getSerializedManifest())); - return response; - } - - // --- Utils - - private static ByteBuffer parseInputBody(APIGatewayProxyRequestEvent input) { - if (input.getBody() == null) { - return ByteBuffer.wrap(new byte[] {}); - } - if (!input.getIsBase64Encoded()) { - throw new IllegalArgumentException( - "Input is not Base64 encoded. This is most likely an SDK bug, please contact the developers."); - } - return ByteBuffer.wrap(Base64.getDecoder().decode(input.getBody())); - } -} diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpointBuilder.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpointBuilder.java deleted file mode 100644 index f1e61a795..000000000 --- a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpointBuilder.java +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda; - -import dev.restate.sdk.auth.RequestIdentityVerifier; -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.core.RestateEndpoint; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; -import io.opentelemetry.api.OpenTelemetry; - -/** Endpoint builder for a Restate AWS Lambda Endpoint, to serve Restate service. */ -public final class RestateLambdaEndpointBuilder { - - private final RestateEndpoint.Builder restateEndpoint = - RestateEndpoint.newBuilder(EndpointManifestSchema.ProtocolMode.REQUEST_RESPONSE); - private OpenTelemetry openTelemetry = OpenTelemetry.noop(); - - /** - * Add a Restate service to the endpoint. This will automatically discover the generated factory - * based on the class name. - * - *

You can also manually instantiate the {@link ServiceDefinition} using {@link - * #bind(ServiceDefinition)}. - */ - public RestateLambdaEndpointBuilder bind(Object service) { - return this.bind(RestateEndpoint.discoverServiceDefinitionFactory(service).create(service)); - } - - /** - * Add a Restate service to the endpoint. - * - *

To set the options, use {@link #bind(ServiceDefinition, Object)}. - */ - public RestateLambdaEndpointBuilder bind(ServiceDefinition service) { - //noinspection unchecked - this.restateEndpoint.bind((ServiceDefinition) service, null); - return this; - } - - /** Add a Restate service to the endpoint, setting the options. */ - public RestateLambdaEndpointBuilder bind(ServiceDefinition serviceDefinition, O options) { - this.restateEndpoint.bind(serviceDefinition, options); - return this; - } - - /** - * Add a {@link OpenTelemetry} implementation for tracing and metrics. - * - * @see OpenTelemetry - */ - public RestateLambdaEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetry) { - this.openTelemetry = openTelemetry; - return this; - } - - /** - * Set the request identity verifier for this endpoint. - * - *

For the Restate implementation to use with Restate Cloud, check the module {@code - * sdk-request-identity}. - */ - public RestateLambdaEndpointBuilder withRequestIdentityVerifier( - RequestIdentityVerifier requestIdentityVerifier) { - this.restateEndpoint.withRequestIdentityVerifier(requestIdentityVerifier); - return this; - } - - public RestateLambdaEndpointBuilder enablePreviewContext() { - this.restateEndpoint.enablePreviewContext(); - return this; - } - - /** Build the {@link RestateLambdaEndpoint} serving the Restate service endpoint. */ - public RestateLambdaEndpoint build() { - return new RestateLambdaEndpoint(this.restateEndpoint.build(), this.openTelemetry); - } -} diff --git a/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt b/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt deleted file mode 100644 index 1ece43b88..000000000 --- a/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.lambda.testservices - -import dev.restate.sdk.common.HandlerType -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.ServiceType -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.syscalls.HandlerDefinition -import dev.restate.sdk.common.syscalls.HandlerSpecification -import dev.restate.sdk.common.syscalls.ServiceDefinition -import dev.restate.sdk.kotlin.HandlerRunner -import dev.restate.sdk.kotlin.KtSerdes -import dev.restate.sdk.kotlin.ObjectContext -import java.nio.charset.StandardCharsets - -private val COUNTER: StateKey = - StateKey.of( - "counter", - Serde.using( - { l: Long -> l.toString().toByteArray(StandardCharsets.UTF_8) }, - { v: ByteArray? -> String(v!!, StandardCharsets.UTF_8).toLong() })) - -fun counter(): ServiceDefinition<*> = - ServiceDefinition.of( - "KtCounter", - ServiceType.VIRTUAL_OBJECT, - listOf( - HandlerDefinition.of( - HandlerSpecification.of( - "get", HandlerType.EXCLUSIVE, KtSerdes.UNIT, KtSerdes.json()), - HandlerRunner.of { ctx: ObjectContext, _: Unit -> ctx.get(COUNTER) ?: -1 }))) diff --git a/sdk-lambda/src/test/resources/log4j2.properties b/sdk-lambda/src/test/resources/log4j2.properties deleted file mode 100644 index 5933d7fa8..000000000 --- a/sdk-lambda/src/test/resources/log4j2.properties +++ /dev/null @@ -1,8 +0,0 @@ -rootLogger.level = DEBUG -rootLogger.appenderRef.testlogger.ref = TestLogger - -appender.testlogger.name = TestLogger -appender.testlogger.type = CONSOLE -appender.testlogger.target = SYSTEM_ERR -appender.testlogger.layout.type = PatternLayout -appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n \ No newline at end of file diff --git a/sdk-request-identity/src/main/java/dev/restate/sdk/auth/signing/RestateRequestIdentityVerifier.java b/sdk-request-identity/src/main/java/dev/restate/sdk/auth/signing/RestateRequestIdentityVerifier.java index d7d7ea08f..191795e80 100644 --- a/sdk-request-identity/src/main/java/dev/restate/sdk/auth/signing/RestateRequestIdentityVerifier.java +++ b/sdk-request-identity/src/main/java/dev/restate/sdk/auth/signing/RestateRequestIdentityVerifier.java @@ -15,7 +15,8 @@ import com.nimbusds.jose.jwk.OctetKeyPair; import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.SignedJWT; -import dev.restate.sdk.auth.RequestIdentityVerifier; +import dev.restate.sdk.endpoint.HeadersAccessor; +import dev.restate.sdk.endpoint.RequestIdentityVerifier; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -34,7 +35,7 @@ private RestateRequestIdentityVerifier(List verifier) { } @Override - public void verifyRequest(Headers headers) throws Exception { + public void verifyRequest(HeadersAccessor headers) throws Exception { String signatureScheme = expectHeader(headers, SIGNATURE_SCHEME_HEADER); switch (signatureScheme) { case SIGNATURE_SCHEME_V1: @@ -53,7 +54,7 @@ public void verifyRequest(Headers headers) throws Exception { } } - private String expectHeader(Headers headers, String key) { + private String expectHeader(HeadersAccessor headers, String key) { String value = headers.get(key); if (value == null) { throw new IllegalArgumentException("Missing header " + key); diff --git a/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdeFactory.java b/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdeFactory.java new file mode 100644 index 000000000..480ff5a42 --- /dev/null +++ b/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdeFactory.java @@ -0,0 +1,93 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.serde.jackson; + +import static dev.restate.sdk.serde.jackson.JacksonSerdes.sneakyThrow; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import dev.restate.common.Slice; +import dev.restate.serde.Serde; +import dev.restate.serde.SerdeFactory; +import dev.restate.serde.TypeRef; +import java.io.IOException; +import java.lang.reflect.Type; +import org.jspecify.annotations.NonNull; + +public class JacksonSerdeFactory implements SerdeFactory { + + public static final JacksonSerdeFactory DEFAULT = new JacksonSerdeFactory(); + + private final ObjectMapper mapper; + private final SchemaGenerator schemaGenerator; + + public JacksonSerdeFactory() { + this(JacksonSerdes.defaultMapper); + } + + public JacksonSerdeFactory(ObjectMapper mapper) { + this(mapper, JacksonSerdes.schemaGenerator); + } + + public JacksonSerdeFactory(ObjectMapper mapper, SchemaGenerator schemaGenerator) { + this.mapper = mapper; + this.schemaGenerator = schemaGenerator; + } + + @Override + public Serde create(TypeRef typeRef) { + return create( + mapper.constructType(typeRef.getType()), typeRef.getType(), schemaGenerator, mapper); + } + + @Override + public Serde create(Class clazz) { + return create(mapper.constructType(clazz), clazz, schemaGenerator, mapper); + } + + static Serde create( + JavaType constructedType, + Type originalType, + SchemaGenerator schemaGenerator, + ObjectMapper mapper) { + return new Serde<>() { + @Override + public Schema jsonSchema() { + return new Serde.JsonSchema(schemaGenerator.generateSchema(originalType)); + } + + @Override + public Slice serialize(T value) { + try { + return Slice.wrap(mapper.writeValueAsBytes(value)); + } catch (JsonProcessingException e) { + sneakyThrow(e); + return null; + } + } + + @Override + public T deserialize(@NonNull Slice value) { + try { + return mapper.readValue(value.toByteArray(), constructedType); + } catch (IOException e) { + sneakyThrow(e); + return null; + } + } + + @Override + public String contentType() { + return "application/json"; + } + }; + } +} diff --git a/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdes.java b/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdes.java index 4e48eca1f..36daa91aa 100644 --- a/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdes.java +++ b/sdk-serde-jackson/src/main/java/dev/restate/sdk/serde/jackson/JacksonSerdes.java @@ -8,18 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.serde.jackson; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.github.victools.jsonschema.generator.*; import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; -import dev.restate.sdk.common.RichSerde; -import dev.restate.sdk.common.Serde; -import java.io.IOException; +import dev.restate.serde.Serde; import java.util.stream.StreamSupport; -import org.jspecify.annotations.Nullable; /** * {@link Serde} implementations for Jackson. @@ -46,8 +42,8 @@ public final class JacksonSerdes { private JacksonSerdes() {} - private static final ObjectMapper defaultMapper; - private static final SchemaGenerator schemaGenerator; + static final ObjectMapper defaultMapper; + static final SchemaGenerator schemaGenerator; static { defaultMapper = new ObjectMapper(); @@ -101,37 +97,7 @@ public static Serde of(Class clazz) { /** Serialize/Deserialize class using the provided object mapper. */ public static Serde of(ObjectMapper mapper, Class clazz) { - return new RichSerde<>() { - @Override - public @Nullable Object jsonSchema() { - return schemaGenerator.generateSchema(clazz); - } - - @Override - public byte[] serialize(T value) { - try { - return mapper.writeValueAsBytes(value); - } catch (JsonProcessingException e) { - sneakyThrow(e); - return null; - } - } - - @Override - public T deserialize(byte[] value) { - try { - return mapper.readValue(value, clazz); - } catch (IOException e) { - sneakyThrow(e); - return null; - } - } - - @Override - public String contentType() { - return "application/json"; - } - }; + return JacksonSerdeFactory.create(mapper.constructType(clazz), clazz, schemaGenerator, mapper); } /** Serialize/Deserialize {@link TypeReference} using the default object mapper. */ @@ -141,41 +107,12 @@ public static Serde of(TypeReference typeReference) { /** Serialize/Deserialize {@link TypeReference} using the default object mapper. */ public static Serde of(ObjectMapper mapper, TypeReference typeReference) { - return new RichSerde<>() { - @Override - public @Nullable Object jsonSchema() { - return schemaGenerator.generateSchema(typeReference.getType()); - } - - @Override - public byte[] serialize(T value) { - try { - return mapper.writeValueAsBytes(value); - } catch (JsonProcessingException e) { - sneakyThrow(e); - return null; - } - } - - @Override - public T deserialize(byte[] value) { - try { - return mapper.readValue(value, typeReference); - } catch (IOException e) { - sneakyThrow(e); - return null; - } - } - - @Override - public String contentType() { - return "application/json"; - } - }; + return JacksonSerdeFactory.create( + mapper.constructType(typeReference), typeReference.getType(), schemaGenerator, mapper); } @SuppressWarnings("unchecked") - private static void sneakyThrow(Object exception) throws E { + static void sneakyThrow(Object exception) throws E { throw (E) exception; } } diff --git a/sdk-serde-jackson/src/test/java/dev/restate/sdk/serde/jackson/JacksonSerdesTest.java b/sdk-serde-jackson/src/test/java/dev/restate/sdk/serde/jackson/JacksonSerdesTest.java index 8d422ab7f..0a07b06f0 100644 --- a/sdk-serde-jackson/src/test/java/dev/restate/sdk/serde/jackson/JacksonSerdesTest.java +++ b/sdk-serde-jackson/src/test/java/dev/restate/sdk/serde/jackson/JacksonSerdesTest.java @@ -13,7 +13,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.type.TypeReference; -import dev.restate.sdk.common.Serde; +import dev.restate.serde.Serde; import java.util.List; import java.util.Objects; import java.util.Set; diff --git a/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java b/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java deleted file mode 100644 index 65189e774..000000000 --- a/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.serde.protobuf; - -import com.google.protobuf.*; -import dev.restate.sdk.common.Serde; -import java.nio.ByteBuffer; -import java.util.Objects; -import org.jspecify.annotations.Nullable; - -/** Collection of serializers/deserializers for Protobuf */ -public abstract class ProtobufSerdes { - - private ProtobufSerdes() {} - - public static Serde of(Parser parser) { - return new Serde<>() { - @Override - public byte[] serialize(@Nullable T value) { - return Objects.requireNonNull(value).toByteArray(); - } - - @Override - public T deserialize(byte[] value) { - try { - return parser.parseFrom(value); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Cannot deserialize Protobuf object", e); - } - } - - // -- We reimplement the ByteBuffer variants here as it might be more efficient to use them. - - @Override - public ByteBuffer serializeToByteBuffer(@Nullable T value) { - return Objects.requireNonNull(value).toByteString().asReadOnlyByteBuffer(); - } - - @Override - public T deserialize(ByteBuffer byteBuffer) { - try { - return parser.parseFrom(UnsafeByteOperations.unsafeWrap(byteBuffer.rewind())); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Cannot deserialize Protobuf object", e); - } - } - }; - } -} diff --git a/sdk-spring-boot-kotlin-starter/build.gradle.kts b/sdk-spring-boot-kotlin-starter/build.gradle.kts index e34a2265b..7962733d8 100644 --- a/sdk-spring-boot-kotlin-starter/build.gradle.kts +++ b/sdk-spring-boot-kotlin-starter/build.gradle.kts @@ -11,6 +11,7 @@ dependencies { compileOnly(libs.jspecify) api(project(":sdk-api-kotlin")) + api(project(":client-kotlin")) api(project(":sdk-spring-boot")) kspTest(project(":sdk-api-kotlin-gen")) diff --git a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt index a29ce0ded..5bb7c6387 100644 --- a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt +++ b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/RestateHttpEndpointBeanTest.kt @@ -9,8 +9,8 @@ package dev.restate.sdk.springboot.kotlin import com.fasterxml.jackson.databind.ObjectMapper -import dev.restate.sdk.core.manifest.EndpointManifestSchema -import dev.restate.sdk.core.manifest.Service +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema +import dev.restate.sdk.core.generated.manifest.Service import dev.restate.sdk.springboot.RestateHttpEndpointBean import java.io.IOException import java.net.URI diff --git a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt index 6dc2ed51c..09f5d72f5 100644 --- a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt +++ b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.springboot.kotlin -import dev.restate.sdk.client.Client +import dev.restate.client.Client import dev.restate.sdk.testing.BindService import dev.restate.sdk.testing.RestateClient import dev.restate.sdk.testing.RestateTest diff --git a/sdk-spring-boot-starter/build.gradle.kts b/sdk-spring-boot-starter/build.gradle.kts index 8fa1d4c3d..4b421287d 100644 --- a/sdk-spring-boot-starter/build.gradle.kts +++ b/sdk-spring-boot-starter/build.gradle.kts @@ -1,7 +1,6 @@ plugins { `java-conventions` `java-library` - `test-jar-conventions` `library-publishing-conventions` alias(libs.plugins.spring.dependency.management) } @@ -18,6 +17,12 @@ dependencies { exclude(group = "com.fasterxml.jackson.core") exclude(group = "com.fasterxml.jackson.datatype") } + api(project(":client")) { + // Let spring bring jackson in + exclude(group = "com.fasterxml.jackson") + exclude(group = "com.fasterxml.jackson.core") + exclude(group = "com.fasterxml.jackson.datatype") + } api(project(":sdk-serde-jackson")) { // Let spring bring jackson in exclude(group = "com.fasterxml.jackson") diff --git a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/RestateHttpEndpointBeanTest.java b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/RestateHttpEndpointBeanTest.java index 1e1d7d6f7..0de42bf06 100644 --- a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/RestateHttpEndpointBeanTest.java +++ b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/RestateHttpEndpointBeanTest.java @@ -11,7 +11,7 @@ import static org.assertj.core.api.Assertions.assertThat; import com.fasterxml.jackson.databind.ObjectMapper; -import dev.restate.sdk.core.manifest.EndpointManifestSchema; +import dev.restate.sdk.core.generated.manifest.EndpointManifestSchema; import dev.restate.sdk.springboot.RestateHttpEndpointBean; import java.io.IOException; import java.net.URI; @@ -52,7 +52,7 @@ public void httpEndpointShouldBeRunning() throws IOException, InterruptedExcepti new ObjectMapper().readValue(response.body(), EndpointManifestSchema.class); assertThat(endpointManifest.getServices()) - .map(dev.restate.sdk.core.manifest.Service::getName) + .map(dev.restate.sdk.core.generated.manifest.Service::getName) .containsOnly("greeter"); } } diff --git a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java index a1f1537c7..34782faea 100644 --- a/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java +++ b/sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/SdkTestingIntegrationTest.java @@ -10,7 +10,7 @@ import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.client.Client; +import dev.restate.client.Client; import dev.restate.sdk.testing.*; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; diff --git a/sdk-spring-boot/build.gradle.kts b/sdk-spring-boot/build.gradle.kts index 779093a66..b6ba11c49 100644 --- a/sdk-spring-boot/build.gradle.kts +++ b/sdk-spring-boot/build.gradle.kts @@ -1,7 +1,6 @@ plugins { `java-conventions` `java-library` - `test-jar-conventions` `library-publishing-conventions` alias(libs.plugins.spring.dependency.management) } @@ -18,6 +17,13 @@ dependencies { exclude(group = "com.fasterxml.jackson.datatype") } + api(project(":client")) { + // Let spring bring jackson in + exclude(group = "com.fasterxml.jackson") + exclude(group = "com.fasterxml.jackson.core") + exclude(group = "com.fasterxml.jackson.datatype") + } + implementation(project(":sdk-http-vertx")) { // Let spring bring jackson in exclude(group = "com.fasterxml.jackson") diff --git a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateClientAutoConfiguration.java b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateClientAutoConfiguration.java index 93a5b3100..ce982a81d 100644 --- a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateClientAutoConfiguration.java +++ b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateClientAutoConfiguration.java @@ -8,7 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.springboot; -import dev.restate.sdk.client.Client; +import dev.restate.client.Client; +import dev.restate.client.ClientRequestOptions; import java.util.Collections; import java.util.Map; import org.springframework.boot.context.properties.EnableConfigurationProperties; @@ -30,6 +31,7 @@ public Client client(RestateClientProperties restateClientProperties) { if (headers == null) { headers = Collections.emptyMap(); } - return Client.connect(restateClientProperties.getBaseUri(), headers); + return Client.connect( + restateClientProperties.getBaseUri(), ClientRequestOptions.withHeaders(headers).build()); } } diff --git a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateComponent.java b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateComponent.java index 2b80e0a52..4185ff0bb 100644 --- a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateComponent.java +++ b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateComponent.java @@ -17,7 +17,7 @@ * dev.restate.sdk.annotation.Workflow} to bind them to the Restate HTTP Endpoint. * *

You can configure the Restate HTTP Endpoint using {@link RestateEndpointProperties} and {@link - * RestateEndpointHttpServerProperties}. + * RestateHttpServerProperties}. */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) diff --git a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointProperties.java b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointProperties.java index 30b2b0bed..d7f5e1c78 100644 --- a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointProperties.java +++ b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointProperties.java @@ -8,8 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.springboot; -import dev.restate.sdk.auth.RequestIdentityVerifier; -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.endpoint.RequestIdentityVerifier; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.boot.context.properties.bind.DefaultValue; @@ -28,14 +28,14 @@ public RestateEndpointProperties( } /** - * @see RestateHttpEndpointBuilder#enablePreviewContext() + * @see Endpoint.Builder#enablePreviewContext() */ public boolean isEnablePreviewContext() { return enablePreviewContext; } /** - * @see RestateHttpEndpointBuilder#withRequestIdentityVerifier(RequestIdentityVerifier) + * @see Endpoint.Builder#withRequestIdentityVerifier(RequestIdentityVerifier) */ public String getIdentityKey() { return identityKey; diff --git a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpEndpointBean.java b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpEndpointBean.java index e10de70ac..6757f9990 100644 --- a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpEndpointBean.java +++ b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpEndpointBean.java @@ -9,7 +9,8 @@ package dev.restate.sdk.springboot; import dev.restate.sdk.auth.signing.RestateRequestIdentityVerifier; -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.http.vertx.RestateHttpServer; import io.vertx.core.http.HttpServer; import java.util.Map; import org.slf4j.Logger; @@ -26,17 +27,14 @@ * @see Component */ @Component -@EnableConfigurationProperties({ - RestateEndpointHttpServerProperties.class, - RestateEndpointProperties.class -}) +@EnableConfigurationProperties({RestateHttpServerProperties.class, RestateEndpointProperties.class}) public class RestateHttpEndpointBean implements InitializingBean, SmartLifecycle { private final Logger logger = LoggerFactory.getLogger(getClass()); private final ApplicationContext applicationContext; private final RestateEndpointProperties restateEndpointProperties; - private final RestateEndpointHttpServerProperties restateEndpointHttpServerProperties; + private final RestateHttpServerProperties restateHttpServerProperties; private volatile boolean running; @@ -45,10 +43,10 @@ public class RestateHttpEndpointBean implements InitializingBean, SmartLifecycle public RestateHttpEndpointBean( ApplicationContext applicationContext, RestateEndpointProperties restateEndpointProperties, - RestateEndpointHttpServerProperties restateEndpointHttpServerProperties) { + RestateHttpServerProperties restateHttpServerProperties) { this.applicationContext = applicationContext; this.restateEndpointProperties = restateEndpointProperties; - this.restateEndpointHttpServerProperties = restateEndpointHttpServerProperties; + this.restateHttpServerProperties = restateHttpServerProperties; } @Override @@ -62,7 +60,7 @@ public void afterPropertiesSet() { return; } - var builder = RestateHttpEndpointBuilder.builder(); + var builder = Endpoint.builder(); for (Object component : restateComponents.values()) { builder = builder.bind(component); } @@ -76,7 +74,7 @@ public void afterPropertiesSet() { RestateRequestIdentityVerifier.fromKey(restateEndpointProperties.getIdentityKey())); } - this.server = builder.build(); + this.server = RestateHttpServer.fromEndpoint(builder.build()); } @Override @@ -84,7 +82,7 @@ public void start() { if (this.server != null) { try { this.server - .listen(this.restateEndpointHttpServerProperties.getPort()) + .listen(this.restateHttpServerProperties.getPort()) .toCompletionStage() .toCompletableFuture() .get(); @@ -92,7 +90,7 @@ public void start() { } catch (Exception e) { logger.error( "Error when starting Restate Spring HTTP server on port {}", - this.restateEndpointHttpServerProperties.getPort(), + this.restateHttpServerProperties.getPort(), e); } this.running = true; diff --git a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointHttpServerProperties.java b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpServerProperties.java similarity index 85% rename from sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointHttpServerProperties.java rename to sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpServerProperties.java index 5ef97d30c..b9a223d5b 100644 --- a/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateEndpointHttpServerProperties.java +++ b/sdk-spring-boot/src/main/java/dev/restate/sdk/springboot/RestateHttpServerProperties.java @@ -14,12 +14,12 @@ import org.springframework.boot.context.properties.bind.Name; @ConfigurationProperties(prefix = "restate.sdk.http") -public class RestateEndpointHttpServerProperties { +public class RestateHttpServerProperties { private final int port; @ConstructorBinding - public RestateEndpointHttpServerProperties(@Name("port") @DefaultValue(value = "9080") int port) { + public RestateHttpServerProperties(@Name("port") @DefaultValue(value = "9080") int port) { this.port = port; } diff --git a/sdk-spring-boot/src/test/java/dev/restate/sdk/springboot/RestateClientAutoConfigurationTest.java b/sdk-spring-boot/src/test/java/dev/restate/sdk/springboot/RestateClientAutoConfigurationTest.java index 101b4cbdc..fee246132 100644 --- a/sdk-spring-boot/src/test/java/dev/restate/sdk/springboot/RestateClientAutoConfigurationTest.java +++ b/sdk-spring-boot/src/test/java/dev/restate/sdk/springboot/RestateClientAutoConfigurationTest.java @@ -10,7 +10,7 @@ import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.client.Client; +import dev.restate.client.Client; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; diff --git a/sdk-testing/build.gradle.kts b/sdk-testing/build.gradle.kts index d4e1e992a..df84ed11f 100644 --- a/sdk-testing/build.gradle.kts +++ b/sdk-testing/build.gradle.kts @@ -9,6 +9,7 @@ description = "Restate SDK testing tools" dependencies { api(project(":sdk-common")) + api(project(":client")) api(libs.junit.api) api(libs.testcontainers) diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/ManualRestateRunner.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/ManualRestateRunner.java deleted file mode 100644 index 60dfe37a7..000000000 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/ManualRestateRunner.java +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.testing; - -import dev.restate.admin.api.DeploymentApi; -import dev.restate.admin.client.ApiClient; -import dev.restate.admin.client.ApiException; -import dev.restate.admin.model.RegisterDeploymentRequest; -import dev.restate.admin.model.RegisterDeploymentRequestAnyOf; -import dev.restate.admin.model.RegisterDeploymentResponse; -import io.vertx.core.http.HttpServer; -import java.net.MalformedURLException; -import java.net.URL; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.stream.Collectors; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.junit.jupiter.api.extension.ExtensionContext; -import org.testcontainers.Testcontainers; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.containers.wait.strategy.WaitAllStrategy; -import org.testcontainers.images.builder.Transferable; -import org.testcontainers.utility.DockerImageName; - -/** - * Manual runner for the Restate test infra, starting the Restate server container together with the - * provided services and automatically registering them. To start the infra use {@link #run()} and - * to stop it use {@link #stop()}. - * - *

Use {@link RestateRunnerBuilder#buildManualRunner()} to build an instance of this class. - * - *

If you use JUnit 5, we suggest using {@link RestateRunner} instead. - */ -public class ManualRestateRunner - implements AutoCloseable, ExtensionContext.Store.CloseableResource { - - private static final Logger LOG = LogManager.getLogger(ManualRestateRunner.class); - - private static final String RESTATE_RUNTIME = "runtime"; - public static final int RESTATE_INGRESS_ENDPOINT_PORT = 8080; - public static final int RESTATE_ADMIN_ENDPOINT_PORT = 9070; - - private final HttpServer server; - private final GenericContainer runtimeContainer; - - ManualRestateRunner( - HttpServer server, - String runtimeContainerImage, - Map additionalEnv, - String configFile) { - this.server = server; - this.runtimeContainer = new GenericContainer<>(DockerImageName.parse(runtimeContainerImage)); - - // Configure runtimeContainer - this.runtimeContainer - // We expose these ports only to enable port checks - .withExposedPorts(RESTATE_INGRESS_ENDPOINT_PORT, RESTATE_ADMIN_ENDPOINT_PORT) - // Let's have a high logging level by default to avoid spamming too much, it can be - // overriden by the user - .withEnv("RUST_LOG", "warn") - .withEnv(additionalEnv) - // These envs should not be overriden by additionalEnv - .withEnv("RESTATE_META__REST_ADDRESS", "0.0.0.0:" + RESTATE_ADMIN_ENDPOINT_PORT) - .withEnv( - "RESTATE_WORKER__INGRESS__BIND_ADDRESS", "0.0.0.0:" + RESTATE_INGRESS_ENDPOINT_PORT) - .withNetworkAliases(RESTATE_RUNTIME) - // Configure wait strategy on health paths - .waitingFor( - new WaitAllStrategy() - .withStrategy(Wait.forHttp("/health").forPort(RESTATE_ADMIN_ENDPOINT_PORT)) - .withStrategy( - Wait.forHttp("/restate/health").forPort(RESTATE_INGRESS_ENDPOINT_PORT))) - .withLogConsumer( - outputFrame -> { - switch (outputFrame.getType()) { - case STDOUT, STDERR -> - LOG.debug("[restate] {}", outputFrame.getUtf8StringWithoutLineEnding()); - case END -> LOG.debug("[restate] END"); - } - }); - - if (configFile != null) { - this.runtimeContainer.withCopyToContainer(Transferable.of(configFile), "/config.yaml"); - this.runtimeContainer.withEnv("RESTATE_CONFIG", "/config.yaml"); - } - } - - /** - * @deprecated Use {@link #start()} instead. - */ - @Deprecated(forRemoval = true) - public void run() { - this.start(); - } - - /** Run restate, run the embedded service endpoint server, and register the services. */ - public void start() { - // Start listening the local server - try { - server.listen(0).toCompletionStage().toCompletableFuture().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - - // Expose the server port - int serviceEndpointPort = server.actualPort(); - LOG.debug("Started embedded service endpoint server on port {}", serviceEndpointPort); - Testcontainers.exposeHostPorts(serviceEndpointPort); - - // Now create the runtime container and deploy it - this.runtimeContainer.start(); - LOG.debug("Started Restate container"); - - // Register services now - ApiClient client = getAdminClient(); - try { - RegisterDeploymentResponse response = - new DeploymentApi(client) - .createDeployment( - new RegisterDeploymentRequest( - new RegisterDeploymentRequestAnyOf() - .uri("http://host.testcontainers.internal:" + serviceEndpointPort))); - LOG.debug( - "Registered services {}", - response.getServices().stream() - .map(dev.restate.admin.model.ServiceMetadata::getName) - .collect(Collectors.toList())); - } catch (ApiException e) { - throw new RuntimeException(e); - } - } - - /** - * Get restate ingress url to send HTTP/gRPC requests to services. - * - * @throws IllegalStateException if the restate container is not running. - */ - public URL getRestateUrl() { - try { - return new URL( - "http", - runtimeContainer.getHost(), - runtimeContainer.getMappedPort(RESTATE_INGRESS_ENDPOINT_PORT), - "/"); - } catch (MalformedURLException e) { - throw new RuntimeException(e); - } - } - - /** - * Get restate admin url to send HTTP requests to the admin API. - * - * @throws IllegalStateException if the restate container is not running. - */ - public URL getAdminUrl() { - try { - return new URL( - "http", - runtimeContainer.getHost(), - runtimeContainer.getMappedPort(RESTATE_ADMIN_ENDPOINT_PORT), - "/"); - } catch (MalformedURLException e) { - throw new RuntimeException(e); - } - } - - /** Get the restate container. */ - public GenericContainer getRestateContainer() { - return this.runtimeContainer; - } - - /** Stop restate and the embedded service endpoint server. */ - public void stop() { - this.close(); - } - - /** Like {@link #stop()}. */ - @Override - public void close() { - runtimeContainer.stop(); - LOG.debug("Stopped Restate container"); - server.close().toCompletionStage().toCompletableFuture().join(); - LOG.debug("Stopped Embedded Service endpoint server"); - } - - // -- Methods used by the JUnit5 extension - - ApiClient getAdminClient() { - return new ApiClient() - .setHost(runtimeContainer.getHost()) - .setPort(runtimeContainer.getMappedPort(RESTATE_ADMIN_ENDPOINT_PORT)); - } - - URL getIngressUrl() { - try { - return new URL( - "http", - runtimeContainer.getHost(), - runtimeContainer.getMappedPort(RESTATE_INGRESS_ENDPOINT_PORT), - "/"); - } catch (MalformedURLException e) { - throw new RuntimeException(e); - } - } -} diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateClient.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateClient.java index 806f68954..711be65e6 100644 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateClient.java +++ b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateClient.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testing; -import dev.restate.sdk.client.Client; +import dev.restate.client.Client; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateExtension.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateExtension.java index 13a7a66e4..ef1b333cf 100644 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateExtension.java +++ b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateExtension.java @@ -8,6 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testing; +import dev.restate.client.Client; +import dev.restate.sdk.endpoint.Endpoint; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; import java.util.List; import java.util.regex.Pattern; import org.junit.jupiter.api.extension.*; @@ -24,33 +29,55 @@ public class RestateExtension implements BeforeAllCallback, ParameterResolver { @Override public void beforeAll(ExtensionContext extensionContext) { - extensionContext - .getStore(NAMESPACE) - .getOrComputeIfAbsent( - RUNNER, ignored -> initializeRestateRunner(extensionContext), RestateRunner.class) - .beforeAll(extensionContext); + getOrCreateRunner(extensionContext).start(); } @Override public boolean supportsParameter( ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { - return extensionContext - .getStore(NAMESPACE) - .getOrComputeIfAbsent( - RUNNER, ignored -> initializeRestateRunner(extensionContext), RestateRunner.class) - .supportsParameter(parameterContext, extensionContext); + return (parameterContext.isAnnotated(RestateAdminClient.class) + && dev.restate.admin.client.ApiClient.class.isAssignableFrom( + parameterContext.getParameter().getType())) + || (parameterContext.isAnnotated(RestateClient.class) + && Client.class.isAssignableFrom(parameterContext.getParameter().getType())) + || (parameterContext.isAnnotated(RestateURL.class) + && (String.class.isAssignableFrom(parameterContext.getParameter().getType()) + || URL.class.isAssignableFrom(parameterContext.getParameter().getType()))); } @Override public Object resolveParameter( ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { + RestateRunner runner = getOrCreateRunner(extensionContext); + if (parameterContext.isAnnotated(RestateAdminClient.class)) { + return runner.getAdminClient(); + } else if (parameterContext.isAnnotated(RestateClient.class)) { + URL url = runner.getIngressUrl(); + return Client.connect(url.toString()); + } else if (parameterContext.isAnnotated(RestateURL.class)) { + URL url = runner.getIngressUrl(); + if (parameterContext.getParameter().getType().equals(String.class)) { + return url.toString(); + } + if (parameterContext.getParameter().getType().equals(URI.class)) { + try { + return url.toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + return url; + } + throw new ParameterResolutionException("The parameter is not supported"); + } + + private RestateRunner getOrCreateRunner(ExtensionContext extensionContext) { return extensionContext .getStore(NAMESPACE) .getOrComputeIfAbsent( - RUNNER, ignored -> initializeRestateRunner(extensionContext), RestateRunner.class) - .resolveParameter(parameterContext, extensionContext); + RUNNER, ignored -> initializeRestateRunner(extensionContext), RestateRunner.class); } private RestateRunner initializeRestateRunner(ExtensionContext extensionContext) { @@ -73,8 +100,10 @@ private RestateRunner initializeRestateRunner(ExtensionContext extensionContext) "Expecting @RestateTest annotation on the test class")); // Build runner discovering services to bind - var runnerBuilder = RestateRunnerBuilder.create(); - servicesToBind.forEach(runnerBuilder::bind); + Endpoint.Builder endpointBuilder = Endpoint.builder(); + servicesToBind.forEach(endpointBuilder::bind); + + var runnerBuilder = RestateRunner.from(endpointBuilder.build()); runnerBuilder.withRestateContainerImage(testAnnotation.containerImage()); if (testAnnotation.environment() != null) { for (String env : testAnnotation.environment()) { @@ -88,6 +117,6 @@ private RestateRunner initializeRestateRunner(ExtensionContext extensionContext) runnerBuilder.withAdditionalEnv(splitEnv[0], splitEnv[1]); } } - return runnerBuilder.buildRunner(); + return runnerBuilder.build(); } } diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunner.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunner.java index 4e9eb1ab3..2a27dfdb6 100644 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunner.java +++ b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunner.java @@ -8,111 +8,247 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testing; -import dev.restate.sdk.client.Client; -import java.net.URI; -import java.net.URISyntaxException; +import dev.restate.admin.api.DeploymentApi; +import dev.restate.admin.client.ApiClient; +import dev.restate.admin.client.ApiException; +import dev.restate.admin.model.RegisterDeploymentRequest; +import dev.restate.admin.model.RegisterDeploymentRequestAnyOf; +import dev.restate.admin.model.RegisterDeploymentResponse; +import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.sdk.http.vertx.RestateHttpServer; +import io.vertx.core.http.HttpServer; +import java.net.MalformedURLException; import java.net.URL; -import org.junit.jupiter.api.extension.*; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.testcontainers.Testcontainers; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.containers.wait.strategy.WaitAllStrategy; +import org.testcontainers.images.builder.Transferable; +import org.testcontainers.utility.DockerImageName; /** - * Restate runner for JUnit 5. Example: + * Manual runner for the Restate test infra, starting the Restate server container together with the + * provided services and automatically registering them. To start the infra use {@link #start()} and + * to stop it use {@link #stop()}. * - *

- * {@code @RegisterExtension}
- * private final static RestateRunner restateRunner = RestateRunnerBuilder.create()
- *         .withService(new MyService())
- *         .buildRunner();
- * 
- * - *

The runner will deploy the services locally, execute Restate as container using Testcontainers, and register the services. - * - *

This extension is scoped per test class, meaning that the restate runner will be shared among - * test methods. - * - *

Use the annotations {@link RestateClient}, {@link RestateURL} and {@link RestateAdminClient} - * to interact with the deployed server: - * - *

- * {@code @Test}
- * void initialCountIsZero({@code @RestateClient} Client client) {
- *     var client = CounterClient.fromClient(ingressClient, "my-counter");
- *
- *     // Use client as usual
- *     long response = client.get();
- *     assertThat(response).isEqualTo(0L);
- * }
- * - * @deprecated We now recommend using {@link RestateTest}. + *

If you use JUnit 5, we suggest using {@link RestateTest} instead. */ -@Deprecated -public class RestateRunner implements BeforeAllCallback, ParameterResolver { +public class RestateRunner implements AutoCloseable, ExtensionContext.Store.CloseableResource { - static final ExtensionContext.Namespace NAMESPACE = - ExtensionContext.Namespace.create(RestateRunner.class); - static final String DEPLOYER_KEY = "Deployer"; + private static final Logger LOG = LogManager.getLogger(RestateRunner.class); - private final ManualRestateRunner deployer; + private static final String RESTATE_RUNTIME = "runtime"; + public static final int RESTATE_INGRESS_ENDPOINT_PORT = 8080; + public static final int RESTATE_ADMIN_ENDPOINT_PORT = 9070; - RestateRunner(ManualRestateRunner deployer) { - this.deployer = deployer; + private final HttpServer server; + private final GenericContainer runtimeContainer; + + RestateRunner( + Endpoint endpoint, + String runtimeContainerImage, + Map additionalEnv, + String configFile) { + this.server = RestateHttpServer.fromEndpoint(endpoint); + this.runtimeContainer = new GenericContainer<>(DockerImageName.parse(runtimeContainerImage)); + + // Configure runtimeContainer + this.runtimeContainer + // We expose these ports only to enable port checks + .withExposedPorts(RESTATE_INGRESS_ENDPOINT_PORT, RESTATE_ADMIN_ENDPOINT_PORT) + // Let's have a high logging level by default to avoid spamming too much, it can be + // overriden by the user + .withEnv("RUST_LOG", "warn") + .withEnv(additionalEnv) + // These envs should not be overriden by additionalEnv + .withEnv("RESTATE_META__REST_ADDRESS", "0.0.0.0:" + RESTATE_ADMIN_ENDPOINT_PORT) + .withEnv( + "RESTATE_WORKER__INGRESS__BIND_ADDRESS", "0.0.0.0:" + RESTATE_INGRESS_ENDPOINT_PORT) + .withNetworkAliases(RESTATE_RUNTIME) + // Configure wait strategy on health paths + .waitingFor( + new WaitAllStrategy() + .withStrategy(Wait.forHttp("/health").forPort(RESTATE_ADMIN_ENDPOINT_PORT)) + .withStrategy( + Wait.forHttp("/restate/health").forPort(RESTATE_INGRESS_ENDPOINT_PORT))) + .withLogConsumer( + outputFrame -> { + switch (outputFrame.getType()) { + case STDOUT, STDERR -> + LOG.debug("[restate] {}", outputFrame.getUtf8StringWithoutLineEnding()); + case END -> LOG.debug("[restate] END"); + } + }); + + if (configFile != null) { + this.runtimeContainer.withCopyToContainer(Transferable.of(configFile), "/config.yaml"); + this.runtimeContainer.withEnv("RESTATE_CONFIG", "/config.yaml"); + } } - @Override - public void beforeAll(ExtensionContext context) { - deployer.start(); - context.getStore(NAMESPACE).put(DEPLOYER_KEY, deployer); + /** Create from {@link Endpoint}. */ + public static Builder from(Endpoint endpoint) { + return new Builder(endpoint); } - @Override - public boolean supportsParameter( - ParameterContext parameterContext, ExtensionContext extensionContext) - throws ParameterResolutionException { - return supportsParameter(parameterContext); + /** + * Builder for {@link RestateRunner}. + * + * @see RestateRunner + */ + public static class Builder { + + private static final String DEFAULT_RESTATE_CONTAINER = "docker.io/restatedev/restate:latest"; + private final Endpoint endpoint; + private String restateContainerImage = DEFAULT_RESTATE_CONTAINER; + private final Map additionalEnv = new HashMap<>(); + private String configFile; + + Builder(Endpoint endpoint) { + this.endpoint = endpoint; + } + + /** Override the container image to use for the Restate runtime. */ + public Builder withRestateContainerImage(String restateContainerImage) { + this.restateContainerImage = restateContainerImage; + return this; + } + + /** Add additional environment variables to the Restate container. */ + public Builder withAdditionalEnv(String key, String value) { + this.additionalEnv.put(key, value); + return this; + } + + /** Mount a config file in the Restate container. */ + public Builder withConfigFile(String configFile) { + this.configFile = configFile; + return this; + } + + /** + * @return a {@link RestateRunner} to start and stop the test infra manually. + */ + public RestateRunner build() { + return new RestateRunner( + endpoint, this.restateContainerImage, this.additionalEnv, this.configFile); + } } - static boolean supportsParameter(ParameterContext parameterContext) { - return (parameterContext.isAnnotated(RestateAdminClient.class) - && dev.restate.admin.client.ApiClient.class.isAssignableFrom( - parameterContext.getParameter().getType())) - || (parameterContext.isAnnotated(RestateClient.class) - && Client.class.isAssignableFrom(parameterContext.getParameter().getType())) - || (parameterContext.isAnnotated(RestateURL.class) - && (String.class.isAssignableFrom(parameterContext.getParameter().getType()) - || URL.class.isAssignableFrom(parameterContext.getParameter().getType()))); + /** Run restate, run the embedded service endpoint server, and register the services. */ + public void start() { + // Start listening the local server + try { + server.listen(0).toCompletionStage().toCompletableFuture().get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + + // Expose the server port + int serviceEndpointPort = server.actualPort(); + LOG.debug("Started embedded service endpoint server on port {}", serviceEndpointPort); + Testcontainers.exposeHostPorts(serviceEndpointPort); + + // Now create the runtime container and deploy it + this.runtimeContainer.start(); + LOG.debug("Started Restate container"); + + // Register services now + ApiClient client = getAdminClient(); + try { + RegisterDeploymentResponse response = + new DeploymentApi(client) + .createDeployment( + new RegisterDeploymentRequest( + new RegisterDeploymentRequestAnyOf() + .uri("http://host.testcontainers.internal:" + serviceEndpointPort))); + LOG.debug( + "Registered services {}", + response.getServices().stream() + .map(dev.restate.admin.model.ServiceMetadata::getName) + .collect(Collectors.toList())); + } catch (ApiException e) { + throw new RuntimeException(e); + } } - @Override - public Object resolveParameter( - ParameterContext parameterContext, ExtensionContext extensionContext) - throws ParameterResolutionException { - if (parameterContext.isAnnotated(RestateAdminClient.class)) { - return getDeployer(extensionContext).getAdminClient(); - } else if (parameterContext.isAnnotated(RestateClient.class)) { - return resolveClient(extensionContext); - } else if (parameterContext.isAnnotated(RestateURL.class)) { - URL url = getDeployer(extensionContext).getIngressUrl(); - if (parameterContext.getParameter().getType().equals(String.class)) { - return url.toString(); - } - if (parameterContext.getParameter().getType().equals(URI.class)) { - try { - return url.toURI(); - } catch (URISyntaxException e) { - throw new RuntimeException(e); - } - } - return url; + /** + * Get restate ingress url to send HTTP/gRPC requests to services. + * + * @throws IllegalStateException if the restate container is not running. + */ + public URL getRestateUrl() { + try { + return new URL( + "http", + runtimeContainer.getHost(), + runtimeContainer.getMappedPort(RESTATE_INGRESS_ENDPOINT_PORT), + "/"); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } + + /** + * Get restate admin url to send HTTP requests to the admin API. + * + * @throws IllegalStateException if the restate container is not running. + */ + public URL getAdminUrl() { + try { + return new URL( + "http", + runtimeContainer.getHost(), + runtimeContainer.getMappedPort(RESTATE_ADMIN_ENDPOINT_PORT), + "/"); + } catch (MalformedURLException e) { + throw new RuntimeException(e); } - throw new ParameterResolutionException("The parameter is not supported"); } - private Client resolveClient(ExtensionContext extensionContext) { - URL url = getDeployer(extensionContext).getIngressUrl(); - return Client.connect(url.toString()); + /** Get the restate container. */ + public GenericContainer getRestateContainer() { + return this.runtimeContainer; + } + + /** Stop restate and the embedded service endpoint server. */ + public void stop() { + this.close(); + } + + /** Like {@link #stop()}. */ + @Override + public void close() { + runtimeContainer.stop(); + LOG.debug("Stopped Restate container"); + server.close().toCompletionStage().toCompletableFuture().join(); + LOG.debug("Stopped Embedded Service endpoint server"); } - private ManualRestateRunner getDeployer(ExtensionContext extensionContext) { - return (ManualRestateRunner) extensionContext.getStore(NAMESPACE).get(DEPLOYER_KEY); + // -- Methods used by the JUnit5 extension + + ApiClient getAdminClient() { + return new ApiClient() + .setHost(runtimeContainer.getHost()) + .setPort(runtimeContainer.getMappedPort(RESTATE_ADMIN_ENDPOINT_PORT)); + } + + URL getIngressUrl() { + try { + return new URL( + "http", + runtimeContainer.getHost(), + runtimeContainer.getMappedPort(RESTATE_INGRESS_ENDPOINT_PORT), + "/"); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } } } diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunnerBuilder.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunnerBuilder.java deleted file mode 100644 index e2e21be68..000000000 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/RestateRunnerBuilder.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.testing; - -import dev.restate.sdk.common.syscalls.ServiceDefinition; -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder; -import java.util.HashMap; -import java.util.Map; - -/** - * Builder for {@link RestateRunner}. - * - * @see RestateRunner - */ -public class RestateRunnerBuilder { - - private static final String DEFAULT_RESTATE_CONTAINER = "docker.io/restatedev/restate:latest"; - private final RestateHttpEndpointBuilder endpointBuilder; - private String restateContainerImage = DEFAULT_RESTATE_CONTAINER; - private final Map additionalEnv = new HashMap<>(); - private String configFile; - - RestateRunnerBuilder(RestateHttpEndpointBuilder endpointBuilder) { - this.endpointBuilder = endpointBuilder; - } - - /** Override the container image to use for the Restate runtime. */ - public RestateRunnerBuilder withRestateContainerImage(String restateContainerImage) { - this.restateContainerImage = restateContainerImage; - return this; - } - - /** Add additional environment variables to the Restate container. */ - public RestateRunnerBuilder withAdditionalEnv(String key, String value) { - this.additionalEnv.put(key, value); - return this; - } - - /** Mount a config file in the Restate container. */ - public RestateRunnerBuilder withConfigFile(String configFile) { - this.configFile = configFile; - return this; - } - - /** - * Add a Restate service to the endpoint. This will automatically discover the generated factory - * based on the class name. - * - *

You can also manually instantiate the {@link ServiceDefinition} using {@link - * #bind(ServiceDefinition)}. - */ - public RestateRunnerBuilder bind(Object service) { - this.endpointBuilder.bind(service); - return this; - } - - /** - * Add a Restate service to the endpoint. - * - *

To set the options, use {@link #bind(ServiceDefinition, Object)}. - */ - public RestateRunnerBuilder bind(ServiceDefinition serviceDefinition) { - //noinspection unchecked - this.endpointBuilder.bind((ServiceDefinition) serviceDefinition, null); - return this; - } - - /** Add a Restate service to the endpoint, setting the options. */ - public RestateRunnerBuilder bind(ServiceDefinition serviceDefinition, O options) { - this.endpointBuilder.bind(serviceDefinition, options); - return this; - } - - /** - * @return a {@link ManualRestateRunner} to start and stop the test infra manually. - */ - public ManualRestateRunner buildManualRunner() { - return new ManualRestateRunner( - this.endpointBuilder.build(), - this.restateContainerImage, - this.additionalEnv, - this.configFile); - } - - /** - * @return a {@link RestateRunner} to be used as JUnit 5 Extension. - * @deprecated If you use JUnit 5, use {@link RestateTest} - */ - @Deprecated - public RestateRunner buildRunner() { - return new RestateRunner(this.buildManualRunner()); - } - - public static RestateRunnerBuilder create() { - return new RestateRunnerBuilder(RestateHttpEndpointBuilder.builder()); - } - - /** Create from {@link RestateHttpEndpointBuilder}. */ - public static RestateRunnerBuilder of(RestateHttpEndpointBuilder endpointBuilder) { - return new RestateRunnerBuilder(endpointBuilder); - } -} diff --git a/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java b/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java index 483ea6066..a821090f6 100644 --- a/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java +++ b/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java @@ -8,11 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testing; -import dev.restate.sdk.JsonSerdes; import dev.restate.sdk.ObjectContext; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.VirtualObject; -import dev.restate.sdk.common.StateKey; +import dev.restate.sdk.types.StateKey; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -21,7 +20,7 @@ public class Counter { private static final Logger LOG = LogManager.getLogger(Counter.class); - private static final StateKey TOTAL = StateKey.of("total", JsonSerdes.LONG); + private static final StateKey TOTAL = StateKey.of("total", Long.class); @Handler public void reset(ObjectContext ctx) { diff --git a/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterOldExtensionTest.java b/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterOldExtensionTest.java deleted file mode 100644 index b50b81301..000000000 --- a/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterOldExtensionTest.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.testing; - -import static org.assertj.core.api.Assertions.assertThat; - -import dev.restate.sdk.client.Client; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.extension.RegisterExtension; - -class CounterOldExtensionTest { - - @RegisterExtension - private static final RestateRunner RESTATE_RUNNER = - RestateRunnerBuilder.create() - .withRestateContainerImage( - "ghcr.io/restatedev/restate:main") // test against the latest main Restate image - .bind(new Counter()) - .buildRunner(); - - @Test - @Timeout(value = 10) - void testGreet(@RestateClient Client ingressClient) { - var client = CounterClient.fromClient(ingressClient, "my-counter"); - - long response = client.get(); - assertThat(response).isEqualTo(0L); - } -} diff --git a/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterTest.java b/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterTest.java index 590165884..6b3510ae5 100644 --- a/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterTest.java +++ b/sdk-testing/src/test/java/dev/restate/sdk/testing/CounterTest.java @@ -10,7 +10,7 @@ import static org.assertj.core.api.Assertions.assertThat; -import dev.restate.sdk.client.Client; +import dev.restate.client.Client; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; diff --git a/settings.gradle.kts b/settings.gradle.kts index 4dbfd42d0..6a00091d7 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -9,16 +9,18 @@ rootProject.name = "sdk-java" -plugins { id("org.gradle.toolchains.foojay-resolver-convention") version "0.7.0" } +plugins { id("org.gradle.toolchains.foojay-resolver-convention") version "0.9.0" } include( "admin-client", + "common", + "client", + "client-kotlin", "sdk-common", "sdk-api", "sdk-api-kotlin", "sdk-core", "sdk-serde-jackson", - "sdk-serde-protobuf", "sdk-request-identity", "sdk-http-vertx", "sdk-lambda", @@ -27,10 +29,19 @@ include( "sdk-api-gen", "sdk-api-kotlin-gen", "sdk-spring-boot", - "sdk-spring-boot-starter", - "sdk-spring-boot-kotlin-starter", + + // Other modules we don't publish "examples", "sdk-aggregated-javadocs", - "test-services") + "test-services", + + // Meta modules + "sdk-java-http", + "sdk-java-lambda", + "sdk-kotlin-http", + "sdk-kotlin-lambda", + "sdk-spring-boot-starter", + "sdk-spring-boot-kotlin-starter", +) dependencyResolutionManagement { repositories { mavenCentral() } } diff --git a/test-services/build.gradle.kts b/test-services/build.gradle.kts index 70c99771a..f9a4713b6 100644 --- a/test-services/build.gradle.kts +++ b/test-services/build.gradle.kts @@ -11,17 +11,20 @@ plugins { dependencies { ksp(project(":sdk-api-kotlin-gen")) - implementation(project(":sdk-api-kotlin")) - implementation(project(":sdk-http-vertx")) - implementation(project(":sdk-serde-jackson")) + implementation(project(":sdk-kotlin-http")) implementation(project(":sdk-request-identity")) implementation(libs.kotlinx.serialization.core) implementation(libs.kotlinx.serialization.json) - implementation(libs.log4j.core) - implementation(libs.kotlinx.coroutines.core) + + // You might be wondering why I'm repeating these dependencies here. Well, don't, it's gradle. + implementation(project(":sdk-common")) + implementation(libs.log4j.api) + implementation(libs.opentelemetry.api) + implementation(libs.jackson.annotations) + implementation(libs.jackson.databind) } // Configuration of jib container images parameters diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt index b61683ada..9af687419 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt @@ -8,16 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.resolve +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.AwakeableHolder +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException class AwakeableHolderImpl : AwakeableHolder { companion object { - private val ID_KEY: StateKey = KtStateKey.json("id") + private val ID_KEY: StateKey = stateKey("id") } override suspend fun hold(context: ObjectContext, id: String) { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt index 8c6ac1289..f3f14bd4b 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt @@ -8,20 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.DurablePromiseKey -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.kotlin.KtSerdes -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.SharedWorkflowContext -import dev.restate.sdk.kotlin.WorkflowContext +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.BlockAndWaitWorkflow +import dev.restate.sdk.types.DurablePromiseKey +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException class BlockAndWaitWorkflowImpl : BlockAndWaitWorkflow { companion object { - private val MY_DURABLE_PROMISE: DurablePromiseKey = - DurablePromiseKey.of("durable-promise", KtSerdes.json()) - private val MY_STATE: StateKey = KtStateKey.json("my-state") + private val MY_DURABLE_PROMISE: DurablePromiseKey = durablePromiseKey("durable-promise") + private val MY_STATE: StateKey = stateKey("my-state") } override suspend fun run(context: WorkflowContext, input: String): String { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt index 4279cceb9..538510aa6 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt @@ -8,22 +8,19 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.kotlin.Awakeable -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.awakeable +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.AwakeableHolderClient import dev.restate.sdk.testservices.contracts.BlockingOperation import dev.restate.sdk.testservices.contracts.CancelTest import dev.restate.sdk.testservices.contracts.CancelTestBlockingServiceClient +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException import kotlin.time.Duration.Companion.days class CancelTestImpl { class RunnerImpl : CancelTest.Runner { companion object { - private val CANCELED_STATE: StateKey = KtStateKey.json("canceled") + private val CANCELED_STATE: StateKey = stateKey("canceled") } override suspend fun startTest(context: ObjectContext, operation: BlockingOperation) { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt index a6393f1f8..2715a5993 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt @@ -8,13 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.SharedObjectContext +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.Counter import dev.restate.sdk.testservices.contracts.CounterUpdateResponse +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException import org.apache.logging.log4j.LogManager import org.apache.logging.log4j.Logger @@ -23,7 +21,7 @@ class CounterImpl : Counter { companion object { private val logger: Logger = LogManager.getLogger(CounterImpl::class.java) - private val COUNTER_KEY: StateKey = KtStateKey.json("counter") + private val COUNTER_KEY: StateKey = stateKey("counter") } override suspend fun reset(context: ObjectContext) { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt index 362812a44..45463903a 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.TerminalException import dev.restate.sdk.kotlin.ObjectContext import dev.restate.sdk.kotlin.retryPolicy import dev.restate.sdk.kotlin.runBlock import dev.restate.sdk.testservices.contracts.Failing import dev.restate.sdk.testservices.contracts.FailingClient +import dev.restate.sdk.types.TerminalException import java.util.concurrent.atomic.AtomicInteger import kotlin.time.Duration.Companion.milliseconds import org.apache.logging.log4j.LogManager diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt index 2d2eca564..581b2c4c4 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.Serde import dev.restate.sdk.kotlin.Context import dev.restate.sdk.kotlin.ObjectContext import dev.restate.sdk.testservices.contracts.AwakeableHolderClient import dev.restate.sdk.testservices.contracts.KillTest import dev.restate.sdk.testservices.contracts.KillTestSingletonClient +import dev.restate.serde.Serde class KillTestImpl { class RunnerImpl : KillTest.Runner { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt index 46b9a302b..17c438ad7 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt @@ -8,15 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.ListObject +import dev.restate.sdk.types.StateKey class ListObjectImpl : ListObject { companion object { private val LIST_KEY: StateKey> = - KtStateKey.json>( + stateKey( "list", ) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt index 67a244efe..8ba125106 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt @@ -9,28 +9,29 @@ package dev.restate.sdk.testservices import dev.restate.sdk.auth.signing.RestateRequestIdentityVerifier -import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder +import dev.restate.sdk.http.vertx.RestateHttpServer +import dev.restate.sdk.kotlin.endpoint.endpoint import dev.restate.sdk.testservices.contracts.* val KNOWN_SERVICES_FACTORIES: Map Any> = mapOf( - AwakeableHolderDefinitions.SERVICE_NAME to { AwakeableHolderImpl() }, - BlockAndWaitWorkflowDefinitions.SERVICE_NAME to { BlockAndWaitWorkflowImpl() }, - CancelTestBlockingServiceDefinitions.SERVICE_NAME to { CancelTestImpl.BlockingService() }, - CancelTestRunnerDefinitions.SERVICE_NAME to { CancelTestImpl.RunnerImpl() }, - CounterDefinitions.SERVICE_NAME to { CounterImpl() }, - FailingDefinitions.SERVICE_NAME to { FailingImpl() }, - KillTestRunnerDefinitions.SERVICE_NAME to { KillTestImpl.RunnerImpl() }, - KillTestSingletonDefinitions.SERVICE_NAME to { KillTestImpl.SingletonImpl() }, - ListObjectDefinitions.SERVICE_NAME to { ListObjectImpl() }, - MapObjectDefinitions.SERVICE_NAME to { MapObjectImpl() }, - NonDeterministicDefinitions.SERVICE_NAME to { NonDeterministicImpl() }, - ProxyDefinitions.SERVICE_NAME to { ProxyImpl() }, - TestUtilsServiceDefinitions.SERVICE_NAME to { TestUtilsServiceImpl() }, + AwakeableHolderMetadata.SERVICE_NAME to { AwakeableHolderImpl() }, + BlockAndWaitWorkflowMetadata.SERVICE_NAME to { BlockAndWaitWorkflowImpl() }, + CancelTestBlockingServiceMetadata.SERVICE_NAME to { CancelTestImpl.BlockingService() }, + CancelTestRunnerMetadata.SERVICE_NAME to { CancelTestImpl.RunnerImpl() }, + CounterMetadata.SERVICE_NAME to { CounterImpl() }, + FailingMetadata.SERVICE_NAME to { FailingImpl() }, + KillTestRunnerMetadata.SERVICE_NAME to { KillTestImpl.RunnerImpl() }, + KillTestSingletonMetadata.SERVICE_NAME to { KillTestImpl.SingletonImpl() }, + ListObjectMetadata.SERVICE_NAME to { ListObjectImpl() }, + MapObjectMetadata.SERVICE_NAME to { MapObjectImpl() }, + NonDeterministicMetadata.SERVICE_NAME to { NonDeterministicImpl() }, + ProxyMetadata.SERVICE_NAME to { ProxyImpl() }, + TestUtilsServiceMetadata.SERVICE_NAME to { TestUtilsServiceImpl() }, interpreterName(0) to { ObjectInterpreterImpl.getInterpreterDefinition(0) }, interpreterName(1) to { ObjectInterpreterImpl.getInterpreterDefinition(1) }, interpreterName(2) to { ObjectInterpreterImpl.getInterpreterDefinition(2) }, - ServiceInterpreterHelperDefinitions.SERVICE_NAME to { ServiceInterpreterHelperImpl() }) + ServiceInterpreterHelperMetadata.SERVICE_NAME to { ServiceInterpreterHelperImpl() }) val NEEDS_EXPERIMENTAL_CONTEXT: Set = setOf() @@ -39,27 +40,29 @@ fun main(args: Array) { if (env == null) { env = "*" } - val restateHttpEndpointBuilder = RestateHttpEndpointBuilder.builder() - if (env == "*") { - KNOWN_SERVICES_FACTORIES.values.forEach { restateHttpEndpointBuilder.bind(it()) } - } else { - for (svc in env.split(",".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) { - val fqsn = svc.trim { it <= ' ' } - restateHttpEndpointBuilder.bind( - KNOWN_SERVICES_FACTORIES[fqsn]?.invoke() - ?: throw IllegalStateException("Service $fqsn not implemented")) + val endpoint = endpoint { + if (env == "*") { + for (svc in KNOWN_SERVICES_FACTORIES.values) { + bind(svc()) + } + } else { + for (svc in env.split(",".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) { + val fqsn = svc.trim { it <= ' ' } + bind( + KNOWN_SERVICES_FACTORIES[fqsn]?.invoke() + ?: throw IllegalStateException("Service $fqsn not implemented")) + } } - } - val requestSigningKey = System.getenv("E2E_REQUEST_SIGNING") - if (requestSigningKey != null) { - restateHttpEndpointBuilder.withRequestIdentityVerifier( - RestateRequestIdentityVerifier.fromKey(requestSigningKey)) - } + val requestSigningKey = System.getenv("E2E_REQUEST_SIGNING") + if (requestSigningKey != null) { + withRequestIdentityVerifier(RestateRequestIdentityVerifier.fromKey(requestSigningKey)) + } - if (env == "*" || NEEDS_EXPERIMENTAL_CONTEXT.any { env.contains(it) }) { - restateHttpEndpointBuilder.enablePreviewContext() + if (env == "*" || NEEDS_EXPERIMENTAL_CONTEXT.any { env.contains(it) }) { + enablePreviewContext() + } } - restateHttpEndpointBuilder.buildAndListen() + RestateHttpServer.listen(endpoint) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt index 0bda7a1b6..71745347e 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt @@ -8,18 +8,17 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.Entry import dev.restate.sdk.testservices.contracts.MapObject class MapObjectImpl : MapObject { override suspend fun set(context: ObjectContext, entry: Entry) { - context.set(KtStateKey.json(entry.key), entry.value) + context.set(stateKey(entry.key), entry.value) } override suspend fun get(context: ObjectContext, key: String): String { - return context.get(KtStateKey.json(key)) ?: "" + return context.get(stateKey(key)) ?: "" } override suspend fun clearAll(context: ObjectContext): List { @@ -27,7 +26,7 @@ class MapObjectImpl : MapObject { // AH AH AH and here I wanna see if you really respect determinism!!! val result = mutableListOf() for (k in keys) { - result.add(Entry(k, context.get(KtStateKey.json(k))!!)) + result.add(Entry(k, context.get(stateKey(k))!!)) } context.clearAll() return result diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt index 11ba74e10..56bceadc9 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt @@ -8,18 +8,17 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.kotlin.KtStateKey -import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.CounterClient import dev.restate.sdk.testservices.contracts.NonDeterministic +import dev.restate.sdk.types.StateKey import java.util.concurrent.ConcurrentHashMap import kotlin.time.Duration.Companion.milliseconds class NonDeterministicImpl : NonDeterministic { private val invocationCounts: ConcurrentHashMap = ConcurrentHashMap() - private val STATE_A: StateKey = KtStateKey.json("a") - private val STATE_B: StateKey = KtStateKey.json("b") + private val STATE_A: StateKey = stateKey("a") + private val STATE_B: StateKey = stateKey("b") override suspend fun eitherSleepOrCall(context: ObjectContext) { if (doLeftAction(context)) { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt index 348ecdbf9..02500d5f7 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt @@ -8,14 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.Serde -import dev.restate.sdk.common.Target -import dev.restate.sdk.kotlin.Awaitable -import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.kotlin.awaitAll +import dev.restate.common.Request +import dev.restate.common.SendRequest +import dev.restate.common.Target +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.ManyCallRequest import dev.restate.sdk.testservices.contracts.Proxy import dev.restate.sdk.testservices.contracts.ProxyRequest +import dev.restate.serde.Serde import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds @@ -29,15 +29,16 @@ class ProxyImpl : Proxy { } override suspend fun call(context: Context, request: ProxyRequest): ByteArray { - return context.call(request.toTarget(), Serde.RAW, Serde.RAW, request.message) + return context + .call(Request.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message)) + .await() } - override suspend fun oneWayCall(context: Context, request: ProxyRequest) { - context.send( - request.toTarget(), - Serde.RAW, - request.message, - request.delayMillis?.milliseconds ?: Duration.ZERO) + override suspend fun oneWayCall(context: Context, request: ProxyRequest): Unit { + val ignored = + context.send( + SendRequest.of(request.toTarget(), Serde.RAW, Serde.SLICE, request.message) + .asSendDelayed((request.delayMillis?.milliseconds ?: Duration.ZERO))) } override suspend fun manyCalls(context: Context, requests: List) { @@ -46,14 +47,20 @@ class ProxyImpl : Proxy { for (request in requests) { if (request.oneWayCall) { context.send( - request.proxyRequest.toTarget(), - Serde.RAW, - request.proxyRequest.message, - request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO) + SendRequest.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.SLICE, + request.proxyRequest.message) + .asSendDelayed((request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO))) } else { val awaitable = - context.callAsync( - request.proxyRequest.toTarget(), Serde.RAW, Serde.RAW, request.proxyRequest.message) + context.call( + Request.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.RAW, + request.proxyRequest.message)) if (request.awaitAtTheEnd) { toAwait.add(awaitable) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt index 9853f54be..536834e6b 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt @@ -44,9 +44,10 @@ class TestUtilsServiceImpl : TestUtilsService { val timeout = ctx.timer(req.awaitTimeout.milliseconds) return select { - awakeable.onAwait { AwakeableResultResponse(it) } - timeout.onAwait { TimeoutResponse } - } + awakeable.onAwait { AwakeableResultResponse(it) } + timeout.onAwait { TimeoutResponse } + } + .await() } override suspend fun sleepConcurrently(context: Context, millisDuration: List) { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index 7df1ea349..b21e0019a 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -8,18 +8,21 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.common.StateKey -import dev.restate.sdk.common.Target -import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.common.syscalls.ServiceDefinition +import dev.restate.common.Request +import dev.restate.common.SendRequest +import dev.restate.common.Target +import dev.restate.sdk.endpoint.definition.ServiceDefinition import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.* import dev.restate.sdk.testservices.contracts.Program +import dev.restate.sdk.types.StateKey +import dev.restate.sdk.types.TerminalException +import dev.restate.serde.Serde import kotlin.random.Random import kotlin.time.Duration.Companion.milliseconds fun interpreterName(layer: Int): String { - return "${ObjectInterpreterDefinitions.SERVICE_NAME}L$layer" + return "${ObjectInterpreterMetadata.SERVICE_NAME}L$layer" } fun interpretTarget(layer: Int, key: String): Target { @@ -54,16 +57,16 @@ suspend fun checkAwaitableFails( } fun cmdStateKey(key: Int): StateKey { - return KtStateKey.json("key-$key") + return stateKey("key-$key") } class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { companion object { - private val COUNTER: StateKey = KtStateKey.json("counter") + private val COUNTER: StateKey = stateKey("counter") - fun getInterpreterDefinition(layer: Int): ServiceDefinition { + fun getInterpreterDefinition(layer: Int): ServiceDefinition { val originalDefinition = - ObjectInterpreterServiceDefinitionFactory().create(ObjectInterpreterImpl(layer)) + ObjectInterpreterServiceDefinitionFactory().create(ObjectInterpreterImpl(layer), null) return ServiceDefinition.of( interpreterName(layer), originalDefinition.serviceType, originalDefinition.handlers) } @@ -91,11 +94,12 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is CallObject -> { val awaitable = - ctx.callAsync( - interpretTarget(layer + 1, cmd.key.toString()), - ObjectInterpreterDefinitions.Serde.INTERPRET_INPUT, - ObjectInterpreterDefinitions.Serde.INTERPRET_OUTPUT, - cmd.program) + ctx.call( + Request.of( + interpretTarget(layer + 1, cmd.key.toString()), + ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, + ObjectInterpreterMetadata.Serde.INTERPRET_OUTPUT, + cmd.program)) promises[i] = { awaitable.await() } } is CallService -> { @@ -136,8 +140,8 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is IncrementViaDelayedCall -> { ServiceInterpreterHelperClient.fromContext(ctx) - .send(delay = cmd.duration.milliseconds) - .incrementIndirectly(interpreterId(ctx)) + .send() + .incrementIndirectly(interpreterId(ctx), delay = cmd.duration.milliseconds) } is RecoverTerminalCall -> { var caught = false @@ -209,9 +213,11 @@ class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { override suspend fun incrementIndirectly(ctx: Context, id: InterpreterId) { ctx.send( - interpretTarget(id.layer, id.key), - ObjectInterpreterDefinitions.Serde.INTERPRET_INPUT, - Program(listOf(IncrementStateCounter()))) + SendRequest.of( + interpretTarget(id.layer, id.key), + ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, + Serde.SLICE, + Program(listOf(IncrementStateCounter())))) } override suspend fun resolveAwakeable(ctx: Context, id: String) { @@ -242,8 +248,10 @@ class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { // 4. to thank our interpret, let us ask it to inc its state. // ctx.send( - interpretTarget(req.interpreter.layer, req.interpreter.key), - ObjectInterpreterDefinitions.Serde.INTERPRET_INPUT, - Program(listOf(IncrementStateCounter()))) + SendRequest.of( + interpretTarget(req.interpreter.layer, req.interpreter.key), + ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, + Serde.SLICE, + Program(listOf(IncrementStateCounter())))) } }