Skip to content

Commit

Permalink
Add cancel invocation/attach invocation/get invocation output
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Feb 18, 2025
1 parent de5db4b commit f5f4791
Show file tree
Hide file tree
Showing 31 changed files with 603 additions and 460 deletions.
25 changes: 13 additions & 12 deletions client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ import dev.restate.client.Client
import dev.restate.client.ClientRequestOptions
import dev.restate.client.ClientResponse
import dev.restate.client.SendResponse
import dev.restate.common.CallRequest
import dev.restate.common.Output
import dev.restate.common.SendRequest
import dev.restate.common.Request
import dev.restate.serde.Serde
import kotlinx.coroutines.future.await

Expand All @@ -26,24 +25,26 @@ fun clientRequestOptions(init: ClientRequestOptions.Builder.() -> Unit): ClientR
return builder.build()
}

suspend fun <Req, Res> Client.callSuspend(callRequest: CallRequest<Req, Res>): ClientResponse<Res> {
return this.callAsync(callRequest).await()
suspend fun <Req, Res> Client.callSuspend(request: Request<Req, Res>): ClientResponse<Res> {
return this.callAsync(request).await()
}

suspend fun <Req, Res> Client.callSuspend(
callRequestBuilder: CallRequest.Builder<Req, Res>
requestBuilder: Request.Builder<Req, Res>
): ClientResponse<Res> {
return this.callAsync(callRequestBuilder).await()
return this.callAsync(requestBuilder).await()
}

suspend fun <Req> Client.sendSuspend(sendRequest: SendRequest<Req>): ClientResponse<SendResponse> {
return this.sendAsync(sendRequest).await()
suspend fun <Req, Res> Client.sendSuspend(
request: Request<Req, Res>
): ClientResponse<SendResponse<Res>> {
return this.sendAsync(request).await()
}

suspend fun <Req> Client.sendSuspend(
sendRequestBuilder: SendRequest.Builder<Req>
): ClientResponse<SendResponse> {
return this.sendAsync(sendRequestBuilder).await()
suspend fun <Req, Res> Client.sendSuspend(
request: Request.Builder<Req, Res>
): ClientResponse<SendResponse<Res>> {
return this.sendSuspend(request.build())
}

suspend fun <T : Any> Client.AwakeableHandle.resolveSuspend(
Expand Down
29 changes: 14 additions & 15 deletions client/src/main/java/dev/restate/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.client;

import dev.restate.common.CallRequest;
import dev.restate.common.Output;
import dev.restate.common.SendRequest;
import dev.restate.common.Request;
import dev.restate.common.Target;
import dev.restate.serde.Serde;
import dev.restate.serde.SerdeFactory;
Expand All @@ -21,15 +20,14 @@

public interface Client {

<Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(CallRequest<Req, Res> request);
<Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(Request<Req, Res> request);

default <Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(
CallRequest.Builder<Req, Res> request) {
Request.Builder<Req, Res> request) {
return callAsync(request.build());
}

default <Req, Res> ClientResponse<Res> call(CallRequest<Req, Res> request)
throws IngressException {
default <Req, Res> ClientResponse<Res> call(Request<Req, Res> request) throws IngressException {
try {
return callAsync(request).join();
} catch (CompletionException e) {
Expand All @@ -40,19 +38,15 @@ default <Req, Res> ClientResponse<Res> call(CallRequest<Req, Res> request)
}
}

default <Req, Res> ClientResponse<Res> call(CallRequest.Builder<Req, Res> request)
default <Req, Res> ClientResponse<Res> call(Request.Builder<Req, Res> request)
throws IngressException {
return call(request.build());
}

<Req> CompletableFuture<ClientResponse<SendResponse>> sendAsync(SendRequest<Req> request);

default <Req> CompletableFuture<ClientResponse<SendResponse>> sendAsync(
SendRequest.Builder<Req> request) {
return sendAsync(request.build());
}
<Req, Res> CompletableFuture<ClientResponse<SendResponse<Res>>> sendAsync(
Request<Req, Res> request);

default <Req> ClientResponse<SendResponse> send(SendRequest<Req> request)
default <Req, Res> ClientResponse<SendResponse<Res>> send(Request<Req, Res> request)
throws IngressException {
try {
return sendAsync(request).join();
Expand All @@ -64,7 +58,12 @@ default <Req> ClientResponse<SendResponse> send(SendRequest<Req> request)
}
}

default <Req> ClientResponse<SendResponse> send(SendRequest.Builder<Req> request)
default <Req, Res> CompletableFuture<ClientResponse<SendResponse<Res>>> sendAsync(
Request.Builder<Req, Res> request) {
return sendAsync(request.build());
}

default <Req, Res> ClientResponse<SendResponse<Res>> send(Request.Builder<Req, Res> request)
throws IngressException {
return send(request.build());
}
Expand Down
2 changes: 1 addition & 1 deletion client/src/main/java/dev/restate/client/SendResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.client;

public record SendResponse(SendStatus status, String invocationId) {
public record SendResponse<Res>(SendStatus status, Client.InvocationHandle<Res> invocationHandle) {
public enum SendStatus {
/** The request was sent for the first time. */
ACCEPTED,
Expand Down
31 changes: 18 additions & 13 deletions client/src/main/java/dev/restate/client/base/BaseClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ protected BaseClient(URI baseUri, SerdeFactory serdeFactory, ClientRequestOption
}

@Override
public <Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(
CallRequest<Req, Res> request) {
Serde<Req> reqSerde = this.serdeFactory.create(request.requestSerdeInfo());
Serde<Res> resSerde = this.serdeFactory.create(request.responseSerdeInfo());
public <Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(Request<Req, Res> request) {
Serde<Req> reqSerde = this.serdeFactory.create(request.requestTypeTag());
Serde<Res> resSerde = this.serdeFactory.create(request.responseTypeTag());

URI requestUri = toRequestURI(request.target(), false, null);
Stream<Map.Entry<String, String>> headersStream =
Expand All @@ -75,10 +74,15 @@ public <Req, Res> CompletableFuture<ClientResponse<Res>> callAsync(
}

@Override
public <Req> CompletableFuture<ClientResponse<SendResponse>> sendAsync(SendRequest<Req> request) {
Serde<Req> reqSerde = this.serdeFactory.create(request.requestSerdeInfo());

URI requestUri = toRequestURI(request.target(), true, request.delay());
public <Req, Res> CompletableFuture<ClientResponse<SendResponse<Res>>> sendAsync(
Request<Req, Res> request) {
Serde<Req> reqSerde = this.serdeFactory.create(request.requestTypeTag());

URI requestUri =
toRequestURI(
request.target(),
true,
(request instanceof SendRequest<Req, Res> sendRequest) ? sendRequest.delay() : null);
Stream<Map.Entry<String, String>> headersStream =
Stream.concat(
baseOptions.headers().entrySet().stream(), request.headers().entrySet().stream());
Expand Down Expand Up @@ -146,7 +150,10 @@ public <Req> CompletableFuture<ClientResponse<SendResponse>> sendAsync(SendReque
}

return new ClientResponse<>(
statusCode, responseHeaders, new SendResponse(status, fields.get("invocationId")));
statusCode,
responseHeaders,
new SendResponse<>(
status, invocationHandle(fields.get("invocationId"), request.responseTypeTag())));
});
}

Expand Down Expand Up @@ -199,6 +206,8 @@ public CompletableFuture<ClientResponse<Void>> rejectAsync(
@Override
public <Res> InvocationHandle<Res> invocationHandle(
String invocationId, TypeTag<Res> resTypeTag) {
Serde<Res> resSerde = serdeFactory.create(resTypeTag);

return new InvocationHandle<>() {
@Override
public String invocationId() {
Expand All @@ -207,8 +216,6 @@ public String invocationId() {

@Override
public CompletableFuture<ClientResponse<Res>> attachAsync(ClientRequestOptions options) {
Serde<Res> resSerde = serdeFactory.create(resTypeTag);

URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/attach");
Stream<Map.Entry<String, String>> headersStream =
Stream.concat(
Expand All @@ -221,8 +228,6 @@ public CompletableFuture<ClientResponse<Res>> attachAsync(ClientRequestOptions o
@Override
public CompletableFuture<ClientResponse<Output<Res>>> getOutputAsync(
ClientRequestOptions options) {
Serde<Res> resSerde = serdeFactory.create(resTypeTag);

URI requestUri = baseUri.resolve("/restate/invocation/" + invocationId + "/output");
Stream<Map.Entry<String, String>> headersStream =
Stream.concat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

import dev.restate.serde.Serde;
import dev.restate.serde.TypeTag;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import org.jspecify.annotations.Nullable;

public final class CallRequest<Req, Res> {
public sealed class Request<Req, Res> permits SendRequest {

private final Target target;
private final TypeTag<Req> reqTypeTag;
Expand All @@ -24,7 +25,7 @@ public final class CallRequest<Req, Res> {
@Nullable private final String idempotencyKey;
@Nullable private final LinkedHashMap<String, String> headers;

private CallRequest(
Request(
Target target,
TypeTag<Req> reqTypeTag,
TypeTag<Res> resTypeTag,
Expand All @@ -43,11 +44,11 @@ public Target target() {
return target;
}

public TypeTag<Req> requestSerdeInfo() {
public TypeTag<Req> requestTypeTag() {
return reqTypeTag;
}

public TypeTag<Res> responseSerdeInfo() {
public TypeTag<Res> responseTypeTag() {
return resTypeTag;
}

Expand All @@ -71,15 +72,15 @@ public static <Req, Res> Builder<Req, Res> of(
return new Builder<>(target, reqTypeTag, resTypeTag, request);
}

public static <Res> Builder<Void, Res> withNoRequestBody(Target target, TypeTag<Res> resTypeTag) {
return new Builder<>(target, Serde.VOID, resTypeTag, null);
}

public static <Req> Builder<Req, Void> withNoResponseBody(
Target target, TypeTag<Req> reqTypeTag, Req request) {
return new Builder<>(target, reqTypeTag, Serde.VOID, request);
}

public static <Res> Builder<Void, Res> withNoRequestBody(Target target, TypeTag<Res> resTypeTag) {
return new Builder<>(target, Serde.VOID, resTypeTag, null);
}

public static Builder<byte[], byte[]> ofRaw(Target target, byte[] request) {
return new Builder<>(target, TypeTag.of(Serde.RAW), TypeTag.of(Serde.RAW), request);
}
Expand All @@ -92,6 +93,21 @@ public static final class Builder<Req, Res> {
@Nullable private String idempotencyKey;
@Nullable private LinkedHashMap<String, String> headers;

public Builder(
Target target,
TypeTag<Req> reqTypeTag,
TypeTag<Res> resTypeTag,
Req request,
@Nullable String idempotencyKey,
@Nullable LinkedHashMap<String, String> headers) {
this.target = target;
this.reqTypeTag = reqTypeTag;
this.resTypeTag = resTypeTag;
this.request = request;
this.idempotencyKey = idempotencyKey;
this.headers = headers;
}

private Builder(Target target, TypeTag<Req> reqTypeTag, TypeTag<Res> resTypeTag, Req request) {
this.target = target;
this.reqTypeTag = reqTypeTag;
Expand Down Expand Up @@ -149,8 +165,18 @@ public Builder<Req, Res> setHeaders(@Nullable Map<String, String> headers) {
return headers(headers);
}

public CallRequest<Req, Res> build() {
return new CallRequest<>(
public SendRequest<Req, Res> asSend() {
return new SendRequest<>(
target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null);
}

public SendRequest<Req, Res> asSendDelayed(Duration delay) {
return new SendRequest<>(
target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay);
}

public Request<Req, Res> build() {
return new Request<>(
this.target,
this.reqTypeTag,
this.resTypeTag,
Expand All @@ -160,9 +186,29 @@ public CallRequest<Req, Res> build() {
}
}

public Builder<Req, Res> toBuilder() {
return new Builder<>(
this.target,
this.reqTypeTag,
this.resTypeTag,
this.request,
this.idempotencyKey,
this.headers);
}

public SendRequest<Req, Res> asSend() {
return new SendRequest<>(
target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, null);
}

public SendRequest<Req, Res> asSendDelayed(Duration delay) {
return new SendRequest<>(
target, reqTypeTag, resTypeTag, request, idempotencyKey, headers, delay);
}

@Override
public boolean equals(Object o) {
if (!(o instanceof CallRequest<?, ?> that)) return false;
if (!(o instanceof Request<?, ?> that)) return false;
return Objects.equals(target, that.target)
&& Objects.equals(reqTypeTag, that.reqTypeTag)
&& Objects.equals(resTypeTag, that.resTypeTag)
Expand Down
Loading

0 comments on commit f5f4791

Please sign in to comment.