Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support idempotent (GET) requests #14

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ codecs: [ CODEC_JSON ]
stream_types: [ STREAM_TYPE_UNARY ]
supports_tls: false
supports_trailers: false
supports_connect_get: false
supports_connect_get: true
supports_message_receive_limit: false
```

Expand All @@ -87,7 +87,7 @@ Diagnostic data from the server itself is output in the `out/out.log` file.

### Conformance tests status

Current status: 6/78 tests pass
Current status: 6/79 tests pass

Known issues:

Expand Down
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ lazy val noPublish = List(
lazy val Versions = new {
val grpc = "1.68.1"
val http4s = "0.23.29"
val logback = "1.5.12"
}

lazy val core = project
Expand Down Expand Up @@ -51,6 +52,8 @@ lazy val core = project
"org.http4s" %% "http4s-client" % Versions.http4s % Test,

"org.scalatest" %% "scalatest" % "3.2.19" % Test,

"ch.qos.logback" % "logback-classic" % Versions.logback % Test,
),
)

Expand All @@ -63,7 +66,7 @@ lazy val conformance = project
libraryDependencies ++= Seq(
"org.http4s" %% "http4s-ember-server" % Versions.http4s,

"ch.qos.logback" % "logback-classic" % "1.5.12" % Runtime,
"ch.qos.logback" % "logback-classic" % Versions.logback % Runtime,
),
)

Expand Down
2 changes: 1 addition & 1 deletion conformance-suite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ features:
stream_types: [ STREAM_TYPE_UNARY ]
supports_tls: false
supports_trailers: false
supports_connect_get: false
supports_connect_get: true
supports_message_receive_limit: false
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,48 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai

private val logger = LoggerFactory.getLogger(getClass)

override def unary(request: UnaryRequest, ctx: Metadata): F[(UnaryResponse, Metadata)] = {
val responseDefinition = request.getResponseDefinition
override def unary(
request: UnaryRequest,
ctx: Metadata
): F[(UnaryResponse, Metadata)] = {
for
payload <- handleUnaryRequest(
request.getResponseDefinition,
Seq(request.toProtoAny),
ctx
)
yield (UnaryResponse(payload.some), new Metadata())
}

val trailers = new Metadata()
responseDefinition.responseTrailers.foreach { h =>
val key = Metadata.Key.of(h.name, Metadata.ASCII_STRING_MARSHALLER)
h.value.foreach(v => trailers.put(key, v))
}
override def idempotentUnary(
request: IdempotentUnaryRequest,
ctx: Metadata,
): F[(IdempotentUnaryResponse, Metadata)] = {
for
payload <- handleUnaryRequest(
request.getResponseDefinition,
Seq(request.toProtoAny),
ctx
)
yield (IdempotentUnaryResponse(payload.some), new Metadata())
}

private def handleUnaryRequest(
responseDefinition: UnaryResponseDefinition,
requests: Seq[com.google.protobuf.any.Any],
ctx: Metadata,
): F[ConformancePayload] = {
//val trailers = new Metadata()
//responseDefinition.responseTrailers.foreach { h =>
// val key = Metadata.Key.of(h.name, Metadata.ASCII_STRING_MARSHALLER)
// h.value.foreach(v => trailers.put(key, v))
//}

val requestInfo = ConformancePayload.RequestInfo(
requestHeaders = mkConformanceHeaders(ctx),
timeoutMs = extractTimeout(ctx),
requests = requests
)

val responseData = responseDefinition.response match {
case UnaryResponseDefinition.Response.ResponseData(bs) =>
Expand All @@ -37,39 +71,33 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai
val status = Status.fromCodeValue(code.value)
.withDescription(message.orNull)
.augmentDescription(
TextFormat.printToSingleLineUnicodeString(
ConformancePayload.RequestInfo(
requests = Seq(request.toProtoAny)
).toProtoAny
)
TextFormat.printToSingleLineUnicodeString(requestInfo.toProtoAny)
)

throw new StatusRuntimeException(status, trailers)
throw new StatusRuntimeException(status)
}

val timeout = Option(ctx.get(Metadata.Key.of("grpc-timeout", Metadata.ASCII_STRING_MARSHALLER)))
.map(v => v.substring(0, v.length - 1).toLong / 1000)

val payload = ConformancePayload(
data = responseData.getOrElse(ByteString.EMPTY),
requestInfo = ConformancePayload.RequestInfo(
requestHeaders = mkConformanceHeaders(ctx),
timeoutMs = timeout,
requests = Seq(request.toProtoAny),
connectGetInfo = None,
).some
)

Async[F].sleep(Duration(responseDefinition.responseDelayMs, TimeUnit.MILLISECONDS)) *>
Async[F].pure((UnaryResponse(payload.some), trailers))
Async[F].pure(ConformancePayload(
responseData.getOrElse(ByteString.EMPTY),
requestInfo.some
))
}

private def keyof(key: String): Metadata.Key[String] =
Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)

private def mkConformanceHeaders(metadata: Metadata): Seq[Header] = {
metadata.keys().asScala.map { key =>
Header(key, metadata.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)).asScala.toSeq)
Header(key, metadata.getAll(keyof(key)).asScala.toSeq)
}.toSeq
}

private def extractTimeout(metadata: Metadata): Option[Long] = {
Option(metadata.get(keyof("grpc-timeout")))
.map(v => v.substring(0, v.length - 1).toLong / 1000)
}

override def serverStream(
request: ServerStreamRequest,
ctx: Metadata
Expand All @@ -91,8 +119,4 @@ class ConformanceServiceImpl[F[_] : Async] extends ConformanceServiceFs2GrpcTrai
ctx: Metadata
): F[(UnimplementedResponse, Metadata)] = ???

override def idempotentUnary(
request: IdempotentUnaryRequest,
ctx: Metadata
): F[(IdempotentUnaryResponse, Metadata)] = ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ object Main extends IOApp.Simple {
p.withTypeRegistry(
TypeRegistry.default
.addMessage[connectrpc.conformance.v1.UnaryRequest]
.addMessage[connectrpc.conformance.v1.IdempotentUnaryRequest]
.addMessage[connectrpc.conformance.v1.ConformancePayload.RequestInfo]
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ import cats.effect.Async
import cats.effect.kernel.Resource
import cats.implicits.*
import fs2.compression.Compression
import fs2.{Chunk, Stream}
import io.grpc.*
import io.grpc.MethodDescriptor.MethodType
import io.grpc.stub.MetadataUtils
import org.http4s.*
import org.http4s.dsl.Http4sDsl
import org.http4s.headers.`Content-Type`
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.Headers.{`Connect-Timeout-Ms`, `X-Test-Case-Name`}
import org.ivovk.connect_rpc_scala.http.Headers.*
import org.ivovk.connect_rpc_scala.http.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.QueryParams.*
import org.slf4j.{Logger, LoggerFactory}
import scalapb.grpc.ClientCalls
import scalapb.json4s.{JsonFormat, Printer}
Expand Down Expand Up @@ -54,6 +56,41 @@ object ConnectRpcHttpRoutes {
ipChannel <- InProcessChannelBridge.create(services, configuration.waitForShutdown)
yield
HttpRoutes.of[F] {
case [email protected] -> Root / serviceName / methodName :? EncodingQP(contentType) +& MessageQP(message) =>
val grpcMethod = grpcMethodName(serviceName, methodName)

codecRegistry.byContentType(contentType) match {
case Some(codec) =>
given MessageCodec[F] = codec

val media = Media[F](Stream.chunk(Chunk.array(message.getBytes)), req.headers)

methodRegistry.get(grpcMethod) match {
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
case Some(entry) if entry.methodDescriptor.isSafe || true =>
entry.methodDescriptor.getType match
case MethodType.UNARY =>
handleUnary(dsl, entry, media, ipChannel)
case unsupported =>
NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
))
case Some(_) =>
Forbidden(connectrpc.Error(
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
message = s"Method supports calling using POST: $grpcMethod".some
))
case None =>
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: $grpcMethod".some
))
}
case None =>
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
}
case [email protected] -> Root / serviceName / methodName =>
val grpcMethod = grpcMethodName(serviceName, methodName)
val contentType = req.headers.get[`Content-Type`].map(_.mediaType)
Expand All @@ -79,7 +116,8 @@ object ConnectRpcHttpRoutes {
))
}
case None =>
UnsupportedMediaType(s"Unsupported Content-Type header ${contentType.map(_.show).orNull}")
UnsupportedMediaType(s"Unsupported content-type ${contentType.map(_.show).orNull}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
}
}
}
Expand All @@ -88,7 +126,7 @@ object ConnectRpcHttpRoutes {
private def handleUnary[F[_] : Async](
dsl: Http4sDsl[F],
entry: RegistryEntry,
req: Request[F],
req: Media[F],
channel: Channel
)(using codec: MessageCodec[F]): F[Response[F]] = {
import dsl.*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.ivovk.connect_rpc_scala.http

import org.http4s.MediaType

import scala.annotation.targetName

object MediaTypes {

@targetName("applicationJson")
val `application/json`: MediaType = MediaType.application.json

@targetName("applicationProto")
val `application/proto`: MediaType = MediaType.unsafeParse("application/proto")

val allSupported: Seq[MediaType] = List(`application/json`, `application/proto`)

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class JsonMessageCodec[F[_] : Sync : Compression](jsonPrinter: Printer) extends

private val logger: Logger = LoggerFactory.getLogger(getClass)

override val mediaType: MediaType = MediaType.application.`json`
override val mediaType: MediaType = MediaTypes.`application/json`

override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val charset = m.charset.getOrElse(Charset.`UTF-8`).nioCharset
Expand Down Expand Up @@ -72,8 +72,7 @@ class ProtoMessageCodec[F[_] : Async : Compression] extends MessageCodec[F] {

private val logger: Logger = LoggerFactory.getLogger(getClass)

override val mediaType: MediaType =
MediaType.unsafeParse("application/proto")
override val mediaType: MediaType = MediaTypes.`application/proto`

override def decode[A <: Message](m: Media[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val f = toInputStreamResource(decompressed(m)).use { is =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.ivovk.connect_rpc_scala.http

import org.http4s.dsl.impl.QueryParamDecoderMatcher
import org.http4s.{Charset, MediaType, ParseFailure, QueryParamDecoder}

import java.net.URLDecoder

object QueryParams {

private val encodingQPDecoder = QueryParamDecoder[String].emap {
case "json" => Right(MediaTypes.`application/json`)
case "proto" => Right(MediaTypes.`application/proto`)
case other => Left(ParseFailure(other, "Invalid encoding"))
}

object EncodingQP extends QueryParamDecoderMatcher[MediaType]("encoding")(encodingQPDecoder)

private val messageQPDecoder = QueryParamDecoder[String]
.map(URLDecoder.decode(_, Charset.`UTF-8`.nioCharset))

object MessageQP extends QueryParamDecoderMatcher[String]("message")(messageQPDecoder)

}
13 changes: 13 additions & 0 deletions core/src/test/protobuf/TestService.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ package org.ivovk.connect_rpc_scala.test;

service TestService {
rpc Add(AddRequest) returns (AddResponse) {}

// This method can be called using GET request
rpc Get(GetRequest) returns (GetResponse) {
option idempotency_level = NO_SIDE_EFFECTS;
}
}

message AddRequest {
Expand All @@ -14,3 +19,11 @@ message AddRequest {
message AddResponse {
int32 sum = 1;
}

message GetRequest {
string key = 1;
}

message GetResponse {
string value = 1;
}
16 changes: 16 additions & 0 deletions core/src/test/resources/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE configuration>

<configuration>

<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%-4relative %-5level %logger{35} -%kvp- %msg%n</pattern>
</encoder>
</appender>

<root level="WARN">
<appender-ref ref="STDOUT"/>
</root>
<logger name="org.ivovk.connect_rpc_scala" level="TRACE"/>
</configuration>
Loading