From f5f47910a2b126d826f1e373c83a1c7ecf9d22c2 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 18 Feb 2025 18:09:29 +0100 Subject: [PATCH] Add cancel invocation/attach invocation/get invocation output --- .../dev/restate/client/kotlin/ingress.kt | 25 +-- .../main/java/dev/restate/client/Client.java | 29 ++- .../java/dev/restate/client/SendResponse.java | 2 +- .../dev/restate/client/base/BaseClient.java | 31 +-- .../common/{CallRequest.java => Request.java} | 68 +++++-- .../java/dev/restate/common/SendRequest.java | 178 +++--------------- .../src/main/resources/templates/Client.hbs | 27 ++- .../src/main/resources/templates/Requests.hbs | 18 +- .../src/main/resources/templates/Client.hbs | 34 +++- .../src/main/resources/templates/Requests.hbs | 18 +- .../dev/restate/sdk/kotlin/ContextImpl.kt | 60 ++++-- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 78 ++++++-- .../dev/restate/sdk/kotlin/awaitables.kt | 24 ++- .../java/dev/restate/sdk/CallAwaitable.java | 13 +- .../main/java/dev/restate/sdk/Context.java | 45 +++-- .../java/dev/restate/sdk/ContextImpl.java | 95 +++++++--- .../dev/restate/sdk/InvocationHandle.java | 29 +++ .../main/java/dev/restate/sdk/SendHandle.java | 22 --- .../src/main/java/dev/restate/sdk/Util.java | 11 ++ .../endpoint/definition/HandlerContext.java | 8 +- .../{Request.java => HandlerRequest.java} | 4 +- .../restate/sdk/core/HandlerContextImpl.java | 146 +++++++------- .../core/statemachine/CommandAccessor.java | 6 + .../sdk/core/statemachine/StateMachine.java | 4 + .../core/statemachine/StateMachineImpl.java | 30 +++ .../restate/sdk/core/javaapi/CallTest.java | 4 +- .../sdk/core/javaapi/JavaAPITests.java | 4 +- .../restate/sdk/core/kotlinapi/CallTest.kt | 11 +- .../sdk/core/kotlinapi/KotlinAPITests.kt | 4 +- .../dev/restate/sdk/testservices/ProxyImpl.kt | 21 ++- .../restate/sdk/testservices/interpreter.kt | 14 +- 31 files changed, 603 insertions(+), 460 deletions(-) rename common/src/main/java/dev/restate/common/{CallRequest.java => Request.java} (76%) create mode 100644 sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java delete mode 100644 sdk-api/src/main/java/dev/restate/sdk/SendHandle.java rename sdk-common/src/main/java/dev/restate/sdk/types/{Request.java => HandlerRequest.java} (88%) diff --git a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt index 10a71883b..f682e3416 100644 --- a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt +++ b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt @@ -12,9 +12,8 @@ import dev.restate.client.Client import dev.restate.client.ClientRequestOptions import dev.restate.client.ClientResponse import dev.restate.client.SendResponse -import dev.restate.common.CallRequest import dev.restate.common.Output -import dev.restate.common.SendRequest +import dev.restate.common.Request import dev.restate.serde.Serde import kotlinx.coroutines.future.await @@ -26,24 +25,26 @@ fun clientRequestOptions(init: ClientRequestOptions.Builder.() -> Unit): ClientR return builder.build() } -suspend fun Client.callSuspend(callRequest: CallRequest): ClientResponse { - return this.callAsync(callRequest).await() +suspend fun Client.callSuspend(request: Request): ClientResponse { + return this.callAsync(request).await() } suspend fun Client.callSuspend( - callRequestBuilder: CallRequest.Builder + requestBuilder: Request.Builder ): ClientResponse { - return this.callAsync(callRequestBuilder).await() + return this.callAsync(requestBuilder).await() } -suspend fun Client.sendSuspend(sendRequest: SendRequest): ClientResponse { - return this.sendAsync(sendRequest).await() +suspend fun Client.sendSuspend( + request: Request +): ClientResponse> { + return this.sendAsync(request).await() } -suspend fun Client.sendSuspend( - sendRequestBuilder: SendRequest.Builder -): ClientResponse { - return this.sendAsync(sendRequestBuilder).await() +suspend fun Client.sendSuspend( + request: Request.Builder +): ClientResponse> { + return this.sendSuspend(request.build()) } suspend fun Client.AwakeableHandle.resolveSuspend( diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index a9514dcb4..12cb1f7ec 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -8,9 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.client; -import dev.restate.common.CallRequest; import dev.restate.common.Output; -import dev.restate.common.SendRequest; +import dev.restate.common.Request; import dev.restate.common.Target; import dev.restate.serde.Serde; import dev.restate.serde.SerdeFactory; @@ -21,15 +20,14 @@ public interface Client { - CompletableFuture> callAsync(CallRequest request); + CompletableFuture> callAsync(Request request); default CompletableFuture> callAsync( - CallRequest.Builder request) { + Request.Builder request) { return callAsync(request.build()); } - default ClientResponse call(CallRequest request) - throws IngressException { + default ClientResponse call(Request request) throws IngressException { try { return callAsync(request).join(); } catch (CompletionException e) { @@ -40,19 +38,15 @@ default ClientResponse call(CallRequest request) } } - default ClientResponse call(CallRequest.Builder request) + default ClientResponse call(Request.Builder request) throws IngressException { return call(request.build()); } - CompletableFuture> sendAsync(SendRequest request); - - default CompletableFuture> sendAsync( - SendRequest.Builder request) { - return sendAsync(request.build()); - } + CompletableFuture>> sendAsync( + Request request); - default ClientResponse send(SendRequest request) + default ClientResponse> send(Request request) throws IngressException { try { return sendAsync(request).join(); @@ -64,7 +58,12 @@ default ClientResponse send(SendRequest request) } } - default ClientResponse send(SendRequest.Builder request) + default CompletableFuture>> sendAsync( + Request.Builder request) { + return sendAsync(request.build()); + } + + default ClientResponse> send(Request.Builder request) throws IngressException { return send(request.build()); } diff --git a/client/src/main/java/dev/restate/client/SendResponse.java b/client/src/main/java/dev/restate/client/SendResponse.java index ded27ed9e..c9c5bae07 100644 --- a/client/src/main/java/dev/restate/client/SendResponse.java +++ b/client/src/main/java/dev/restate/client/SendResponse.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.client; -public record SendResponse(SendStatus status, String invocationId) { +public record SendResponse(SendStatus status, Client.InvocationHandle invocationHandle) { public enum SendStatus { /** The request was sent for the first time. */ ACCEPTED, diff --git a/client/src/main/java/dev/restate/client/base/BaseClient.java b/client/src/main/java/dev/restate/client/base/BaseClient.java index 4d0262233..6f268d2bc 100644 --- a/client/src/main/java/dev/restate/client/base/BaseClient.java +++ b/client/src/main/java/dev/restate/client/base/BaseClient.java @@ -49,10 +49,9 @@ protected BaseClient(URI baseUri, SerdeFactory serdeFactory, ClientRequestOption } @Override - public CompletableFuture> callAsync( - CallRequest request) { - Serde reqSerde = this.serdeFactory.create(request.requestSerdeInfo()); - Serde resSerde = this.serdeFactory.create(request.responseSerdeInfo()); + public CompletableFuture> callAsync(Request request) { + Serde reqSerde = this.serdeFactory.create(request.requestTypeTag()); + Serde resSerde = this.serdeFactory.create(request.responseTypeTag()); URI requestUri = toRequestURI(request.target(), false, null); Stream> headersStream = @@ -75,10 +74,15 @@ public CompletableFuture> callAsync( } @Override - public CompletableFuture> sendAsync(SendRequest request) { - Serde reqSerde = this.serdeFactory.create(request.requestSerdeInfo()); - - URI requestUri = toRequestURI(request.target(), true, request.delay()); + public CompletableFuture>> sendAsync( + Request request) { + Serde reqSerde = this.serdeFactory.create(request.requestTypeTag()); + + URI requestUri = + toRequestURI( + request.target(), + true, + (request instanceof SendRequest sendRequest) ? sendRequest.delay() : null); Stream> headersStream = Stream.concat( baseOptions.headers().entrySet().stream(), request.headers().entrySet().stream()); @@ -146,7 +150,10 @@ public CompletableFuture> sendAsync(SendReque } return new ClientResponse<>( - statusCode, responseHeaders, new SendResponse(status, fields.get("invocationId"))); + statusCode, + responseHeaders, + new SendResponse<>( + status, invocationHandle(fields.get("invocationId"), request.responseTypeTag()))); }); } @@ -199,6 +206,8 @@ public CompletableFuture> rejectAsync( @Override public InvocationHandle invocationHandle( String invocationId, TypeTag resTypeTag) { + Serde resSerde = serdeFactory.create(resTypeTag); + return new InvocationHandle<>() { @Override public String invocationId() { @@ -207,8 +216,6 @@ public String invocationId() { @Override public CompletableFuture> attachAsync(ClientRequestOptions options) { - Serde resSerde = serdeFactory.create(resTypeTag); - URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/attach"); Stream> headersStream = Stream.concat( @@ -221,8 +228,6 @@ public CompletableFuture> attachAsync(ClientRequestOptions o @Override public CompletableFuture>> getOutputAsync( ClientRequestOptions options) { - Serde resSerde = serdeFactory.create(resTypeTag); - URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/output"); Stream> headersStream = Stream.concat( diff --git a/common/src/main/java/dev/restate/common/CallRequest.java b/common/src/main/java/dev/restate/common/Request.java similarity index 76% rename from common/src/main/java/dev/restate/common/CallRequest.java rename to common/src/main/java/dev/restate/common/Request.java index 12529c75b..3e15df886 100644 --- a/common/src/main/java/dev/restate/common/CallRequest.java +++ b/common/src/main/java/dev/restate/common/Request.java @@ -10,12 +10,13 @@ import dev.restate.serde.Serde; import dev.restate.serde.TypeTag; +import java.time.Duration; import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; import org.jspecify.annotations.Nullable; -public final class CallRequest { +public sealed class Request permits SendRequest { private final Target target; private final TypeTag reqTypeTag; @@ -24,7 +25,7 @@ public final class CallRequest { @Nullable private final String idempotencyKey; @Nullable private final LinkedHashMap headers; - private CallRequest( + Request( Target target, TypeTag reqTypeTag, TypeTag resTypeTag, @@ -43,11 +44,11 @@ public Target target() { return target; } - public TypeTag requestSerdeInfo() { + public TypeTag requestTypeTag() { return reqTypeTag; } - public TypeTag responseSerdeInfo() { + public TypeTag responseTypeTag() { return resTypeTag; } @@ -71,15 +72,15 @@ public static Builder of( return new Builder<>(target, reqTypeTag, resTypeTag, request); } - public static Builder withNoRequestBody(Target target, TypeTag resTypeTag) { - return new Builder<>(target, Serde.VOID, resTypeTag, null); - } - public static Builder withNoResponseBody( Target target, TypeTag reqTypeTag, Req request) { return new Builder<>(target, reqTypeTag, Serde.VOID, request); } + public static Builder withNoRequestBody(Target target, TypeTag resTypeTag) { + return new Builder<>(target, Serde.VOID, resTypeTag, null); + } + public static Builder ofRaw(Target target, byte[] request) { return new Builder<>(target, TypeTag.of(Serde.RAW), TypeTag.of(Serde.RAW), request); } @@ -92,6 +93,21 @@ public static final class Builder { @Nullable private String idempotencyKey; @Nullable private LinkedHashMap headers; + public Builder( + Target target, + TypeTag reqTypeTag, + TypeTag resTypeTag, + Req request, + @Nullable String idempotencyKey, + @Nullable LinkedHashMap headers) { + this.target = target; + this.reqTypeTag = reqTypeTag; + this.resTypeTag = resTypeTag; + this.request = request; + this.idempotencyKey = idempotencyKey; + this.headers = headers; + } + private Builder(Target target, TypeTag reqTypeTag, TypeTag resTypeTag, Req request) { this.target = target; this.reqTypeTag = reqTypeTag; @@ -149,8 +165,18 @@ public Builder setHeaders(@Nullable Map headers) { return headers(headers); } - public CallRequest build() { - return new CallRequest<>( + public SendRequest asSend() { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null); + } + + public SendRequest asSendDelayed(Duration delay) { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay); + } + + public Request build() { + return new Request<>( this.target, this.reqTypeTag, this.resTypeTag, @@ -160,9 +186,29 @@ public CallRequest build() { } } + public Builder toBuilder() { + return new Builder<>( + this.target, + this.reqTypeTag, + this.resTypeTag, + this.request, + this.idempotencyKey, + this.headers); + } + + public SendRequest asSend() { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null); + } + + public SendRequest asSendDelayed(Duration delay) { + return new SendRequest<>( + target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay); + } + @Override public boolean equals(Object o) { - if (!(o instanceof CallRequest that)) return false; + if (!(o instanceof Request that)) return false; return Objects.equals(target, that.target) && Objects.equals(reqTypeTag, that.reqTypeTag) && Objects.equals(resTypeTag, that.resTypeTag) diff --git a/common/src/main/java/dev/restate/common/SendRequest.java b/common/src/main/java/dev/restate/common/SendRequest.java index c97122d88..61f908bfa 100644 --- a/common/src/main/java/dev/restate/common/SendRequest.java +++ b/common/src/main/java/dev/restate/common/SendRequest.java @@ -8,196 +8,72 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.common; -import dev.restate.serde.Serde; import dev.restate.serde.TypeTag; import java.time.Duration; import java.util.LinkedHashMap; -import java.util.Map; import java.util.Objects; import org.jspecify.annotations.Nullable; -public final class SendRequest { +public final class SendRequest extends Request { - private final Target target; - private final TypeTag reqTypeTag; - private final Req request; - @Nullable private final String idempotencyKey; - @Nullable private final LinkedHashMap headers; @Nullable private final Duration delay; - private SendRequest( + SendRequest( Target target, TypeTag reqTypeTag, + TypeTag resTypeTag, Req request, @Nullable String idempotencyKey, @Nullable LinkedHashMap headers, @Nullable Duration delay) { - this.target = target; - this.reqTypeTag = reqTypeTag; - this.request = request; - this.idempotencyKey = idempotencyKey; - this.headers = headers; + super(target, reqTypeTag, resTypeTag, request, idempotencyKey, headers); this.delay = delay; } - public Target target() { - return target; - } - - public TypeTag requestSerdeInfo() { - return reqTypeTag; - } - - public Req request() { - return request; - } - - public @Nullable String idempotencyKey() { - return idempotencyKey; - } - - public Map headers() { - if (headers == null) { - return Map.of(); - } - return headers; - } - public @Nullable Duration delay() { return delay; } - public static Builder of(Target target, TypeTag reqTypeTag, Req request) { - return new Builder<>(target, reqTypeTag, request); - } - - public static Builder withNoRequestBody(Target target) { - return new Builder<>(target, Serde.VOID, null); - } - - public static Builder ofRaw(Target target, byte[] request) { - return new Builder<>(target, TypeTag.of(Serde.RAW), request); - } - - public static final class Builder { - private final Target target; - private final TypeTag reqTypeTag; - private final Req request; - @Nullable private String idempotencyKey; - @Nullable private LinkedHashMap headers; - @Nullable private Duration delay; - - private Builder(Target target, TypeTag reqTypeTag, Req request) { - this.target = target; - this.reqTypeTag = reqTypeTag; - this.request = request; - } - - /** - * @param idempotencyKey Idempotency key to attach in the request. - * @return this instance, so the builder can be used fluently. - */ - public Builder idempotencyKey(String idempotencyKey) { - this.idempotencyKey = idempotencyKey; - return this; - } - - /** - * @param key header key - * @param value header value - * @return this instance, so the builder can be used fluently. - */ - public Builder header(String key, String value) { - if (this.headers == null) { - this.headers = new LinkedHashMap<>(); - } - this.headers.put(key, value); - return this; - } - - /** - * @param newHeaders headers to send together with the request. - * @return this instance, so the builder can be used fluently. - */ - public Builder headers(Map newHeaders) { - if (this.headers == null) { - this.headers = new LinkedHashMap<>(); - } - this.headers.putAll(newHeaders); - return this; - } - - /** - * @param delay time to wait before executing the call. The time is waited by Restate, and not - * by this service. - * @return this instance, so the builder can be used fluently. - */ - public Builder delay(Duration delay) { - this.delay = delay; - return this; - } - - public @Nullable String getIdempotencyKey() { - return idempotencyKey; - } - - public Builder setIdempotencyKey(@Nullable String idempotencyKey) { - return idempotencyKey(idempotencyKey); - } - - public @Nullable Map getHeaders() { - return headers; - } - - public Builder setHeaders(@Nullable Map headers) { - return headers(headers); - } - - public @Nullable Duration delay() { - return delay; - } - - public SendRequest build() { - return new SendRequest<>( - this.target, - this.reqTypeTag, - this.request, - this.idempotencyKey, - this.headers, - this.delay); - } - } - @Override public boolean equals(Object o) { - if (!(o instanceof SendRequest that)) return false; - return Objects.equals(target, that.target) - && Objects.equals(reqTypeTag, that.reqTypeTag) - && Objects.equals(request, that.request) - && Objects.equals(idempotencyKey, that.idempotencyKey) - && Objects.equals(headers, that.headers) + if (!(o instanceof SendRequest that)) return false; + return Objects.equals(target(), that.target()) + && Objects.equals(requestTypeTag(), that.requestTypeTag()) + && Objects.equals(responseTypeTag(), that.responseTypeTag()) + && Objects.equals(request(), that.request()) + && Objects.equals(idempotencyKey(), that.idempotencyKey()) + && Objects.equals(headers(), that.headers()) && Objects.equals(delay, that.delay); } @Override public int hashCode() { - return Objects.hash(target, reqTypeTag, request, idempotencyKey, headers, delay); + return Objects.hash( + target(), + requestTypeTag(), + responseTypeTag(), + request(), + idempotencyKey(), + headers(), + delay); } @Override public String toString() { - return "SendRequest{" + return "CallRequest{" + "target=" - + target + + target() + ", reqSerdeInfo=" - + reqTypeTag + + requestTypeTag() + + ", resSerdeInfo=" + + responseTypeTag() + ", request=" - + request + + request() + ", idempotencyKey='" - + idempotencyKey + + idempotencyKey() + '\'' + ", headers=" - + headers + + headers() + ", delay=" + delay + '}'; diff --git a/sdk-api-gen/src/main/resources/templates/Client.hbs b/sdk-api-gen/src/main/resources/templates/Client.hbs index c09e037c5..56134f365 100644 --- a/sdk-api-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-gen/src/main/resources/templates/Client.hbs @@ -67,9 +67,14 @@ public class {{generatedClassSimpleName}} { } {{#handlers}} - public dev.restate.sdk.SendHandle {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + public dev.restate.sdk.InvocationHandle<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + if (this.delay == null) { + return ContextClient.this.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ); + } return ContextClient.this.ctx.send( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).delay(this.delay) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}ContextClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(this.delay) ); } {{/handlers}} @@ -98,13 +103,13 @@ public class {{generatedClassSimpleName}} { public dev.restate.client.SendResponse submit({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return IngressClient.this.client.send( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() ).response(); } public java.util.concurrent.CompletableFuture submitAsync({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return IngressClient.this.client.sendAsync( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() ).thenApply(dev.restate.client.ClientResponse::response); } {{else}} @@ -139,14 +144,24 @@ public class {{generatedClassSimpleName}} { {{#handlers}}{{^isWorkflow}} public dev.restate.client.SendResponse {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + if (this.delay == null) { + return IngressClient.this.client.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).response(); + } return IngressClient.this.client.send( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).delay(this.delay) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(this.delay) ).response(); } public java.util.concurrent.CompletableFuture {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + if (this.delay == null) { + return IngressClient.this.client.sendAsync( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSend() + ).thenApply(dev.restate.client.ClientResponse::response); + } return IngressClient.this.client.sendAsync( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).delay(this.delay) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}IngressClient.this.key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}req{{/inputEmpty}}).asSendDelayed(this.delay) ).thenApply(dev.restate.client.ClientResponse::response); }{{/isWorkflow}}{{/handlers}} } diff --git a/sdk-api-gen/src/main/resources/templates/Requests.hbs b/sdk-api-gen/src/main/resources/templates/Requests.hbs index a9f1617d5..21063900f 100644 --- a/sdk-api-gen/src/main/resources/templates/Requests.hbs +++ b/sdk-api-gen/src/main/resources/templates/Requests.hbs @@ -5,8 +5,8 @@ public final class {{generatedClassSimpleName}} { private {{generatedClassSimpleName}}() {} {{#handlers}} - public static dev.restate.common.CallRequest.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> {{methodName}}({{#if ../isKeyed}}String key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return dev.restate.common.CallRequest.of( + public static dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> {{methodName}}({{#if ../isKeyed}}String key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return dev.restate.common.Request.of( {{{targetExpr this "key"}}}, {{inputSerdeRef}}, {{outputSerdeRef}}, @@ -14,18 +14,4 @@ public final class {{generatedClassSimpleName}} { } {{/handlers}} - - public final static class Send { - private Send() {} - - {{#handlers}} - public static dev.restate.common.SendRequest.Builder<{{{boxedInputFqcn}}}> {{methodName}}({{#if ../isKeyed}}String key{{^inputEmpty}}, {{/inputEmpty}}{{/if}}{{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { - return dev.restate.common.SendRequest.of( - {{{targetExpr this "key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}); - } - - {{/handlers}} - } } \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs index 50f9554b2..753051b96 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs @@ -1,13 +1,17 @@ {{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} +{{#contextClientEnabled}} import dev.restate.sdk.kotlin.CallAwaitable -import dev.restate.sdk.kotlin.SendHandle +import dev.restate.sdk.kotlin.InvocationHandle import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.types.StateKey +import dev.restate.sdk.kotlin.asSendDelayed +{{/contextClientEnabled}} import dev.restate.serde.Serde import dev.restate.common.Target import kotlin.time.Duration +{{#ingressClientEnabled}} import dev.restate.client.kotlin.* +{{/ingressClientEnabled}} object {{generatedClassSimpleName}} { @@ -30,7 +34,7 @@ object {{generatedClassSimpleName}} { {{#contextClientEnabled}} class ContextClient(private val ctx: Context{{#isKeyed}}, private val key: String{{/isKeyed}}){ {{#handlers}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.CallRequest.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): CallAwaitable<{{{boxedOutputFqcn}}}> { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): CallAwaitable<{{{boxedOutputFqcn}}}> { return this.ctx.call( {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) ) @@ -42,9 +46,14 @@ object {{generatedClassSimpleName}} { inner class Send internal constructor() { {{#handlers}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.SendRequest.Builder<{{{boxedInputFqcn}}}>.() -> Unit = {}): SendHandle { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): InvocationHandle<{{{boxedOutputFqcn}}}> { + if (delay != null) { + return this@ContextClient.ctx.send( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init).asSendDelayed(delay) + ); + } return this@ContextClient.ctx.send( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@ContextClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) ); }{{/handlers}} } @@ -62,13 +71,13 @@ object {{generatedClassSimpleName}} { {{outputSerdeRef}}); } - suspend fun submit({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.SendRequest.Builder<{{{boxedInputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse { + suspend fun submit({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { return this@IngressClient.client.sendSuspend( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) ).response(); } {{else}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.CallRequest.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): {{{boxedOutputFqcn}}} { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): {{{boxedOutputFqcn}}} { return this@IngressClient.client.callSuspend( {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) ).response(); @@ -81,9 +90,14 @@ object {{generatedClassSimpleName}} { inner class Send() { {{#handlers}}{{^isWorkflow}} - suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.SendRequest.Builder<{{{boxedInputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse { + suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}delay: Duration? = null, init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.client.SendResponse<{{{boxedOutputFqcn}}}> { + if (delay != null) { + return this@IngressClient.client.sendSuspend( + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init).asSendDelayed(delay) + ).response(); + } return this@IngressClient.client.sendSuspend( - {{../requestsClass}}.Send.{{methodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) + {{../requestsClass}}.{{methodName}}({{#if ../isKeyed}}this@IngressClient.key, {{/if}}{{^inputEmpty}}req, {{/inputEmpty}}init) ).response(); }{{/isWorkflow}}{{/handlers}} } diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs index 1a4aa77dd..be07e022e 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Requests.hbs @@ -3,8 +3,8 @@ object {{generatedClassSimpleName}} { {{#handlers}} - fun {{methodName}}({{#if ../isKeyed}}key: String, {{/if}}{{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.CallRequest.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.common.CallRequest<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> { - val builder = dev.restate.common.CallRequest.of<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>( + fun {{methodName}}({{#if ../isKeyed}}key: String, {{/if}}{{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.Request.Builder<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>.() -> Unit = {}): dev.restate.common.Request<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}> { + val builder = dev.restate.common.Request.of<{{{boxedInputFqcn}}}, {{{boxedOutputFqcn}}}>( {{{targetExpr this "key"}}}, {{inputSerdeRef}}, {{outputSerdeRef}}, @@ -14,18 +14,4 @@ object {{generatedClassSimpleName}} { } {{/handlers}} - - object Send { - {{#handlers}} - fun {{methodName}}({{#if ../isKeyed}}key: String, {{/if}}{{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}init: dev.restate.common.SendRequest.Builder<{{{boxedInputFqcn}}}>.() -> Unit = {}): dev.restate.common.SendRequest<{{{boxedInputFqcn}}}> { - val builder = dev.restate.common.SendRequest.of<{{{boxedInputFqcn}}}>( - {{{targetExpr this "key"}}}, - {{inputSerdeRef}}, - {{#if inputEmpty}}null{{else}}req{{/if}}); - builder.init() - return builder.build() - } - - {{/handlers}} - } } \ No newline at end of file diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index f9fdb6a68..5620e1d7d 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -8,13 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.common.CallRequest import dev.restate.common.Output +import dev.restate.common.Request import dev.restate.common.SendRequest import dev.restate.common.Slice import dev.restate.sdk.endpoint.definition.HandlerContext import dev.restate.sdk.types.DurablePromiseKey -import dev.restate.sdk.types.Request +import dev.restate.sdk.types.HandlerRequest import dev.restate.sdk.types.StateKey import dev.restate.sdk.types.TerminalException import dev.restate.serde.Serde @@ -36,7 +36,7 @@ internal constructor( return this.handlerContext.objectKey() } - override fun request(): Request { + override fun request(): HandlerRequest { return this.handlerContext.request() } @@ -67,35 +67,55 @@ internal constructor( override suspend fun timer(duration: Duration, name: String?): Awaitable = SingleAwaitableImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {} - override suspend fun call(callRequest: CallRequest): CallAwaitable = - resolveSerde(callRequest.responseSerdeInfo()).let { responseSerde -> + override suspend fun call( + request: Request + ): CallAwaitable = + resolveSerde(request.responseTypeTag()).let { responseSerde -> val callHandle = handlerContext .call( - callRequest.target(), - resolveAndSerialize(callRequest.requestSerdeInfo(), callRequest.request()), - callRequest.idempotencyKey(), - callRequest.headers().entries) + request.target(), + resolveAndSerialize(request.requestTypeTag(), request.request()), + request.idempotencyKey(), + request.headers().entries) .await() val callAsyncResult = callHandle.callAsyncResult.map { - CompletableFuture.completedFuture(responseSerde.deserialize(it)) + CompletableFuture.completedFuture(responseSerde.deserialize(it)) } return@let CallAwaitableImpl(callAsyncResult, callHandle.invocationIdAsyncResult) } - override suspend fun send(sendRequest: SendRequest): SendHandle = - SendHandleImpl( - handlerContext - .send( - sendRequest.target(), - resolveAndSerialize(sendRequest.requestSerdeInfo(), sendRequest.request()), - sendRequest.idempotencyKey(), - sendRequest.headers().entries, - sendRequest.delay()) - .await()) + override suspend fun send( + request: Request + ): InvocationHandle = + resolveSerde(request.responseTypeTag()).let { responseSerde -> + val invocationIdAsyncResult = + handlerContext + .send( + request.target(), + resolveAndSerialize(request.requestTypeTag(), request.request()), + request.idempotencyKey(), + request.headers().entries, + (request as? SendRequest)?.delay()) + .await() + + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await() + } + } + + override fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag + ): InvocationHandle = + resolveSerde(responseTypeTag).let { responseSerde -> + object : BaseInvocationHandle(handlerContext, responseSerde) { + override suspend fun invocationId(): String = invocationId + } + } override suspend fun runAsync( typeTag: TypeTag, diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 431127690..b3224170b 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import dev.restate.common.CallRequest import dev.restate.common.Output +import dev.restate.common.Request import dev.restate.common.SendRequest import dev.restate.sdk.kotlin.serialization.typeTag import dev.restate.sdk.types.DurablePromiseKey -import dev.restate.sdk.types.Request +import dev.restate.sdk.types.HandlerRequest import dev.restate.sdk.types.StateKey import dev.restate.sdk.types.TerminalException import dev.restate.serde.TypeTag @@ -21,7 +21,6 @@ import java.util.* import kotlin.random.Random import kotlin.time.Duration import kotlin.time.toJavaDuration -import kotlin.time.toKotlinDuration /** * This interface exposes the Restate functionalities to Restate services. It can be used to @@ -36,7 +35,7 @@ import kotlin.time.toKotlinDuration */ sealed interface Context { - fun request(): Request + fun request(): HandlerRequest /** * Causes the current execution of the function invocation to sleep for the given duration. @@ -66,7 +65,7 @@ sealed interface Context { * @param callOptions request options. * @return a [CallAwaitable] that wraps the result. */ - suspend fun call(callRequest: CallRequest): CallAwaitable + suspend fun call(request: Request): CallAwaitable /** * Invoke another Restate service method. @@ -78,10 +77,10 @@ sealed interface Context { * @param callOptions request options. * @return a [CallAwaitable] that wraps the result. */ - suspend fun call( - callRequestBuilder: CallRequest.Builder - ): CallAwaitable { - return call(callRequestBuilder.build()) + suspend fun call( + requestBuilder: Request.Builder + ): CallAwaitable { + return call(requestBuilder.build()) } /** @@ -93,7 +92,7 @@ sealed interface Context { * @param sendOptions request options. * @return a [SendHandle] to interact with the sent request. */ - suspend fun send(sendRequest: SendRequest): SendHandle + suspend fun send(request: Request): InvocationHandle /** * Invoke another Restate service without waiting for the response. @@ -104,10 +103,24 @@ sealed interface Context { * @param sendOptions request options. * @return a [SendHandle] to interact with the sent request. */ - suspend fun send(sendRequestBuilder: SendRequest.Builder): SendHandle { + suspend fun send( + sendRequestBuilder: Request.Builder + ): InvocationHandle { return send(sendRequestBuilder.build()) } + /** + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + */ + fun invocationHandle( + invocationId: String, + responseTypeTag: TypeTag + ): InvocationHandle + /** * Execute a non-deterministic closure, recording the result value in the journal. The result * value will be re-played in case of re-invocation (e.g. because of failure recovery or @@ -207,6 +220,19 @@ sealed interface Context { fun random(): RestateRandom } +/** + * Get an [InvocationHandle] for an already existing invocation. This will let you interact with a + * running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + */ +inline fun Context.invocationHandle( + invocationId: String +): InvocationHandle { + return this.invocationHandle(invocationId, typeTag()) +} + /** * Execute a non-deterministic closure, recording the result value in the journal. The result value * will be re-played in case of re-invocation (e.g. because of failure recovery or suspension point) @@ -506,9 +532,19 @@ sealed interface CallAwaitable : Awaitable { suspend fun invocationId(): String } -/** The handle returned by a [Context.send]. */ -sealed interface SendHandle { +/** An invocation handle, that can be used to interact with a running invocation. */ +sealed interface InvocationHandle { + /** @return the invocation id of this invocation */ suspend fun invocationId(): String + + /** Cancel this invocation. */ + suspend fun cancel() + + /** Attach to this invocation. This will wait for the invocation to complete */ + suspend fun attach(): Awaitable + + /** @return the output of this invocation, if present. */ + suspend fun output(): Output } /** @@ -605,8 +641,14 @@ inline fun durablePromiseKey(name: String): DurablePromiseKey { return DurablePromiseKey.of(name, typeTag()) } -var SendRequest.Builder.delay: Duration? - get() = this.delay()?.toKotlinDuration() - set(value) { - this.delay(value?.toJavaDuration()) - } +fun Request.Builder.asSendDelayed( + duration: Duration +): SendRequest { + return this.asSendDelayed(duration.toJavaDuration()) +} + +fun Request.asSendDelayed( + duration: Duration +): SendRequest { + return this.asSendDelayed(duration.toJavaDuration()) +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt index ecb1382b9..10f170f1e 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/awaitables.kt @@ -8,8 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin +import dev.restate.common.Output import dev.restate.common.Slice import dev.restate.sdk.endpoint.definition.AsyncResult +import dev.restate.sdk.endpoint.definition.HandlerContext import dev.restate.sdk.types.TerminalException import dev.restate.sdk.types.TimeoutException import dev.restate.serde.Serde @@ -167,11 +169,25 @@ internal constructor( } } -internal class SendHandleImpl -internal constructor(private val invocationIdAsyncResult: AsyncResult) : SendHandle { - override suspend fun invocationId(): String { - return invocationIdAsyncResult.poll().await() +internal abstract class BaseInvocationHandle +internal constructor( + private val handlerContext: HandlerContext, + private val responseSerde: Serde +) : InvocationHandle { + override suspend fun cancel() { + val ignored = handlerContext.cancelInvocation(invocationId()).await() } + + override suspend fun attach(): Awaitable = + SingleAwaitableImpl( + handlerContext.attachInvocation(invocationId()).await().map { + CompletableFuture.completedFuture(responseSerde.deserialize(it)) + }) + + override suspend fun output(): Output = + SingleAwaitableImpl(handlerContext.getInvocationOutput(invocationId()).await()) + .simpleMap { it.map { responseSerde.deserialize(it) } } + .await() } internal class AwakeableImpl diff --git a/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java b/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java index 645f406ce..145f5fce9 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java +++ b/sdk-api/src/main/java/dev/restate/sdk/CallAwaitable.java @@ -9,15 +9,21 @@ package dev.restate.sdk; import dev.restate.sdk.endpoint.definition.AsyncResult; +import dev.restate.sdk.endpoint.definition.HandlerContext; import java.util.concurrent.Executor; /** {@link Awaitable} returned by a call to another service. */ public final class CallAwaitable extends Awaitable { + private final HandlerContext context; private final AsyncResult asyncResult; private final Awaitable invocationIdAwaitable; - CallAwaitable(AsyncResult callAsyncResult, Awaitable invocationIdAwaitable) { + CallAwaitable( + HandlerContext context, + AsyncResult callAsyncResult, + Awaitable invocationIdAwaitable) { + this.context = context; this.asyncResult = callAsyncResult; this.invocationIdAwaitable = invocationIdAwaitable; } @@ -29,6 +35,11 @@ public String invocationId() { return this.invocationIdAwaitable.await(); } + /** Cancel this invocation */ + public void cancel() { + Util.awaitCompletableFuture(context.cancelInvocation(invocationId())); + } + @Override protected AsyncResult asyncResult() { return asyncResult; diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index 0a28a415f..d465cbd61 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import dev.restate.common.CallRequest; -import dev.restate.common.SendRequest; +import dev.restate.common.Request; +import dev.restate.common.Slice; import dev.restate.common.function.ThrowingRunnable; import dev.restate.common.function.ThrowingSupplier; import dev.restate.sdk.types.AbortedExecutionException; -import dev.restate.sdk.types.Request; +import dev.restate.sdk.types.HandlerRequest; import dev.restate.sdk.types.RetryPolicy; import dev.restate.sdk.types.TerminalException; import dev.restate.serde.Serde; @@ -34,31 +34,50 @@ */ public interface Context { - Request request(); + HandlerRequest request(); /** * Invoke another Restate service method. * - * @param callRequest request + * @param request request * @return an {@link Awaitable} that wraps the Restate service method result. */ - CallAwaitable call(CallRequest callRequest); + CallAwaitable call(Request request); - /** Like {@link #call(CallRequest)} */ - default CallAwaitable call(CallRequest.Builder callRequestBuilder) { + /** Like {@link #call(Request)} */ + default CallAwaitable call(Request.Builder callRequestBuilder) { return call(callRequestBuilder.build()); } /** * Invoke another Restate service without waiting for the response. * - * @param sendRequest request - * @return an {@link SendHandle} that can be used to retrieve the invocation id + * @param request request + * @return an {@link InvocationHandle} that can be used to retrieve the invocation id, cancel the + * invocation, attach to its result. */ - SendHandle send(SendRequest sendRequest); + InvocationHandle send(Request request); - default SendHandle send(SendRequest.Builder sendRequest) { - return send(sendRequest.build()); + /** Like {@link #send(Request)} */ + default InvocationHandle send(Request.Builder request) { + return send(request.build()); + } + + InvocationHandle invocationHandle(String invocationId, TypeTag responseTypeTag); + + /** + * Get an {@link InvocationHandle} for an already existing invocation. This will let you interact + * with a running invocation, for example to cancel it or retrieve its result. + * + * @param invocationId The invocation to interact with. + * @param responseClazz The response class. + */ + default InvocationHandle invocationHandle(String invocationId, Class responseClazz) { + return invocationHandle(invocationId, TypeTag.of(responseClazz)); + } + + default InvocationHandle invocationHandle(String invocationId) { + return invocationHandle(invocationId, Serde.SLICE); } /** diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index cde5a9800..d5ad45fb5 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -13,7 +13,7 @@ import dev.restate.sdk.endpoint.definition.AsyncResult; import dev.restate.sdk.endpoint.definition.HandlerContext; import dev.restate.sdk.types.DurablePromiseKey; -import dev.restate.sdk.types.Request; +import dev.restate.sdk.types.HandlerRequest; import dev.restate.sdk.types.RetryPolicy; import dev.restate.sdk.types.StateKey; import dev.restate.sdk.types.TerminalException; @@ -45,7 +45,7 @@ public String key() { } @Override - public Request request() { + public HandlerRequest request() { return handlerContext.request(); } @@ -89,46 +89,97 @@ public Awaitable timer(String name, Duration duration) { } @Override - public CallAwaitable call(CallRequest callRequest) { + public CallAwaitable call(Request request) { Slice input = Util.executeOrFail( handlerContext, - serdeFactory.create(callRequest.requestSerdeInfo())::serialize, - callRequest.request()); + serdeFactory.create(request.requestTypeTag())::serialize, + request.request()); HandlerContext.CallResult result = Util.awaitCompletableFuture( handlerContext.call( - callRequest.target(), - input, - callRequest.idempotencyKey(), - callRequest.headers().entrySet())); + request.target(), input, request.idempotencyKey(), request.headers().entrySet())); return new CallAwaitable<>( + handlerContext, result .callAsyncResult() .map( s -> CompletableFuture.completedFuture( - serdeFactory.create(callRequest.responseSerdeInfo()).deserialize(s))), + serdeFactory.create(request.responseTypeTag()).deserialize(s))), Awaitable.fromAsyncResult(result.invocationIdAsyncResult(), serviceExecutor)); } @Override - public SendHandle send(SendRequest sendRequest) { + public InvocationHandle send(Request request) { Slice input = Util.executeOrFail( handlerContext, - serdeFactory.create(sendRequest.requestSerdeInfo())::serialize, - sendRequest.request()); - var invocationIdAsyncResult = - Util.awaitCompletableFuture( - handlerContext.send( - sendRequest.target(), - input, - sendRequest.idempotencyKey(), - sendRequest.headers().entrySet(), - sendRequest.delay())); - return new SendHandle(Awaitable.fromAsyncResult(invocationIdAsyncResult, serviceExecutor)); + serdeFactory.create(request.requestTypeTag())::serialize, + request.request()); + + var invocationIdAwaitable = + Awaitable.fromAsyncResult( + Util.awaitCompletableFuture( + handlerContext.send( + request.target(), + input, + request.idempotencyKey(), + request.headers().entrySet(), + (request instanceof SendRequest sendRequest) + ? sendRequest.delay() + : null)), + serviceExecutor); + + return new BaseInvocationHandle<>( + Util.executeOrFail(handlerContext, () -> serdeFactory.create(request.responseTypeTag()))) { + @Override + public String invocationId() { + return invocationIdAwaitable.await(); + } + }; + } + + @Override + public InvocationHandle invocationHandle(String invocationId, TypeTag responseTypeTag) { + return new BaseInvocationHandle<>( + Util.executeOrFail(handlerContext, () -> serdeFactory.create(responseTypeTag))) { + @Override + public String invocationId() { + return invocationId; + } + }; + } + + abstract class BaseInvocationHandle implements InvocationHandle { + private final Serde responseSerde; + + BaseInvocationHandle(Serde responseSerde) { + this.responseSerde = responseSerde; + } + + @Override + public void cancel() { + Util.awaitCompletableFuture(handlerContext.cancelInvocation(invocationId())); + } + + @Override + public Awaitable attach() { + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.attachInvocation(invocationId())) + .map(s -> CompletableFuture.completedFuture(responseSerde.deserialize(s))), + serviceExecutor); + } + + @Override + public Output getOutput() { + return Awaitable.fromAsyncResult( + Util.awaitCompletableFuture(handlerContext.getInvocationOutput(invocationId())) + .map(o -> CompletableFuture.completedFuture(o.map(responseSerde::deserialize))), + serviceExecutor) + .await(); + } } @Override diff --git a/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java new file mode 100644 index 000000000..83c62706f --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/InvocationHandle.java @@ -0,0 +1,29 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.common.Output; + +public interface InvocationHandle { + /** + * @return the invocation id of this invocation + */ + String invocationId(); + + /** Cancel this invocation. */ + void cancel(); + + /** Attach to this invocation. This will wait for the invocation to complete */ + Awaitable attach(); + + /** + * @return the output of this invocation, if present. + */ + Output getOutput(); +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/SendHandle.java b/sdk-api/src/main/java/dev/restate/sdk/SendHandle.java deleted file mode 100644 index 022281e9c..000000000 --- a/sdk-api/src/main/java/dev/restate/sdk/SendHandle.java +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk; - -public final class SendHandle { - - private final Awaitable invocationIdAwaitable; - - SendHandle(Awaitable invocationIdAwaitable) { - this.invocationIdAwaitable = invocationIdAwaitable; - } - - public String invocationId() { - return this.invocationIdAwaitable.await(); - } -} diff --git a/sdk-api/src/main/java/dev/restate/sdk/Util.java b/sdk-api/src/main/java/dev/restate/sdk/Util.java index eb9a32dd5..13fa1a9fe 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Util.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Util.java @@ -9,6 +9,7 @@ package dev.restate.sdk; import dev.restate.common.function.ThrowingFunction; +import dev.restate.common.function.ThrowingSupplier; import dev.restate.sdk.endpoint.definition.HandlerContext; import dev.restate.sdk.types.AbortedExecutionException; import java.util.concurrent.CancellationException; @@ -31,6 +32,16 @@ static R executeOrFail(HandlerContext handlerContext, ThrowingFunction R executeOrFail(HandlerContext handlerContext, ThrowingSupplier fn) { + try { + return fn.get(); + } catch (Throwable e) { + handlerContext.fail(e); + AbortedExecutionException.sneakyThrow(); + return null; + } + } + static @NonNull T awaitCompletableFuture(CompletableFuture future) { try { return future.get(); diff --git a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java index ce3e929f1..61a204b48 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java +++ b/sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java @@ -27,7 +27,7 @@ public interface HandlerContext { String objectKey(); - Request request(); + HandlerRequest request(); // ----- IO // Note: These are not supposed to be exposed in the user's facing Context API. @@ -93,6 +93,12 @@ record Awakeable(String id, AsyncResult asyncResult) {} CompletableFuture> rejectPromise(String key, TerminalException reason); + CompletableFuture cancelInvocation(String invocationId); + + CompletableFuture> attachInvocation(String invocationId); + + CompletableFuture>> getInvocationOutput(String invocationId); + void fail(Throwable cause); // ----- Deferred diff --git a/sdk-common/src/main/java/dev/restate/sdk/types/Request.java b/sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java similarity index 88% rename from sdk-common/src/main/java/dev/restate/sdk/types/Request.java rename to sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java index 0ba774ccf..07ad19e2f 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/types/Request.java +++ b/sdk-common/src/main/java/dev/restate/sdk/types/HandlerRequest.java @@ -13,8 +13,8 @@ import java.nio.ByteBuffer; import java.util.Map; -/** The Request object represents the incoming request to a handler. */ -public record Request( +/** This record encapsulates the inputs to a handler. */ +public record HandlerRequest( InvocationId invocationId, Context otelContext, Slice body, Map headers) { public byte[] bodyAsByteArray() { return body.toByteArray(); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index 96b5b7ad5..529491604 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -37,7 +37,7 @@ class HandlerContextImpl implements HandlerContextInternal { private static final int CANCEL_HANDLE = 1; - private final Request request; + private final HandlerRequest handlerRequest; private final StateMachine stateMachine; private final @Nullable String objectKey; private final String fullyQualifiedHandlerName; @@ -51,7 +51,8 @@ class HandlerContextImpl implements HandlerContextInternal { StateMachine stateMachine, Context otelContext, StateMachine.Input input) { - this.request = new Request(input.invocationId(), otelContext, input.body(), input.headers()); + this.handlerRequest = + new HandlerRequest(input.invocationId(), otelContext, input.body(), input.headers()); this.objectKey = input.key(); this.stateMachine = stateMachine; this.fullyQualifiedHandlerName = fullyQualifiedHandlerName; @@ -59,14 +60,47 @@ class HandlerContextImpl implements HandlerContextInternal { this.scheduledRuns = new HashMap<>(); } + private static void parseSuccessOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Success success) { + cf.complete(success.slice()); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + + private static void parseEmptyOrSuccessOrFailure( + NotificationValue s, CompletableFuture> cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(Output.notReady()); + } else if (s instanceof NotificationValue.Success success) { + cf.complete(Output.ready(success.slice())); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + + private static void parseEmptyOrFailure(NotificationValue s, CompletableFuture cf) { + if (s instanceof NotificationValue.Empty) { + cf.complete(null); + } else if (s instanceof NotificationValue.Failure failure) { + cf.completeExceptionally(failure.exception()); + } else { + throw ProtocolException.unexpectedNotificationVariant(s.getClass()); + } + } + @Override public String objectKey() { return this.objectKey; } @Override - public Request request() { - return this.request; + public HandlerRequest request() { + return this.handlerRequest; } @Override @@ -166,17 +200,7 @@ public CompletableFuture call( AsyncResult callAsyncResult = AsyncResults.single( - this, - callHandle.resultHandle(), - (s, cf) -> { - if (s instanceof NotificationValue.Success success) { - cf.complete(success.slice()); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - }); + this, callHandle.resultHandle(), HandlerContextImpl::parseSuccessOrFailure); return new CallResult(invocationIdAsyncResult, callAsyncResult); }); @@ -219,18 +243,7 @@ public CompletableFuture> submitRun( () -> { int runHandle = this.stateMachine.run(name); this.scheduledRuns.put(runHandle, closure); - return AsyncResults.single( - this, - runHandle, - (s, cf) -> { - if (s instanceof NotificationValue.Success success) { - cf.complete(success.slice()); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - }); + return AsyncResults.single(this, runHandle, HandlerContextImpl::parseSuccessOrFailure); }); } @@ -242,17 +255,7 @@ public CompletableFuture awakeable() { return new Awakeable( awakeable.awakeableId(), AsyncResults.single( - this, - awakeable.handle(), - (s, cf) -> { - if (s instanceof NotificationValue.Success success) { - cf.complete(success.slice()); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - })); + this, awakeable.handle(), HandlerContextImpl::parseSuccessOrFailure)); }); } @@ -273,15 +276,7 @@ public CompletableFuture> promise(String key) { AsyncResults.single( this, this.stateMachine.promiseGet(key), - (s, cf) -> { - if (s instanceof NotificationValue.Success success) { - cf.complete(success.slice()); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - })); + HandlerContextImpl::parseSuccessOrFailure)); } @Override @@ -291,17 +286,7 @@ public CompletableFuture>> peekPromise(String key) { AsyncResults.single( this, this.stateMachine.promisePeek(key), - (s, cf) -> { - if (s instanceof NotificationValue.Empty) { - cf.complete(Output.notReady()); - } else if (s instanceof NotificationValue.Success success) { - cf.complete(Output.ready(success.slice())); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - })); + HandlerContextImpl::parseEmptyOrSuccessOrFailure)); } @Override @@ -311,15 +296,7 @@ public CompletableFuture> resolvePromise(String key, Slice pay AsyncResults.single( this, this.stateMachine.promiseComplete(key, payload), - (s, cf) -> { - if (s instanceof NotificationValue.Empty) { - cf.complete(null); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - })); + HandlerContextImpl::parseEmptyOrFailure)); } @Override @@ -329,15 +306,32 @@ public CompletableFuture> rejectPromise(String key, TerminalEx AsyncResults.single( this, this.stateMachine.promiseComplete(key, reason), - (s, cf) -> { - if (s instanceof NotificationValue.Empty) { - cf.complete(null); - } else if (s instanceof NotificationValue.Failure failure) { - cf.completeExceptionally(failure.exception()); - } else { - throw ProtocolException.unexpectedNotificationVariant(s.getClass()); - } - })); + HandlerContextImpl::parseEmptyOrFailure)); + } + + @Override + public CompletableFuture cancelInvocation(String invocationId) { + return this.catchExceptions(() -> this.stateMachine.cancelInvocation(invocationId)); + } + + @Override + public CompletableFuture> attachInvocation(String invocationId) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.attachInvocation(invocationId), + HandlerContextImpl::parseSuccessOrFailure)); + } + + @Override + public CompletableFuture>> getInvocationOutput(String invocationId) { + return catchExceptions( + () -> + AsyncResults.single( + this, + this.stateMachine.getInvocationOutput(invocationId), + HandlerContextImpl::parseEmptyOrSuccessOrFailure)); } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java index 92521013f..1a50cd8dd 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/CommandAccessor.java @@ -136,4 +136,10 @@ public void checkEntryHeader(Protocol.OneWayCallCommandMessage expected, Message CommandAccessor SEND_SIGNAL = Protocol.SendSignalCommandMessage::getName; + + CommandAccessor ATTACH_INVOCATION = + Protocol.AttachInvocationCommandMessage::getName; + + CommandAccessor GET_INVOCATION_OUTPUT = + Protocol.GetInvocationOutputCommandMessage::getName; } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java index 012f7e680..a9ee2eb1a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java @@ -135,6 +135,10 @@ void proposeRunCompletion( void cancelInvocation(String targetInvocationId); + int attachInvocation(String invocationId); + + int getInvocationOutput(String invocationId); + void writeOutput(Slice value); void writeOutput(TerminalException exception); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java index 09a8e068c..7bc2644a2 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java @@ -577,6 +577,36 @@ public void cancelInvocation(String targetInvocationId) { this.stateContext); } + @Override + public int attachInvocation(String invocationId) { + LOG.debug("Executing 'Attach invocation {}'", invocationId); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.AttachInvocationCommandMessage.newBuilder() + .setInvocationId(invocationId) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.ATTACH_INVOCATION, + new int[] {completionId}, + this.stateContext)[0]; + } + + @Override + public int getInvocationOutput(String invocationId) { + LOG.debug("Executing 'Get invocation output {}'", invocationId); + var completionId = this.stateContext.getJournal().nextCompletionNotificationId(); + return this.stateContext.getCurrentState() + .processCompletableCommand( + Protocol.GetInvocationOutputCommandMessage.newBuilder() + .setInvocationId(invocationId) + .setResultCompletionId(completionId) + .build(), + CommandAccessor.GET_INVOCATION_OUTPUT, + new int[] {completionId}, + this.stateContext)[0]; + } + @Override public void writeOutput(Slice value) { LOG.debug("Executing 'Write invocation output with success'"); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java index c8e69c67f..6669a6179 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/CallTest.java @@ -10,7 +10,7 @@ import static dev.restate.sdk.core.javaapi.JavaAPITests.testDefinitionForService; -import dev.restate.common.CallRequest; +import dev.restate.common.Request; import dev.restate.common.SendRequest; import dev.restate.common.Slice; import dev.restate.common.Target; @@ -43,6 +43,6 @@ protected TestInvocationBuilder implicitCancellation(Target target, Slice body) "ImplicitCancellation", Serde.VOID, Serde.RAW, - (context, unused) -> context.call(CallRequest.ofRaw(target, body.toByteArray())).await()); + (context, unused) -> context.call(Request.ofRaw(target, body.toByteArray())).await()); } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java index cfd0a1fbd..95bbd9475 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/JavaAPITests.java @@ -10,7 +10,7 @@ import static dev.restate.sdk.core.statemachine.ProtoUtils.GREETER_SERVICE_TARGET; -import dev.restate.common.CallRequest; +import dev.restate.common.Request; import dev.restate.common.function.ThrowingBiFunction; import dev.restate.sdk.*; import dev.restate.sdk.core.*; @@ -108,6 +108,6 @@ public static TestInvocationBuilder testDefinitionForWorkflow( public static Awaitable callGreeterGreetService(Context ctx, String parameter) { return ctx.call( - CallRequest.of(GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)); + Request.of(GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)); } } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt index 655bc2ddf..b13e6b095 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/CallTest.kt @@ -8,8 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core.kotlinapi -import dev.restate.common.CallRequest -import dev.restate.common.SendRequest +import dev.restate.common.Request import dev.restate.common.Slice import dev.restate.common.Target import dev.restate.sdk.core.CallTestSuite @@ -25,16 +24,12 @@ class CallTest : CallTestSuite() { body: Slice ) = testDefinitionForService("OneWayCall") { ctx, _: Unit -> - val ignored = - ctx.send( - SendRequest.of(target, Serde.SLICE, body) - .idempotencyKey(idempotencyKey) - .headers(headers)) + val ignored = ctx.send(Request.of(target, Serde.SLICE, Serde.RAW, body)) } override fun implicitCancellation(target: Target, body: Slice) = testDefinitionForService("ImplicitCancellation") { ctx, _: Unit -> val ignored = - ctx.call(CallRequest.of(target, Serde.SLICE, Serde.RAW, body)).await() + ctx.call(Request.of(target, Serde.SLICE, Serde.RAW, body)).await() } } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt index 34c38b49d..a08008729 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core.kotlinapi -import dev.restate.common.CallRequest +import dev.restate.common.Request import dev.restate.sdk.core.* import dev.restate.sdk.core.TestDefinitions.TestExecutor import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder @@ -111,7 +111,7 @@ class KotlinAPITests : TestRunner() { suspend fun callGreeterGreetService(ctx: Context, parameter: String): Awaitable { return ctx.call( - CallRequest.of( + Request.of( ProtoUtils.GREETER_SERVICE_TARGET, TestSerdes.STRING, TestSerdes.STRING, parameter)) } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt index 1fbaae552..02500d5f7 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.common.CallRequest +import dev.restate.common.Request import dev.restate.common.SendRequest import dev.restate.common.Target import dev.restate.sdk.kotlin.* @@ -18,7 +18,6 @@ import dev.restate.sdk.testservices.contracts.ProxyRequest import dev.restate.serde.Serde import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds -import kotlin.time.toJavaDuration class ProxyImpl : Proxy { private fun ProxyRequest.toTarget(): Target { @@ -31,15 +30,15 @@ class ProxyImpl : Proxy { override suspend fun call(context: Context, request: ProxyRequest): ByteArray { return context - .call(CallRequest.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message)) + .call(Request.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message)) .await() } override suspend fun oneWayCall(context: Context, request: ProxyRequest): Unit { val ignored = context.send( - SendRequest.of(request.toTarget(), Serde.RAW, request.message) - .delay((request.delayMillis?.milliseconds ?: Duration.ZERO).toJavaDuration())) + SendRequest.of(request.toTarget(), Serde.RAW, Serde.SLICE, request.message) + .asSendDelayed((request.delayMillis?.milliseconds ?: Duration.ZERO))) } override suspend fun manyCalls(context: Context, requests: List) { @@ -48,14 +47,16 @@ class ProxyImpl : Proxy { for (request in requests) { if (request.oneWayCall) { context.send( - SendRequest.of(request.proxyRequest.toTarget(), Serde.RAW, request.proxyRequest.message) - .delay( - (request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO) - .toJavaDuration())) + SendRequest.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.SLICE, + request.proxyRequest.message) + .asSendDelayed((request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO))) } else { val awaitable = context.call( - CallRequest.of( + Request.of( request.proxyRequest.toTarget(), Serde.RAW, Serde.RAW, diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index b4eee6aa8..b21e0019a 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.common.CallRequest +import dev.restate.common.Request import dev.restate.common.SendRequest import dev.restate.common.Target import dev.restate.sdk.endpoint.definition.ServiceDefinition @@ -17,6 +17,7 @@ import dev.restate.sdk.testservices.contracts.* import dev.restate.sdk.testservices.contracts.Program import dev.restate.sdk.types.StateKey import dev.restate.sdk.types.TerminalException +import dev.restate.serde.Serde import kotlin.random.Random import kotlin.time.Duration.Companion.milliseconds @@ -94,7 +95,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { is CallObject -> { val awaitable = ctx.call( - CallRequest.of( + Request.of( interpretTarget(layer + 1, cmd.key.toString()), ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, ObjectInterpreterMetadata.Serde.INTERPRET_OUTPUT, @@ -138,10 +139,9 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { ctx.awakeableHandle(theirPromiseIdForUsToResolve).resolve("ok") } is IncrementViaDelayedCall -> { - ServiceInterpreterHelperClient.fromContext(ctx).send().incrementIndirectly( - interpreterId(ctx)) { - delay = cmd.duration.milliseconds - } + ServiceInterpreterHelperClient.fromContext(ctx) + .send() + .incrementIndirectly(interpreterId(ctx), delay = cmd.duration.milliseconds) } is RecoverTerminalCall -> { var caught = false @@ -216,6 +216,7 @@ class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { SendRequest.of( interpretTarget(id.layer, id.key), ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, + Serde.SLICE, Program(listOf(IncrementStateCounter())))) } @@ -250,6 +251,7 @@ class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { SendRequest.of( interpretTarget(req.interpreter.layer, req.interpreter.key), ObjectInterpreterMetadata.Serde.INTERPRET_INPUT, + Serde.SLICE, Program(listOf(IncrementStateCounter())))) } }