diff --git a/components/client/pom.xml b/components/client/pom.xml index 33fa8f660..792192974 100644 --- a/components/client/pom.xml +++ b/components/client/pom.xml @@ -32,6 +32,11 @@ linux-aarch_64 + + io.projectreactor.netty + reactor-netty + + com.hotels.styx styx-api @@ -78,16 +83,19 @@ org.hamcrest hamcrest + test org.scalatest scalatest_${scala.version} + test org.mockito mockito-core + test @@ -96,6 +104,24 @@ test + + io.mockk + mockk-jvm + test + + + + com.squareup.okhttp3 + mockwebserver + test + + + + com.squareup.okhttp3 + okhttp-tls + test + + diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt new file mode 100644 index 000000000..c8762e9a0 --- /dev/null +++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorBackendServiceClient.kt @@ -0,0 +1,330 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.hotels.styx.api.HttpHeaderNames.CONTENT_LENGTH +import com.hotels.styx.api.HttpHeaderNames.HOST +import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.HttpMethod +import com.hotels.styx.api.Id +import com.hotels.styx.api.LiveHttpRequest +import com.hotels.styx.api.LiveHttpResponse +import com.hotels.styx.api.exceptions.NoAvailableHostsException +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.RemoteHost +import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer +import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy +import com.hotels.styx.api.extension.service.StickySessionConfig +import com.hotels.styx.client.StyxHeaderConfig.ORIGIN_ID_DEFAULT +import com.hotels.styx.client.retry.RetryNTimes +import com.hotels.styx.client.stickysession.StickySessionCookie.newStickySessionCookie +import com.hotels.styx.client.stickysession.StickySessionLoadBalancingStrategy +import com.hotels.styx.ext.newRequest +import com.hotels.styx.ext.newResponse +import com.hotels.styx.metrics.CentralisedMetrics +import org.reactivestreams.Publisher +import reactor.core.publisher.Mono +import java.util.Objects.nonNull +import java.util.Optional + +/** + * A configurable HTTP client with integration of Reactor Netty client + */ +class ReactorBackendServiceClient( + private val id: Id, + private val rewriteRuleset: RewriteRuleset, + private val originsRestrictionCookieName: String?, + private val stickySessionConfig: StickySessionConfig, + private val originIdHeader: CharSequence, + private val loadBalancer: LoadBalancer, + private val retryPolicy: RetryPolicy, + private val metrics: CentralisedMetrics, + private val overrideHostHeader: Boolean, +) : BackendServiceClient { + override fun sendRequest( + request: LiveHttpRequest, + context: HttpInterceptor.Context, + ): Publisher = sendRequest(rewriteUrl(request), emptyList(), 0, context) + + private fun sendRequest( + request: LiveHttpRequest, + previousOrigins: List, + attempt: Int, + context: HttpInterceptor.Context, + ): Publisher { + if (attempt >= MAX_RETRY_ATTEMPTS) { + return Mono.error(NoAvailableHostsException(id)) + } + val remoteHost = selectOrigin(request) + return if (remoteHost.isPresent) { + val host = remoteHost.get() + val updatedRequest = shouldOverrideHostHeader(host, request) + val newPreviousOrigins = previousOrigins.toMutableList() + newPreviousOrigins.add(host) + Mono.from(host.hostClient().handle(updatedRequest, context)) + .doOnNext { recordErrorStatusMetrics(it) } + .map { response -> + response.addStickySessionIdentifier(host.origin()) + .removeUnexpectedResponseBody(updatedRequest) + .removeRedundantContentLengthHeader() + .addOriginId(host.id()) + .let { LiveHttpResponse.Builder(it).request(updatedRequest).build() } + } + .onErrorResume { cause -> + val retryContext = RetryPolicyContext(id, attempt + 1, cause, updatedRequest, previousOrigins) + retry(updatedRequest, retryContext, newPreviousOrigins, attempt + 1, cause, context) + } + } else { + val retryContext = RetryPolicyContext(id, attempt + 1, null, request, previousOrigins) + retry(request, retryContext, previousOrigins, attempt + 1, NoAvailableHostsException(id), context) + } + } + + private fun recordErrorStatusMetrics(response: LiveHttpResponse) { + val statusCode = response.status().code() + if (statusCode.isErrorStatus()) { + metrics.proxy.client.errorResponseFromOriginByStatus(statusCode).increment() + } + } + + private fun Int.isErrorStatus() = this >= 400 + + private fun bodyNeedsToBeRemoved( + request: LiveHttpRequest, + response: LiveHttpResponse, + ) = isHeadRequest(request) || isBodilessResponse(response) + + private fun responseWithoutBody(response: LiveHttpResponse) = + response.newResponse { + header(CONTENT_LENGTH, 0) + removeHeader(TRANSFER_ENCODING) + removeBody() + } + + private fun isBodilessResponse(response: LiveHttpResponse): Boolean = + when (val code = response.status().code()) { + 204, 304 -> true + else -> code / 100 == 1 + } + + private fun isHeadRequest(request: LiveHttpRequest): Boolean = request.method() == HttpMethod.HEAD + + private fun shouldOverrideHostHeader( + host: RemoteHost, + request: LiveHttpRequest, + ): LiveHttpRequest = + if (overrideHostHeader && !host.origin().host().isNullOrBlank()) { + request.newRequest { header(HOST, host.origin().host()) } + } else { + request + } + + private fun LiveHttpResponse.addOriginId(originId: Id): LiveHttpResponse = + newResponse { + header(originIdHeader, originId) + } + + private fun retry( + request: LiveHttpRequest, + retryContext: RetryPolicyContext, + previousOrigins: List, + attempt: Int, + cause: Throwable, + context: HttpInterceptor.Context, + ): Mono { + val lbContext: LoadBalancer.Preferences = + object : LoadBalancer.Preferences { + override fun preferredOrigins(): Optional = Optional.empty() + + override fun avoidOrigins(): List = previousOrigins.map { it.origin() } + } + return if (retryPolicy.evaluate(retryContext, loadBalancer, lbContext).shouldRetry()) { + Mono.from(sendRequest(request, previousOrigins, attempt, context)) + } else { + Mono.error(cause) + } + } + + private fun LiveHttpResponse.removeUnexpectedResponseBody(request: LiveHttpRequest) = + if (bodyNeedsToBeRemoved(request, this)) { + responseWithoutBody(this) + } else { + this + } + + private fun LiveHttpResponse.removeRedundantContentLengthHeader() = + if (contentLength().isPresent && chunked()) { + newResponse { + removeHeader(CONTENT_LENGTH) + } + } else { + this + } + + private fun selectOrigin(rewrittenRequest: LiveHttpRequest): Optional { + val preferences = + object : LoadBalancer.Preferences { + override fun preferredOrigins(): Optional { + return if (nonNull(originsRestrictionCookieName)) { + rewrittenRequest.cookie(originsRestrictionCookieName) + .map { it.value() } + .or { rewrittenRequest.cookie("styx_origin_$id").map { it.value() } } + } else { + rewrittenRequest.cookie("styx_origin_$id").map { it.value() } + } + } + + override fun avoidOrigins(): List = emptyList() + } + return loadBalancer.choose(preferences) + } + + private fun LiveHttpResponse.addStickySessionIdentifier(origin: Origin): LiveHttpResponse = + if (loadBalancer is StickySessionLoadBalancingStrategy) { + val maxAge = stickySessionConfig.stickySessionTimeoutSeconds() + newResponse { + addCookies(newStickySessionCookie(id, origin.id(), maxAge)) + } + } else { + this + } + + private fun rewriteUrl(request: LiveHttpRequest): LiveHttpRequest = rewriteRuleset.rewrite(request) + + private class RetryPolicyContext( + private val appId: Id, + private val retryCount: Int, + private val lastException: Throwable?, + private val request: LiveHttpRequest, + private val previouslyUsedOrigins: Iterable, + ) : RetryPolicy.Context { + override fun appId(): Id = appId + + override fun currentRetryCount(): Int = retryCount + + override fun lastException(): Optional = Optional.ofNullable(lastException) + + override fun currentRequest(): LiveHttpRequest = request + + override fun previousOrigins(): Iterable = previouslyUsedOrigins + + override fun toString(): String = + buildString { + append("appId", appId) + append(", retryCount", retryCount) + append(", lastException", lastException) + append(", request", request.url()) + append(", previouslyUsedOrigins", previouslyUsedOrigins) + } + + fun hosts(): String = hosts(previouslyUsedOrigins) + + companion object { + private fun hosts(origins: Iterable): String = + origins.asSequence().map { it.origin().hostAndPortString() }.joinToString(", ") + } + } + + override fun toString(): String = + buildString { + append("id", id) + append(", stickySessionConfig", stickySessionConfig) + append(", retryPolicy", retryPolicy) + append(", rewriteRuleset", rewriteRuleset) + append(", loadBalancingStrategy", loadBalancer) + append(", overrideHostHeader", overrideHostHeader) + } + + /** + * A builder for [ReactorBackendServiceClient]. + */ + class Builder(val id: Id) { + private var originStatsFactory: OriginStatsFactory? = null + private var loadBalancer: LoadBalancer? = null + private var metrics: CentralisedMetrics? = null + private var rewriteRuleset: RewriteRuleset = RewriteRuleset(emptyList()) + private var originsRestrictionCookieName: String? = null + private var stickySessionConfig: StickySessionConfig = StickySessionConfig.stickySessionDisabled() + private var originIdHeader: CharSequence = ORIGIN_ID_DEFAULT + private var retryPolicy: RetryPolicy = RetryNTimes(3) + private var overrideHostHeader: Boolean = false + + fun rewriteRules(rewriteRuleset: RewriteRuleset) = + apply { + this.rewriteRuleset = rewriteRuleset + } + + fun originStatsFactory(originStatsFactory: OriginStatsFactory) = + apply { + this.originStatsFactory = originStatsFactory + } + + fun originsRestrictionCookieName(originsRestrictionCookieName: String?) = + apply { + this.originsRestrictionCookieName = originsRestrictionCookieName + } + + fun stickySessionConfig(stickySessionConfig: StickySessionConfig) = + apply { + this.stickySessionConfig = stickySessionConfig + } + + fun originIdHeader(originIdHeader: CharSequence) = + apply { + this.originIdHeader = originIdHeader + } + + fun loadBalancer(loadBalancer: LoadBalancer) = + apply { + this.loadBalancer = loadBalancer + } + + fun retryPolicy(retryPolicy: RetryPolicy) = + apply { + this.retryPolicy = retryPolicy + } + + fun metrics(metrics: CentralisedMetrics) = + apply { + this.metrics = metrics + } + + fun overrideHostHeader(overrideHostHeader: Boolean) = + apply { + this.overrideHostHeader = overrideHostHeader + } + + fun build(): ReactorBackendServiceClient = + ReactorBackendServiceClient( + id, + rewriteRuleset, + originsRestrictionCookieName, + stickySessionConfig, + originIdHeader, + checkNotNull(loadBalancer) { "loadBalancer is required" }, + retryPolicy, + checkNotNull(metrics) { "metrics is required" }, + overrideHostHeader, + ) + } + + companion object { + private const val MAX_RETRY_ATTEMPTS = 3 + + @JvmStatic fun newHttpClientBuilder(backendServiceId: Id): Builder = Builder(backendServiceId) + } +} diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt new file mode 100644 index 000000000..ff35dd6d9 --- /dev/null +++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorConnectionPool.kt @@ -0,0 +1,107 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.hotels.styx.api.HttpVersion +import com.hotels.styx.api.HttpVersion.HTTP_1_1 +import com.hotels.styx.api.HttpVersion.HTTP_2 +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.api.extension.service.ConnectionPoolSettings +import com.hotels.styx.api.extension.service.ConnectionPoolSettings.defaultConnectionPoolSettings +import reactor.netty.http.HttpProtocol +import reactor.netty.http.HttpProtocol.H2 +import reactor.netty.http.HttpProtocol.HTTP11 +import reactor.netty.http.client.Http2AllocationStrategy +import reactor.netty.resources.ConnectionProvider +import java.time.Duration + +class ReactorConnectionPool( + connectionPoolSettings: ConnectionPoolSettings = defaultConnectionPoolSettings(), + private val httpVersion: HttpVersion = HTTP_1_1, + private val backendService: BackendService = BackendService.Builder().build(), +) { + val pendingAcquireTimeoutMillis: Int = connectionPoolSettings.pendingConnectionTimeoutMillis() + val connectTimeoutMillis: Int = connectionPoolSettings.connectTimeoutMillis() + val maxConnections = connectionPoolSettings.maxConnectionsPerHost() + val maxIdleTimeMillis: Int = backendService.responseTimeoutMillis() + val h2MaxConnections = connectionPoolSettings.http2ConnectionPoolSettings().maxConnections ?: DEFAULT_H2_MAX_CONNECTIONS + val h2MinConnections = connectionPoolSettings.http2ConnectionPoolSettings().minConnections ?: DEFAULT_H2_MIN_CONNECTIONS + val h2MaxConcurrentStreams = + connectionPoolSettings.http2ConnectionPoolSettings().maxStreamsPerConnection + ?: DEFAULT_H2_MAX_STREAMS_PER_CONNECTION + val pendingAcquireMaxCount: Int = + if (HTTP_2 == httpVersion) { + connectionPoolSettings.http2ConnectionPoolSettings().maxPendingStreamsPerHost + ?: DEFAULT_H2_MAX_PENDING_STREAMS_PER_HOST + } else { + connectionPoolSettings.maxPendingConnectionsPerHost() + } + val connectionExpirationSeconds: Long = connectionPoolSettings.connectionExpirationSeconds() + + // TODO: needs some investigation on how inflight connections should be shut down + val disposeTimeoutMillis: Int = backendService.responseTimeoutMillis() + + private val connectionProviderBuilder: ConnectionProvider.Builder = initiateConnectionProvider() + + fun getConnectionProvider(origin: Origin): ConnectionProvider = connectionProviderBuilder.name(origin.id().toString()).build() + + fun supportedHttpProtocols(): Array = + if (HTTP_2 == httpVersion) { + arrayOf(H2, HTTP11) + } else { + arrayOf(HTTP11) + } + + fun isHttp2(): Boolean = HTTP_2 == httpVersion + + private fun initiateConnectionProvider(): ConnectionProvider.Builder { + // Configuration details: https://projectreactor.io/docs/netty/release/reference/index.html#_connection_pool_2 + val builder = + ConnectionProvider.builder(backendService.id().toString()) + .metrics(true) + .pendingAcquireTimeout(Duration.ofMillis(pendingAcquireTimeoutMillis.toLong())) + .disposeTimeout(Duration.ofMillis(disposeTimeoutMillis.toLong())) + .maxIdleTime(Duration.ofMillis(maxIdleTimeMillis.toLong())) + .maxLifeTime(Duration.ofSeconds(connectionExpirationSeconds)) + .evictInBackground(Duration.ofSeconds(DEFAULT_CONNECTION_EVICTION_FREQUENCY_SECONDS)) + .maxConnections(maxConnections) + .pendingAcquireMaxCount(pendingAcquireMaxCount) + .lifo() + + if (HTTP_2 == httpVersion) { + // Increasing the number of max/min connections could result in + // unnecessary connection creation and cause overheads on concurrency + val allocationStrategy = + Http2AllocationStrategy.builder() + .maxConcurrentStreams(h2MaxConcurrentStreams.toLong()) + .maxConnections(h2MaxConnections) + .minConnections(h2MinConnections) + .build() + return builder.allocationStrategy(allocationStrategy) + } + + return builder + } + + companion object { + private const val DEFAULT_H2_MAX_CONNECTIONS = 10 + private const val DEFAULT_H2_MIN_CONNECTIONS = 2 + private const val DEFAULT_H2_MAX_STREAMS_PER_CONNECTION = 1024 + private const val DEFAULT_H2_MAX_PENDING_STREAMS_PER_HOST = 200 + private const val DEFAULT_CONNECTION_EVICTION_FREQUENCY_SECONDS = 30L + } +} diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt new file mode 100644 index 000000000..256563716 --- /dev/null +++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorHostHttpClient.kt @@ -0,0 +1,397 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.hotels.styx.api.Buffers +import com.hotels.styx.api.ByteStream +import com.hotels.styx.api.HttpHeaderNames.HOST +import com.hotels.styx.api.HttpHeaders +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.HttpResponseStatus.statusWithCode +import com.hotels.styx.api.HttpVersion.httpVersion +import com.hotels.styx.api.LiveHttpRequest +import com.hotels.styx.api.LiveHttpResponse +import com.hotels.styx.api.exceptions.OriginUnreachableException +import com.hotels.styx.api.exceptions.ResponseTimeoutException +import com.hotels.styx.api.exceptions.TransportLostException +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancingMetric +import com.hotels.styx.client.ReactorHostHttpClient.ErrorType.REQUEST +import com.hotels.styx.client.ReactorHostHttpClient.ErrorType.RESPONSE +import com.hotels.styx.client.applications.OriginStats +import com.hotels.styx.client.connectionpool.LatencyTiming.finishRequestTiming +import com.hotels.styx.client.connectionpool.LatencyTiming.startResponseTiming +import com.hotels.styx.client.connectionpool.MaxPendingConnectionTimeoutException +import com.hotels.styx.client.connectionpool.MaxPendingConnectionsExceededException +import com.hotels.styx.common.logging.HttpRequestMessageLogger +import com.hotels.styx.metrics.CentralisedMetrics +import com.hotels.styx.metrics.Deleter +import com.hotels.styx.metrics.ReactorNettyMeterFilter +import com.hotels.styx.metrics.TimerMetric +import io.netty.channel.Channel +import io.netty.channel.ChannelOption.CONNECT_TIMEOUT_MILLIS +import io.netty.channel.ChannelOption.SO_KEEPALIVE +import io.netty.channel.ChannelOption.TCP_NODELAY +import io.netty.handler.codec.DecoderException +import io.netty.handler.codec.http.HttpMethod +import io.netty.handler.ssl.SslHandshakeTimeoutException +import io.netty.handler.timeout.ReadTimeoutException +import io.netty.resolver.dns.DnsNameResolverException +import org.reactivestreams.Publisher +import org.slf4j.LoggerFactory.getLogger +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import reactor.netty.ByteBufFlux +import reactor.netty.Connection +import reactor.netty.Metrics.REGISTRY +import reactor.netty.NettyInbound +import reactor.netty.http.client.HttpClient +import reactor.netty.http.client.HttpClientConfig +import reactor.netty.http.client.HttpClientResponse +import reactor.netty.http.client.PrematureCloseException +import reactor.netty.internal.shaded.reactor.pool.PoolAcquirePendingLimitException +import reactor.netty.internal.shaded.reactor.pool.PoolAcquireTimeoutException +import reactor.netty.resources.ConnectionProvider +import reactor.netty.resources.LoopResources +import reactor.netty.tcp.SslProvider.SslContextSpec +import java.net.SocketException +import java.net.UnknownHostException +import java.time.Duration +import java.util.concurrent.atomic.AtomicInteger +import java.util.function.Consumer +import javax.annotation.concurrent.ThreadSafe + +/** + * A Reactor HTTP Client for proxying to an individual origin host. + */ +@ThreadSafe +class ReactorHostHttpClient private constructor( + private val origin: Origin, + private val connectionPool: ReactorConnectionPool, + private val httpConfig: HttpConfig, + private val h2SslProvider: Consumer?, + private val h11SslHandler: Consumer?, + private val responseTimeoutMillis: Int, + private val httpRequestMessageLogger: HttpRequestMessageLogger?, + private val originStatsFactory: OriginStatsFactory, + private val metrics: CentralisedMetrics, + private val eventLoopGroup: LoopResources, + private val doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)? = null, +) : HostHttpClient { + private val httpClient: HttpClient + private val ongoingRequestCount: AtomicInteger = AtomicInteger() + private val pendingAcquireCount: AtomicInteger = AtomicInteger() + private val stats: Stats = Stats() + private val connectionProvider: ConnectionProvider = connectionPool.getConnectionProvider(origin) + + @Volatile + private var ongoingRequestsDeleter: Deleter? = null + + @Volatile + private var originStats: OriginStats? = null + + init { + require(h2SslProvider == null || h11SslHandler == null) { + "There can only be one type of SSL context" + } + + REGISTRY.config().meterFilter(ReactorNettyMeterFilter(origin)) + httpClient = HttpClient.create(connectionProvider).init() + httpClient.warmup().block() + } + + override fun sendRequest( + request: LiveHttpRequest, + context: HttpInterceptor.Context?, + ): Publisher = + httpClient.addListeners(request, context) + .sendRequest(request, context) + .addStyxResponseListeners(request) + + override fun close() { + connectionProvider.dispose() + ongoingRequestsDeleter?.delete() + } + + override fun loadBalancingMetric(): LoadBalancingMetric = LoadBalancingMetric(stats.ongoingRequestCount()) + + internal fun configuration(): HttpClientConfig = httpClient.configuration() + + private fun HttpClient.init(): HttpClient = + this.host(origin.host()) + .port(origin.port()) + .option(CONNECT_TIMEOUT_MILLIS, connectionPool.connectTimeoutMillis) + .option(TCP_NODELAY, true) + .option(SO_KEEPALIVE, true) + .protocol(*connectionPool.supportedHttpProtocols()) + .addSslContext() + .compress(httpConfig.compress()) + .httpResponseDecoder { + it.maxInitialLineLength(httpConfig.maxInitialLength()) + .maxHeaderSize(httpConfig.maxHeadersSize()) + .maxChunkSize(httpConfig.maxChunkSize()) + } + .responseTimeout(Duration.ofMillis(responseTimeoutMillis.toLong())) + .metrics(true) { _ -> origin.id().toString() } + .runOn(eventLoopGroup) + .disableRetry(true) + .resolver { + it.cacheMaxTimeToLive(Duration.ofSeconds(DNS_MAX_CACHE_TIME_TO_LIVE_SECONDS)) + .cacheNegativeTimeToLive(Duration.ofSeconds(DNS_NEGATIVE_TIME_TO_LIVE_SECONDS)) + } + .doOnChannelInit { _, _, _ -> + if (ongoingRequestsDeleter == null) { + ongoingRequestsDeleter = + metrics.proxy.client.ongoingRequests(origin) + .register { loadBalancingMetric().ongoingActivities() } + } + } + .doOnConnect { pendingAcquireCount.incrementAndGet() } + .doOnConnected { pendingAcquireCount.decrementAndGet() } + + private fun HttpClient.addSslContext(): HttpClient = + if (connectionPool.isHttp2() && h2SslProvider != null) { + secure(h2SslProvider) + } else if (!connectionPool.isHttp2() && h11SslHandler != null) { + doOnChannelInit { _, channel, _ -> + h11SslHandler.accept(channel) + } + } else { + this + } + + private fun HttpClient.addListeners( + request: LiveHttpRequest, + context: HttpInterceptor.Context?, + ): HttpClient { + var requestLatencyTiming: TimerMetric.Stopper? = null + var timeToFirstByteTiming: TimerMetric.Stopper? = null + + if (originStats == null) { + originStats = originStatsFactory.originStats(origin) + } + + return this + .doOnRequest { _, _ -> + httpRequestMessageLogger?.logRequest(request, origin) + requestLatencyTiming = originStats!!.requestLatencyTimer().startTiming() + timeToFirstByteTiming = originStats!!.timeToFirstByteTimer().startTiming() + } + .doAfterRequest { _, _ -> + // Request timing started at Styx Server first receiving requests + finishRequestTiming(context) + requestLatencyTiming?.stop() + } + .doOnRequestError { _, throwable -> + logError(REQUEST, request, throwable) + requestLatencyTiming?.stop() + } + .doOnResponse { response, _ -> + updateHttpResponseCounters(originStats!!, response.status().code()) + timeToFirstByteTiming?.stop() + doOnResponse?.invoke(response, context) + } + .doAfterResponseSuccess { _, _ -> + // Response timing stopping at Styx Server returning response to end users + startResponseTiming(metrics, context) + } + .doOnResponseError { _, throwable -> + logError(RESPONSE, request, throwable) + timeToFirstByteTiming?.stop() + } + } + + private fun HttpClient.sendRequest( + request: LiveHttpRequest, + context: HttpInterceptor.Context?, + ): Mono { + context?.add(ORIGINID_CONTEXT_KEY, origin.id()) + return this + .headers { + request.headers().forEach { name, value -> + it.add(name, value) + } + if (!request.header(HOST).isPresent) { + it.add(HOST, origin.hostAndPortString()) + } + } + .request(HttpMethod(request.method().name())) + .uri(request.url().toString()) + .send( + ByteBufFlux.fromInbound( + Flux.from(request.body()) + .map(Buffers::toByteBuf) + .flatMapSequential { Flux.just(it) }, + ).doOnCancel { originStats?.requestCancelled() }, + ) + .responseConnection { res: HttpClientResponse, conn: Connection -> + toStyxResponse(res, conn.inbound()) + } + .single() + .onErrorMap { toStyxExceptions(it) } + } + + private fun toStyxResponse( + response: HttpClientResponse, + nettyInbound: NettyInbound, + ): Mono { + val headersBuilder = + HttpHeaders.Builder().apply { + response.responseHeaders().forEach { add(it.key, it.value) } + } + val body = toStyxByteStream(nettyInbound) + return Mono.just( + LiveHttpResponse.Builder() + .status(statusWithCode(response.status().code(), response.status().toString())) + .version(httpVersion(response.version().text())) + .headers(headersBuilder.build()) + .body(body) + .build(), + ) + } + + private fun toStyxByteStream(nettyInbound: NettyInbound): ByteStream = + ByteStream( + nettyInbound + .receive() + .retain() + .map(Buffers::fromByteBuf) + .doOnCancel { originStats?.requestCancelled() } + .onErrorMap { toStyxExceptions(it) }, + ) + + private fun Mono.addStyxResponseListeners(request: LiveHttpRequest): Mono = + this.doOnNext { httpRequestMessageLogger?.logResponse(request, it) } + .doOnSubscribe { ongoingRequestCount.incrementAndGet() } + .doFinally { ongoingRequestCount.decrementAndGet() } + + private fun logError( + type: ErrorType, + request: LiveHttpRequest, + throwable: Throwable, + ) = LOGGER.error( + """ + |Error Handling ${type.name} request=$request exceptionClass=${throwable.javaClass.name} exceptionMessage=\"${throwable.message}\" + """.trimMargin(), + ) + + private fun updateHttpResponseCounters( + originStats: OriginStats, + statusCode: Int, + ) { + if (isServerError(statusCode)) { + originStats.requestError() + } else { + originStats.requestSuccess() + } + originStats.responseWithStatusCode(statusCode) + } + + private fun isServerError(status: Int) = status >= 500 + + private fun toStyxExceptions(throwable: Throwable): Throwable = + when (throwable) { + is PoolAcquireTimeoutException -> + MaxPendingConnectionTimeoutException(origin, connectionPool.pendingAcquireTimeoutMillis) + + is PoolAcquirePendingLimitException -> { + pendingAcquireCount.decrementAndGet() + MaxPendingConnectionsExceededException( + origin, + stats.pendingAcquireCount(), + connectionPool.pendingAcquireMaxCount, + ) + } + + is ReadTimeoutException -> ResponseTimeoutException(origin) + + is SslHandshakeTimeoutException, is DnsNameResolverException, is UnknownHostException -> + OriginUnreachableException(origin, throwable.cause) + + is SocketException, is PrematureCloseException -> + TransportLostException(configuration().remoteAddress().get(), origin) + + is DecoderException, is IllegalArgumentException -> + BadHttpResponseException(origin, throwable.cause) + + else -> throwable + } + + inner class Stats { + fun ongoingRequestCount(): Int = ongoingRequestCount.get() + + fun pendingAcquireCount(): Int = pendingAcquireCount.get() + } + + private enum class ErrorType { + REQUEST, + RESPONSE, + } + + /** + * A factory for creating ReactorHostHttpClient instances. + */ + fun interface Factory { + fun create( + origin: Origin, + connectionPool: ReactorConnectionPool, + httpConfig: HttpConfig, + h2SslProvider: Consumer?, + h11SslHandler: Consumer?, + responseTimeoutMillis: Int, + httpRequestMessageLogger: HttpRequestMessageLogger?, + originStatsFactory: OriginStatsFactory, + metrics: CentralisedMetrics, + eventLoopGroup: LoopResources, + doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?, + ): ReactorHostHttpClient + } + + companion object : Factory { + private val LOGGER = getLogger(this::class.java) + private const val ORIGINID_CONTEXT_KEY = "styx.originid" + private const val X_HTTP2_STREAM_ID = "x-http2-stream-id" + private const val DNS_MAX_CACHE_TIME_TO_LIVE_SECONDS = 30L + private const val DNS_NEGATIVE_TIME_TO_LIVE_SECONDS = 5L + + override fun create( + origin: Origin, + connectionPool: ReactorConnectionPool, + httpConfig: HttpConfig, + h2SslProvider: Consumer?, + h11SslHandler: Consumer?, + responseTimeoutMillis: Int, + httpRequestMessageLogger: HttpRequestMessageLogger?, + originStatsFactory: OriginStatsFactory, + metrics: CentralisedMetrics, + eventLoopGroup: LoopResources, + doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?, + ): ReactorHostHttpClient = + ReactorHostHttpClient( + origin, + connectionPool, + httpConfig, + h2SslProvider, + h11SslHandler, + responseTimeoutMillis, + httpRequestMessageLogger, + originStatsFactory, + metrics, + eventLoopGroup, + doOnResponse, + ) + } +} diff --git a/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt new file mode 100644 index 000000000..479bb66f7 --- /dev/null +++ b/components/client/src/main/kotlin/com/hotels/styx/client/ReactorOriginsInventory.kt @@ -0,0 +1,252 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.google.common.eventbus.EventBus +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.Id +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.api.extension.service.TlsSettings +import com.hotels.styx.client.HttpConfig.defaultHttpConfig +import com.hotels.styx.client.healthcheck.OriginHealthStatusMonitor +import com.hotels.styx.client.healthcheck.monitors.NoOriginHealthStatusMonitor +import com.hotels.styx.common.QueueDrainingEventProcessor +import com.hotels.styx.common.StyxFutures.await +import com.hotels.styx.common.logging.HttpRequestMessageLogger +import com.hotels.styx.metrics.CentralisedMetrics +import io.netty.channel.Channel +import io.netty.handler.ssl.SslContext +import reactor.netty.http.client.HttpClientResponse +import reactor.netty.resources.LoopResources +import reactor.netty.tcp.SslProvider +import reactor.netty.tcp.SslProvider.SslContextSpec +import java.net.InetSocketAddress +import java.util.function.Consumer +import javax.annotation.concurrent.ThreadSafe +import javax.net.ssl.SNIHostName + +/** + * A Reactor version inventory of the origins configured for a single application + */ +@ThreadSafe +class ReactorOriginsInventory( + eventBus: EventBus, + appId: Id, + originHealthStatusMonitor: OriginHealthStatusMonitor, + private val metrics: CentralisedMetrics, + private val connectionPool: ReactorConnectionPool, + private val hostClientFactory: ReactorHostHttpClient.Factory, + private val httpConfig: HttpConfig, + private val tlsSettings: TlsSettings?, + private val responseTimeoutMillis: Int, + private val httpRequestMessageLogger: HttpRequestMessageLogger?, + private val originStatsFactory: OriginStatsFactory, + private val eventLoopGroup: LoopResources, + private val sslContext: SslContext?, + private val doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?, +) : OriginsInventory(eventBus, originHealthStatusMonitor, appId, metrics) { + override val eventQueue: QueueDrainingEventProcessor = QueueDrainingEventProcessor(this, true) + + override fun Origin.toMonitoredOrigin(): OriginsInventory.MonitoredOrigin = MonitoredOrigin(this) + + override fun registerEvent() { + eventBus.register(this) + } + + override fun addOriginStatusListener() { + originHealthStatusMonitor.addOriginStatusListener(this) + } + + /** + * A builder for [ReactorOriginsInventory]. + */ + class Builder(val appId: Id) { + private var originHealthMonitor: OriginHealthStatusMonitor = NoOriginHealthStatusMonitor() + private var metrics: CentralisedMetrics? = null + private var eventBus = EventBus() + private var connectionPool: ReactorConnectionPool = ReactorConnectionPool() + private var hostClientFactory: ReactorHostHttpClient.Factory = ReactorHostHttpClient + private var initialOrigins: Set = emptySet() + private var httpConfig: HttpConfig = defaultHttpConfig() + private var tlsSettings: TlsSettings? = null + private var responseTimeoutMillis: Int = 60_000 + private var httpRequestMessageLogger: HttpRequestMessageLogger? = null + private var originStatsFactory: OriginStatsFactory? = null + private var eventLoopGroup: LoopResources = LoopResources.create("$appId-client") + private var sslContext: SslContext? = null + private var doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)? = null + + fun metrics(metrics: CentralisedMetrics) = + apply { + this.metrics = metrics + } + + fun connectionPool(connectionPool: ReactorConnectionPool) = + apply { + this.connectionPool = connectionPool + } + + fun hostClientFactory(hostClientFactory: ReactorHostHttpClient.Factory) = + apply { + this.hostClientFactory = hostClientFactory + } + + fun originHealthMonitor(originHealthMonitor: OriginHealthStatusMonitor) = + apply { + this.originHealthMonitor = originHealthMonitor + } + + fun eventBus(eventBus: EventBus) = + apply { + this.eventBus = eventBus + } + + fun initialOrigins(origins: Set) = + apply { + initialOrigins = origins.toSet() + } + + fun httpConfig(httpConfig: HttpConfig) = + apply { + this.httpConfig = httpConfig + } + + fun tlsSettings(tlsSettings: TlsSettings?) = + apply { + this.tlsSettings = tlsSettings + } + + fun responseTimeoutMillis(responseTimeoutMillis: Int) = + apply { + this.responseTimeoutMillis = responseTimeoutMillis + } + + fun httpRequestMessageLogger(httpRequestMessageLogger: HttpRequestMessageLogger?) = + apply { + this.httpRequestMessageLogger = httpRequestMessageLogger + } + + fun originStatsFactory(originStatsFactory: OriginStatsFactory) = + apply { + this.originStatsFactory = originStatsFactory + } + + fun eventLoopGroup(eventLoopGroup: LoopResources) = + apply { + this.eventLoopGroup = eventLoopGroup + } + + fun sslContext(sslContext: SslContext?) = + apply { + this.sslContext = sslContext + } + + fun doOnResponse(doOnResponse: ((HttpClientResponse, HttpInterceptor.Context?) -> Unit)?) = + apply { + this.doOnResponse = doOnResponse + } + + fun build(): ReactorOriginsInventory { + await(originHealthMonitor.start()) + val originsInventory = + ReactorOriginsInventory( + eventBus, + appId, + originHealthMonitor, + checkNotNull(metrics) { "metrics is required" }, + connectionPool, + hostClientFactory, + httpConfig, + tlsSettings, + responseTimeoutMillis, + httpRequestMessageLogger, + originStatsFactory ?: OriginStatsFactory.CachingOriginStatsFactory(metrics), + eventLoopGroup, + sslContext, + doOnResponse, + ) + initialOrigins.takeIf { it.isNotEmpty() }?.apply { + originsInventory.setOrigins(initialOrigins) + } + return originsInventory + } + } + + inner class MonitoredOrigin(origin: Origin) : OriginsInventory.MonitoredOrigin(origin) { + // SNI info must be passed while using HTTP/2 + private val h2SslProvider: Consumer? = + if (sslContext != null && connectionPool.isHttp2()) { + Consumer { sslContextSpec -> + sslContextSpec.sslContext(sslContext) + .serverNames(SNIHostName(tlsSettings?.sniHost ?: origin.host())) + } + } else { + null + } + + // By default, reactor-netty sends the remote hostname as SNI server name. + // This is a workaround to avoid reactor-netty injecting SNI host if tlsSettings.sendSni() is false. + private val h11SslHandler: Consumer? = + if (sslContext != null && !connectionPool.isHttp2()) { + Consumer { channel -> + val sslProviderBuilder = SslProvider.builder().sslContext(sslContext) + + if (tlsSettings?.sendSni() == true) { + sslProviderBuilder + .serverNames(SNIHostName(tlsSettings.sniHost ?: origin.host())) + .build() + .addSslHandler(channel, InetSocketAddress(origin.host(), origin.port()), false) + } else { + sslProviderBuilder + .build() + .addSslHandler(channel, null, false) + } + } + } else { + null + } + + override val hostClient: ReactorHostHttpClient = + hostClientFactory.create( + origin, + connectionPool, + httpConfig, + h2SslProvider, + h11SslHandler, + responseTimeoutMillis, + httpRequestMessageLogger, + originStatsFactory, + metrics, + eventLoopGroup, + doOnResponse, + ) + } + + companion object { + @JvmStatic + fun newOriginsInventoryBuilder(appId: Id): Builder = Builder(appId) + + @JvmStatic + fun newOriginsInventoryBuilder( + metrics: CentralisedMetrics, + backendService: BackendService, + ): Builder = + Builder(backendService.id()) + .metrics(metrics) + .initialOrigins(backendService.origins()) + } +} diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt new file mode 100644 index 000000000..f83aa0de7 --- /dev/null +++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorBackendServiceClientTest.kt @@ -0,0 +1,808 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.google.common.net.HostAndPort +import com.hotels.styx.api.Eventual +import com.hotels.styx.api.HttpHandler +import com.hotels.styx.api.HttpHeaderNames.CHUNKED +import com.hotels.styx.api.HttpHeaderNames.CONTENT_LENGTH +import com.hotels.styx.api.HttpHeaderNames.HOST +import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.HttpResponseStatus.BAD_REQUEST +import com.hotels.styx.api.HttpResponseStatus.INTERNAL_SERVER_ERROR +import com.hotels.styx.api.HttpResponseStatus.NOT_IMPLEMENTED +import com.hotels.styx.api.HttpResponseStatus.OK +import com.hotels.styx.api.HttpResponseStatus.UNAUTHORIZED +import com.hotels.styx.api.Id.GENERIC_APP +import com.hotels.styx.api.LiveHttpRequest +import com.hotels.styx.api.LiveHttpRequest.get +import com.hotels.styx.api.LiveHttpResponse +import com.hotels.styx.api.LiveHttpResponse.response +import com.hotels.styx.api.MeterRegistry +import com.hotels.styx.api.MicrometerRegistry +import com.hotels.styx.api.RequestCookie.requestCookie +import com.hotels.styx.api.exceptions.NoAvailableHostsException +import com.hotels.styx.api.exceptions.OriginUnreachableException +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.Origin.newOriginBuilder +import com.hotels.styx.api.extension.RemoteHost +import com.hotels.styx.api.extension.RemoteHost.remoteHost +import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer +import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.api.extension.service.StickySessionConfig +import com.hotels.styx.client.retry.RetryNTimes +import com.hotels.styx.metrics.CentralisedMetrics +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.types.shouldBeInstanceOf +import io.micrometer.core.instrument.simple.SimpleMeterRegistry +import io.mockk.Called +import io.mockk.CapturingSlot +import io.mockk.every +import io.mockk.mockk +import io.mockk.slot +import io.mockk.verify +import io.mockk.verifyOrder +import org.reactivestreams.Publisher +import reactor.core.publisher.Mono +import reactor.test.StepVerifier +import java.util.Optional +import kotlin.jvm.optionals.getOrNull + +class ReactorBackendServiceClientTest : StringSpec() { + private val context: HttpInterceptor.Context = mockk() + private val backendService: BackendService = + backendBuilderWithOrigins(SOME_ORIGIN.port()) + .stickySessionConfig(STICKY_SESSION_CONFIG) + .build() + private lateinit var meterRegistry: MeterRegistry + private lateinit var metrics: CentralisedMetrics + + init { + beforeTest { + meterRegistry = MicrometerRegistry(SimpleMeterRegistry()) + metrics = CentralisedMetrics(meterRegistry) + } + + "sendRequest routes the request to host selected by load balancer" { + val hostClient = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of( + remoteHost( + SOME_ORIGIN, + toHandler(hostClient), + hostClient, + ), + ), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response!!.status() shouldBe OK + verify { hostClient.sendRequest(SOME_REQ, any()) } + } + + "constructs retry context when load balancer does not find available origins" { + val retryContextSlot = slot() + val loadBalancerSlot = slot() + val lbPreferencesSlot = slot() + val retryPolicy = + mockRetryPolicy( + retryContextSlot, + loadBalancerSlot, + lbPreferencesSlot, + true, + true, + true, + ) + val hostClient = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.empty(), + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = retryPolicy, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + retryContextSlot.isCaptured shouldBe true + retryContextSlot.captured.appId() shouldBe backendService.id() + retryContextSlot.captured.currentRetryCount() shouldBe 1 + retryContextSlot.captured.lastException() shouldBe Optional.empty() + + loadBalancerSlot.isCaptured shouldBe true + loadBalancerSlot.captured shouldNotBe null + + lbPreferencesSlot.isCaptured shouldBe true + lbPreferencesSlot.captured.avoidOrigins() shouldBe emptyList() + lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.empty() + + response.status() shouldBe OK + } + + "retries when retry policy tells to retry" { + val retryContextSlot = slot() + val loadBalancerSlot = slot() + val lbPreferencesSlot = slot() + val retryPolicy = + mockRetryPolicy( + retryContextSlot, + loadBalancerSlot, + lbPreferencesSlot, + true, + false, + ) + + val hostClient1 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")), + ), + ) + val hostClient2 = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)), + Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)), + ), + retryPolicy = retryPolicy, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + retryContextSlot.isCaptured shouldBe true + retryContextSlot.captured.appId() shouldBe backendService.id() + retryContextSlot.captured.currentRetryCount() shouldBe 1 + retryContextSlot.captured.lastException().getOrNull().shouldBeInstanceOf() + + loadBalancerSlot.isCaptured shouldBe true + loadBalancerSlot.captured shouldNotBe null + + lbPreferencesSlot.isCaptured shouldBe true + lbPreferencesSlot.captured.avoidOrigins() shouldBe listOf(ORIGIN_1) + lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.empty() + + response.status() shouldBe OK + + verifyOrder { + hostClient1.sendRequest(SOME_REQ, any()) + hostClient2.sendRequest(SOME_REQ, any()) + } + } + + "stops retries when retry policy tells to stop" { + val hostClient1 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")), + ), + ) + val hostClient2 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")), + ), + ) + val hostClient3 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")), + ), + ) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)), + Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)), + Optional.of(remoteHost(ORIGIN_3, toHandler(hostClient3), hostClient3)), + ), + retryPolicy = mockRetryPolicy(true, false), + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + StepVerifier.create(backendServiceClient.sendRequest(SOME_REQ, context)) + .verifyError(OriginUnreachableException::class.java) + + verifyOrder { + hostClient1.sendRequest(SOME_REQ, any()) + hostClient2.sendRequest(SOME_REQ, any()) + hostClient3.sendRequest(SOME_REQ, any()) wasNot Called + } + } + + "retries at most 3 times" { + val hostClient1 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_1, RuntimeException("An error occurred")), + ), + ) + val hostClient2 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_2, RuntimeException("An error occurred")), + ), + ) + val hostClient3 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_3, RuntimeException("An error occurred")), + ), + ) + val hostClient4 = + mockHostClient( + Mono.error( + OriginUnreachableException(ORIGIN_4, RuntimeException("An error occurred")), + ), + ) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(ORIGIN_1, toHandler(hostClient1), hostClient1)), + Optional.of(remoteHost(ORIGIN_2, toHandler(hostClient2), hostClient2)), + Optional.of(remoteHost(ORIGIN_3, toHandler(hostClient3), hostClient3)), + Optional.of(remoteHost(ORIGIN_4, toHandler(hostClient4), hostClient4)), + ), + retryPolicy = mockRetryPolicy(true, true, true, true), + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + StepVerifier.create(backendServiceClient.sendRequest(SOME_REQ, context)) + .verifyError(NoAvailableHostsException::class.java) + + verifyOrder { + hostClient1.sendRequest(SOME_REQ, any()) + hostClient2.sendRequest(SOME_REQ, any()) + hostClient3.sendRequest(SOME_REQ, any()) + hostClient4.sendRequest(SOME_REQ, any()) wasNot Called + } + } + + "increments response status metrics for bad response" { + val hostClient = mockHostClient(Mono.just(response(BAD_REQUEST).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response.status() shouldBe BAD_REQUEST + verify { hostClient.sendRequest(SOME_REQ, any()) } + meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "400").counter() shouldNotBe null + } + + "increments response status metrics for 401" { + val hostClient = mockHostClient(Mono.just(response(UNAUTHORIZED).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response.status() shouldBe UNAUTHORIZED + verify { hostClient.sendRequest(SOME_REQ, any()) } + meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "401").counter() shouldNotBe null + } + + "increments response status metrics for 500" { + val hostClient = mockHostClient(Mono.just(response(INTERNAL_SERVER_ERROR).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response.status() shouldBe INTERNAL_SERVER_ERROR + verify { hostClient.sendRequest(SOME_REQ, any()) } + meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "500").counter() shouldNotBe null + } + + "increments response status metrics for 501" { + val hostClient = mockHostClient(Mono.just(response(NOT_IMPLEMENTED).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response.status() shouldBe NOT_IMPLEMENTED + verify { hostClient.sendRequest(SOME_REQ, any()) } + meterRegistry.get("proxy.client.responseCode.errorStatus").tag("statusCode", "501").counter() shouldNotBe null + } + + "removes bad content length" { + val hostClient = + mockHostClient( + Mono.just( + response(OK) + .addHeader(CONTENT_LENGTH, 50) + .addHeader(TRANSFER_ENCODING, CHUNKED) + .build(), + ), + ) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(SOME_ORIGIN, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = Mono.from(backendServiceClient.sendRequest(SOME_REQ, context)).block() + + response.status() shouldBe OK + response.contentLength().isPresent shouldBe false + response.header(TRANSFER_ENCODING).get() shouldBe "chunked" + } + + "prefers sticky origins" { + val lbPreferencesSlot = slot() + val origin = originWithId("localhost:234", "App-X", "Origin-Y") + val hostClient = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = null, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + lbPreferencesSlot, + Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = + Mono.from( + backendServiceClient + .sendRequest( + get("/foo") + .cookies(requestCookie("styx_origin_$GENERIC_APP", "Origin-Y")) + .build(), + context, + ), + ).block() + + response.status() shouldBe OK + + lbPreferencesSlot.isCaptured shouldBe true + lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y") + } + + "prefers restricted origins" { + val lbPreferencesSlot = slot() + val origin = originWithId("localhost:234", "App-X", "Origin-Y") + val hostClient = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "restrictedOrigin", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + lbPreferencesSlot, + Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = + Mono.from( + backendServiceClient + .sendRequest( + get("/foo") + .cookies(requestCookie("restrictedOrigin", "Origin-Y")) + .build(), + context, + ), + ).block() + + response.status() shouldBe OK + + lbPreferencesSlot.isCaptured shouldBe true + lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y") + } + + "prefers restricted origins over sticky origins when both are configured" { + val lbPreferencesSlot = slot() + val origin = originWithId("localhost:234", "App-X", "Origin-Y") + val hostClient = mockHostClient(Mono.just(response(OK).build())) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "restrictedOrigin", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + lbPreferencesSlot, + Optional.of(remoteHost(origin, toHandler(hostClient), hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + + val response = + Mono.from( + backendServiceClient + .sendRequest( + get("/foo") + .cookies( + requestCookie("restrictedOrigin", "Origin-Y"), + requestCookie("styx_origin_$GENERIC_APP", "Origin-X"), + ) + .build(), + context, + ), + ).block() + + response.status() shouldBe OK + + lbPreferencesSlot.isCaptured shouldBe true + lbPreferencesSlot.captured.preferredOrigins() shouldBe Optional.of("Origin-Y") + } + + "host header is not over written when overrideHostHeader is false" { + val hostClient = mockHostClient(Mono.just(response(OK).build())) + val origin = newOriginBuilder(INCOMING_HOSTNAME, 9090).applicationId("app").build() + val httpHandler: HttpHandler = mockk() + + every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "someCookie", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(origin, httpHandler, hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = false, + ) + + Mono.from(backendServiceClient.sendRequest(TEST_REQUEST, context)).block() + + verify { httpHandler.handle(TEST_REQUEST, context) } + } + + "host header is not over written when overrideHostHeader is true" { + val hostClient = mockHostClient(Mono.just(response(OK).build())) + val origin = newOriginBuilder(UPDATED_HOSTNAME, 9090).applicationId("app").build() + val httpHandler: HttpHandler = mockk() + val updatedRequestSlot = slot() + + every { httpHandler.handle(capture(updatedRequestSlot), any()) } returns Eventual.of(TEST_RESPONSE) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "someCookie", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(origin, httpHandler, hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = true, + ) + + Mono.from(backendServiceClient.sendRequest(TEST_REQUEST, context)).block() + + updatedRequestSlot.captured.header(HOST).getOrNull() shouldBe UPDATED_HOSTNAME + } + + "original requests is present in response when overrideHostHeader is false" { + val hostClient = mockHostClient(Mono.just(response(OK).build())) + val origin = newOriginBuilder(INCOMING_HOSTNAME, 9090).applicationId("app").build() + val httpHandler: HttpHandler = mockk() + + every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "someCookie", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(origin, httpHandler, hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = false, + ) + + StepVerifier.create(backendServiceClient.sendRequest(TEST_REQUEST, context)) + .expectNextMatches { response -> + response.request().header(HOST).map { it == INCOMING_HOSTNAME }.orElse(false) + } + .verifyComplete() + } + + "original requests is present in response when overrideHostHeader is true" { + val hostClient = mockHostClient(Mono.just(response(OK).build())) + val origin = newOriginBuilder(UPDATED_HOSTNAME, 9090).applicationId("app").build() + val httpHandler: HttpHandler = mockk() + + every { httpHandler.handle(any(), any()) } returns Eventual.of(TEST_RESPONSE) + + val backendServiceClient = + ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = "someCookie", + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = StyxHeaderConfig.ORIGIN_ID_DEFAULT, + loadBalancer = + mockLoadBalancer( + Optional.of(remoteHost(origin, httpHandler, hostClient)), + ), + retryPolicy = DEFAULT_RETRY_POLICY, + metrics = metrics, + overrideHostHeader = true, + ) + + StepVerifier.create(backendServiceClient.sendRequest(TEST_REQUEST, context)) + .expectNextMatches { response -> + response.request().header(HOST).map { it == UPDATED_HOSTNAME }.orElse(false) + } + .verifyComplete() + } + } + + private fun toHandler(hostClient: ReactorHostHttpClient): HttpHandler = + HttpHandler { request: LiveHttpRequest, ctx: HttpInterceptor.Context? -> + Eventual( + hostClient.sendRequest( + request, + ctx, + ), + ) + } + + private fun mockLoadBalancer(first: Optional): LoadBalancer { + val lbStategy: LoadBalancer = mockk() + every { lbStategy.choose(any()) } returns first + return lbStategy + } + + private fun mockLoadBalancer( + first: Optional, + vararg remoteHostWrappers: Optional, + ): LoadBalancer { + val lbStategy: LoadBalancer = mockk() + every { lbStategy.choose(any()) } returns first andThenMany remoteHostWrappers.toList() + return lbStategy + } + + private fun mockLoadBalancer( + lbPreferencesSlot: CapturingSlot, + first: Optional, + vararg remoteHostWrappers: Optional, + ): LoadBalancer { + val lbStategy: LoadBalancer = mockk() + every { lbStategy.choose(capture(lbPreferencesSlot)) } returns first andThenMany remoteHostWrappers.toList() + return lbStategy + } + + private fun mockRetryPolicy( + first: Boolean, + vararg outcomes: Boolean, + ): RetryPolicy { + val retryPolicy: RetryPolicy = mockk() + val retryOutcome: RetryPolicy.Outcome = mockk() + every { retryOutcome.shouldRetry() } returns first andThenMany outcomes.toList() + val retryOutcomes: List = outcomes.map { retryOutcome }.toList() + every { retryPolicy.evaluate(any(), any(), any()) } returns retryOutcome andThenMany retryOutcomes + return retryPolicy + } + + private fun mockRetryPolicy( + retryContextSlot: CapturingSlot, + loadBalancerSlot: CapturingSlot, + lbPreferencesSlot: CapturingSlot, + first: Boolean, + vararg outcomes: Boolean, + ): RetryPolicy { + val retryPolicy: RetryPolicy = mockk() + val retryOutcome: RetryPolicy.Outcome = mockk() + every { retryOutcome.shouldRetry() } returns first andThenMany outcomes.toList() + val retryOutcomes: List = outcomes.map { retryOutcome }.toList() + every { + retryPolicy.evaluate(capture(retryContextSlot), capture(loadBalancerSlot), capture(lbPreferencesSlot)) + } returns retryOutcome andThenMany retryOutcomes + return retryPolicy + } + + private fun mockHostClient(responsePublisher: Publisher): ReactorHostHttpClient { + val hostClient: ReactorHostHttpClient = mockk() + every { hostClient.sendRequest(any(), any()) } returns responsePublisher + return hostClient + } + + private fun backendBuilderWithOrigins(originPort: Int): BackendService.Builder { + return BackendService.Builder() + .origins(newOriginBuilder("localhost", originPort).build()) + } + + private fun originWithId( + host: String, + appId: String, + originId: String, + ): Origin? { + val hap = HostAndPort.fromString(host) + return newOriginBuilder(hap.host, hap.port) + .applicationId(appId) + .id(originId) + .build() + } + + companion object { + private val SOME_ORIGIN = newOriginBuilder("localhost", 9090).applicationId(GENERIC_APP).build() + private val SOME_REQ = get("/some-req").build() + private val DEFAULT_RETRY_POLICY = RetryNTimes(3) + + private val ORIGIN_1 = newOriginBuilder("localhost", 9091).applicationId("app").id("app-01").build() + private val ORIGIN_2 = newOriginBuilder("localhost", 9092).applicationId("app").id("app-02").build() + private val ORIGIN_3 = newOriginBuilder("localhost", 9093).applicationId("app").id("app-03").build() + private val ORIGIN_4 = newOriginBuilder("localhost", 9094).applicationId("app").id("app-04").build() + + private val STICKY_SESSION_CONFIG = StickySessionConfig.stickySessionDisabled() + + private val TEST_REQUEST = get("/test").header(HOST, "www.expedia.com").build() + private val TEST_RESPONSE = response(OK).build() + private const val INCOMING_HOSTNAME = "www.expedia.com" + private const val UPDATED_HOSTNAME = "host.domain.com" + } +} diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt new file mode 100644 index 000000000..a0e296514 --- /dev/null +++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorConnectionPoolTest.kt @@ -0,0 +1,131 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.hotels.styx.api.HttpVersion +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.api.extension.service.ConnectionPoolSettings +import com.hotels.styx.api.extension.service.Http2ConnectionPoolSettings +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import reactor.netty.http.HttpProtocol.H2 +import reactor.netty.http.HttpProtocol.HTTP11 +import reactor.netty.http.client.Http2AllocationStrategy +import reactor.netty.resources.ConnectionProvider +import java.time.Duration +import java.util.concurrent.TimeUnit.MILLISECONDS + +class ReactorConnectionPoolTest : StringSpec() { + private lateinit var connectionPool: ReactorConnectionPool + + init { + "getConnectionProvider returns a ConnectionProvider for HTTP/1.1" { + connectionPool = + ReactorConnectionPool( + CONNECTION_POOL_SETTINGS, + HttpVersion.HTTP_1_1, + BACKEND_SERVICE, + ) + + val expected = + ConnectionProvider.builder(ORIGIN.id().toString()) + .metrics(true) + .pendingAcquireTimeout(Duration.ofMillis(CONNECTION_POOL_SETTINGS.pendingConnectionTimeoutMillis().toLong())) + .disposeTimeout(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong())) + .maxIdleTime(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong())) + .maxLifeTime(Duration.ofSeconds(CONNECTION_POOL_SETTINGS.connectionExpirationSeconds())) + .maxConnections(CONNECTION_POOL_SETTINGS.maxConnectionsPerHost()) + .pendingAcquireMaxCount(CONNECTION_POOL_SETTINGS.maxPendingConnectionsPerHost()) + .lifo() + .build() + + connectionPool.supportedHttpProtocols() shouldBe arrayOf(HTTP11) + connectionPool.pendingAcquireMaxCount shouldBe CONNECTION_POOL_SETTINGS.maxPendingConnectionsPerHost() + connectionPool.connectTimeoutMillis shouldBe CONNECTION_POOL_SETTINGS.connectTimeoutMillis() + connectionPool.maxConnections shouldBe CONNECTION_POOL_SETTINGS.maxConnectionsPerHost() + connectionPool.disposeTimeoutMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis() + connectionPool.maxIdleTimeMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis() + connectionPool.connectionExpirationSeconds shouldBe CONNECTION_POOL_SETTINGS.connectionExpirationSeconds() + assertConnectionProvider(connectionPool.getConnectionProvider(ORIGIN), expected) + } + + "getConnectionProvider returns a ConnectionProvider for HTTP/2" { + connectionPool = + ReactorConnectionPool( + CONNECTION_POOL_SETTINGS, + HttpVersion.HTTP_2, + BACKEND_SERVICE, + ) + + val expected = + ConnectionProvider.builder(ORIGIN.id().toString()) + .metrics(true) + .pendingAcquireTimeout(Duration.ofMillis(CONNECTION_POOL_SETTINGS.pendingConnectionTimeoutMillis().toLong())) + .disposeTimeout(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong())) + .maxIdleTime(Duration.ofMillis(BACKEND_SERVICE.responseTimeoutMillis().toLong())) + .maxLifeTime(Duration.ofSeconds(CONNECTION_POOL_SETTINGS.connectionExpirationSeconds())) + .maxConnections(CONNECTION_POOL_SETTINGS.maxConnectionsPerHost()) + .pendingAcquireMaxCount(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxPendingStreamsPerHost!!) + .lifo() + .allocationStrategy( + Http2AllocationStrategy.builder() + .maxConcurrentStreams(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxStreamsPerConnection!!.toLong()) + .maxConnections(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxConnections!!) + .minConnections(CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().minConnections!!) + .build(), + ) + .build() + + connectionPool.supportedHttpProtocols() shouldBe arrayOf(H2, HTTP11) + connectionPool.pendingAcquireMaxCount shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxPendingStreamsPerHost + connectionPool.connectTimeoutMillis shouldBe CONNECTION_POOL_SETTINGS.connectTimeoutMillis() + connectionPool.maxConnections shouldBe CONNECTION_POOL_SETTINGS.maxConnectionsPerHost() + connectionPool.disposeTimeoutMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis() + connectionPool.maxIdleTimeMillis shouldBe BACKEND_SERVICE.responseTimeoutMillis() + connectionPool.h2MaxConnections shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxConnections + connectionPool.h2MinConnections shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().minConnections + connectionPool.h2MaxConcurrentStreams shouldBe CONNECTION_POOL_SETTINGS.http2ConnectionPoolSettings().maxStreamsPerConnection + connectionPool.connectionExpirationSeconds shouldBe CONNECTION_POOL_SETTINGS.connectionExpirationSeconds() + assertConnectionProvider(connectionPool.getConnectionProvider(ORIGIN), expected) + } + } + + private fun assertConnectionProvider( + actual: ConnectionProvider, + expected: ConnectionProvider, + ) { + actual.name() shouldBe expected.name() + actual.maxConnections() shouldBe expected.maxConnections() + actual.maxConnectionsPerHost() shouldBe expected.maxConnectionsPerHost() + } + + companion object { + private val ORIGIN = Origin.newOriginBuilder("localhost", 886).build() + private val CONNECTION_POOL_SETTINGS: ConnectionPoolSettings = + ConnectionPoolSettings.Builder() + .connectTimeout(123, MILLISECONDS) + .pendingConnectionTimeout(543, MILLISECONDS) + .maxConnectionsPerHost(10) + .maxPendingConnectionsPerHost(87) + .http2ConnectionPoolSettings(Http2ConnectionPoolSettings(10, 4, 3, 6)) + .build() + private val BACKEND_SERVICE: BackendService = + BackendService.newBackendServiceBuilder() + .responseTimeoutMillis(777) + .build() + } +} diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt new file mode 100644 index 000000000..28d782a76 --- /dev/null +++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorHostHttpClientTest.kt @@ -0,0 +1,744 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.hotels.styx.api.Buffers +import com.hotels.styx.api.HttpHeaderNames.CHUNKED +import com.hotels.styx.api.HttpHeaderNames.TRANSFER_ENCODING +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.HttpRequest +import com.hotels.styx.api.Id.GENERIC_APP +import com.hotels.styx.api.MicrometerRegistry +import com.hotels.styx.api.exceptions.ResponseTimeoutException +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.service.TlsSettings +import com.hotels.styx.client.ssl.SslContextFactory +import com.hotels.styx.common.logging.HttpRequestMessageLogger +import com.hotels.styx.metrics.CentralisedMetrics +import io.kotest.assertions.throwables.shouldNotThrowMessage +import io.kotest.assertions.throwables.shouldThrowWithMessage +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import io.micrometer.core.instrument.simple.SimpleMeterRegistry +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.netty.channel.Channel +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.util.InsecureTrustManagerFactory +import kotlinx.coroutines.repackaged.net.bytebuddy.utility.RandomString +import okhttp3.Protocol +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.tls.HandshakeCertificates +import okhttp3.tls.HeldCertificate +import reactor.core.publisher.Mono +import reactor.netty.http.Http2SslContextSpec +import reactor.netty.http.HttpProtocol.H2 +import reactor.netty.http.HttpProtocol.HTTP11 +import reactor.netty.http.client.Http2AllocationStrategy +import reactor.netty.resources.ConnectionProvider +import reactor.netty.resources.LoopResources +import reactor.netty.tcp.SslProvider +import reactor.test.StepVerifier +import java.nio.charset.StandardCharsets +import java.time.Duration +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import java.util.function.Consumer + +class ReactorHostHttpClientTest : StringSpec() { + private val context: HttpInterceptor.Context = mockk() + private val httpRequestMessageLogger: HttpRequestMessageLogger = mockk(relaxed = true) + private val connectionPool: ReactorConnectionPool = mockk(relaxed = true) + private lateinit var h2MockServerWithAlpn: MockWebServer + private lateinit var h2MockServerWithoutAlpn: MockWebServer + private lateinit var h11MockServer: MockWebServer + private lateinit var originStats: OriginStatsFactory + private lateinit var meterRegistry: MicrometerRegistry + private lateinit var metrics: CentralisedMetrics + + init { + beforeSpec { + val localhostCertificate = + HeldCertificate.Builder() + .addSubjectAlternativeName(ORIGIN_1.host()) + .addSubjectAlternativeName(ORIGIN_2.host()) + .addSubjectAlternativeName(ORIGIN_3.host()) + .build() + val serverCertificates = + HandshakeCertificates.Builder() + .heldCertificate(localhostCertificate) + .build() + + h11MockServer = MockWebServer() + h11MockServer.useHttps(serverCertificates.sslSocketFactory(), false) + h11MockServer.protocols = listOf(Protocol.HTTP_1_1) + h11MockServer.start(ORIGIN_1.port()) + + h2MockServerWithAlpn = MockWebServer() + h2MockServerWithAlpn.useHttps(serverCertificates.sslSocketFactory(), false) + h2MockServerWithAlpn.protocols = listOf(Protocol.HTTP_2, Protocol.HTTP_1_1) + h2MockServerWithAlpn.start(ORIGIN_2.port()) + + h2MockServerWithoutAlpn = MockWebServer() + h2MockServerWithoutAlpn.useHttps(serverCertificates.sslSocketFactory(), false) + h2MockServerWithoutAlpn.protocols = listOf(Protocol.HTTP_2, Protocol.HTTP_1_1) + h2MockServerWithoutAlpn.protocolNegotiationEnabled = false + h2MockServerWithoutAlpn.start(ORIGIN_3.port()) + } + + beforeTest { + meterRegistry = MicrometerRegistry(SimpleMeterRegistry()) + metrics = CentralisedMetrics(meterRegistry) + originStats = OriginStatsFactory.CachingOriginStatsFactory(metrics) + } + + afterTest { clearAllMocks() } + + afterSpec { + h11MockServer.shutdown() + h2MockServerWithAlpn.shutdown() + h2MockServerWithoutAlpn.shutdown() + LOOP_RESOURCES.dispose() + } + + "HttpClient is created with connectionProvider passed" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().connectionProvider() shouldBe connectionProvider + } + + "init sets responseTimeoutMillis" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().responseTimeout()!!.toMillis() shouldBe 1000 + } + + "init sets protocols" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11) + every { connectionPool.isHttp2() } returns true + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_2, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = H2_SSL_PROVIDER, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().protocols() shouldBe arrayOf(H2, HTTP11) + } + + "init sets decoder configs" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().decoder().maxInitialLineLength() shouldBe HTTP_CONFIG.maxInitialLength() + reactorHostClient.configuration().decoder().maxHeaderSize() shouldBe HTTP_CONFIG.maxHeadersSize() + } + + "init disables default retry" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().isRetryDisabled shouldBe true + } + + "init sets event loop resources" { + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().loopResources() shouldBe LOOP_RESOURCES + } + + "sendRequest calls an ALPN enabled http2 origin and returns a response using http2" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h2MockServerWithAlpn.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11) + every { connectionPool.isHttp2() } returns true + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_2, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = H2_SSL_PROVIDER, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + it.header("x-http2-stream-id").isPresent shouldBe true + } + .verifyComplete() + } + + "sendRequest calls an ALPN disabled http2 origin and returns a response which fallbacks to http1.1" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h2MockServerWithoutAlpn.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11) + every { connectionPool.isHttp2() } returns true + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_3, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = H2_SSL_PROVIDER, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + it.header("x-http2-stream-id").isEmpty shouldBe true + } + .verifyComplete() + } + + "sendRequest throws errors when hitting response timeout during a request to an http2 origin" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + .setBodyDelay(5, TimeUnit.SECONDS) + h2MockServerWithAlpn.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11) + every { connectionPool.isHttp2() } returns true + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_2, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = H2_SSL_PROVIDER, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + val response = Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block() + StepVerifier.create(response!!.body()) + .verifyError(ResponseTimeoutException::class.java) + } + + "sendRequest calls an ALPN enabled http2 origin and returns a response using http1.1" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h2MockServerWithAlpn.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_2, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + it.header("x-http2-stream-id").isEmpty shouldBe true + } + .verifyComplete() + } + + "sendRequest calls an ALPN disabled http2 origin and returns a response using http1.1" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h2MockServerWithoutAlpn.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_3, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + it.header("x-http2-stream-id").isEmpty shouldBe true + } + .verifyComplete() + } + + "sendRequest calls an http1.1 origin and returns a response" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h11MockServer.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(mockRequest, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + it.header("x-http2-stream-id").isEmpty shouldBe true + } + .verifyComplete() + } + + "sendRequest throws errors when hitting response timeout during a request to an http1.1 origin" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + .setBodyDelay(5, TimeUnit.SECONDS) + h11MockServer.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + val response = Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block() + StepVerifier.create(response!!.body()) + .verifyError(ResponseTimeoutException::class.java) + } + + "closes connection pool when the host http client is closed" { + val mockConnectionProvider: ConnectionProvider = mockk(relaxed = true) + + every { connectionPool.getConnectionProvider(any()) } returns mockConnectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = mockk(relaxed = true), + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = H2_SSL_PROVIDER, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = mockk(relaxed = true), + originStatsFactory = mockk(relaxed = true), + metrics = mockk(relaxed = true), + eventLoopGroup = mockk(relaxed = true), + ) + + every { connectionPool.supportedHttpProtocols() } returns arrayOf(H2, HTTP11) + every { connectionPool.isHttp2() } returns true + + reactorHostClient.close() + + verify(exactly = 1) { + mockConnectionProvider.dispose() + } + } + + "sendRequest calls an http1.1 origin with POST and returns a response with body" { + val requestBody = RandomString.make(100_000) + val responseBody = RandomString.make(100_000) + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody(responseBody) + h11MockServer.enqueue(mockResponse) + + var request = + HttpRequest.post("/") + .header(TRANSFER_ENCODING, CHUNKED) + .body(requestBody, StandardCharsets.UTF_8) + .build() + .stream() + + val requestBodySize = AtomicInteger() + request = + request.newBuilder() + .body { body -> body.doOnEach { requestBodySize.addAndGet(it.get()?.size() ?: 0) } } + .build() + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = H11_SSL_HANDLER, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + StepVerifier.create(reactorHostClient.sendRequest(request, context)) + .assertNext { + it.status().code() shouldBe 200 + it.header("X-Id").orElse(null) shouldBe "123" + Mono.from(it.body()).map(Buffers::toByteBuf) + .map { chunk -> chunk.toString(StandardCharsets.UTF_8) } + .doOnNext { body -> body shouldBe responseBody } + .subscribe() + } + .verifyComplete() + + requestBodySize.get() shouldBe 100_000 + } + + "SSL handler for HTTP/1.1 is added to channel if h11SslHandler is passed" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h11MockServer.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val sslHandler = mockk>() + + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = sslHandler, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + runCatching { + Mono.from(reactorHostClient.sendRequest(mockRequest, context)).block() + } + + verify { + sslHandler.accept(any()) + } + } + + "SSL provider for HTTP/2 is used if h2SslProvider is passed" { + val mockResponse = + MockResponse() + .setResponseCode(200) + .addHeader("X-Id", "123") + .setBody("xyz") + h11MockServer.enqueue(mockResponse) + + every { connectionPool.getConnectionProvider(any()) } returns connectionProvider + every { connectionPool.connectTimeoutMillis } returns 500 + every { connectionPool.supportedHttpProtocols() } returns arrayOf(HTTP11) + + val sslProvider = mockk>(relaxed = true) + + runCatching { + val reactorHostClient = + ReactorHostHttpClient.create( + origin = ORIGIN_2, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = sslProvider, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + + reactorHostClient.configuration().sslProvider() shouldBe sslProvider.accept(SslProvider.builder()) + } + } + + "Only 1 ssl context can be passed" { + val sslProvider = mockk>() + val sslHandler = mockk>() + + shouldThrowWithMessage("There can only be one type of SSL context") { + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = sslProvider, + h11SslHandler = sslHandler, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + } + } + + "Supports http without any sslContext" { + shouldNotThrowMessage("There can only be one type of SSL context") { + ReactorHostHttpClient.create( + origin = ORIGIN_1, + connectionPool = connectionPool, + httpConfig = HTTP_CONFIG, + h2SslProvider = null, + h11SslHandler = null, + responseTimeoutMillis = 1000, + httpRequestMessageLogger = httpRequestMessageLogger, + originStatsFactory = originStats, + metrics = metrics, + eventLoopGroup = LOOP_RESOURCES, + ) + } + } + } + + companion object { + private val ORIGIN_1 = + Origin.newOriginBuilder("localhost", 58887) + .applicationId(GENERIC_APP) + .id("app-01") + .build() + private val ORIGIN_2 = + Origin.newOriginBuilder("localhost", 58888) + .applicationId(GENERIC_APP) + .id("app-01") + .build() + private val ORIGIN_3 = + Origin.newOriginBuilder("localhost", 58889) + .applicationId(GENERIC_APP) + .id("app-01") + .build() + private var mockRequest = HttpRequest.get("/").build().stream() + private val connectionProvider = + ConnectionProvider.builder("app-01") + .allocationStrategy( + Http2AllocationStrategy.builder() + .maxConcurrentStreams(10) + .maxConnections(1) + .build(), + ) + .evictionPredicate { _, connectionMetadata -> + connectionMetadata.lifeTime() >= 300000 || connectionMetadata.idleTime() >= 60000 + } + .pendingAcquireTimeout(Duration.ofMillis(5000)) + .pendingAcquireMaxCount(1000) + .disposeTimeout(Duration.ofMillis(1000)) + .build() + private val LOOP_RESOURCES = LoopResources.create("app-01", 1, 1, true) + private val HTTP_CONFIG = HttpConfig.defaultHttpConfig() + private val H11_SSL_HANDLER: Consumer = + Consumer { channel -> + SslProvider.builder() + .sslContext(SslContextFactory.get(TlsSettings.Builder().build())) + .build() + .addSslHandler(channel, null, false) + } + private val H2_SSL_PROVIDER: Consumer = + Consumer { sslContextSpec -> + sslContextSpec.sslContext( + Http2SslContextSpec.forClient() + .configure { sslContextBuilder: SslContextBuilder -> + sslContextBuilder.trustManager( + InsecureTrustManagerFactory.INSTANCE, + ) + } + .sslContext(), + ) + } + } +} diff --git a/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt new file mode 100644 index 000000000..b7f7c99f8 --- /dev/null +++ b/components/client/src/test/unit/kotlin/com/hotels/styx/client/ReactorOriginsInventoryTest.kt @@ -0,0 +1,528 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.client + +import com.google.common.eventbus.EventBus +import com.hotels.styx.api.Id.GENERIC_APP +import com.hotels.styx.api.Id.id +import com.hotels.styx.api.MeterRegistry +import com.hotels.styx.api.Metrics +import com.hotels.styx.api.MicrometerRegistry +import com.hotels.styx.api.extension.Origin.newOriginBuilder +import com.hotels.styx.api.extension.OriginsChangeListener +import com.hotels.styx.api.extension.OriginsSnapshot +import com.hotels.styx.client.OriginsInventory.OriginState.ACTIVE +import com.hotels.styx.client.OriginsInventory.OriginState.DISABLED +import com.hotels.styx.client.healthcheck.OriginHealthStatusMonitor +import com.hotels.styx.client.origincommands.DisableOrigin +import com.hotels.styx.client.origincommands.EnableOrigin +import com.hotels.styx.metrics.CentralisedMetrics +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import io.micrometer.core.instrument.Gauge +import io.micrometer.core.instrument.Tags +import io.micrometer.core.instrument.simple.SimpleMeterRegistry +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.netty.handler.ssl.SslContext +import reactor.netty.resources.LoopResources + +class ReactorOriginsInventoryTest : StringSpec() { + private lateinit var inventory: OriginsInventory + private lateinit var meterRegistry: MeterRegistry + private val eventBus: EventBus = mockk(relaxed = true) + private val monitor: OriginHealthStatusMonitor = mockk(relaxed = true) + private val connectionPool: ReactorConnectionPool = mockk(relaxed = true) + private val hostClientFactory: ReactorHostHttpClient.Factory = mockk(relaxed = true) + private val originStatsFactory: OriginStatsFactory = mockk(relaxed = true) + private val eventLoopGroup: LoopResources = mockk() + private val sslContext: SslContext = mockk() + + init { + beforeTest { + meterRegistry = MicrometerRegistry(SimpleMeterRegistry()) + inventory = + ReactorOriginsInventory.Builder(GENERIC_APP) + .eventLoopGroup(eventLoopGroup) + .eventBus(eventBus) + .originHealthMonitor(monitor) + .hostClientFactory(hostClientFactory) + .metrics(CentralisedMetrics(meterRegistry)) + .connectionPool(connectionPool) + .originStatsFactory(originStatsFactory) + .sslContext(sslContext) + .build() + } + + afterTest { + clearAllMocks() + } + + "start monitoring new origins" { + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + + inventory.originCount(ACTIVE) shouldBe 2 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0 + gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0 + verify(exactly = 1) { + monitor.monitor(setOf(ORIGIN_1)) + monitor.monitor(setOf(ORIGIN_2)) + eventBus.post(any()) + } + } + + "updates on origin port number change" { + val originV1 = + newOriginBuilder("acme.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + val originV2 = + newOriginBuilder("acme.com", 443) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + + inventory.setOrigins(originV1) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue("generic-app", "acme-01") shouldBe 1.0 + verify(exactly = 1) { + monitor.monitor(setOf(originV1)) + eventBus.post(any()) + } + + inventory.setOrigins(originV2) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue("generic-app", "acme-01") shouldBe 1.0 + verify(exactly = 1) { + monitor.stopMonitoring(setOf(originV1)) + monitor.monitor(setOf(originV2)) + } + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "updates on origin hostname change" { + val originV1 = + newOriginBuilder("acme01.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + val originV2 = + newOriginBuilder("acme02.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + + inventory.setOrigins(originV1) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue("generic-app", "acme-01") shouldBe 1.0 + verify(exactly = 1) { + monitor.monitor(setOf(originV1)) + eventBus.post(any()) + } + + inventory.setOrigins(originV2) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue("generic-app", "acme-01") shouldBe 1.0 + verify(exactly = 1) { + monitor.stopMonitoring(setOf(originV1)) + monitor.monitor(setOf(originV2)) + } + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "stops and restarts monitoring modified origin" { + val originV1 = + newOriginBuilder("acme01.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + val originV2 = + newOriginBuilder("acme02.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + + inventory.setOrigins(originV1) + + verify(exactly = 1) { + monitor.monitor(setOf(originV1)) + } + + inventory.setOrigins(originV2) + + verify(exactly = 1) { + monitor.stopMonitoring(setOf(originV1)) + monitor.monitor(setOf(originV2)) + } + } + + "shuts connection provider on origin change" { + val originV1 = + newOriginBuilder("acme01.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + val originV2 = + newOriginBuilder("acme02.com", 80) + .applicationId(GENERIC_APP) + .id("acme-01") + .build() + + val hostClient1: ReactorHostHttpClient = mockk(relaxed = true) + val hostClient2: ReactorHostHttpClient = mockk(relaxed = true) + every { + hostClientFactory.create( + originV1, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient1 + every { + hostClientFactory.create( + originV2, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient2 + + inventory.setOrigins(originV1) + + verify(exactly = 1) { + hostClientFactory.create( + originV1, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } + + inventory.setOrigins(originV2) + + verify(exactly = 1) { + hostClientFactory.create( + originV2, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + hostClient1.close() + } + } + + "ignores unchanged origins" { + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + + inventory.originCount(ACTIVE) shouldBe 2 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0 + gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0 + verify(exactly = 1) { + monitor.monitor(setOf(ORIGIN_1)) + monitor.monitor(setOf(ORIGIN_2)) + eventBus.post(any()) + } + + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + inventory.originCount(ACTIVE) shouldBe 2 + verify(exactly = 1) { + monitor.monitor(setOf(ORIGIN_1)) + monitor.monitor(setOf(ORIGIN_2)) + eventBus.post(any()) + } + verify(exactly = 0) { + monitor.stopMonitoring(setOf(ORIGIN_1)) + monitor.stopMonitoring(setOf(ORIGIN_2)) + } + } + + "stop monitoring on origins removal" { + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + + inventory.originCount(ACTIVE) shouldBe 2 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0 + gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0 + verify(exactly = 1) { + monitor.monitor(setOf(ORIGIN_1)) + monitor.monitor(setOf(ORIGIN_2)) + eventBus.post(any()) + } + + inventory.setOrigins(ORIGIN_2) + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe null + gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe 1.0 + verify(exactly = 1) { + monitor.stopMonitoring(setOf(ORIGIN_1)) + } + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "shuts connection provider on origin removal" { + val hostClient1: ReactorHostHttpClient = mockk(relaxed = true) + val hostClient2: ReactorHostHttpClient = mockk(relaxed = true) + every { + hostClientFactory.create( + ORIGIN_1, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient1 + every { + hostClientFactory.create( + ORIGIN_2, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient2 + + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + + verify(exactly = 2) { + hostClientFactory.create( + any(), any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } + + inventory.setOrigins(ORIGIN_2) + + verify(exactly = 1) { + hostClient1.close() + } + } + + "does not disable origins not belonging to the app" { + inventory.setOrigins(ORIGIN_1) + + verify(exactly = 1) { + eventBus.post(any()) + } + + inventory.onCommand(DisableOrigin(id("some-other-app"), ORIGIN_1.id())) + + inventory.originCount(ACTIVE) shouldBe 1 + verify(exactly = 1) { + eventBus.post(any()) + } + } + + "does not enable origins not belonging to the app" { + inventory.setOrigins(ORIGIN_1) + + verify(exactly = 1) { + eventBus.post(any()) + } + + inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id())) + inventory.onCommand(EnableOrigin(id("some-other-app"), ORIGIN_1.id())) + + inventory.originCount(ACTIVE) shouldBe 0 + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "removes from active set and stops health check monitoring on disabling an origin" { + inventory.setOrigins(ORIGIN_1) + + inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id())) + + inventory.originCount(ACTIVE) shouldBe 0 + inventory.originCount(DISABLED) shouldBe 1 + + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe -1.0 + verify(exactly = 1) { + monitor.stopMonitoring(setOf(ORIGIN_1)) + } + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "removes from inactive set and stops health check monitoring on disabling an origin" { + inventory.setOrigins(ORIGIN_1) + + inventory.originUnhealthy(ORIGIN_1) + inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id())) + + inventory.originCount(ACTIVE) shouldBe 0 + inventory.originCount(DISABLED) shouldBe 1 + + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe -1.0 + verify(exactly = 1) { + monitor.stopMonitoring(setOf(ORIGIN_1)) + } + verify(exactly = 3) { + eventBus.post(any()) + } + } + + "re-initiates health check monitoring on enabling an origin" { + inventory.setOrigins(ORIGIN_1) + + inventory.onCommand(DisableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id())) + inventory.onCommand(EnableOrigin(ORIGIN_1.applicationId(), ORIGIN_1.id())) + + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0 + verify(exactly = 2) { + monitor.monitor(setOf(ORIGIN_1)) + } + verify(exactly = 3) { + eventBus.post(any()) + } + } + + "updates on removing unhealthy origins from active set" { + inventory.setOrigins(ORIGIN_1) + inventory.originCount(ACTIVE) shouldBe 1 + + inventory.originUnhealthy(ORIGIN_1) + + inventory.originCount(ACTIVE) shouldBe 0 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0 + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "updates on adding healthy origins back to active set" { + inventory.setOrigins(ORIGIN_1) + inventory.originCount(ACTIVE) shouldBe 1 + + inventory.originUnhealthy(ORIGIN_1) + inventory.originHealthy(ORIGIN_1) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0 + verify(exactly = 3) { + eventBus.post(any()) + } + } + + "repeatedly reporting healthy does not affect current active origins" { + inventory.setOrigins(ORIGIN_1) + inventory.originCount(ACTIVE) shouldBe 1 + + inventory.originHealthy(ORIGIN_1) + inventory.originHealthy(ORIGIN_1) + inventory.originHealthy(ORIGIN_1) + + inventory.originCount(ACTIVE) shouldBe 1 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 1.0 + verify(exactly = 1) { + eventBus.post(any()) + } + } + + "repeatedly reporting unhealthy does not affect current active origins" { + inventory.setOrigins(ORIGIN_1) + inventory.originCount(ACTIVE) shouldBe 1 + + inventory.originUnhealthy(ORIGIN_1) + inventory.originUnhealthy(ORIGIN_1) + inventory.originUnhealthy(ORIGIN_1) + + inventory.originCount(ACTIVE) shouldBe 0 + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe 0.0 + verify(exactly = 2) { + eventBus.post(any()) + } + } + + "announces listeners on origin state change" { + val listener: OriginsChangeListener = mockk() + + inventory.addOriginsChangeListener(listener) + inventory.setOrigins(ORIGIN_1) + inventory.originUnhealthy(ORIGIN_1) + + verify(exactly = 2) { + listener.originsChanged(any()) + } + } + + "registers to event bus when created" { + verify(exactly = 1) { + eventBus.register(inventory) + } + } + + "stops monitoring and unregisters when closed" { + val hostClient1: ReactorHostHttpClient = mockk(relaxed = true) + val hostClient2: ReactorHostHttpClient = mockk(relaxed = true) + every { + hostClientFactory.create( + ORIGIN_1, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient1 + every { + hostClientFactory.create( + ORIGIN_2, any(), any(), any(), any(), any(), + any(), any(), any(), any(), any(), + ) + } returns hostClient2 + + inventory.setOrigins(ORIGIN_1, ORIGIN_2) + + inventory.close() + + gaugeValue(ORIGIN_1.applicationId().toString(), ORIGIN_1.id().toString()) shouldBe null + gaugeValue(ORIGIN_2.applicationId().toString(), ORIGIN_2.id().toString()) shouldBe null + verify(exactly = 2) { + eventBus.post(any()) + } + verify(exactly = 1) { + monitor.stopMonitoring(setOf(ORIGIN_1)) + monitor.stopMonitoring(setOf(ORIGIN_2)) + hostClient1.close() + hostClient2.close() + eventBus.unregister(inventory) + } + } + } + + private fun gaugeValue( + appId: String, + originId: String, + ): Double? { + val name = "proxy.client.originHealthStatus" + val tags = Tags.of(Metrics.APPID_TAG, appId, Metrics.ORIGINID_TAG, originId) + return gauge(name, tags)?.value() + } + + private fun gauge( + name: String, + tags: Tags, + ): Gauge? = meterRegistry.find(name).tags(tags).gauge() + + companion object { + private val ORIGIN_1 = + newOriginBuilder("localhost", 8001) + .applicationId(GENERIC_APP) + .id("app-01") + .build() + private val ORIGIN_2 = + newOriginBuilder("localhost", 8002) + .applicationId(GENERIC_APP) + .id("app-02") + .build() + } +} diff --git a/components/common/pom.xml b/components/common/pom.xml index f49c0556d..88fbe1e07 100644 --- a/components/common/pom.xml +++ b/components/common/pom.xml @@ -38,6 +38,11 @@ reactor-core + + io.projectreactor.netty + reactor-netty + + org.hdrhistogram HdrHistogram diff --git a/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt b/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt new file mode 100644 index 000000000..124fa99f1 --- /dev/null +++ b/components/common/src/main/kotlin/com/hotels/styx/ext/HttpMessageExt.kt @@ -0,0 +1,31 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.ext + +import com.hotels.styx.api.LiveHttpRequest +import com.hotels.styx.api.LiveHttpResponse + +inline fun LiveHttpRequest.newRequest(block: LiveHttpRequest.Transformer.() -> Unit): LiveHttpRequest { + val transformer = newBuilder() + block(transformer) + return transformer.build() +} + +inline fun LiveHttpResponse.newResponse(block: LiveHttpResponse.Transformer.() -> Unit): LiveHttpResponse { + val transformer = newBuilder() + block(transformer) + return transformer.build() +} diff --git a/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt b/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt new file mode 100644 index 000000000..b0619905b --- /dev/null +++ b/components/common/src/main/kotlin/com/hotels/styx/metrics/ReactorNettyMeterFilter.kt @@ -0,0 +1,46 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.metrics + +import com.hotels.styx.api.extension.Origin +import io.micrometer.core.instrument.Meter +import io.micrometer.core.instrument.Tags +import io.micrometer.core.instrument.config.MeterFilter +import reactor.netty.Metrics.NAME + +class ReactorNettyMeterFilter(private val origin: Origin) : MeterFilter { + override fun map(id: Meter.Id): Meter.Id = + if (id.name.startsWith(REACTOR_NETTY_PREFIX) && + id.getTag(NAME)?.endsWith("${origin.id()}") == true + ) { + id.withTags( + Tags.of( + APP_ID, + origin.applicationId().toString(), + ORIGIN_ID, + origin.id().toString(), + ), + ) + } else { + id + } + + companion object { + private const val REACTOR_NETTY_PREFIX = "reactor.netty" + private const val APP_ID = "appId" + private const val ORIGIN_ID = "originId" + } +} diff --git a/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt new file mode 100644 index 000000000..0f170324f --- /dev/null +++ b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactory.kt @@ -0,0 +1,120 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.proxy + +import com.hotels.styx.Environment +import com.hotels.styx.api.configuration.Configuration +import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancer +import com.hotels.styx.api.extension.retrypolicy.spi.RetryPolicy +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.client.BackendServiceClient +import com.hotels.styx.client.OriginRestrictionLoadBalancingStrategy +import com.hotels.styx.client.OriginStatsFactory +import com.hotels.styx.client.OriginsInventory +import com.hotels.styx.client.ReactorBackendServiceClient +import com.hotels.styx.client.RewriteRuleset +import com.hotels.styx.client.loadbalancing.strategies.BusyActivitiesStrategy +import com.hotels.styx.client.retry.RetryNTimes +import com.hotels.styx.client.stickysession.StickySessionLoadBalancingStrategy +import com.hotels.styx.serviceproviders.ServiceProvision +import org.slf4j.LoggerFactory.getLogger + +/** + * A BackendServiceClientFactory implementation for creating {@link ReactorBackendServiceClient} + */ +class ReactorBackendServiceClientFactory(private val environment: Environment) : BackendServiceClientFactory { + override fun createClient( + backendService: BackendService, + originsInventory: OriginsInventory, + originStatsFactory: OriginStatsFactory, + ): BackendServiceClient { + val styxConfig: Configuration = environment.configuration() + val originRestrictionCookie = styxConfig["originRestrictionCookie"].orElse(null) + val stickySessionEnabled = backendService.stickySessionConfig().stickySessionEnabled() + val retryPolicy = loadRetryPolicy(styxConfig) ?: defaultRetryPolicy() + val configuredLbStrategy = loadLoadBalancer(styxConfig, originsInventory) ?: BusyActivitiesStrategy(originsInventory) + originsInventory.addOriginsChangeListener(configuredLbStrategy) + + val loadBalancingStrategy = + decorateLoadBalancer( + configuredLbStrategy, + stickySessionEnabled, + originsInventory, + originRestrictionCookie, + ) + return ReactorBackendServiceClient( + id = backendService.id(), + rewriteRuleset = RewriteRuleset(backendService.rewrites()), + originsRestrictionCookieName = originRestrictionCookie, + stickySessionConfig = backendService.stickySessionConfig(), + originIdHeader = environment.configuration().styxHeaderConfig().originIdHeaderName(), + loadBalancer = loadBalancingStrategy, + retryPolicy = retryPolicy, + metrics = environment.centralisedMetrics(), + overrideHostHeader = backendService.isOverrideHostHeader(), + ) + } + + private fun loadRetryPolicy(styxConfig: Configuration) = + ServiceProvision.loadRetryPolicy( + styxConfig, + environment, + "retrypolicy.policy.factory", + RetryPolicy::class.java, + ).orElse(null) + + private fun loadLoadBalancer( + styxConfig: Configuration, + originsInventory: OriginsInventory, + ) = ServiceProvision.loadLoadBalancer( + styxConfig, + environment, + "loadBalancing.strategy.factory", + LoadBalancer::class.java, + originsInventory, + ).orElse(null) + + private fun decorateLoadBalancer( + configuredLbStrategy: LoadBalancer, + stickySessionEnabled: Boolean, + originsInventory: OriginsInventory, + originRestrictionCookie: String?, + ): LoadBalancer = + if (stickySessionEnabled) { + StickySessionLoadBalancingStrategy(originsInventory, configuredLbStrategy) + } else if (originRestrictionCookie == null) { + LOGGER.info("originRestrictionCookie not specified - origin restriction disabled") + configuredLbStrategy + } else { + LOGGER.info( + """ + originRestrictionCookie specified as $originRestrictionCookie + - origin restriction will apply when this cookie is sent + """.trimIndent(), + ) + OriginRestrictionLoadBalancingStrategy(originsInventory, configuredLbStrategy) + } + + companion object { + private val LOGGER = getLogger(this::class.java) + + private fun defaultRetryPolicy(): RetryPolicy { + val retryOnce = RetryNTimes(1) + LOGGER.warn("No configured retry policy found, using $retryOnce") + return retryOnce + } + } +} diff --git a/components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt similarity index 99% rename from components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt rename to components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt index 44c43e97c..8a3b26bc9 100644 --- a/components/proxy/src/main/java/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt +++ b/components/proxy/src/main/kotlin/com/hotels/styx/proxy/StyxBackendServiceClientFactory.kt @@ -1,5 +1,5 @@ /* - Copyright (C) 2013-2023 Expedia Inc. + Copyright (C) 2013-2024 Expedia Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt b/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt new file mode 100644 index 000000000..b87fd405a --- /dev/null +++ b/components/proxy/src/test/kotlin/com/hotels/styx/proxy/ReactorBackendServiceClientFactoryTest.kt @@ -0,0 +1,202 @@ +/* + Copyright (C) 2013-2024 Expedia Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package com.hotels.styx.proxy + +import com.hotels.styx.Environment +import com.hotels.styx.StyxConfig +import com.hotels.styx.api.HttpInterceptor +import com.hotels.styx.api.HttpResponseStatus.OK +import com.hotels.styx.api.Id.GENERIC_APP +import com.hotels.styx.api.Id.id +import com.hotels.styx.api.LiveHttpRequest.get +import com.hotels.styx.api.LiveHttpResponse +import com.hotels.styx.api.LiveHttpResponse.response +import com.hotels.styx.api.MicrometerRegistry +import com.hotels.styx.api.RequestCookie.requestCookie +import com.hotels.styx.api.configuration.Configuration.MapBackedConfiguration +import com.hotels.styx.api.extension.Origin +import com.hotels.styx.api.extension.loadbalancing.spi.LoadBalancingMetric +import com.hotels.styx.api.extension.service.BackendService +import com.hotels.styx.api.extension.service.BackendService.Companion.newBackendServiceBuilder +import com.hotels.styx.api.extension.service.StickySessionConfig.newStickySessionConfigBuilder +import com.hotels.styx.client.OriginStatsFactory +import com.hotels.styx.client.OriginStatsFactory.CachingOriginStatsFactory +import com.hotels.styx.client.ReactorBackendServiceClient +import com.hotels.styx.client.ReactorConnectionPool +import com.hotels.styx.client.ReactorHostHttpClient +import com.hotels.styx.client.ReactorOriginsInventory +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf +import io.micrometer.core.instrument.simple.SimpleMeterRegistry +import io.mockk.every +import io.mockk.mockk +import io.netty.handler.ssl.SslContext +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import reactor.netty.resources.LoopResources +import kotlin.jvm.optionals.getOrNull + +class ReactorBackendServiceClientFactoryTest : StringSpec() { + private lateinit var environment: Environment + private lateinit var backendService: BackendService + private val eventLoopGroup: LoopResources = mockk() + private val originStatsFactory: OriginStatsFactory = mockk() + private val sslContext: SslContext = mockk() + private val connectionPool: ReactorConnectionPool = mockk() + private val context: HttpInterceptor.Context = mockk() + + init { + beforeTest { + environment = Environment.Builder().registry(MicrometerRegistry(SimpleMeterRegistry())).build() + backendService = + newBackendServiceBuilder() + .origins(Origin.newOriginBuilder("localhost", 8081).build()) + .build() + + every { connectionPool.isHttp2() } returns false + } + + "createClient" { + val originsInventory = + ReactorOriginsInventory.Builder(backendService.id()) + .metrics(environment.centralisedMetrics()) + .eventLoopGroup(eventLoopGroup) + .initialOrigins(backendService.origins()) + .originStatsFactory(originStatsFactory) + .sslContext(sslContext) + .connectionPool(connectionPool) + .build() + val client = + ReactorBackendServiceClientFactory(environment) + .createClient(backendService, originsInventory, originStatsFactory) + + client.shouldBeInstanceOf() + } + + "uses the origin specified in the sticky session cookie" { + val backendService = + newBackendServiceBuilder() + .origins( + Origin.newOriginBuilder("localhost", 9091).id("x").build(), + Origin.newOriginBuilder("localhost", 9092).id("y").build(), + Origin.newOriginBuilder("localhost", 9093).id("z").build(), + ) + .stickySessionConfig(newStickySessionConfigBuilder().enabled(true).build()) + .build() + val originsInventory = + ReactorOriginsInventory.Builder(backendService.id()) + .metrics(environment.centralisedMetrics()) + .eventLoopGroup(eventLoopGroup) + .initialOrigins(backendService.origins()) + .hostClientFactory { origin, _, _, _, _, _, _, _, _, _, _ -> + if (origin.id() == id("x")) { + hostClient(response(OK).header("X-Origin-Id", "x").build()) + } else if (origin.id() == id("y")) { + hostClient(response(OK).header("X-Origin-Id", "y").build()) + } else { + hostClient(response(OK).header("X-Origin-Id", "z").build()) + } + } + .originStatsFactory(originStatsFactory) + .sslContext(sslContext) + .connectionPool(connectionPool) + .build() + + val client = + ReactorBackendServiceClientFactory(environment) + .createClient(backendService, originsInventory, CachingOriginStatsFactory(environment.centralisedMetrics())) + + val requestX = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("x").toString())).build() + val requestY = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("y").toString())).build() + val requestZ = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("z").toString())).build() + + val responseX = Mono.from(client.sendRequest(requestX, context)).block() + val responseY = Mono.from(client.sendRequest(requestY, context)).block() + val responseZ = Mono.from(client.sendRequest(requestZ, context)).block() + + responseX!!.header("X-Origin-Id").getOrNull() shouldBe "x" + responseY!!.header("X-Origin-Id").getOrNull() shouldBe "y" + responseZ!!.header("X-Origin-Id").getOrNull() shouldBe "z" + } + + "uses the origin specified in the origin restriction cookie" { + val config = MapBackedConfiguration() + config["originRestrictionCookie"] = ORIGINS_RESTRICTION_COOKIE + + environment = + Environment.Builder() + .registry(MicrometerRegistry(SimpleMeterRegistry())) + .configuration(StyxConfig(config)) + .build() + + val backendService = + newBackendServiceBuilder() + .origins( + Origin.newOriginBuilder("localhost", 9091).id("x").build(), + Origin.newOriginBuilder("localhost", 9092).id("y").build(), + Origin.newOriginBuilder("localhost", 9093).id("z").build(), + ) + .build() + val originsInventory = + ReactorOriginsInventory.Builder(backendService.id()) + .metrics(environment.centralisedMetrics()) + .eventLoopGroup(eventLoopGroup) + .initialOrigins(backendService.origins()) + .hostClientFactory { origin, _, _, _, _, _, _, _, _, _, _ -> + if (origin.id() == id("x")) { + hostClient(response(OK).header("X-Origin-Id", "x").build()) + } else if (origin.id() == id("y")) { + hostClient(response(OK).header("X-Origin-Id", "y").build()) + } else { + hostClient(response(OK).header("X-Origin-Id", "z").build()) + } + } + .originStatsFactory(originStatsFactory) + .sslContext(sslContext) + .connectionPool(connectionPool) + .build() + + val client = + ReactorBackendServiceClientFactory(environment) + .createClient(backendService, originsInventory, CachingOriginStatsFactory(environment.centralisedMetrics())) + + val requestX = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("x").toString())).build() + val requestY = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("y").toString())).build() + val requestZ = get("/some-req").cookies(requestCookie(STICKY_COOKIE, id("z").toString())).build() + + val responseX = Mono.from(client.sendRequest(requestX, context)).block() + val responseY = Mono.from(client.sendRequest(requestY, context)).block() + val responseZ = Mono.from(client.sendRequest(requestZ, context)).block() + + responseX!!.header("X-Origin-Id").getOrNull() shouldBe "x" + responseY!!.header("X-Origin-Id").getOrNull() shouldBe "y" + responseZ!!.header("X-Origin-Id").getOrNull() shouldBe "z" + } + } + + private fun hostClient(response: LiveHttpResponse): ReactorHostHttpClient { + val mockClient: ReactorHostHttpClient = mockk() + every { mockClient.sendRequest(any(), any()) } returns Flux.just(response) + every { mockClient.loadBalancingMetric() } returns LoadBalancingMetric(1) + return mockClient + } + + companion object { + private const val ORIGINS_RESTRICTION_COOKIE = "styx-origins-restriction" + private val STICKY_COOKIE = "styx_origin_$GENERIC_APP" + } +} diff --git a/pom.xml b/pom.xml index 5e676f445..5e46f2925 100644 --- a/pom.xml +++ b/pom.xml @@ -108,7 +108,7 @@ 4.1.106.Final 1.0.4 - 2023.0.2 + 2023.0.3 3.0.5 1.14.11 @@ -131,6 +131,7 @@ 2.2 5.10.0 1.13.9 + 4.12.0 5.10.1 3.0.9 1.17.0 @@ -540,6 +541,20 @@ test + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + + + + com.squareup.okhttp3 + okhttp-tls + ${okhttp.version} + test + +