diff --git a/components/client/pom.xml b/components/client/pom.xml
index 33fa8f660..792192974 100644
--- a/components/client/pom.xml
+++ b/components/client/pom.xml
@@ -32,6 +32,11 @@
linux-aarch_64
+
+ io.projectreactor.netty
+ reactor-netty
+
+
com.hotels.styx
styx-api
@@ -78,16 +83,19 @@
org.hamcrest
hamcrest
+ test
org.scalatest
scalatest_${scala.version}
+ test
org.mockito
mockito-core
+ test
@@ -96,6 +104,24 @@
test
+
+ io.mockk
+ mockk-jvm
+ test
+
+
+
+ com.squareup.okhttp3
+ mockwebserver
+ test
+
+
+
+ com.squareup.okhttp3
+ okhttp-tls
+ test
+
+
diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt
new file mode 100644
index 000000000..c8762e9a0
--- /dev/null
+++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt
@@ -0,0 +1,330 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.hotels.styx.api.HttpHeaderNames.CONTENT_LENGTH
+import com.hotels.styx.api.HttpHeaderNames.HOST
+import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.HttpMethod
+import com.hotels.styx.api.Id
+import com.hotels.styx.api.LiveHttpRequest
+import com.hotels.styx.api.LiveHttpResponse
+import com.hotels.styx.api.exceptions.NoAvailableHostsException
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.RemoteHost
+import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer
+import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy
+import com.hotels.styx.api.extension.service.StickySessionConfig
+import com.hotels.styx.client.StyxHeaderConfig.ORIGIN_ID_DEFAULT
+import com.hotels.styx.client.retry.RetryNTimes
+import com.hotels.styx.client.stickysession.StickySessionCookie.newStickySessionCookie
+import com.hotels.styx.client.stickysession.StickySessionLoadBalancingStrategy
+import com.hotels.styx.ext.newRequest
+import com.hotels.styx.ext.newResponse
+import com.hotels.styx.metrics.CentralisedMetrics
+import org.reactivestreams.Publisher
+import reactor.core.publisher.Mono
+import java.util.Objects.nonNull
+import java.util.Optional
+
+/**
+ * A configurable HTTP client with integration of Reactor Netty client
+ */
+class ReactorBackendServiceClient(
+ private val id: Id,
+ private val rewriteRuleset: RewriteRuleset,
+ private val originsRestrictionCookieName: String?,
+ private val stickySessionConfig: StickySessionConfig,
+ private val originIdHeader: CharSequence,
+ private val loadBalancer: LoadBalancer,
+ private val retryPolicy: RetryPolicy,
+ private val metrics: CentralisedMetrics,
+ private val overrideHostHeader: Boolean,
+) : BackendServiceClient {
+ override fun sendRequest(
+ request: LiveHttpRequest,
+ context: HttpInterceptor.Context,
+ ): Publisher = sendRequest(rewriteUrl(request), emptyList(), 0, context)
+
+ private fun sendRequest(
+ request: LiveHttpRequest,
+ previousOrigins: List,
+ attempt: Int,
+ context: HttpInterceptor.Context,
+ ): Publisher {
+ if (attempt >= MAX_RETRY_ATTEMPTS) {
+ return Mono.error(NoAvailableHostsException(id))
+ }
+ val remoteHost = selectOrigin(request)
+ return if (remoteHost.isPresent) {
+ val host = remoteHost.get()
+ val updatedRequest = shouldOverrideHostHeader(host, request)
+ val newPreviousOrigins = previousOrigins.toMutableList()
+ newPreviousOrigins.add(host)
+ Mono.from(host.hostClient().handle(updatedRequest, context))
+ .doOnNext { recordErrorStatusMetrics(it) }
+ .map { response ->
+ response.addStickySessionIdentifier(host.origin())
+ .removeUnexpectedResponseBody(updatedRequest)
+ .removeRedundantContentLengthHeader()
+ .addOriginId(host.id())
+ .let { LiveHttpResponse.Builder(it).request(updatedRequest).build() }
+ }
+ .onErrorResume { cause ->
+ val retryContext = RetryPolicyContext(id, attempt + 1, cause, updatedRequest, previousOrigins)
+ retry(updatedRequest, retryContext, newPreviousOrigins, attempt + 1, cause, context)
+ }
+ } else {
+ val retryContext = RetryPolicyContext(id, attempt + 1, null, request, previousOrigins)
+ retry(request, retryContext, previousOrigins, attempt + 1, NoAvailableHostsException(id), context)
+ }
+ }
+
+ private fun recordErrorStatusMetrics(response: LiveHttpResponse) {
+ val statusCode = response.status().code()
+ if (statusCode.isErrorStatus()) {
+ metrics.proxy.client.errorResponseFromOriginByStatus(statusCode).increment()
+ }
+ }
+
+ private fun Int.isErrorStatus() = this >= 400
+
+ private fun bodyNeedsToBeRemoved(
+ request: LiveHttpRequest,
+ response: LiveHttpResponse,
+ ) = isHeadRequest(request) || isBodilessResponse(response)
+
+ private fun responseWithoutBody(response: LiveHttpResponse) =
+ response.newResponse {
+ header(CONTENT_LENGTH, 0)
+ removeHeader(TRANSFER_ENCODING)
+ removeBody()
+ }
+
+ private fun isBodilessResponse(response: LiveHttpResponse): Boolean =
+ when (val code = response.status().code()) {
+ 204, 304 -> true
+ else -> code / 100 == 1
+ }
+
+ private fun isHeadRequest(request: LiveHttpRequest): Boolean = request.method() == HttpMethod.HEAD
+
+ private fun shouldOverrideHostHeader(
+ host: RemoteHost,
+ request: LiveHttpRequest,
+ ): LiveHttpRequest =
+ if (overrideHostHeader && !host.origin().host().isNullOrBlank()) {
+ request.newRequest { header(HOST, host.origin().host()) }
+ } else {
+ request
+ }
+
+ private fun LiveHttpResponse.addOriginId(originId: Id): LiveHttpResponse =
+ newResponse {
+ header(originIdHeader, originId)
+ }
+
+ private fun retry(
+ request: LiveHttpRequest,
+ retryContext: RetryPolicyContext,
+ previousOrigins: List,
+ attempt: Int,
+ cause: Throwable,
+ context: HttpInterceptor.Context,
+ ): Mono {
+ val lbContext: LoadBalancer.Preferences =
+ object : LoadBalancer.Preferences {
+ override fun preferredOrigins(): Optional = Optional.empty()
+
+ override fun avoidOrigins(): List = previousOrigins.map { it.origin() }
+ }
+ return if (retryPolicy.evaluate(retryContext, loadBalancer, lbContext).shouldRetry()) {
+ Mono.from(sendRequest(request, previousOrigins, attempt, context))
+ } else {
+ Mono.error(cause)
+ }
+ }
+
+ private fun LiveHttpResponse.removeUnexpectedResponseBody(request: LiveHttpRequest) =
+ if (bodyNeedsToBeRemoved(request, this)) {
+ responseWithoutBody(this)
+ } else {
+ this
+ }
+
+ private fun LiveHttpResponse.removeRedundantContentLengthHeader() =
+ if (contentLength().isPresent && chunked()) {
+ newResponse {
+ removeHeader(CONTENT_LENGTH)
+ }
+ } else {
+ this
+ }
+
+ private fun selectOrigin(rewrittenRequest: LiveHttpRequest): Optional {
+ val preferences =
+ object : LoadBalancer.Preferences {
+ override fun preferredOrigins(): Optional {
+ return if (nonNull(originsRestrictionCookieName)) {
+ rewrittenRequest.cookie(originsRestrictionCookieName)
+ .map { it.value() }
+ .or { rewrittenRequest.cookie("styx_origin_$id").map { it.value() } }
+ } else {
+ rewrittenRequest.cookie("styx_origin_$id").map { it.value() }
+ }
+ }
+
+ override fun avoidOrigins(): List = emptyList()
+ }
+ return loadBalancer.choose(preferences)
+ }
+
+ private fun LiveHttpResponse.addStickySessionIdentifier(origin: Origin): LiveHttpResponse =
+ if (loadBalancer is StickySessionLoadBalancingStrategy) {
+ val maxAge = stickySessionConfig.stickySessionTimeoutSeconds()
+ newResponse {
+ addCookies(newStickySessionCookie(id, origin.id(), maxAge))
+ }
+ } else {
+ this
+ }
+
+ private fun rewriteUrl(request: LiveHttpRequest): LiveHttpRequest = rewriteRuleset.rewrite(request)
+
+ private class RetryPolicyContext(
+ private val appId: Id,
+ private val retryCount: Int,
+ private val lastException: Throwable?,
+ private val request: LiveHttpRequest,
+ private val previouslyUsedOrigins: Iterable,
+ ) : RetryPolicy.Context {
+ override fun appId(): Id = appId
+
+ override fun currentRetryCount(): Int = retryCount
+
+ override fun lastException(): Optional = Optional.ofNullable(lastException)
+
+ override fun currentRequest(): LiveHttpRequest = request
+
+ override fun previousOrigins(): Iterable = previouslyUsedOrigins
+
+ override fun toString(): String =
+ buildString {
+ append("appId", appId)
+ append(", retryCount", retryCount)
+ append(", lastException", lastException)
+ append(", request", request.url())
+ append(", previouslyUsedOrigins", previouslyUsedOrigins)
+ }
+
+ fun hosts(): String = hosts(previouslyUsedOrigins)
+
+ companion object {
+ private fun hosts(origins: Iterable): String =
+ origins.asSequence().map { it.origin().hostAndPortString() }.joinToString(", ")
+ }
+ }
+
+ override fun toString(): String =
+ buildString {
+ append("id", id)
+ append(", stickySessionConfig", stickySessionConfig)
+ append(", retryPolicy", retryPolicy)
+ append(", rewriteRuleset", rewriteRuleset)
+ append(", loadBalancingStrategy", loadBalancer)
+ append(", overrideHostHeader", overrideHostHeader)
+ }
+
+ /**
+ * A builder for [ReactorBackendServiceClient].
+ */
+ class Builder(val id: Id) {
+ private var originStatsFactory: OriginStatsFactory? = null
+ private var loadBalancer: LoadBalancer? = null
+ private var metrics: CentralisedMetrics? = null
+ private var rewriteRuleset: RewriteRuleset = RewriteRuleset(emptyList())
+ private var originsRestrictionCookieName: String? = null
+ private var stickySessionConfig: StickySessionConfig = StickySessionConfig.stickySessionDisabled()
+ private var originIdHeader: CharSequence = ORIGIN_ID_DEFAULT
+ private var retryPolicy: RetryPolicy = RetryNTimes(3)
+ private var overrideHostHeader: Boolean = false
+
+ fun rewriteRules(rewriteRuleset: RewriteRuleset) =
+ apply {
+ this.rewriteRuleset = rewriteRuleset
+ }
+
+ fun originStatsFactory(originStatsFactory: OriginStatsFactory) =
+ apply {
+ this.originStatsFactory = originStatsFactory
+ }
+
+ fun originsRestrictionCookieName(originsRestrictionCookieName: String?) =
+ apply {
+ this.originsRestrictionCookieName = originsRestrictionCookieName
+ }
+
+ fun stickySessionConfig(stickySessionConfig: StickySessionConfig) =
+ apply {
+ this.stickySessionConfig = stickySessionConfig
+ }
+
+ fun originIdHeader(originIdHeader: CharSequence) =
+ apply {
+ this.originIdHeader = originIdHeader
+ }
+
+ fun loadBalancer(loadBalancer: LoadBalancer) =
+ apply {
+ this.loadBalancer = loadBalancer
+ }
+
+ fun retryPolicy(retryPolicy: RetryPolicy) =
+ apply {
+ this.retryPolicy = retryPolicy
+ }
+
+ fun metrics(metrics: CentralisedMetrics) =
+ apply {
+ this.metrics = metrics
+ }
+
+ fun overrideHostHeader(overrideHostHeader: Boolean) =
+ apply {
+ this.overrideHostHeader = overrideHostHeader
+ }
+
+ fun build(): ReactorBackendServiceClient =
+ ReactorBackendServiceClient(
+ id,
+ rewriteRuleset,
+ originsRestrictionCookieName,
+ stickySessionConfig,
+ originIdHeader,
+ checkNotNull(loadBalancer) { "loadBalancer is required" },
+ retryPolicy,
+ checkNotNull(metrics) { "metrics is required" },
+ overrideHostHeader,
+ )
+ }
+
+ companion object {
+ private const val MAX_RETRY_ATTEMPTS = 3
+
+ @JvmStatic fun newHttpClientBuilder(backendServiceId: Id): Builder = Builder(backendServiceId)
+ }
+}
diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt
new file mode 100644
index 000000000..ff35dd6d9
--- /dev/null
+++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt
@@ -0,0 +1,107 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.hotels.styx.api.HttpVersion
+import com.hotels.styx.api.HttpVersion.HTTP_1_1
+import com.hotels.styx.api.HttpVersion.HTTP_2
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.api.extension.service.ConnectionPoolSettings
+import com.hotels.styx.api.extension.service.ConnectionPoolSettings.defaultConnectionPoolSettings
+import reactor.netty.http.HttpProtocol
+import reactor.netty.http.HttpProtocol.H2
+import reactor.netty.http.HttpProtocol.HTTP11
+import reactor.netty.http.client.Http2AllocationStrategy
+import reactor.netty.resources.ConnectionProvider
+import java.time.Duration
+
+class ReactorConnectionPool(
+ connectionPoolSettings: ConnectionPoolSettings = defaultConnectionPoolSettings(),
+ private val httpVersion: HttpVersion = HTTP_1_1,
+ private val backendService: BackendService = BackendService.Builder().build(),
+) {
+ val pendingAcquireTimeoutMillis: Int = connectionPoolSettings.pendingConnectionTimeoutMillis()
+ val connectTimeoutMillis: Int = connectionPoolSettings.connectTimeoutMillis()
+ val maxConnections = connectionPoolSettings.maxConnectionsPerHost()
+ val maxIdleTimeMillis: Int = backendService.responseTimeoutMillis()
+ val h2MaxConnections = connectionPoolSettings.http2ConnectionPoolSettings().maxConnections ?: DEFAULT_H2_MAX_CONNECTIONS
+ val h2MinConnections = connectionPoolSettings.http2ConnectionPoolSettings().minConnections ?: DEFAULT_H2_MIN_CONNECTIONS
+ val h2MaxConcurrentStreams =
+ connectionPoolSettings.http2ConnectionPoolSettings().maxStreamsPerConnection
+ ?: DEFAULT_H2_MAX_STREAMS_PER_CONNECTION
+ val pendingAcquireMaxCount: Int =
+ if (HTTP_2 == httpVersion) {
+ connectionPoolSettings.http2ConnectionPoolSettings().maxPendingStreamsPerHost
+ ?: DEFAULT_H2_MAX_PENDING_STREAMS_PER_HOST
+ } else {
+ connectionPoolSettings.maxPendingConnectionsPerHost()
+ }
+ val connectionExpirationSeconds: Long = connectionPoolSettings.connectionExpirationSeconds()
+
+ // TODO: needs some investigation on how inflight connections should be shut down
+ val disposeTimeoutMillis: Int = backendService.responseTimeoutMillis()
+
+ private val connectionProviderBuilder: ConnectionProvider.Builder = initiateConnectionProvider()
+
+ fun getConnectionProvider(origin: Origin): ConnectionProvider = connectionProviderBuilder.name(origin.id().toString()).build()
+
+ fun supportedHttpProtocols(): Array =
+ if (HTTP_2 == httpVersion) {
+ arrayOf(H2, HTTP11)
+ } else {
+ arrayOf(HTTP11)
+ }
+
+ fun isHttp2(): Boolean = HTTP_2 == httpVersion
+
+ private fun initiateConnectionProvider(): ConnectionProvider.Builder {
+ // Configuration details: https://projectreactor.io/docs/netty/release/reference/index.html#_connection_pool_2
+ val builder =
+ ConnectionProvider.builder(backendService.id().toString())
+ .metrics(true)
+ .pendingAcquireTimeout(Duration.ofMillis(pendingAcquireTimeoutMillis.toLong()))
+ .disposeTimeout(Duration.ofMillis(disposeTimeoutMillis.toLong()))
+ .maxIdleTime(Duration.ofMillis(maxIdleTimeMillis.toLong()))
+ .maxLifeTime(Duration.ofSeconds(connectionExpirationSeconds))
+ .evictInBackground(Duration.ofSeconds(DEFAULT_CONNECTION_EVICTION_FREQUENCY_SECONDS))
+ .maxConnections(maxConnections)
+ .pendingAcquireMaxCount(pendingAcquireMaxCount)
+ .lifo()
+
+ if (HTTP_2 == httpVersion) {
+ // Increasing the number of max/min connections could result in
+ // unnecessary connection creation and cause overheads on concurrency
+ val allocationStrategy =
+ Http2AllocationStrategy.builder()
+ .maxConcurrentStreams(h2MaxConcurrentStreams.toLong())
+ .maxConnections(h2MaxConnections)
+ .minConnections(h2MinConnections)
+ .build()
+ return builder.allocationStrategy(allocationStrategy)
+ }
+
+ return builder
+ }
+
+ companion object {
+ private const val DEFAULT_H2_MAX_CONNECTIONS = 10
+ private const val DEFAULT_H2_MIN_CONNECTIONS = 2
+ private const val DEFAULT_H2_MAX_STREAMS_PER_CONNECTION = 1024
+ private const val DEFAULT_H2_MAX_PENDING_STREAMS_PER_HOST = 200
+ private const val DEFAULT_CONNECTION_EVICTION_FREQUENCY_SECONDS = 30L
+ }
+}
diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt
new file mode 100644
index 000000000..256563716
--- /dev/null
+++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt
@@ -0,0 +1,397 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.hotels.styx.api.Buffers
+import com.hotels.styx.api.ByteStream
+import com.hotels.styx.api.HttpHeaderNames.HOST
+import com.hotels.styx.api.HttpHeaders
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.HttpResponseStatus.statusWithCode
+import com.hotels.styx.api.HttpVersion.httpVersion
+import com.hotels.styx.api.LiveHttpRequest
+import com.hotels.styx.api.LiveHttpResponse
+import com.hotels.styx.api.exceptions.OriginUnreachableException
+import com.hotels.styx.api.exceptions.ResponseTimeoutException
+import com.hotels.styx.api.exceptions.TransportLostException
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancingMetric
+import com.hotels.styx.client.ReactorHostHttpClient.ErrorType.REQUEST
+import com.hotels.styx.client.ReactorHostHttpClient.ErrorType.RESPONSE
+import com.hotels.styx.client.applications.OriginStats
+import com.hotels.styx.client.connectionpool.LatencyTiming.finishRequestTiming
+import com.hotels.styx.client.connectionpool.LatencyTiming.startResponseTiming
+import com.hotels.styx.client.connectionpool.MaxPendingConnectionTimeoutException
+import com.hotels.styx.client.connectionpool.MaxPendingConnectionsExceededException
+import com.hotels.styx.common.logging.HttpRequestMessageLogger
+import com.hotels.styx.metrics.CentralisedMetrics
+import com.hotels.styx.metrics.Deleter
+import com.hotels.styx.metrics.ReactorNettyMeterFilter
+import com.hotels.styx.metrics.TimerMetric
+import io.netty.channel.Channel
+import io.netty.channel.ChannelOption.CONNECT_TIMEOUT_MILLIS
+import io.netty.channel.ChannelOption.SO_KEEPALIVE
+import io.netty.channel.ChannelOption.TCP_NODELAY
+import io.netty.handler.codec.DecoderException
+import io.netty.handler.codec.http.HttpMethod
+import io.netty.handler.ssl.SslHandshakeTimeoutException
+import io.netty.handler.timeout.ReadTimeoutException
+import io.netty.resolver.dns.DnsNameResolverException
+import org.reactivestreams.Publisher
+import org.slf4j.LoggerFactory.getLogger
+import reactor.core.publisher.Flux
+import reactor.core.publisher.Mono
+import reactor.netty.ByteBufFlux
+import reactor.netty.Connection
+import reactor.netty.Metrics.REGISTRY
+import reactor.netty.NettyInbound
+import reactor.netty.http.client.HttpClient
+import reactor.netty.http.client.HttpClientConfig
+import reactor.netty.http.client.HttpClientResponse
+import reactor.netty.http.client.PrematureCloseException
+import reactor.netty.internal.shaded.reactor.pool.PoolAcquirePendingLimitException
+import reactor.netty.internal.shaded.reactor.pool.PoolAcquireTimeoutException
+import reactor.netty.resources.ConnectionProvider
+import reactor.netty.resources.LoopResources
+import reactor.netty.tcp.SslProvider.SslContextSpec
+import java.net.SocketException
+import java.net.UnknownHostException
+import java.time.Duration
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.function.Consumer
+import javax.annotation.concurrent.ThreadSafe
+
+/**
+ * A Reactor HTTP Client for proxying to an individual origin host.
+ */
+@ThreadSafe
+class ReactorHostHttpClient private constructor(
+ private val origin: Origin,
+ private val connectionPool: ReactorConnectionPool,
+ private val httpConfig: HttpConfig,
+ private val h2SslProvider: Consumer?,
+ private val h11SslHandler: Consumer?,
+ private val responseTimeoutMillis: Int,
+ private val httpRequestMessageLogger: HttpRequestMessageLogger?,
+ private val originStatsFactory: OriginStatsFactory,
+ private val metrics: CentralisedMetrics,
+ private val eventLoopGroup: LoopResources,
+ private val doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)? = null,
+) : HostHttpClient {
+ private val httpClient: HttpClient
+ private val ongoingRequestCount: AtomicInteger = AtomicInteger()
+ private val pendingAcquireCount: AtomicInteger = AtomicInteger()
+ private val stats: Stats = Stats()
+ private val connectionProvider: ConnectionProvider = connectionPool.getConnectionProvider(origin)
+
+ @Volatile
+ private var ongoingRequestsDeleter: Deleter? = null
+
+ @Volatile
+ private var originStats: OriginStats? = null
+
+ init {
+ require(h2SslProvider == null || h11SslHandler == null) {
+ "There can only be one type of SSL context"
+ }
+
+ REGISTRY.config().meterFilter(ReactorNettyMeterFilter(origin))
+ httpClient = HttpClient.create(connectionProvider).init()
+ httpClient.warmup().block()
+ }
+
+ override fun sendRequest(
+ request: LiveHttpRequest,
+ context: HttpInterceptor.Context?,
+ ): Publisher =
+ httpClient.addListeners(request, context)
+ .sendRequest(request, context)
+ .addStyxResponseListeners(request)
+
+ override fun close() {
+ connectionProvider.dispose()
+ ongoingRequestsDeleter?.delete()
+ }
+
+ override fun loadBalancingMetric(): LoadBalancingMetric = LoadBalancingMetric(stats.ongoingRequestCount())
+
+ internal fun configuration(): HttpClientConfig = httpClient.configuration()
+
+ private fun HttpClient.init(): HttpClient =
+ this.host(origin.host())
+ .port(origin.port())
+ .option(CONNECT_TIMEOUT_MILLIS, connectionPool.connectTimeoutMillis)
+ .option(TCP_NODELAY, true)
+ .option(SO_KEEPALIVE, true)
+ .protocol(*connectionPool.supportedHttpProtocols())
+ .addSslContext()
+ .compress(httpConfig.compress())
+ .httpResponseDecoder {
+ it.maxInitialLineLength(httpConfig.maxInitialLength())
+ .maxHeaderSize(httpConfig.maxHeadersSize())
+ .maxChunkSize(httpConfig.maxChunkSize())
+ }
+ .responseTimeout(Duration.ofMillis(responseTimeoutMillis.toLong()))
+ .metrics(true) { _ -> origin.id().toString() }
+ .runOn(eventLoopGroup)
+ .disableRetry(true)
+ .resolver {
+ it.cacheMaxTimeToLive(Duration.ofSeconds(DNS_MAX_CACHE_TIME_TO_LIVE_SECONDS))
+ .cacheNegativeTimeToLive(Duration.ofSeconds(DNS_NEGATIVE_TIME_TO_LIVE_SECONDS))
+ }
+ .doOnChannelInit { _, _, _ ->
+ if (ongoingRequestsDeleter == null) {
+ ongoingRequestsDeleter =
+ metrics.proxy.client.ongoingRequests(origin)
+ .register { loadBalancingMetric().ongoingActivities() }
+ }
+ }
+ .doOnConnect { pendingAcquireCount.incrementAndGet() }
+ .doOnConnected { pendingAcquireCount.decrementAndGet() }
+
+ private fun HttpClient.addSslContext(): HttpClient =
+ if (connectionPool.isHttp2() && h2SslProvider != null) {
+ secure(h2SslProvider)
+ } else if (!connectionPool.isHttp2() && h11SslHandler != null) {
+ doOnChannelInit { _, channel, _ ->
+ h11SslHandler.accept(channel)
+ }
+ } else {
+ this
+ }
+
+ private fun HttpClient.addListeners(
+ request: LiveHttpRequest,
+ context: HttpInterceptor.Context?,
+ ): HttpClient {
+ var requestLatencyTiming: TimerMetric.Stopper? = null
+ var timeToFirstByteTiming: TimerMetric.Stopper? = null
+
+ if (originStats == null) {
+ originStats = originStatsFactory.originStats(origin)
+ }
+
+ return this
+ .doOnRequest { _, _ ->
+ httpRequestMessageLogger?.logRequest(request, origin)
+ requestLatencyTiming = originStats!!.requestLatencyTimer().startTiming()
+ timeToFirstByteTiming = originStats!!.timeToFirstByteTimer().startTiming()
+ }
+ .doAfterRequest { _, _ ->
+ // Request timing started at Styx Server first receiving requests
+ finishRequestTiming(context)
+ requestLatencyTiming?.stop()
+ }
+ .doOnRequestError { _, throwable ->
+ logError(REQUEST, request, throwable)
+ requestLatencyTiming?.stop()
+ }
+ .doOnResponse { response, _ ->
+ updateHttpResponseCounters(originStats!!, response.status().code())
+ timeToFirstByteTiming?.stop()
+ doOnResponse?.invoke(response, context)
+ }
+ .doAfterResponseSuccess { _, _ ->
+ // Response timing stopping at Styx Server returning response to end users
+ startResponseTiming(metrics, context)
+ }
+ .doOnResponseError { _, throwable ->
+ logError(RESPONSE, request, throwable)
+ timeToFirstByteTiming?.stop()
+ }
+ }
+
+ private fun HttpClient.sendRequest(
+ request: LiveHttpRequest,
+ context: HttpInterceptor.Context?,
+ ): Mono {
+ context?.add(ORIGINID_CONTEXT_KEY, origin.id())
+ return this
+ .headers {
+ request.headers().forEach { name, value ->
+ it.add(name, value)
+ }
+ if (!request.header(HOST).isPresent) {
+ it.add(HOST, origin.hostAndPortString())
+ }
+ }
+ .request(HttpMethod(request.method().name()))
+ .uri(request.url().toString())
+ .send(
+ ByteBufFlux.fromInbound(
+ Flux.from(request.body())
+ .map(Buffers::toByteBuf)
+ .flatMapSequential { Flux.just(it) },
+ ).doOnCancel { originStats?.requestCancelled() },
+ )
+ .responseConnection { res: HttpClientResponse, conn: Connection ->
+ toStyxResponse(res, conn.inbound())
+ }
+ .single()
+ .onErrorMap { toStyxExceptions(it) }
+ }
+
+ private fun toStyxResponse(
+ response: HttpClientResponse,
+ nettyInbound: NettyInbound,
+ ): Mono {
+ val headersBuilder =
+ HttpHeaders.Builder().apply {
+ response.responseHeaders().forEach { add(it.key, it.value) }
+ }
+ val body = toStyxByteStream(nettyInbound)
+ return Mono.just(
+ LiveHttpResponse.Builder()
+ .status(statusWithCode(response.status().code(), response.status().toString()))
+ .version(httpVersion(response.version().text()))
+ .headers(headersBuilder.build())
+ .body(body)
+ .build(),
+ )
+ }
+
+ private fun toStyxByteStream(nettyInbound: NettyInbound): ByteStream =
+ ByteStream(
+ nettyInbound
+ .receive()
+ .retain()
+ .map(Buffers::fromByteBuf)
+ .doOnCancel { originStats?.requestCancelled() }
+ .onErrorMap { toStyxExceptions(it) },
+ )
+
+ private fun Mono.addStyxResponseListeners(request: LiveHttpRequest): Mono =
+ this.doOnNext { httpRequestMessageLogger?.logResponse(request, it) }
+ .doOnSubscribe { ongoingRequestCount.incrementAndGet() }
+ .doFinally { ongoingRequestCount.decrementAndGet() }
+
+ private fun logError(
+ type: ErrorType,
+ request: LiveHttpRequest,
+ throwable: Throwable,
+ ) = LOGGER.error(
+ """
+ |Error Handling ${type.name} request=$request exceptionClass=${throwable.javaClass.name} exceptionMessage=\"${throwable.message}\"
+ """.trimMargin(),
+ )
+
+ private fun updateHttpResponseCounters(
+ originStats: OriginStats,
+ statusCode: Int,
+ ) {
+ if (isServerError(statusCode)) {
+ originStats.requestError()
+ } else {
+ originStats.requestSuccess()
+ }
+ originStats.responseWithStatusCode(statusCode)
+ }
+
+ private fun isServerError(status: Int) = status >= 500
+
+ private fun toStyxExceptions(throwable: Throwable): Throwable =
+ when (throwable) {
+ is PoolAcquireTimeoutException ->
+ MaxPendingConnectionTimeoutException(origin, connectionPool.pendingAcquireTimeoutMillis)
+
+ is PoolAcquirePendingLimitException -> {
+ pendingAcquireCount.decrementAndGet()
+ MaxPendingConnectionsExceededException(
+ origin,
+ stats.pendingAcquireCount(),
+ connectionPool.pendingAcquireMaxCount,
+ )
+ }
+
+ is ReadTimeoutException -> ResponseTimeoutException(origin)
+
+ is SslHandshakeTimeoutException, is DnsNameResolverException, is UnknownHostException ->
+ OriginUnreachableException(origin, throwable.cause)
+
+ is SocketException, is PrematureCloseException ->
+ TransportLostException(configuration().remoteAddress().get(), origin)
+
+ is DecoderException, is IllegalArgumentException ->
+ BadHttpResponseException(origin, throwable.cause)
+
+ else -> throwable
+ }
+
+ inner class Stats {
+ fun ongoingRequestCount(): Int = ongoingRequestCount.get()
+
+ fun pendingAcquireCount(): Int = pendingAcquireCount.get()
+ }
+
+ private enum class ErrorType {
+ REQUEST,
+ RESPONSE,
+ }
+
+ /**
+ * A factory for creating ReactorHostHttpClient instances.
+ */
+ fun interface Factory {
+ fun create(
+ origin: Origin,
+ connectionPool: ReactorConnectionPool,
+ httpConfig: HttpConfig,
+ h2SslProvider: Consumer?,
+ h11SslHandler: Consumer?,
+ responseTimeoutMillis: Int,
+ httpRequestMessageLogger: HttpRequestMessageLogger?,
+ originStatsFactory: OriginStatsFactory,
+ metrics: CentralisedMetrics,
+ eventLoopGroup: LoopResources,
+ doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?,
+ ): ReactorHostHttpClient
+ }
+
+ companion object : Factory {
+ private val LOGGER = getLogger(this::class.java)
+ private const val ORIGINID_CONTEXT_KEY = "styx.originid"
+ private const val X_HTTP2_STREAM_ID = "x-http2-stream-id"
+ private const val DNS_MAX_CACHE_TIME_TO_LIVE_SECONDS = 30L
+ private const val DNS_NEGATIVE_TIME_TO_LIVE_SECONDS = 5L
+
+ override fun create(
+ origin: Origin,
+ connectionPool: ReactorConnectionPool,
+ httpConfig: HttpConfig,
+ h2SslProvider: Consumer?,
+ h11SslHandler: Consumer?,
+ responseTimeoutMillis: Int,
+ httpRequestMessageLogger: HttpRequestMessageLogger?,
+ originStatsFactory: OriginStatsFactory,
+ metrics: CentralisedMetrics,
+ eventLoopGroup: LoopResources,
+ doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?,
+ ): ReactorHostHttpClient =
+ ReactorHostHttpClient(
+ origin,
+ connectionPool,
+ httpConfig,
+ h2SslProvider,
+ h11SslHandler,
+ responseTimeoutMillis,
+ httpRequestMessageLogger,
+ originStatsFactory,
+ metrics,
+ eventLoopGroup,
+ doOnResponse,
+ )
+ }
+}
diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt
new file mode 100644
index 000000000..479bb66f7
--- /dev/null
+++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt
@@ -0,0 +1,252 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.google.common.eventbus.EventBus
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.Id
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.api.extension.service.TlsSettings
+import com.hotels.styx.client.HttpConfig.defaultHttpConfig
+import com.hotels.styx.client.healthcheck.OriginHealthStatusMonitor
+import com.hotels.styx.client.healthcheck.monitors.NoOriginHealthStatusMonitor
+import com.hotels.styx.common.QueueDrainingEventProcessor
+import com.hotels.styx.common.StyxFutures.await
+import com.hotels.styx.common.logging.HttpRequestMessageLogger
+import com.hotels.styx.metrics.CentralisedMetrics
+import io.netty.channel.Channel
+import io.netty.handler.ssl.SslContext
+import reactor.netty.http.client.HttpClientResponse
+import reactor.netty.resources.LoopResources
+import reactor.netty.tcp.SslProvider
+import reactor.netty.tcp.SslProvider.SslContextSpec
+import java.net.InetSocketAddress
+import java.util.function.Consumer
+import javax.annotation.concurrent.ThreadSafe
+import javax.net.ssl.SNIHostName
+
+/**
+ * A Reactor version inventory of the origins configured for a single application
+ */
+@ThreadSafe
+class ReactorOriginsInventory(
+ eventBus: EventBus,
+ appId: Id,
+ originHealthStatusMonitor: OriginHealthStatusMonitor,
+ private val metrics: CentralisedMetrics,
+ private val connectionPool: ReactorConnectionPool,
+ private val hostClientFactory: ReactorHostHttpClient.Factory,
+ private val httpConfig: HttpConfig,
+ private val tlsSettings: TlsSettings?,
+ private val responseTimeoutMillis: Int,
+ private val httpRequestMessageLogger: HttpRequestMessageLogger?,
+ private val originStatsFactory: OriginStatsFactory,
+ private val eventLoopGroup: LoopResources,
+ private val sslContext: SslContext?,
+ private val doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?,
+) : OriginsInventory(eventBus, originHealthStatusMonitor, appId, metrics) {
+ override val eventQueue: QueueDrainingEventProcessor = QueueDrainingEventProcessor(this, true)
+
+ override fun Origin.toMonitoredOrigin(): OriginsInventory.MonitoredOrigin = MonitoredOrigin(this)
+
+ override fun registerEvent() {
+ eventBus.register(this)
+ }
+
+ override fun addOriginStatusListener() {
+ originHealthStatusMonitor.addOriginStatusListener(this)
+ }
+
+ /**
+ * A builder for [ReactorOriginsInventory].
+ */
+ class Builder(val appId: Id) {
+ private var originHealthMonitor: OriginHealthStatusMonitor = NoOriginHealthStatusMonitor()
+ private var metrics: CentralisedMetrics? = null
+ private var eventBus = EventBus()
+ private var connectionPool: ReactorConnectionPool = ReactorConnectionPool()
+ private var hostClientFactory: ReactorHostHttpClient.Factory = ReactorHostHttpClient
+ private var initialOrigins: Set = emptySet()
+ private var httpConfig: HttpConfig = defaultHttpConfig()
+ private var tlsSettings: TlsSettings? = null
+ private var responseTimeoutMillis: Int = 60_000
+ private var httpRequestMessageLogger: HttpRequestMessageLogger? = null
+ private var originStatsFactory: OriginStatsFactory? = null
+ private var eventLoopGroup: LoopResources = LoopResources.create("$appId-client")
+ private var sslContext: SslContext? = null
+ private var doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)? = null
+
+ fun metrics(metrics: CentralisedMetrics) =
+ apply {
+ this.metrics = metrics
+ }
+
+ fun connectionPool(connectionPool: ReactorConnectionPool) =
+ apply {
+ this.connectionPool = connectionPool
+ }
+
+ fun hostClientFactory(hostClientFactory: ReactorHostHttpClient.Factory) =
+ apply {
+ this.hostClientFactory = hostClientFactory
+ }
+
+ fun originHealthMonitor(originHealthMonitor: OriginHealthStatusMonitor) =
+ apply {
+ this.originHealthMonitor = originHealthMonitor
+ }
+
+ fun eventBus(eventBus: EventBus) =
+ apply {
+ this.eventBus = eventBus
+ }
+
+ fun initialOrigins(origins: Set) =
+ apply {
+ initialOrigins = origins.toSet()
+ }
+
+ fun httpConfig(httpConfig: HttpConfig) =
+ apply {
+ this.httpConfig = httpConfig
+ }
+
+ fun tlsSettings(tlsSettings: TlsSettings?) =
+ apply {
+ this.tlsSettings = tlsSettings
+ }
+
+ fun responseTimeoutMillis(responseTimeoutMillis: Int) =
+ apply {
+ this.responseTimeoutMillis = responseTimeoutMillis
+ }
+
+ fun httpRequestMessageLogger(httpRequestMessageLogger: HttpRequestMessageLogger?) =
+ apply {
+ this.httpRequestMessageLogger = httpRequestMessageLogger
+ }
+
+ fun originStatsFactory(originStatsFactory: OriginStatsFactory) =
+ apply {
+ this.originStatsFactory = originStatsFactory
+ }
+
+ fun eventLoopGroup(eventLoopGroup: LoopResources) =
+ apply {
+ this.eventLoopGroup = eventLoopGroup
+ }
+
+ fun sslContext(sslContext: SslContext?) =
+ apply {
+ this.sslContext = sslContext
+ }
+
+ fun doOnResponse(doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?) =
+ apply {
+ this.doOnResponse = doOnResponse
+ }
+
+ fun build(): ReactorOriginsInventory {
+ await(originHealthMonitor.start())
+ val originsInventory =
+ ReactorOriginsInventory(
+ eventBus,
+ appId,
+ originHealthMonitor,
+ checkNotNull(metrics) { "metrics is required" },
+ connectionPool,
+ hostClientFactory,
+ httpConfig,
+ tlsSettings,
+ responseTimeoutMillis,
+ httpRequestMessageLogger,
+ originStatsFactory ?: OriginStatsFactory.CachingOriginStatsFactory(metrics),
+ eventLoopGroup,
+ sslContext,
+ doOnResponse,
+ )
+ initialOrigins.takeIf { it.isNotEmpty() }?.apply {
+ originsInventory.setOrigins(initialOrigins)
+ }
+ return originsInventory
+ }
+ }
+
+ inner class MonitoredOrigin(origin: Origin) : OriginsInventory.MonitoredOrigin(origin) {
+ // SNI info must be passed while using HTTP/2
+ private val h2SslProvider: Consumer? =
+ if (sslContext != null && connectionPool.isHttp2()) {
+ Consumer { sslContextSpec ->
+ sslContextSpec.sslContext(sslContext)
+ .serverNames(SNIHostName(tlsSettings?.sniHost ?: origin.host()))
+ }
+ } else {
+ null
+ }
+
+ // By default, reactor-netty sends the remote hostname as SNI server name.
+ // This is a workaround to avoid reactor-netty injecting SNI host if tlsSettings.sendSni() is false.
+ private val h11SslHandler: Consumer? =
+ if (sslContext != null && !connectionPool.isHttp2()) {
+ Consumer { channel ->
+ val sslProviderBuilder = SslProvider.builder().sslContext(sslContext)
+
+ if (tlsSettings?.sendSni() == true) {
+ sslProviderBuilder
+ .serverNames(SNIHostName(tlsSettings.sniHost ?: origin.host()))
+ .build()
+ .addSslHandler(channel, InetSocketAddress(origin.host(), origin.port()), false)
+ } else {
+ sslProviderBuilder
+ .build()
+ .addSslHandler(channel, null, false)
+ }
+ }
+ } else {
+ null
+ }
+
+ override val hostClient: ReactorHostHttpClient =
+ hostClientFactory.create(
+ origin,
+ connectionPool,
+ httpConfig,
+ h2SslProvider,
+ h11SslHandler,
+ responseTimeoutMillis,
+ httpRequestMessageLogger,
+ originStatsFactory,
+ metrics,
+ eventLoopGroup,
+ doOnResponse,
+ )
+ }
+
+ companion object {
+ @JvmStatic
+ fun newOriginsInventoryBuilder(appId: Id): Builder = Builder(appId)
+
+ @JvmStatic
+ fun newOriginsInventoryBuilder(
+ metrics: CentralisedMetrics,
+ backendService: BackendService,
+ ): Builder =
+ Builder(backendService.id())
+ .metrics(metrics)
+ .initialOrigins(backendService.origins())
+ }
+}
diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt
new file mode 100644
index 000000000..f83aa0de7
--- /dev/null
+++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt
@@ -0,0 +1,808 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.google.common.net.HostAndPort
+import com.hotels.styx.api.Eventual
+import com.hotels.styx.api.HttpHandler
+import com.hotels.styx.api.HttpHeaderNames.CHUNKED
+import com.hotels.styx.api.HttpHeaderNames.CONTENT_LENGTH
+import com.hotels.styx.api.HttpHeaderNames.HOST
+import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.HttpResponseStatus.BAD_REQUEST
+import com.hotels.styx.api.HttpResponseStatus.INTERNAL_SERVER_ERROR
+import com.hotels.styx.api.HttpResponseStatus.NOT_IMPLEMENTED
+import com.hotels.styx.api.HttpResponseStatus.OK
+import com.hotels.styx.api.HttpResponseStatus.UNAUTHORIZED
+import com.hotels.styx.api.Id.GENERIC_APP
+import com.hotels.styx.api.LiveHttpRequest
+import com.hotels.styx.api.LiveHttpRequest.get
+import com.hotels.styx.api.LiveHttpResponse
+import com.hotels.styx.api.LiveHttpResponse.response
+import com.hotels.styx.api.MeterRegistry
+import com.hotels.styx.api.MicrometerRegistry
+import com.hotels.styx.api.RequestCookie.requestCookie
+import com.hotels.styx.api.exceptions.NoAvailableHostsException
+import com.hotels.styx.api.exceptions.OriginUnreachableException
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.Origin.newOriginBuilder
+import com.hotels.styx.api.extension.RemoteHost
+import com.hotels.styx.api.extension.RemoteHost.remoteHost
+import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer
+import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.api.extension.service.StickySessionConfig
+import com.hotels.styx.client.retry.RetryNTimes
+import com.hotels.styx.metrics.CentralisedMetrics
+import io.kotest.core.spec.style.StringSpec
+import io.kotest.matchers.shouldBe
+import io.kotest.matchers.shouldNotBe
+import io.kotest.matchers.types.shouldBeInstanceOf
+import io.micrometer.core.instrument.simple.SimpleMeterRegistry
+import io.mockk.Called
+import io.mockk.CapturingSlot
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.slot
+import io.mockk.verify
+import io.mockk.verifyOrder
+import org.reactivestreams.Publisher
+import reactor.core.publisher.Mono
+import reactor.test.StepVerifier
+import java.util.Optional
+import kotlin.jvm.optionals.getOrNull
+
+class ReactorBackendServiceClientTest : StringSpec() {
+ private val context: HttpInterceptor.Context = mockk()
+ private val backendService: BackendService =
+ backendBuilderWithOrigins(SOME_ORIGIN.port())
+ .stickySessionConfig(STICKY_SESSION_CONFIG)
+ .build()
+ private lateinit var meterRegistry: MeterRegistry
+ private lateinit var metrics: CentralisedMetrics
+
+ init {
+ beforeTest {
+ meterRegistry = MicrometerRegistry(SimpleMeterRegistry())
+ metrics = CentralisedMetrics(meterRegistry)
+ }
+
+ "sendRequest routes the request to host selected by load balancer" {
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(
+ remoteHost(
+ SOME_ORIGIN,
+ toHandler(hostClient),
+ hostClient,
+ ),
+ ),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response!!.status() shouldBe OK
+ verify { hostClient.sendRequest(SOME_REQ, any()) }
+ }
+
+ "constructs retry context when load balancer does not find available origins" {
+ val retryContextSlot = slot()
+ val loadBalancerSlot = slot()
+ val lbPreferencesSlot = slot()
+ val retryPolicy =
+ mockRetryPolicy(
+ retryContextSlot,
+ loadBalancerSlot,
+ lbPreferencesSlot,
+ true,
+ true,
+ true,
+ )
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.empty(),
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = retryPolicy,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ retryContextSlot.isCaptured shouldBe true
+ retryContextSlot.captured.appId() shouldBe backendService.id()
+ retryContextSlot.captured.currentRetryCount() shouldBe 1
+ retryContextSlot.captured.lastException() shouldBe Optional.empty()
+
+ loadBalancerSlot.isCaptured shouldBe true
+ loadBalancerSlot.captured shouldNotBe null
+
+ lbPreferencesSlot.isCaptured shouldBe true
+ lbPreferencesSlot.captured.avoidOrigins() shouldBe emptyList()
+ lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.empty()
+
+ response.status() shouldBe OK
+ }
+
+ "retries when retry policy tells to retry" {
+ val retryContextSlot = slot()
+ val loadBalancerSlot = slot()
+ val lbPreferencesSlot = slot()
+ val retryPolicy =
+ mockRetryPolicy(
+ retryContextSlot,
+ loadBalancerSlot,
+ lbPreferencesSlot,
+ true,
+ false,
+ )
+
+ val hostClient1 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient2 = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)),
+ Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)),
+ ),
+ retryPolicy = retryPolicy,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ retryContextSlot.isCaptured shouldBe true
+ retryContextSlot.captured.appId() shouldBe backendService.id()
+ retryContextSlot.captured.currentRetryCount() shouldBe 1
+ retryContextSlot.captured.lastException().getOrNull().shouldBeInstanceOf()
+
+ loadBalancerSlot.isCaptured shouldBe true
+ loadBalancerSlot.captured shouldNotBe null
+
+ lbPreferencesSlot.isCaptured shouldBe true
+ lbPreferencesSlot.captured.avoidOrigins() shouldBe listOf(ORIGIN_1)
+ lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.empty()
+
+ response.status() shouldBe OK
+
+ verifyOrder {
+ hostClient1.sendRequest(SOME_REQ, any())
+ hostClient2.sendRequest(SOME_REQ, any())
+ }
+ }
+
+ "stops retries when retry policy tells to stop" {
+ val hostClient1 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient2 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient3 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")),
+ ),
+ )
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)),
+ Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)),
+ Optional.of(remoteHost(ORIGIN_3, toHandler(hostClient3), hostClient3)),
+ ),
+ retryPolicy = mockRetryPolicy(true, false),
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ StepVerifier.create(backendServiceClient.sendRequest(SOME_REQ, context))
+ .verifyError(OriginUnreachableException::class.java)
+
+ verifyOrder {
+ hostClient1.sendRequest(SOME_REQ, any())
+ hostClient2.sendRequest(SOME_REQ, any())
+ hostClient3.sendRequest(SOME_REQ, any()) wasNot Called
+ }
+ }
+
+ "retries at most 3 times" {
+ val hostClient1 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient2 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient3 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_3, RuntimeException("An error occurred")),
+ ),
+ )
+ val hostClient4 =
+ mockHostClient(
+ Mono.error(
+ OriginUnreachableException(ORIGIN_4, RuntimeException("An error occurred")),
+ ),
+ )
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)),
+ Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)),
+ Optional.of(remoteHost(ORIGIN_3, toHandler(hostClient3), hostClient3)),
+ Optional.of(remoteHost(ORIGIN_4, toHandler(hostClient4), hostClient4)),
+ ),
+ retryPolicy = mockRetryPolicy(true, true, true, true),
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ StepVerifier.create(backendServiceClient.sendRequest(SOME_REQ, context))
+ .verifyError(NoAvailableHostsException::class.java)
+
+ verifyOrder {
+ hostClient1.sendRequest(SOME_REQ, any())
+ hostClient2.sendRequest(SOME_REQ, any())
+ hostClient3.sendRequest(SOME_REQ, any())
+ hostClient4.sendRequest(SOME_REQ, any()) wasNot Called
+ }
+ }
+
+ "increments response status metrics for bad response" {
+ val hostClient = mockHostClient(Mono.just(response(BAD_REQUEST).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response.status() shouldBe BAD_REQUEST
+ verify { hostClient.sendRequest(SOME_REQ, any()) }
+ meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "400").counter() shouldNotBe null
+ }
+
+ "increments response status metrics for 401" {
+ val hostClient = mockHostClient(Mono.just(response(UNAUTHORIZED).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response.status() shouldBe UNAUTHORIZED
+ verify { hostClient.sendRequest(SOME_REQ, any()) }
+ meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "401").counter() shouldNotBe null
+ }
+
+ "increments response status metrics for 500" {
+ val hostClient = mockHostClient(Mono.just(response(INTERNAL_SERVER_ERROR).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response.status() shouldBe INTERNAL_SERVER_ERROR
+ verify { hostClient.sendRequest(SOME_REQ, any()) }
+ meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "500").counter() shouldNotBe null
+ }
+
+ "increments response status metrics for 501" {
+ val hostClient = mockHostClient(Mono.just(response(NOT_IMPLEMENTED).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response.status() shouldBe NOT_IMPLEMENTED
+ verify { hostClient.sendRequest(SOME_REQ, any()) }
+ meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "501").counter() shouldNotBe null
+ }
+
+ "removes bad content length" {
+ val hostClient =
+ mockHostClient(
+ Mono.just(
+ response(OK)
+ .addHeader(CONTENT_LENGTH, 50)
+ .addHeader(TRANSFER_ENCODING, CHUNKED)
+ .build(),
+ ),
+ )
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block()
+
+ response.status() shouldBe OK
+ response.contentLength().isPresent shouldBe false
+ response.header(TRANSFER_ENCODING).get() shouldBe "chunked"
+ }
+
+ "prefers sticky origins" {
+ val lbPreferencesSlot = slot()
+ val origin = originWithId("localhost:234", "App-X", "Origin-Y")
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = null,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ lbPreferencesSlot,
+ Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response =
+ Mono.from(
+ backendServiceClient
+ .sendRequest(
+ get("/foo")
+ .cookies(requestCookie("styx_origin_$GENERIC_APP", "Origin-Y"))
+ .build(),
+ context,
+ ),
+ ).block()
+
+ response.status() shouldBe OK
+
+ lbPreferencesSlot.isCaptured shouldBe true
+ lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y")
+ }
+
+ "prefers restricted origins" {
+ val lbPreferencesSlot = slot()
+ val origin = originWithId("localhost:234", "App-X", "Origin-Y")
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "restrictedOrigin",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ lbPreferencesSlot,
+ Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response =
+ Mono.from(
+ backendServiceClient
+ .sendRequest(
+ get("/foo")
+ .cookies(requestCookie("restrictedOrigin", "Origin-Y"))
+ .build(),
+ context,
+ ),
+ ).block()
+
+ response.status() shouldBe OK
+
+ lbPreferencesSlot.isCaptured shouldBe true
+ lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y")
+ }
+
+ "prefers restricted origins over sticky origins when both are configured" {
+ val lbPreferencesSlot = slot()
+ val origin = originWithId("localhost:234", "App-X", "Origin-Y")
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "restrictedOrigin",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ lbPreferencesSlot,
+ Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+
+ val response =
+ Mono.from(
+ backendServiceClient
+ .sendRequest(
+ get("/foo")
+ .cookies(
+ requestCookie("restrictedOrigin", "Origin-Y"),
+ requestCookie("styx_origin_$GENERIC_APP", "Origin-X"),
+ )
+ .build(),
+ context,
+ ),
+ ).block()
+
+ response.status() shouldBe OK
+
+ lbPreferencesSlot.isCaptured shouldBe true
+ lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y")
+ }
+
+ "host header is not over written when overrideHostHeader is false" {
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+ val origin = newOriginBuilder(INCOMING_HOSTNAME, 9090).applicationId("app").build()
+ val httpHandler: HttpHandler = mockk()
+
+ every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE)
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "someCookie",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(origin, httpHandler, hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = false,
+ )
+
+ Mono.from(backendServiceClient.sendRequest(TEST_REQUEST, context)).block()
+
+ verify { httpHandler.handle(TEST_REQUEST, context) }
+ }
+
+ "host header is not over written when overrideHostHeader is true" {
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+ val origin = newOriginBuilder(UPDATED_HOSTNAME, 9090).applicationId("app").build()
+ val httpHandler: HttpHandler = mockk()
+ val updatedRequestSlot = slot()
+
+ every { httpHandler.handle(capture(updatedRequestSlot), any()) } returns Eventual.of(TEST_RESPONSE)
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "someCookie",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(origin, httpHandler, hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = true,
+ )
+
+ Mono.from(backendServiceClient.sendRequest(TEST_REQUEST, context)).block()
+
+ updatedRequestSlot.captured.header(HOST).getOrNull() shouldBe UPDATED_HOSTNAME
+ }
+
+ "original requests is present in response when overrideHostHeader is false" {
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+ val origin = newOriginBuilder(INCOMING_HOSTNAME, 9090).applicationId("app").build()
+ val httpHandler: HttpHandler = mockk()
+
+ every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE)
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "someCookie",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(origin, httpHandler, hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = false,
+ )
+
+ StepVerifier.create(backendServiceClient.sendRequest(TEST_REQUEST, context))
+ .expectNextMatches { response ->
+ response.request().header(HOST).map { it == INCOMING_HOSTNAME }.orElse(false)
+ }
+ .verifyComplete()
+ }
+
+ "original requests is present in response when overrideHostHeader is true" {
+ val hostClient = mockHostClient(Mono.just(response(OK).build()))
+ val origin = newOriginBuilder(UPDATED_HOSTNAME, 9090).applicationId("app").build()
+ val httpHandler: HttpHandler = mockk()
+
+ every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE)
+
+ val backendServiceClient =
+ ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = "someCookie",
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT,
+ loadBalancer =
+ mockLoadBalancer(
+ Optional.of(remoteHost(origin, httpHandler, hostClient)),
+ ),
+ retryPolicy = DEFAULT_RETRY_POLICY,
+ metrics = metrics,
+ overrideHostHeader = true,
+ )
+
+ StepVerifier.create(backendServiceClient.sendRequest(TEST_REQUEST, context))
+ .expectNextMatches { response ->
+ response.request().header(HOST).map { it == UPDATED_HOSTNAME }.orElse(false)
+ }
+ .verifyComplete()
+ }
+ }
+
+ private fun toHandler(hostClient: ReactorHostHttpClient): HttpHandler =
+ HttpHandler { request: LiveHttpRequest, ctx: HttpInterceptor.Context? ->
+ Eventual(
+ hostClient.sendRequest(
+ request,
+ ctx,
+ ),
+ )
+ }
+
+ private fun mockLoadBalancer(first: Optional): LoadBalancer {
+ val lbStategy: LoadBalancer = mockk()
+ every { lbStategy.choose(any()) } returns first
+ return lbStategy
+ }
+
+ private fun mockLoadBalancer(
+ first: Optional,
+ vararg remoteHostWrappers: Optional,
+ ): LoadBalancer {
+ val lbStategy: LoadBalancer = mockk()
+ every { lbStategy.choose(any()) } returns first andThenMany remoteHostWrappers.toList()
+ return lbStategy
+ }
+
+ private fun mockLoadBalancer(
+ lbPreferencesSlot: CapturingSlot,
+ first: Optional,
+ vararg remoteHostWrappers: Optional,
+ ): LoadBalancer {
+ val lbStategy: LoadBalancer = mockk()
+ every { lbStategy.choose(capture(lbPreferencesSlot)) } returns first andThenMany remoteHostWrappers.toList()
+ return lbStategy
+ }
+
+ private fun mockRetryPolicy(
+ first: Boolean,
+ vararg outcomes: Boolean,
+ ): RetryPolicy {
+ val retryPolicy: RetryPolicy = mockk()
+ val retryOutcome: RetryPolicy.Outcome = mockk()
+ every { retryOutcome.shouldRetry() } returns first andThenMany outcomes.toList()
+ val retryOutcomes: List = outcomes.map { retryOutcome }.toList()
+ every { retryPolicy.evaluate(any(), any(), any()) } returns retryOutcome andThenMany retryOutcomes
+ return retryPolicy
+ }
+
+ private fun mockRetryPolicy(
+ retryContextSlot: CapturingSlot,
+ loadBalancerSlot: CapturingSlot,
+ lbPreferencesSlot: CapturingSlot,
+ first: Boolean,
+ vararg outcomes: Boolean,
+ ): RetryPolicy {
+ val retryPolicy: RetryPolicy = mockk()
+ val retryOutcome: RetryPolicy.Outcome = mockk()
+ every { retryOutcome.shouldRetry() } returns first andThenMany outcomes.toList()
+ val retryOutcomes: List = outcomes.map { retryOutcome }.toList()
+ every {
+ retryPolicy.evaluate(capture(retryContextSlot), capture(loadBalancerSlot), capture(lbPreferencesSlot))
+ } returns retryOutcome andThenMany retryOutcomes
+ return retryPolicy
+ }
+
+ private fun mockHostClient(responsePublisher: Publisher): ReactorHostHttpClient {
+ val hostClient: ReactorHostHttpClient = mockk()
+ every { hostClient.sendRequest(any(), any()) } returns responsePublisher
+ return hostClient
+ }
+
+ private fun backendBuilderWithOrigins(originPort: Int): BackendService.Builder {
+ return BackendService.Builder()
+ .origins(newOriginBuilder("localhost", originPort).build())
+ }
+
+ private fun originWithId(
+ host: String,
+ appId: String,
+ originId: String,
+ ): Origin? {
+ val hap = HostAndPort.fromString(host)
+ return newOriginBuilder(hap.host, hap.port)
+ .applicationId(appId)
+ .id(originId)
+ .build()
+ }
+
+ companion object {
+ private val SOME_ORIGIN = newOriginBuilder("localhost", 9090).applicationId(GENERIC_APP).build()
+ private val SOME_REQ = get("/some-req").build()
+ private val DEFAULT_RETRY_POLICY = RetryNTimes(3)
+
+ private val ORIGIN_1 = newOriginBuilder("localhost", 9091).applicationId("app").id("app-01").build()
+ private val ORIGIN_2 = newOriginBuilder("localhost", 9092).applicationId("app").id("app-02").build()
+ private val ORIGIN_3 = newOriginBuilder("localhost", 9093).applicationId("app").id("app-03").build()
+ private val ORIGIN_4 = newOriginBuilder("localhost", 9094).applicationId("app").id("app-04").build()
+
+ private val STICKY_SESSION_CONFIG = StickySessionConfig.stickySessionDisabled()
+
+ private val TEST_REQUEST = get("/test").header(HOST, "www.expedia.com").build()
+ private val TEST_RESPONSE = response(OK).build()
+ private const val INCOMING_HOSTNAME = "www.expedia.com"
+ private const val UPDATED_HOSTNAME = "host.domain.com"
+ }
+}
diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt
new file mode 100644
index 000000000..a0e296514
--- /dev/null
+++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt
@@ -0,0 +1,131 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.hotels.styx.api.HttpVersion
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.api.extension.service.ConnectionPoolSettings
+import com.hotels.styx.api.extension.service.Http2ConnectionPoolSettings
+import io.kotest.core.spec.style.StringSpec
+import io.kotest.matchers.shouldBe
+import reactor.netty.http.HttpProtocol.H2
+import reactor.netty.http.HttpProtocol.HTTP11
+import reactor.netty.http.client.Http2AllocationStrategy
+import reactor.netty.resources.ConnectionProvider
+import java.time.Duration
+import java.util.concurrent.TimeUnit.MILLISECONDS
+
+class ReactorConnectionPoolTest : StringSpec() {
+ private lateinit var connectionPool: ReactorConnectionPool
+
+ init {
+ "getConnectionProvider returns a ConnectionProvider for HTTP/1.1" {
+ connectionPool =
+ ReactorConnectionPool(
+ CONNECTION_POOL_SETTINGS,
+ HttpVersion.HTTP_1_1,
+ BACKEND_SERVICE,
+ )
+
+ val expected =
+ ConnectionProvider.builder(ORIGIN.id().toString())
+ .metrics(true)
+ .pendingAcquireTimeout(Duration.ofMillis(CONNECTION_POOL_SETTINGS.pendingConnectionTimeoutMillis().toLong()))
+ .disposeTimeout(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong()))
+ .maxIdleTime(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong()))
+ .maxLifeTime(Duration.ofSeconds(CONNECTION_POOL_SETTINGS.connectionExpirationSeconds()))
+ .maxConnections(CONNECTION_POOL_SETTINGS.maxConnectionsPerHost())
+ .pendingAcquireMaxCount(CONNECTION_POOL_SETTINGS.maxPendingConnectionsPerHost())
+ .lifo()
+ .build()
+
+ connectionPool.supportedHttpProtocols() shouldBe arrayOf(HTTP11)
+ connectionPool.pendingAcquireMaxCount shouldBe CONNECTION_POOL_SETTINGS.maxPendingConnectionsPerHost()
+ connectionPool.connectTimeoutMillis shouldBe CONNECTION_POOL_SETTINGS.connectTimeoutMillis()
+ connectionPool.maxConnections shouldBe CONNECTION_POOL_SETTINGS.maxConnectionsPerHost()
+ connectionPool.disposeTimeoutMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis()
+ connectionPool.maxIdleTimeMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis()
+ connectionPool.connectionExpirationSeconds shouldBe CONNECTION_POOL_SETTINGS.connectionExpirationSeconds()
+ assertConnectionProvider(connectionPool.getConnectionProvider(ORIGIN), expected)
+ }
+
+ "getConnectionProvider returns a ConnectionProvider for HTTP/2" {
+ connectionPool =
+ ReactorConnectionPool(
+ CONNECTION_POOL_SETTINGS,
+ HttpVersion.HTTP_2,
+ BACKEND_SERVICE,
+ )
+
+ val expected =
+ ConnectionProvider.builder(ORIGIN.id().toString())
+ .metrics(true)
+ .pendingAcquireTimeout(Duration.ofMillis(CONNECTION_POOL_SETTINGS.pendingConnectionTimeoutMillis().toLong()))
+ .disposeTimeout(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong()))
+ .maxIdleTime(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong()))
+ .maxLifeTime(Duration.ofSeconds(CONNECTION_POOL_SETTINGS.connectionExpirationSeconds()))
+ .maxConnections(CONNECTION_POOL_SETTINGS.maxConnectionsPerHost())
+ .pendingAcquireMaxCount(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxPendingStreamsPerHost!!)
+ .lifo()
+ .allocationStrategy(
+ Http2AllocationStrategy.builder()
+ .maxConcurrentStreams(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxStreamsPerConnection!!.toLong())
+ .maxConnections(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxConnections!!)
+ .minConnections(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().minConnections!!)
+ .build(),
+ )
+ .build()
+
+ connectionPool.supportedHttpProtocols() shouldBe arrayOf(H2, HTTP11)
+ connectionPool.pendingAcquireMaxCount shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxPendingStreamsPerHost
+ connectionPool.connectTimeoutMillis shouldBe CONNECTION_POOL_SETTINGS.connectTimeoutMillis()
+ connectionPool.maxConnections shouldBe CONNECTION_POOL_SETTINGS.maxConnectionsPerHost()
+ connectionPool.disposeTimeoutMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis()
+ connectionPool.maxIdleTimeMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis()
+ connectionPool.h2MaxConnections shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxConnections
+ connectionPool.h2MinConnections shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().minConnections
+ connectionPool.h2MaxConcurrentStreams shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxStreamsPerConnection
+ connectionPool.connectionExpirationSeconds shouldBe CONNECTION_POOL_SETTINGS.connectionExpirationSeconds()
+ assertConnectionProvider(connectionPool.getConnectionProvider(ORIGIN), expected)
+ }
+ }
+
+ private fun assertConnectionProvider(
+ actual: ConnectionProvider,
+ expected: ConnectionProvider,
+ ) {
+ actual.name() shouldBe expected.name()
+ actual.maxConnections() shouldBe expected.maxConnections()
+ actual.maxConnectionsPerHost() shouldBe expected.maxConnectionsPerHost()
+ }
+
+ companion object {
+ private val ORIGIN = Origin.newOriginBuilder("localhost", 886).build()
+ private val CONNECTION_POOL_SETTINGS: ConnectionPoolSettings =
+ ConnectionPoolSettings.Builder()
+ .connectTimeout(123, MILLISECONDS)
+ .pendingConnectionTimeout(543, MILLISECONDS)
+ .maxConnectionsPerHost(10)
+ .maxPendingConnectionsPerHost(87)
+ .http2ConnectionPoolSettings(Http2ConnectionPoolSettings(10, 4, 3, 6))
+ .build()
+ private val BACKEND_SERVICE: BackendService =
+ BackendService.newBackendServiceBuilder()
+ .responseTimeoutMillis(777)
+ .build()
+ }
+}
diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt
new file mode 100644
index 000000000..28d782a76
--- /dev/null
+++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt
@@ -0,0 +1,744 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.hotels.styx.api.Buffers
+import com.hotels.styx.api.HttpHeaderNames.CHUNKED
+import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.HttpRequest
+import com.hotels.styx.api.Id.GENERIC_APP
+import com.hotels.styx.api.MicrometerRegistry
+import com.hotels.styx.api.exceptions.ResponseTimeoutException
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.service.TlsSettings
+import com.hotels.styx.client.ssl.SslContextFactory
+import com.hotels.styx.common.logging.HttpRequestMessageLogger
+import com.hotels.styx.metrics.CentralisedMetrics
+import io.kotest.assertions.throwables.shouldNotThrowMessage
+import io.kotest.assertions.throwables.shouldThrowWithMessage
+import io.kotest.core.spec.style.StringSpec
+import io.kotest.matchers.shouldBe
+import io.micrometer.core.instrument.simple.SimpleMeterRegistry
+import io.mockk.clearAllMocks
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.verify
+import io.netty.channel.Channel
+import io.netty.handler.ssl.SslContextBuilder
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory
+import kotlinx.coroutines.repackaged.net.bytebuddy.utility.RandomString
+import okhttp3.Protocol
+import okhttp3.mockwebserver.MockResponse
+import okhttp3.mockwebserver.MockWebServer
+import okhttp3.tls.HandshakeCertificates
+import okhttp3.tls.HeldCertificate
+import reactor.core.publisher.Mono
+import reactor.netty.http.Http2SslContextSpec
+import reactor.netty.http.HttpProtocol.H2
+import reactor.netty.http.HttpProtocol.HTTP11
+import reactor.netty.http.client.Http2AllocationStrategy
+import reactor.netty.resources.ConnectionProvider
+import reactor.netty.resources.LoopResources
+import reactor.netty.tcp.SslProvider
+import reactor.test.StepVerifier
+import java.nio.charset.StandardCharsets
+import java.time.Duration
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.function.Consumer
+
+class ReactorHostHttpClientTest : StringSpec() {
+ private val context: HttpInterceptor.Context = mockk()
+ private val httpRequestMessageLogger: HttpRequestMessageLogger = mockk(relaxed = true)
+ private val connectionPool: ReactorConnectionPool = mockk(relaxed = true)
+ private lateinit var h2MockServerWithAlpn: MockWebServer
+ private lateinit var h2MockServerWithoutAlpn: MockWebServer
+ private lateinit var h11MockServer: MockWebServer
+ private lateinit var originStats: OriginStatsFactory
+ private lateinit var meterRegistry: MicrometerRegistry
+ private lateinit var metrics: CentralisedMetrics
+
+ init {
+ beforeSpec {
+ val localhostCertificate =
+ HeldCertificate.Builder()
+ .addSubjectAlternativeName(ORIGIN_1.host())
+ .addSubjectAlternativeName(ORIGIN_2.host())
+ .addSubjectAlternativeName(ORIGIN_3.host())
+ .build()
+ val serverCertificates =
+ HandshakeCertificates.Builder()
+ .heldCertificate(localhostCertificate)
+ .build()
+
+ h11MockServer = MockWebServer()
+ h11MockServer.useHttps(serverCertificates.sslSocketFactory(), false)
+ h11MockServer.protocols = listOf(Protocol.HTTP_1_1)
+ h11MockServer.start(ORIGIN_1.port())
+
+ h2MockServerWithAlpn = MockWebServer()
+ h2MockServerWithAlpn.useHttps(serverCertificates.sslSocketFactory(), false)
+ h2MockServerWithAlpn.protocols = listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)
+ h2MockServerWithAlpn.start(ORIGIN_2.port())
+
+ h2MockServerWithoutAlpn = MockWebServer()
+ h2MockServerWithoutAlpn.useHttps(serverCertificates.sslSocketFactory(), false)
+ h2MockServerWithoutAlpn.protocols = listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)
+ h2MockServerWithoutAlpn.protocolNegotiationEnabled = false
+ h2MockServerWithoutAlpn.start(ORIGIN_3.port())
+ }
+
+ beforeTest {
+ meterRegistry = MicrometerRegistry(SimpleMeterRegistry())
+ metrics = CentralisedMetrics(meterRegistry)
+ originStats = OriginStatsFactory.CachingOriginStatsFactory(metrics)
+ }
+
+ afterTest { clearAllMocks() }
+
+ afterSpec {
+ h11MockServer.shutdown()
+ h2MockServerWithAlpn.shutdown()
+ h2MockServerWithoutAlpn.shutdown()
+ LOOP_RESOURCES.dispose()
+ }
+
+ "HttpClient is created with connectionProvider passed" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().connectionProvider() shouldBe connectionProvider
+ }
+
+ "init sets responseTimeoutMillis" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().responseTimeout()!!.toMillis() shouldBe 1000
+ }
+
+ "init sets protocols" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11)
+ every { connectionPool.isHttp2() } returns true
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_2,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = H2_SSL_PROVIDER,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().protocols() shouldBe arrayOf(H2, HTTP11)
+ }
+
+ "init sets decoder configs" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().decoder().maxInitialLineLength() shouldBe HTTP_CONFIG.maxInitialLength()
+ reactorHostClient.configuration().decoder().maxHeaderSize() shouldBe HTTP_CONFIG.maxHeadersSize()
+ }
+
+ "init disables default retry" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().isRetryDisabled shouldBe true
+ }
+
+ "init sets event loop resources" {
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().loopResources() shouldBe LOOP_RESOURCES
+ }
+
+ "sendRequest calls an ALPN enabled http2 origin and returns a response using http2" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h2MockServerWithAlpn.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11)
+ every { connectionPool.isHttp2() } returns true
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_2,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = H2_SSL_PROVIDER,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ it.header("x-http2-stream-id").isPresent shouldBe true
+ }
+ .verifyComplete()
+ }
+
+ "sendRequest calls an ALPN disabled http2 origin and returns a response which fallbacks to http1.1" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h2MockServerWithoutAlpn.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11)
+ every { connectionPool.isHttp2() } returns true
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_3,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = H2_SSL_PROVIDER,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ it.header("x-http2-stream-id").isEmpty shouldBe true
+ }
+ .verifyComplete()
+ }
+
+ "sendRequest throws errors when hitting response timeout during a request to an http2 origin" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ .setBodyDelay(5, TimeUnit.SECONDS)
+ h2MockServerWithAlpn.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11)
+ every { connectionPool.isHttp2() } returns true
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_2,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = H2_SSL_PROVIDER,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ val response = Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block()
+ StepVerifier.create(response!!.body())
+ .verifyError(ResponseTimeoutException::class.java)
+ }
+
+ "sendRequest calls an ALPN enabled http2 origin and returns a response using http1.1" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h2MockServerWithAlpn.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_2,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ it.header("x-http2-stream-id").isEmpty shouldBe true
+ }
+ .verifyComplete()
+ }
+
+ "sendRequest calls an ALPN disabled http2 origin and returns a response using http1.1" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h2MockServerWithoutAlpn.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_3,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ it.header("x-http2-stream-id").isEmpty shouldBe true
+ }
+ .verifyComplete()
+ }
+
+ "sendRequest calls an http1.1 origin and returns a response" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h11MockServer.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ it.header("x-http2-stream-id").isEmpty shouldBe true
+ }
+ .verifyComplete()
+ }
+
+ "sendRequest throws errors when hitting response timeout during a request to an http1.1 origin" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ .setBodyDelay(5, TimeUnit.SECONDS)
+ h11MockServer.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ val response = Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block()
+ StepVerifier.create(response!!.body())
+ .verifyError(ResponseTimeoutException::class.java)
+ }
+
+ "closes connection pool when the host http client is closed" {
+ val mockConnectionProvider: ConnectionProvider = mockk(relaxed = true)
+
+ every { connectionPool.getConnectionProvider(any()) } returns mockConnectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = mockk(relaxed = true),
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = H2_SSL_PROVIDER,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = mockk(relaxed = true),
+ originStatsFactory = mockk(relaxed = true),
+ metrics = mockk(relaxed = true),
+ eventLoopGroup = mockk(relaxed = true),
+ )
+
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11)
+ every { connectionPool.isHttp2() } returns true
+
+ reactorHostClient.close()
+
+ verify(exactly = 1) {
+ mockConnectionProvider.dispose()
+ }
+ }
+
+ "sendRequest calls an http1.1 origin with POST and returns a response with body" {
+ val requestBody = RandomString.make(100_000)
+ val responseBody = RandomString.make(100_000)
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody(responseBody)
+ h11MockServer.enqueue(mockResponse)
+
+ var request =
+ HttpRequest.post("/")
+ .header(TRANSFER_ENCODING, CHUNKED)
+ .body(requestBody, StandardCharsets.UTF_8)
+ .build()
+ .stream()
+
+ val requestBodySize = AtomicInteger()
+ request =
+ request.newBuilder()
+ .body { body -> body.doOnEach { requestBodySize.addAndGet(it.get()?.size() ?: 0) } }
+ .build()
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = H11_SSL_HANDLER,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ StepVerifier.create(reactorHostClient.sendRequest(request, context))
+ .assertNext {
+ it.status().code() shouldBe 200
+ it.header("X-Id").orElse(null) shouldBe "123"
+ Mono.from(it.body()).map(Buffers::toByteBuf)
+ .map { chunk -> chunk.toString(StandardCharsets.UTF_8) }
+ .doOnNext { body -> body shouldBe responseBody }
+ .subscribe()
+ }
+ .verifyComplete()
+
+ requestBodySize.get() shouldBe 100_000
+ }
+
+ "SSL handler for HTTP/1.1 is added to channel if h11SslHandler is passed" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h11MockServer.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val sslHandler = mockk>()
+
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = sslHandler,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ runCatching {
+ Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block()
+ }
+
+ verify {
+ sslHandler.accept(any())
+ }
+ }
+
+ "SSL provider for HTTP/2 is used if h2SslProvider is passed" {
+ val mockResponse =
+ MockResponse()
+ .setResponseCode(200)
+ .addHeader("X-Id", "123")
+ .setBody("xyz")
+ h11MockServer.enqueue(mockResponse)
+
+ every { connectionPool.getConnectionProvider(any()) } returns connectionProvider
+ every { connectionPool.connectTimeoutMillis } returns 500
+ every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11)
+
+ val sslProvider = mockk>(relaxed = true)
+
+ runCatching {
+ val reactorHostClient =
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_2,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = sslProvider,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+
+ reactorHostClient.configuration().sslProvider() shouldBe sslProvider.accept(SslProvider.builder())
+ }
+ }
+
+ "Only 1 ssl context can be passed" {
+ val sslProvider = mockk>()
+ val sslHandler = mockk>()
+
+ shouldThrowWithMessage("There can only be one type of SSL context") {
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = sslProvider,
+ h11SslHandler = sslHandler,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+ }
+ }
+
+ "Supports http without any sslContext" {
+ shouldNotThrowMessage("There can only be one type of SSL context") {
+ ReactorHostHttpClient.create(
+ origin = ORIGIN_1,
+ connectionPool = connectionPool,
+ httpConfig = HTTP_CONFIG,
+ h2SslProvider = null,
+ h11SslHandler = null,
+ responseTimeoutMillis = 1000,
+ httpRequestMessageLogger = httpRequestMessageLogger,
+ originStatsFactory = originStats,
+ metrics = metrics,
+ eventLoopGroup = LOOP_RESOURCES,
+ )
+ }
+ }
+ }
+
+ companion object {
+ private val ORIGIN_1 =
+ Origin.newOriginBuilder("localhost", 58887)
+ .applicationId(GENERIC_APP)
+ .id("app-01")
+ .build()
+ private val ORIGIN_2 =
+ Origin.newOriginBuilder("localhost", 58888)
+ .applicationId(GENERIC_APP)
+ .id("app-01")
+ .build()
+ private val ORIGIN_3 =
+ Origin.newOriginBuilder("localhost", 58889)
+ .applicationId(GENERIC_APP)
+ .id("app-01")
+ .build()
+ private var mockRequest = HttpRequest.get("/").build().stream()
+ private val connectionProvider =
+ ConnectionProvider.builder("app-01")
+ .allocationStrategy(
+ Http2AllocationStrategy.builder()
+ .maxConcurrentStreams(10)
+ .maxConnections(1)
+ .build(),
+ )
+ .evictionPredicate { _, connectionMetadata ->
+ connectionMetadata.lifeTime() >= 300000 || connectionMetadata.idleTime() >= 60000
+ }
+ .pendingAcquireTimeout(Duration.ofMillis(5000))
+ .pendingAcquireMaxCount(1000)
+ .disposeTimeout(Duration.ofMillis(1000))
+ .build()
+ private val LOOP_RESOURCES = LoopResources.create("app-01", 1, 1, true)
+ private val HTTP_CONFIG = HttpConfig.defaultHttpConfig()
+ private val H11_SSL_HANDLER: Consumer =
+ Consumer { channel ->
+ SslProvider.builder()
+ .sslContext(SslContextFactory.get(TlsSettings.Builder().build()))
+ .build()
+ .addSslHandler(channel, null, false)
+ }
+ private val H2_SSL_PROVIDER: Consumer =
+ Consumer { sslContextSpec ->
+ sslContextSpec.sslContext(
+ Http2SslContextSpec.forClient()
+ .configure { sslContextBuilder: SslContextBuilder ->
+ sslContextBuilder.trustManager(
+ InsecureTrustManagerFactory.INSTANCE,
+ )
+ }
+ .sslContext(),
+ )
+ }
+ }
+}
diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt
new file mode 100644
index 000000000..b7f7c99f8
--- /dev/null
+++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt
@@ -0,0 +1,528 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.client
+
+import com.google.common.eventbus.EventBus
+import com.hotels.styx.api.Id.GENERIC_APP
+import com.hotels.styx.api.Id.id
+import com.hotels.styx.api.MeterRegistry
+import com.hotels.styx.api.Metrics
+import com.hotels.styx.api.MicrometerRegistry
+import com.hotels.styx.api.extension.Origin.newOriginBuilder
+import com.hotels.styx.api.extension.OriginsChangeListener
+import com.hotels.styx.api.extension.OriginsSnapshot
+import com.hotels.styx.client.OriginsInventory.OriginState.ACTIVE
+import com.hotels.styx.client.OriginsInventory.OriginState.DISABLED
+import com.hotels.styx.client.healthcheck.OriginHealthStatusMonitor
+import com.hotels.styx.client.origincommands.DisableOrigin
+import com.hotels.styx.client.origincommands.EnableOrigin
+import com.hotels.styx.metrics.CentralisedMetrics
+import io.kotest.core.spec.style.StringSpec
+import io.kotest.matchers.shouldBe
+import io.micrometer.core.instrument.Gauge
+import io.micrometer.core.instrument.Tags
+import io.micrometer.core.instrument.simple.SimpleMeterRegistry
+import io.mockk.clearAllMocks
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.verify
+import io.netty.handler.ssl.SslContext
+import reactor.netty.resources.LoopResources
+
+class ReactorOriginsInventoryTest : StringSpec() {
+ private lateinit var inventory: OriginsInventory
+ private lateinit var meterRegistry: MeterRegistry
+ private val eventBus: EventBus = mockk(relaxed = true)
+ private val monitor: OriginHealthStatusMonitor = mockk(relaxed = true)
+ private val connectionPool: ReactorConnectionPool = mockk(relaxed = true)
+ private val hostClientFactory: ReactorHostHttpClient.Factory = mockk(relaxed = true)
+ private val originStatsFactory: OriginStatsFactory = mockk(relaxed = true)
+ private val eventLoopGroup: LoopResources = mockk()
+ private val sslContext: SslContext = mockk()
+
+ init {
+ beforeTest {
+ meterRegistry = MicrometerRegistry(SimpleMeterRegistry())
+ inventory =
+ ReactorOriginsInventory.Builder(GENERIC_APP)
+ .eventLoopGroup(eventLoopGroup)
+ .eventBus(eventBus)
+ .originHealthMonitor(monitor)
+ .hostClientFactory(hostClientFactory)
+ .metrics(CentralisedMetrics(meterRegistry))
+ .connectionPool(connectionPool)
+ .originStatsFactory(originStatsFactory)
+ .sslContext(sslContext)
+ .build()
+ }
+
+ afterTest {
+ clearAllMocks()
+ }
+
+ "start monitoring new origins" {
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+
+ inventory.originCount(ACTIVE) shouldBe 2
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0
+ gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.monitor(setOf(ORIGIN_1))
+ monitor.monitor(setOf(ORIGIN_2))
+ eventBus.post(any())
+ }
+ }
+
+ "updates on origin port number change" {
+ val originV1 =
+ newOriginBuilder("acme.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+ val originV2 =
+ newOriginBuilder("acme.com", 443)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+
+ inventory.setOrigins(originV1)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue("generic-app", "acme-01") shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.monitor(setOf(originV1))
+ eventBus.post(any())
+ }
+
+ inventory.setOrigins(originV2)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue("generic-app", "acme-01") shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(originV1))
+ monitor.monitor(setOf(originV2))
+ }
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "updates on origin hostname change" {
+ val originV1 =
+ newOriginBuilder("acme01.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+ val originV2 =
+ newOriginBuilder("acme02.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+
+ inventory.setOrigins(originV1)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue("generic-app", "acme-01") shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.monitor(setOf(originV1))
+ eventBus.post(any())
+ }
+
+ inventory.setOrigins(originV2)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue("generic-app", "acme-01") shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(originV1))
+ monitor.monitor(setOf(originV2))
+ }
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "stops and restarts monitoring modified origin" {
+ val originV1 =
+ newOriginBuilder("acme01.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+ val originV2 =
+ newOriginBuilder("acme02.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+
+ inventory.setOrigins(originV1)
+
+ verify(exactly = 1) {
+ monitor.monitor(setOf(originV1))
+ }
+
+ inventory.setOrigins(originV2)
+
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(originV1))
+ monitor.monitor(setOf(originV2))
+ }
+ }
+
+ "shuts connection provider on origin change" {
+ val originV1 =
+ newOriginBuilder("acme01.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+ val originV2 =
+ newOriginBuilder("acme02.com", 80)
+ .applicationId(GENERIC_APP)
+ .id("acme-01")
+ .build()
+
+ val hostClient1: ReactorHostHttpClient = mockk(relaxed = true)
+ val hostClient2: ReactorHostHttpClient = mockk(relaxed = true)
+ every {
+ hostClientFactory.create(
+ originV1, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient1
+ every {
+ hostClientFactory.create(
+ originV2, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient2
+
+ inventory.setOrigins(originV1)
+
+ verify(exactly = 1) {
+ hostClientFactory.create(
+ originV1, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ }
+
+ inventory.setOrigins(originV2)
+
+ verify(exactly = 1) {
+ hostClientFactory.create(
+ originV2, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ hostClient1.close()
+ }
+ }
+
+ "ignores unchanged origins" {
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+
+ inventory.originCount(ACTIVE) shouldBe 2
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0
+ gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.monitor(setOf(ORIGIN_1))
+ monitor.monitor(setOf(ORIGIN_2))
+ eventBus.post(any())
+ }
+
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+ inventory.originCount(ACTIVE) shouldBe 2
+ verify(exactly = 1) {
+ monitor.monitor(setOf(ORIGIN_1))
+ monitor.monitor(setOf(ORIGIN_2))
+ eventBus.post(any())
+ }
+ verify(exactly = 0) {
+ monitor.stopMonitoring(setOf(ORIGIN_1))
+ monitor.stopMonitoring(setOf(ORIGIN_2))
+ }
+ }
+
+ "stop monitoring on origins removal" {
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+
+ inventory.originCount(ACTIVE) shouldBe 2
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0
+ gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.monitor(setOf(ORIGIN_1))
+ monitor.monitor(setOf(ORIGIN_2))
+ eventBus.post(any())
+ }
+
+ inventory.setOrigins(ORIGIN_2)
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe null
+ gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(ORIGIN_1))
+ }
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "shuts connection provider on origin removal" {
+ val hostClient1: ReactorHostHttpClient = mockk(relaxed = true)
+ val hostClient2: ReactorHostHttpClient = mockk(relaxed = true)
+ every {
+ hostClientFactory.create(
+ ORIGIN_1, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient1
+ every {
+ hostClientFactory.create(
+ ORIGIN_2, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient2
+
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+
+ verify(exactly = 2) {
+ hostClientFactory.create(
+ any(), any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ }
+
+ inventory.setOrigins(ORIGIN_2)
+
+ verify(exactly = 1) {
+ hostClient1.close()
+ }
+ }
+
+ "does not disable origins not belonging to the app" {
+ inventory.setOrigins(ORIGIN_1)
+
+ verify(exactly = 1) {
+ eventBus.post(any())
+ }
+
+ inventory.onCommand(DisableOrigin(id("some-other-app"), ORIGIN_1.id()))
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ verify(exactly = 1) {
+ eventBus.post(any())
+ }
+ }
+
+ "does not enable origins not belonging to the app" {
+ inventory.setOrigins(ORIGIN_1)
+
+ verify(exactly = 1) {
+ eventBus.post(any())
+ }
+
+ inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id()))
+ inventory.onCommand(EnableOrigin(id("some-other-app"), ORIGIN_1.id()))
+
+ inventory.originCount(ACTIVE) shouldBe 0
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "removes from active set and stops health check monitoring on disabling an origin" {
+ inventory.setOrigins(ORIGIN_1)
+
+ inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id()))
+
+ inventory.originCount(ACTIVE) shouldBe 0
+ inventory.originCount(DISABLED) shouldBe 1
+
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe -1.0
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(ORIGIN_1))
+ }
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "removes from inactive set and stops health check monitoring on disabling an origin" {
+ inventory.setOrigins(ORIGIN_1)
+
+ inventory.originUnhealthy(ORIGIN_1)
+ inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id()))
+
+ inventory.originCount(ACTIVE) shouldBe 0
+ inventory.originCount(DISABLED) shouldBe 1
+
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe -1.0
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(ORIGIN_1))
+ }
+ verify(exactly = 3) {
+ eventBus.post(any())
+ }
+ }
+
+ "re-initiates health check monitoring on enabling an origin" {
+ inventory.setOrigins(ORIGIN_1)
+
+ inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id()))
+ inventory.onCommand(EnableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id()))
+
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0
+ verify(exactly = 2) {
+ monitor.monitor(setOf(ORIGIN_1))
+ }
+ verify(exactly = 3) {
+ eventBus.post(any())
+ }
+ }
+
+ "updates on removing unhealthy origins from active set" {
+ inventory.setOrigins(ORIGIN_1)
+ inventory.originCount(ACTIVE) shouldBe 1
+
+ inventory.originUnhealthy(ORIGIN_1)
+
+ inventory.originCount(ACTIVE) shouldBe 0
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "updates on adding healthy origins back to active set" {
+ inventory.setOrigins(ORIGIN_1)
+ inventory.originCount(ACTIVE) shouldBe 1
+
+ inventory.originUnhealthy(ORIGIN_1)
+ inventory.originHealthy(ORIGIN_1)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0
+ verify(exactly = 3) {
+ eventBus.post(any())
+ }
+ }
+
+ "repeatedly reporting healthy does not affect current active origins" {
+ inventory.setOrigins(ORIGIN_1)
+ inventory.originCount(ACTIVE) shouldBe 1
+
+ inventory.originHealthy(ORIGIN_1)
+ inventory.originHealthy(ORIGIN_1)
+ inventory.originHealthy(ORIGIN_1)
+
+ inventory.originCount(ACTIVE) shouldBe 1
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0
+ verify(exactly = 1) {
+ eventBus.post(any())
+ }
+ }
+
+ "repeatedly reporting unhealthy does not affect current active origins" {
+ inventory.setOrigins(ORIGIN_1)
+ inventory.originCount(ACTIVE) shouldBe 1
+
+ inventory.originUnhealthy(ORIGIN_1)
+ inventory.originUnhealthy(ORIGIN_1)
+ inventory.originUnhealthy(ORIGIN_1)
+
+ inventory.originCount(ACTIVE) shouldBe 0
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ }
+
+ "announces listeners on origin state change" {
+ val listener: OriginsChangeListener = mockk()
+
+ inventory.addOriginsChangeListener(listener)
+ inventory.setOrigins(ORIGIN_1)
+ inventory.originUnhealthy(ORIGIN_1)
+
+ verify(exactly = 2) {
+ listener.originsChanged(any())
+ }
+ }
+
+ "registers to event bus when created" {
+ verify(exactly = 1) {
+ eventBus.register(inventory)
+ }
+ }
+
+ "stops monitoring and unregisters when closed" {
+ val hostClient1: ReactorHostHttpClient = mockk(relaxed = true)
+ val hostClient2: ReactorHostHttpClient = mockk(relaxed = true)
+ every {
+ hostClientFactory.create(
+ ORIGIN_1, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient1
+ every {
+ hostClientFactory.create(
+ ORIGIN_2, any(), any(), any(), any(), any(),
+ any(), any(), any(), any(), any(),
+ )
+ } returns hostClient2
+
+ inventory.setOrigins(ORIGIN_1, ORIGIN_2)
+
+ inventory.close()
+
+ gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe null
+ gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe null
+ verify(exactly = 2) {
+ eventBus.post(any())
+ }
+ verify(exactly = 1) {
+ monitor.stopMonitoring(setOf(ORIGIN_1))
+ monitor.stopMonitoring(setOf(ORIGIN_2))
+ hostClient1.close()
+ hostClient2.close()
+ eventBus.unregister(inventory)
+ }
+ }
+ }
+
+ private fun gaugeValue(
+ appId: String,
+ originId: String,
+ ): Double? {
+ val name = "proxy.client.originHealthStatus"
+ val tags = Tags.of(Metrics.APPID_TAG, appId, Metrics.ORIGINID_TAG, originId)
+ return gauge(name, tags)?.value()
+ }
+
+ private fun gauge(
+ name: String,
+ tags: Tags,
+ ): Gauge? = meterRegistry.find(name).tags(tags).gauge()
+
+ companion object {
+ private val ORIGIN_1 =
+ newOriginBuilder("localhost", 8001)
+ .applicationId(GENERIC_APP)
+ .id("app-01")
+ .build()
+ private val ORIGIN_2 =
+ newOriginBuilder("localhost", 8002)
+ .applicationId(GENERIC_APP)
+ .id("app-02")
+ .build()
+ }
+}
diff --git a/components/common/pom.xml b/components/common/pom.xml
index f49c0556d..88fbe1e07 100644
--- a/components/common/pom.xml
+++ b/components/common/pom.xml
@@ -38,6 +38,11 @@
reactor-core
+
+ io.projectreactor.netty
+ reactor-netty
+
+
org.hdrhistogram
HdrHistogram
diff --git a/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt b/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt
new file mode 100644
index 000000000..124fa99f1
--- /dev/null
+++ b/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt
@@ -0,0 +1,31 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.ext
+
+import com.hotels.styx.api.LiveHttpRequest
+import com.hotels.styx.api.LiveHttpResponse
+
+inline fun LiveHttpRequest.newRequest(block: LiveHttpRequest.Transformer.() -> Unit): LiveHttpRequest {
+ val transformer = newBuilder()
+ block(transformer)
+ return transformer.build()
+}
+
+inline fun LiveHttpResponse.newResponse(block: LiveHttpResponse.Transformer.() -> Unit): LiveHttpResponse {
+ val transformer = newBuilder()
+ block(transformer)
+ return transformer.build()
+}
diff --git a/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt b/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt
new file mode 100644
index 000000000..b0619905b
--- /dev/null
+++ b/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt
@@ -0,0 +1,46 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.metrics
+
+import com.hotels.styx.api.extension.Origin
+import io.micrometer.core.instrument.Meter
+import io.micrometer.core.instrument.Tags
+import io.micrometer.core.instrument.config.MeterFilter
+import reactor.netty.Metrics.NAME
+
+class ReactorNettyMeterFilter(private val origin: Origin) : MeterFilter {
+ override fun map(id: Meter.Id): Meter.Id =
+ if (id.name.startsWith(REACTOR_NETTY_PREFIX) &&
+ id.getTag(NAME)?.endsWith("${origin.id()}") == true
+ ) {
+ id.withTags(
+ Tags.of(
+ APP_ID,
+ origin.applicationId().toString(),
+ ORIGIN_ID,
+ origin.id().toString(),
+ ),
+ )
+ } else {
+ id
+ }
+
+ companion object {
+ private const val REACTOR_NETTY_PREFIX = "reactor.netty"
+ private const val APP_ID = "appId"
+ private const val ORIGIN_ID = "originId"
+ }
+}
diff --git a/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt
new file mode 100644
index 000000000..0f170324f
--- /dev/null
+++ b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt
@@ -0,0 +1,120 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.proxy
+
+import com.hotels.styx.Environment
+import com.hotels.styx.api.configuration.Configuration
+import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer
+import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.client.BackendServiceClient
+import com.hotels.styx.client.OriginRestrictionLoadBalancingStrategy
+import com.hotels.styx.client.OriginStatsFactory
+import com.hotels.styx.client.OriginsInventory
+import com.hotels.styx.client.ReactorBackendServiceClient
+import com.hotels.styx.client.RewriteRuleset
+import com.hotels.styx.client.loadbalancing.strategies.BusyActivitiesStrategy
+import com.hotels.styx.client.retry.RetryNTimes
+import com.hotels.styx.client.stickysession.StickySessionLoadBalancingStrategy
+import com.hotels.styx.serviceproviders.ServiceProvision
+import org.slf4j.LoggerFactory.getLogger
+
+/**
+ * A BackendServiceClientFactory implementation for creating {@link ReactorBackendServiceClient}
+ */
+class ReactorBackendServiceClientFactory(private val environment: Environment) : BackendServiceClientFactory {
+ override fun createClient(
+ backendService: BackendService,
+ originsInventory: OriginsInventory,
+ originStatsFactory: OriginStatsFactory,
+ ): BackendServiceClient {
+ val styxConfig: Configuration = environment.configuration()
+ val originRestrictionCookie = styxConfig["originRestrictionCookie"].orElse(null)
+ val stickySessionEnabled = backendService.stickySessionConfig().stickySessionEnabled()
+ val retryPolicy = loadRetryPolicy(styxConfig) ?: defaultRetryPolicy()
+ val configuredLbStrategy = loadLoadBalancer(styxConfig, originsInventory) ?: BusyActivitiesStrategy(originsInventory)
+ originsInventory.addOriginsChangeListener(configuredLbStrategy)
+
+ val loadBalancingStrategy =
+ decorateLoadBalancer(
+ configuredLbStrategy,
+ stickySessionEnabled,
+ originsInventory,
+ originRestrictionCookie,
+ )
+ return ReactorBackendServiceClient(
+ id = backendService.id(),
+ rewriteRuleset = RewriteRuleset(backendService.rewrites()),
+ originsRestrictionCookieName = originRestrictionCookie,
+ stickySessionConfig = backendService.stickySessionConfig(),
+ originIdHeader = environment.configuration().styxHeaderConfig().originIdHeaderName(),
+ loadBalancer = loadBalancingStrategy,
+ retryPolicy = retryPolicy,
+ metrics = environment.centralisedMetrics(),
+ overrideHostHeader = backendService.isOverrideHostHeader(),
+ )
+ }
+
+ private fun loadRetryPolicy(styxConfig: Configuration) =
+ ServiceProvision.loadRetryPolicy(
+ styxConfig,
+ environment,
+ "retrypolicy.policy.factory",
+ RetryPolicy::class.java,
+ ).orElse(null)
+
+ private fun loadLoadBalancer(
+ styxConfig: Configuration,
+ originsInventory: OriginsInventory,
+ ) = ServiceProvision.loadLoadBalancer(
+ styxConfig,
+ environment,
+ "loadBalancing.strategy.factory",
+ LoadBalancer::class.java,
+ originsInventory,
+ ).orElse(null)
+
+ private fun decorateLoadBalancer(
+ configuredLbStrategy: LoadBalancer,
+ stickySessionEnabled: Boolean,
+ originsInventory: OriginsInventory,
+ originRestrictionCookie: String?,
+ ): LoadBalancer =
+ if (stickySessionEnabled) {
+ StickySessionLoadBalancingStrategy(originsInventory, configuredLbStrategy)
+ } else if (originRestrictionCookie == null) {
+ LOGGER.info("originRestrictionCookie not specified - origin restriction disabled")
+ configuredLbStrategy
+ } else {
+ LOGGER.info(
+ """
+ originRestrictionCookie specified as $originRestrictionCookie
+ - origin restriction will apply when this cookie is sent
+ """.trimIndent(),
+ )
+ OriginRestrictionLoadBalancingStrategy(originsInventory, configuredLbStrategy)
+ }
+
+ companion object {
+ private val LOGGER = getLogger(this::class.java)
+
+ private fun defaultRetryPolicy(): RetryPolicy {
+ val retryOnce = RetryNTimes(1)
+ LOGGER.warn("No configured retry policy found, using $retryOnce")
+ return retryOnce
+ }
+ }
+}
diff --git a/components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt
similarity index 99%
rename from components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt
rename to components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt
index 44c43e97c..8a3b26bc9 100644
--- a/components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt
+++ b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt
@@ -1,5 +1,5 @@
/*
- Copyright (C) 2013-2023 Expedia Inc.
+ Copyright (C) 2013-2024 Expedia Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt b/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt
new file mode 100644
index 000000000..b87fd405a
--- /dev/null
+++ b/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt
@@ -0,0 +1,202 @@
+/*
+ Copyright (C) 2013-2024 Expedia Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package com.hotels.styx.proxy
+
+import com.hotels.styx.Environment
+import com.hotels.styx.StyxConfig
+import com.hotels.styx.api.HttpInterceptor
+import com.hotels.styx.api.HttpResponseStatus.OK
+import com.hotels.styx.api.Id.GENERIC_APP
+import com.hotels.styx.api.Id.id
+import com.hotels.styx.api.LiveHttpRequest.get
+import com.hotels.styx.api.LiveHttpResponse
+import com.hotels.styx.api.LiveHttpResponse.response
+import com.hotels.styx.api.MicrometerRegistry
+import com.hotels.styx.api.RequestCookie.requestCookie
+import com.hotels.styx.api.configuration.Configuration.MapBackedConfiguration
+import com.hotels.styx.api.extension.Origin
+import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancingMetric
+import com.hotels.styx.api.extension.service.BackendService
+import com.hotels.styx.api.extension.service.BackendService.Companion.newBackendServiceBuilder
+import com.hotels.styx.api.extension.service.StickySessionConfig.newStickySessionConfigBuilder
+import com.hotels.styx.client.OriginStatsFactory
+import com.hotels.styx.client.OriginStatsFactory.CachingOriginStatsFactory
+import com.hotels.styx.client.ReactorBackendServiceClient
+import com.hotels.styx.client.ReactorConnectionPool
+import com.hotels.styx.client.ReactorHostHttpClient
+import com.hotels.styx.client.ReactorOriginsInventory
+import io.kotest.core.spec.style.StringSpec
+import io.kotest.matchers.shouldBe
+import io.kotest.matchers.types.shouldBeInstanceOf
+import io.micrometer.core.instrument.simple.SimpleMeterRegistry
+import io.mockk.every
+import io.mockk.mockk
+import io.netty.handler.ssl.SslContext
+import reactor.core.publisher.Flux
+import reactor.core.publisher.Mono
+import reactor.netty.resources.LoopResources
+import kotlin.jvm.optionals.getOrNull
+
+class ReactorBackendServiceClientFactoryTest : StringSpec() {
+ private lateinit var environment: Environment
+ private lateinit var backendService: BackendService
+ private val eventLoopGroup: LoopResources = mockk()
+ private val originStatsFactory: OriginStatsFactory = mockk()
+ private val sslContext: SslContext = mockk()
+ private val connectionPool: ReactorConnectionPool = mockk()
+ private val context: HttpInterceptor.Context = mockk()
+
+ init {
+ beforeTest {
+ environment = Environment.Builder().registry(MicrometerRegistry(SimpleMeterRegistry())).build()
+ backendService =
+ newBackendServiceBuilder()
+ .origins(Origin.newOriginBuilder("localhost", 8081).build())
+ .build()
+
+ every { connectionPool.isHttp2() } returns false
+ }
+
+ "createClient" {
+ val originsInventory =
+ ReactorOriginsInventory.Builder(backendService.id())
+ .metrics(environment.centralisedMetrics())
+ .eventLoopGroup(eventLoopGroup)
+ .initialOrigins(backendService.origins())
+ .originStatsFactory(originStatsFactory)
+ .sslContext(sslContext)
+ .connectionPool(connectionPool)
+ .build()
+ val client =
+ ReactorBackendServiceClientFactory(environment)
+ .createClient(backendService, originsInventory, originStatsFactory)
+
+ client.shouldBeInstanceOf()
+ }
+
+ "uses the origin specified in the sticky session cookie" {
+ val backendService =
+ newBackendServiceBuilder()
+ .origins(
+ Origin.newOriginBuilder("localhost", 9091).id("x").build(),
+ Origin.newOriginBuilder("localhost", 9092).id("y").build(),
+ Origin.newOriginBuilder("localhost", 9093).id("z").build(),
+ )
+ .stickySessionConfig(newStickySessionConfigBuilder().enabled(true).build())
+ .build()
+ val originsInventory =
+ ReactorOriginsInventory.Builder(backendService.id())
+ .metrics(environment.centralisedMetrics())
+ .eventLoopGroup(eventLoopGroup)
+ .initialOrigins(backendService.origins())
+ .hostClientFactory { origin, _, _, _, _, _, _, _, _, _, _ ->
+ if (origin.id() == id("x")) {
+ hostClient(response(OK).header("X-Origin-Id", "x").build())
+ } else if (origin.id() == id("y")) {
+ hostClient(response(OK).header("X-Origin-Id", "y").build())
+ } else {
+ hostClient(response(OK).header("X-Origin-Id", "z").build())
+ }
+ }
+ .originStatsFactory(originStatsFactory)
+ .sslContext(sslContext)
+ .connectionPool(connectionPool)
+ .build()
+
+ val client =
+ ReactorBackendServiceClientFactory(environment)
+ .createClient(backendService, originsInventory, CachingOriginStatsFactory(environment.centralisedMetrics()))
+
+ val requestX = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("x").toString())).build()
+ val requestY = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("y").toString())).build()
+ val requestZ = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("z").toString())).build()
+
+ val responseX = Mono.from(client.sendRequest(requestX, context)).block()
+ val responseY = Mono.from(client.sendRequest(requestY, context)).block()
+ val responseZ = Mono.from(client.sendRequest(requestZ, context)).block()
+
+ responseX!!.header("X-Origin-Id").getOrNull() shouldBe "x"
+ responseY!!.header("X-Origin-Id").getOrNull() shouldBe "y"
+ responseZ!!.header("X-Origin-Id").getOrNull() shouldBe "z"
+ }
+
+ "uses the origin specified in the origin restriction cookie" {
+ val config = MapBackedConfiguration()
+ config["originRestrictionCookie"] = ORIGINS_RESTRICTION_COOKIE
+
+ environment =
+ Environment.Builder()
+ .registry(MicrometerRegistry(SimpleMeterRegistry()))
+ .configuration(StyxConfig(config))
+ .build()
+
+ val backendService =
+ newBackendServiceBuilder()
+ .origins(
+ Origin.newOriginBuilder("localhost", 9091).id("x").build(),
+ Origin.newOriginBuilder("localhost", 9092).id("y").build(),
+ Origin.newOriginBuilder("localhost", 9093).id("z").build(),
+ )
+ .build()
+ val originsInventory =
+ ReactorOriginsInventory.Builder(backendService.id())
+ .metrics(environment.centralisedMetrics())
+ .eventLoopGroup(eventLoopGroup)
+ .initialOrigins(backendService.origins())
+ .hostClientFactory { origin, _, _, _, _, _, _, _, _, _, _ ->
+ if (origin.id() == id("x")) {
+ hostClient(response(OK).header("X-Origin-Id", "x").build())
+ } else if (origin.id() == id("y")) {
+ hostClient(response(OK).header("X-Origin-Id", "y").build())
+ } else {
+ hostClient(response(OK).header("X-Origin-Id", "z").build())
+ }
+ }
+ .originStatsFactory(originStatsFactory)
+ .sslContext(sslContext)
+ .connectionPool(connectionPool)
+ .build()
+
+ val client =
+ ReactorBackendServiceClientFactory(environment)
+ .createClient(backendService, originsInventory, CachingOriginStatsFactory(environment.centralisedMetrics()))
+
+ val requestX = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("x").toString())).build()
+ val requestY = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("y").toString())).build()
+ val requestZ = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("z").toString())).build()
+
+ val responseX = Mono.from(client.sendRequest(requestX, context)).block()
+ val responseY = Mono.from(client.sendRequest(requestY, context)).block()
+ val responseZ = Mono.from(client.sendRequest(requestZ, context)).block()
+
+ responseX!!.header("X-Origin-Id").getOrNull() shouldBe "x"
+ responseY!!.header("X-Origin-Id").getOrNull() shouldBe "y"
+ responseZ!!.header("X-Origin-Id").getOrNull() shouldBe "z"
+ }
+ }
+
+ private fun hostClient(response: LiveHttpResponse): ReactorHostHttpClient {
+ val mockClient: ReactorHostHttpClient = mockk()
+ every { mockClient.sendRequest(any(), any()) } returns Flux.just(response)
+ every { mockClient.loadBalancingMetric() } returns LoadBalancingMetric(1)
+ return mockClient
+ }
+
+ companion object {
+ private const val ORIGINS_RESTRICTION_COOKIE = "styx-origins-restriction"
+ private val STICKY_COOKIE = "styx_origin_$GENERIC_APP"
+ }
+}
diff --git a/pom.xml b/pom.xml
index 5e676f445..5e46f2925 100644
--- a/pom.xml
+++ b/pom.xml
@@ -108,7 +108,7 @@
4.1.106.Final
1.0.4
- 2023.0.2
+ 2023.0.3
3.0.5
1.14.11
@@ -131,6 +131,7 @@
2.2
5.10.0
1.13.9
+ 4.12.0
5.10.1
3.0.9
1.17.0
@@ -540,6 +541,20 @@
test
+
+ com.squareup.okhttp3
+ mockwebserver
+ ${okhttp.version}
+ test
+
+
+
+ com.squareup.okhttp3
+ okhttp-tls
+ ${okhttp.version}
+ test
+
+