From a54f418ced3bf3423efebf7766f24387ca84ee9d Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Fri, 26 Jan 2024 16:48:54 +0900 Subject: [PATCH] Provide a way to run service code out of I/O event loops (#5233) Motivation: When service devs are not very familiar with asynchronous programming, it is very easy to drive Armeria's core event loops into havoc by blocking them. Framework devs may want to make sure Armeria at least handle I/O and function normally for a certain set of core services, such as `PrometheusExpositionService` or `HealthCheckService`, by isolating other non-core services from the I/O event loops. Modifications: * Add the `serviceWorkerGroup()` builder methods to `ServerBuilder` and `ServiceConfigSetters` so a user can specify the service worker groups as shown in the above example. * If `serviceWorkerGroup` is not specified, the `workerGroup` is used by default. * Change how Armeria assigns an event loop to a `ServiceRequestContext`. * If the `workerGroup` is different from `serviceWorkerGroup`, then an event loop from `serviceWorkerGroup` is used. * Otherwise, the IO event loop is used for executing services * Modified the constructor of `DefaultServiceRequestContext` so it accepts an `EventLoop`. * Modified `HttpServerHandler.handleRequest()`, so that `HttpService#serve` is executed from the `serviceWorkerGroup` * Modified so that we can guarantee that pending `RequestLogFuture`s are always scheduled from the context's event loop Result: - Closes #4099. - Users can add per-service/virtual host/server `serviceWorkerGroup` property that makes a service use a different `EventLoopGroup` than `ServerBuilder.workerGroup`. --------- Co-authored-by: kezhenxu94 --- .../armeria/server/RoutersBenchmark.java | 18 +- ...RequestContextCurrentTraceContextTest.java | 8 +- .../linecorp/armeria/common/HttpRequest.java | 6 + .../linecorp/armeria/common/HttpResponse.java | 5 + .../armeria/common/stream/StreamMessage.java | 18 ++ .../stream/SubscribeOnStreamMessage.java | 117 ++++++++ .../server/DefaultServiceRequestContext.java | 23 +- ...AbstractAnnotatedServiceConfigSetters.java | 14 + .../server/AbstractServiceBindingBuilder.java | 15 + .../server/AggregatedHttpResponseHandler.java | 2 +- .../AnnotatedServiceBindingBuilder.java | 13 + ...textPathAnnotatedServiceConfigSetters.java | 14 + .../ContextPathServiceBindingBuilder.java | 13 + .../server/DefaultServiceConfigSetters.java | 25 ++ .../armeria/server/HttpServerHandler.java | 138 +++++---- .../armeria/server/ServerBuilder.java | 32 +- .../linecorp/armeria/server/ServerConfig.java | 3 +- .../armeria/server/ServiceBindingBuilder.java | 13 + .../armeria/server/ServiceConfig.java | 25 +- .../armeria/server/ServiceConfigBuilder.java | 24 ++ .../armeria/server/ServiceConfigSetters.java | 23 ++ .../server/ServiceRequestContextBuilder.java | 4 +- .../linecorp/armeria/server/VirtualHost.java | 20 +- ...ualHostAnnotatedServiceBindingBuilder.java | 14 + .../armeria/server/VirtualHostBuilder.java | 48 ++- ...textPathAnnotatedServiceConfigSetters.java | 14 + ...lHostContextPathServiceBindingBuilder.java | 14 + .../VirtualHostServiceBindingBuilder.java | 13 + .../stream/SubscribeOnStreamMessageTest.java | 124 ++++++++ .../annotation/ServiceWorkerGroupTest.java | 283 ++++++++++++++++++ .../armeria/server/ServiceNamingTest.java | 45 ++- .../linecorp/armeria/server/ServiceTest.java | 3 +- 32 files changed, 1019 insertions(+), 112 deletions(-) create mode 100644 core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java create mode 100644 core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java create mode 100644 core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java index 24660c1fc83..50c84e161d4 100644 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java @@ -66,27 +66,27 @@ public class RoutersBenchmark { new ServiceConfig(route1, route1, SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, - NOOP_CONTEXT_HOOK), + SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), + ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK), new ServiceConfig(route2, route2, SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, - NOOP_CONTEXT_HOOK)); + SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), + ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK)); FALLBACK_SERVICE = new ServiceConfig(Route.ofCatchAll(), Route.ofCatchAll(), SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), - CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.blockingTaskExecutor(), SuccessFunction.always(), 0, + multipartUploadsLocation, CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK); HOST = new VirtualHost( "localhost", "localhost", 0, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, unused -> NOPLogger.NOP_LOGGER, defaultServiceNaming, defaultLogName, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0, SuccessFunction.ofDefault(), - multipartUploadsLocation, ImmutableList.of(), + multipartUploadsLocation, CommonPools.workerGroup(), ImmutableList.of(), ctx -> RequestId.random()); ROUTER = Routers.ofVirtualHost(HOST, SERVICES, RejectedRouteHandler.DISABLED); } diff --git a/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java b/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java index ad901c47d39..9b7e5ed5713 100644 --- a/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java +++ b/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java @@ -17,8 +17,6 @@ package com.linecorp.armeria.common.brave; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.when; import org.junit.jupiter.api.BeforeEach; @@ -26,7 +24,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import org.mockito.stubbing.Answer; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; @@ -55,10 +52,7 @@ class RequestContextCurrentTraceContextTest { @BeforeEach void setUp() { when(eventLoop.inEventLoop()).thenReturn(true); - doAnswer((Answer) invocation -> { - invocation.getArgument(0).run(); - return null; - }).when(eventLoop).execute(any()); + when(eventLoop.next()).thenReturn(eventLoop); ctx = ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")) .eventLoop(eventLoop) diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java index 7ec7a3a360f..8967b6cc0c0 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java @@ -810,4 +810,10 @@ default HttpRequest peekError(Consumer action) { requireNonNull(action, "action"); return of(headers(), HttpMessage.super.peekError(action)); } + + @Override + default HttpRequest subscribeOn(EventExecutor eventExecutor) { + requireNonNull(eventExecutor, "eventExecutor"); + return of(headers(), HttpMessage.super.subscribeOn(eventExecutor)); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java index 7f0c267da42..a92db4b2789 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java @@ -1178,4 +1178,9 @@ default HttpResponse recover(Class causeClass, } }); } + + @Override + default HttpResponse subscribeOn(EventExecutor eventExecutor) { + return of(HttpMessage.super.subscribeOn(eventExecutor)); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java index e0a4c22879f..a203f9c4116 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java @@ -1136,4 +1136,22 @@ default InputStream toInputStream(Function httpDa default StreamMessage endWith(Function<@Nullable Throwable, ? extends @Nullable T> finalizer) { return new SurroundingPublisher<>(null, this, finalizer); } + + /** + * Calls {@link #subscribe(Subscriber, EventExecutor)} to the upstream + * {@link StreamMessage} using the specified {@link EventExecutor} and relays the stream + * transparently downstream. This may be useful if one would like to hide an + * {@link EventExecutor} from an upstream {@link Publisher}. + * + *

For example:

{@code
+     * Subscriber mySubscriber = null;
+     * StreamMessage upstream = ...; // publisher callbacks are invoked by eventLoop1
+     * upstream.subscribeOn(eventLoop1)
+     *         .subscribe(mySubscriber, eventLoop2); // mySubscriber callbacks are invoked with eventLoop2
+     * }
+ */ + default StreamMessage subscribeOn(EventExecutor eventExecutor) { + requireNonNull(eventExecutor, "eventExecutor"); + return new SubscribeOnStreamMessage<>(this, eventExecutor); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java new file mode 100644 index 00000000000..b5a2f0bf3bb --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java @@ -0,0 +1,117 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.stream; + +import java.util.concurrent.CompletableFuture; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import io.netty.util.concurrent.EventExecutor; + +final class SubscribeOnStreamMessage implements StreamMessage { + + private final StreamMessage upstream; + private final EventExecutor upstreamExecutor; + + SubscribeOnStreamMessage(StreamMessage upstream, EventExecutor upstreamExecutor) { + this.upstream = upstream; + this.upstreamExecutor = upstreamExecutor; + } + + @Override + public boolean isOpen() { + return upstream.isOpen(); + } + + @Override + public boolean isEmpty() { + return upstream.isEmpty(); + } + + @Override + public long demand() { + return upstream.demand(); + } + + @Override + public CompletableFuture whenComplete() { + return upstream.whenComplete(); + } + + @Override + public EventExecutor defaultSubscriberExecutor() { + return upstreamExecutor; + } + + @Override + public void subscribe(Subscriber subscriber, EventExecutor downstreamExecutor, + SubscriptionOption... options) { + final Subscriber subscriber0; + if (upstreamExecutor == downstreamExecutor) { + subscriber0 = subscriber; + } else { + subscriber0 = new SchedulingSubscriber<>(downstreamExecutor, subscriber); + } + if (upstreamExecutor.inEventLoop()) { + upstream.subscribe(subscriber0, downstreamExecutor, options); + } else { + upstreamExecutor.execute(() -> upstream.subscribe(subscriber0, upstreamExecutor, options)); + } + } + + @Override + public void abort() { + upstream.abort(); + } + + @Override + public void abort(Throwable cause) { + upstream.abort(cause); + } + + static class SchedulingSubscriber implements Subscriber { + + private final Subscriber downstream; + private final EventExecutor downstreamExecutor; + + SchedulingSubscriber(EventExecutor downstreamExecutor, Subscriber downstream) { + this.downstream = downstream; + this.downstreamExecutor = downstreamExecutor; + } + + @Override + public void onSubscribe(Subscription s) { + downstreamExecutor.execute(() -> downstream.onSubscribe(s)); + } + + @Override + public void onNext(T t) { + downstreamExecutor.execute(() -> downstream.onNext(t)); + } + + @Override + public void onError(Throwable t) { + downstreamExecutor.execute(() -> downstream.onError(t)); + } + + @Override + public void onComplete() { + downstreamExecutor.execute(downstream::onComplete); + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java index 7753112453c..b1fb086287d 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java @@ -72,6 +72,7 @@ import io.micrometer.core.instrument.MeterRegistry; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; +import io.netty.channel.EventLoop; import io.netty.util.AttributeKey; /** @@ -90,6 +91,7 @@ public final class DefaultServiceRequestContext DefaultServiceRequestContext.class, HttpHeaders.class, "additionalResponseTrailers"); private final Channel ch; + private final EventLoop eventLoop; private final ServiceConfig cfg; private final RoutingContext routingContext; private final RoutingResult routingResult; @@ -141,22 +143,24 @@ public final class DefaultServiceRequestContext * e.g. {@code System.currentTimeMillis() * 1000}. */ public DefaultServiceRequestContext( - ServiceConfig cfg, Channel ch, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, RoutingContext routingContext, RoutingResult routingResult, ExchangeType exchangeType, + ServiceConfig cfg, Channel ch, EventLoop eventLoop, MeterRegistry meterRegistry, + SessionProtocol sessionProtocol, RequestId id, RoutingContext routingContext, + RoutingResult routingResult, ExchangeType exchangeType, HttpRequest req, @Nullable SSLSession sslSession, ProxiedAddresses proxiedAddresses, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, long requestStartTimeNanos, long requestStartTimeMicros, Supplier contextHook) { - this(cfg, ch, meterRegistry, sessionProtocol, id, routingContext, routingResult, exchangeType, - req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, + this(cfg, ch, eventLoop, meterRegistry, sessionProtocol, id, routingContext, routingResult, + exchangeType, req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, null /* requestCancellationScheduler */, requestStartTimeNanos, requestStartTimeMicros, HttpHeaders.of(), HttpHeaders.of(), contextHook); } public DefaultServiceRequestContext( - ServiceConfig cfg, Channel ch, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, RoutingContext routingContext, RoutingResult routingResult, ExchangeType exchangeType, + ServiceConfig cfg, Channel ch, EventLoop eventLoop, MeterRegistry meterRegistry, + SessionProtocol sessionProtocol, RequestId id, RoutingContext routingContext, + RoutingResult routingResult, ExchangeType exchangeType, HttpRequest req, @Nullable SSLSession sslSession, ProxiedAddresses proxiedAddresses, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, @Nullable CancellationScheduler requestCancellationScheduler, @@ -170,6 +174,7 @@ public DefaultServiceRequestContext( requireNonNull(req, "req"), null, null, contextHook); this.ch = requireNonNull(ch, "ch"); + this.eventLoop = requireNonNull(eventLoop, "eventLoop"); this.cfg = requireNonNull(cfg, "cfg"); this.routingContext = routingContext; this.routingResult = routingResult; @@ -178,7 +183,9 @@ public DefaultServiceRequestContext( } else { this.requestCancellationScheduler = CancellationScheduler.ofServer(TimeUnit.MILLISECONDS.toNanos(cfg.requestTimeoutMillis())); - this.requestCancellationScheduler.init(eventLoop()); + // the cancellation scheduler uses channelEventLoop since #start is called + // from the netty pipeline logic + this.requestCancellationScheduler.init(ch.eventLoop()); } this.sslSession = sslSession; this.proxiedAddresses = requireNonNull(proxiedAddresses, "proxiedAddresses"); @@ -301,7 +308,7 @@ public ContextAwareEventLoop eventLoop() { if (contextAwareEventLoop != null) { return contextAwareEventLoop; } - return contextAwareEventLoop = ContextAwareEventLoop.of(this, ch.eventLoop()); + return contextAwareEventLoop = ContextAwareEventLoop.of(this, eventLoop); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java index 489642f3e40..80ca978f5a3 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java @@ -47,6 +47,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + @UnstableApi abstract class AbstractAnnotatedServiceConfigSetters implements AnnotatedServiceConfigSetters { @@ -280,6 +282,18 @@ public AbstractAnnotatedServiceConfigSetters multipartUploadsLocation(Path multi return this; } + @Override + public ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + defaultServiceConfigSetters.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + @Override + public ServiceConfigSetters serviceWorkerGroup(int numThreads) { + defaultServiceConfigSetters.serviceWorkerGroup(numThreads); + return this; + } + @Override public AbstractAnnotatedServiceConfigSetters requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java index a570f69a198..9506174428b 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java @@ -35,6 +35,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. * @@ -175,6 +177,19 @@ public AbstractServiceBindingBuilder multipartUploadsLocation(Path multipartUplo return this; } + @Override + public AbstractServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + defaultServiceConfigSetters.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + @Override + public AbstractServiceBindingBuilder serviceWorkerGroup(int numThreads) { + defaultServiceConfigSetters.serviceWorkerGroup(numThreads); + return this; + } + @Override public AbstractServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java b/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java index ea620a13bb7..ec59df678e8 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java @@ -57,7 +57,7 @@ final class AggregatedHttpResponseHandler extends AbstractHttpResponseHandler @Override public Void apply(@Nullable AggregatedHttpResponse response, @Nullable Throwable cause) { - final EventLoop eventLoop = reqCtx.eventLoop(); + final EventLoop eventLoop = ctx.channel().eventLoop(); if (eventLoop.inEventLoop()) { apply0(response, cause); } else { diff --git a/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java index b0da5f3049f..d232a8855d8 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java @@ -33,6 +33,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link ServerBuilder#annotatedService()}. @@ -260,6 +262,17 @@ public AnnotatedServiceBindingBuilder contextHook(Supplier requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java index 13ae406edc4..b5d1e948a9e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java @@ -41,9 +41,12 @@ import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.internal.server.annotation.AnnotatedService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A default implementation of {@link ServiceConfigSetters} that stores service related settings * and provides a method {@link DefaultServiceConfigSetters#toServiceConfigBuilder(Route, String, HttpService)} @@ -76,6 +79,8 @@ final class DefaultServiceConfigSetters implements ServiceConfigSetters { @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private ServiceErrorHandler serviceErrorHandler; private Supplier contextHook = NOOP_CONTEXT_HOOK; private final List shutdownSupports = new ArrayList<>(); @@ -232,6 +237,22 @@ public ServiceConfigSetters multipartUploadsLocation(Path multipartUploadsLocati return this; } + @Override + public ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + @Override + public ServiceConfigSetters serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + @Override public ServiceConfigSetters requestIdGenerator( Function requestIdGenerator) { @@ -356,6 +377,10 @@ ServiceConfigBuilder toServiceConfigBuilder(Route route, String contextPath, Htt if (multipartUploadsLocation != null) { serviceConfigBuilder.multipartUploadsLocation(multipartUploadsLocation); } + if (serviceWorkerGroup != null) { + serviceConfigBuilder.serviceWorkerGroup(serviceWorkerGroup, false); + // Set the serviceWorkerGroup as false because it's shut down in ShutdownSupport. + } if (requestIdGenerator != null) { serviceConfigBuilder.requestIdGenerator(requestIdGenerator); } diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 045842f9ae6..6e4d0b762ab 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -76,6 +76,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.ChannelInputShutdownReadComplete; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2Exception; @@ -343,12 +344,14 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th final InetSocketAddress localAddress = firstNonNull(localAddress(channel), UNKNOWN_ADDR); final ProxiedAddresses proxiedAddresses = determineProxiedAddresses(remoteAddress, headers); final InetAddress clientAddress = config.clientAddressMapper().apply(proxiedAddresses).getAddress(); + final EventLoop channelEventLoop = channel.eventLoop(); final RoutingContext routingCtx = req.routingContext(); final RoutingStatus routingStatus = routingCtx.status(); if (!routingStatus.routeMustExist()) { final ServiceRequestContext reqCtx = newEarlyRespondingRequestContext( - channel, req, proxiedAddresses, clientAddress, remoteAddress, localAddress, routingCtx); + channel, req, proxiedAddresses, clientAddress, remoteAddress, localAddress, routingCtx, + channelEventLoop); // Handle 'OPTIONS * HTTP/1.1'. if (routingStatus == RoutingStatus.OPTIONS) { @@ -367,17 +370,88 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th final RoutingResult routingResult = routed.routingResult(); final ServiceConfig serviceCfg = routed.value(); final HttpService service = serviceCfg.service(); - + final EventLoop serviceEventLoop; + final boolean needsDirectExecution; + final EventLoopGroup serviceWorkerGroup = serviceCfg.serviceWorkerGroup(); + if (serviceWorkerGroup == config.workerGroup()) { + serviceEventLoop = channelEventLoop; + needsDirectExecution = true; + } else { + serviceEventLoop = serviceWorkerGroup.next(); + needsDirectExecution = serviceEventLoop == channelEventLoop; + } final DefaultServiceRequestContext reqCtx = new DefaultServiceRequestContext( - serviceCfg, channel, config.meterRegistry(), protocol, + serviceCfg, channel, serviceEventLoop, config.meterRegistry(), protocol, nextRequestId(routingCtx, serviceCfg), routingCtx, routingResult, req.exchangeType(), req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, req.requestStartTimeNanos(), req.requestStartTimeMicros(), serviceCfg.contextHook()); + final HttpResponse res; + req.init(reqCtx); + if (needsDirectExecution) { + res = serve0(req, serviceCfg, service, reqCtx); + } else { + res = HttpResponse.of(() -> serve0(req.subscribeOn(serviceEventLoop), serviceCfg, service, reqCtx), + serviceEventLoop) + .subscribeOn(serviceEventLoop); + } + + // Keep track of the number of unfinished requests and + // clean up the request stream when response stream ends. + final boolean isTransientService = + serviceCfg.service().as(TransientService.class) != null; + if (!isTransientService) { + gracefulShutdownSupport.inc(); + } + unfinishedRequests.put(req, res); + + if (service.shouldCachePath(routingCtx.path(), routingCtx.query(), routed.route())) { + reqCtx.log().whenComplete().thenAccept(log -> { + final int statusCode = log.responseHeaders().status().code(); + if (statusCode >= 200 && statusCode < 400) { + RequestTargetCache.putForServer(req.path(), routingCtx.requestTarget()); + } + }); + } + + final RequestAndResponseCompleteHandler handler = + new RequestAndResponseCompleteHandler(channelEventLoop, ctx, reqCtx, req, + isTransientService); + req.whenComplete().handle(handler.requestCompleteHandler); + + // A future which is completed when the all response objects are written to channel and + // the returned promises are done. + final CompletableFuture resWriteFuture = new CompletableFuture<>(); + resWriteFuture.handle(handler.responseCompleteHandler); + + // Set the response to the request in order to be able to immediately abort the response + // when the peer cancels the stream. + req.setResponse(res); + + if (req.isHttp1WebSocket()) { + assert responseEncoder instanceof Http1ObjectEncoder; + final WebSocketHttp1ResponseSubscriber resSubscriber = + new WebSocketHttp1ResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.subscribe(resSubscriber, channelEventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); + } else if (reqCtx.exchangeType().isResponseStreaming()) { + final AbstractHttpResponseSubscriber resSubscriber = + new HttpResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.subscribe(resSubscriber, channelEventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); + } else { + final AggregatedHttpResponseHandler resHandler = + new AggregatedHttpResponseHandler(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.aggregate(AggregationOptions.usePooledObjects(ctx.alloc(), channelEventLoop)) + .handle(resHandler); + } + } + + private HttpResponse serve0(HttpRequest req, + ServiceConfig serviceCfg, + HttpService service, + DefaultServiceRequestContext reqCtx) { try (SafeCloseable ignored = reqCtx.push()) { HttpResponse serviceResponse; try { - req.init(reqCtx); serviceResponse = service.serve(reqCtx, req); } catch (Throwable cause) { // No need to consume further since the response is ready. @@ -394,56 +468,7 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th // Recover the failed response with the error handler. return serviceCfg.errorHandler().onServiceException(reqCtx, cause); }); - final HttpResponse res = serviceResponse; - final EventLoop eventLoop = channel.eventLoop(); - - // Keep track of the number of unfinished requests and - // clean up the request stream when response stream ends. - final boolean isTransientService = - serviceCfg.service().as(TransientService.class) != null; - if (!isTransientService) { - gracefulShutdownSupport.inc(); - } - unfinishedRequests.put(req, res); - - if (service.shouldCachePath(routingCtx.path(), routingCtx.query(), routed.route())) { - reqCtx.log().whenComplete().thenAccept(log -> { - final int statusCode = log.responseHeaders().status().code(); - if (statusCode >= 200 && statusCode < 400) { - RequestTargetCache.putForServer(req.path(), routingCtx.requestTarget()); - } - }); - } - - final RequestAndResponseCompleteHandler handler = - new RequestAndResponseCompleteHandler(eventLoop, ctx, reqCtx, req, - isTransientService); - req.whenComplete().handle(handler.requestCompleteHandler); - - // A future which is completed when the all response objects are written to channel and - // the returned promises are done. - final CompletableFuture resWriteFuture = new CompletableFuture<>(); - resWriteFuture.handle(handler.responseCompleteHandler); - - // Set the response to the request in order to be able to immediately abort the response - // when the peer cancels the stream. - req.setResponse(res); - - if (req.isHttp1WebSocket()) { - assert responseEncoder instanceof Http1ObjectEncoder; - final WebSocketHttp1ResponseSubscriber resSubscriber = - new WebSocketHttp1ResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.subscribe(resSubscriber, eventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); - } else if (reqCtx.exchangeType().isResponseStreaming()) { - final AbstractHttpResponseSubscriber resSubscriber = - new HttpResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.subscribe(resSubscriber, eventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); - } else { - final AggregatedHttpResponseHandler resHandler = - new AggregatedHttpResponseHandler(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.aggregate(AggregationOptions.usePooledObjects(ctx.alloc(), eventLoop)) - .handle(resHandler); - } + return serviceResponse; } } @@ -612,14 +637,15 @@ private ServiceRequestContext newEarlyRespondingRequestContext(Channel channel, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, - RoutingContext routingCtx) { + RoutingContext routingCtx, + EventLoop eventLoop) { final ServiceConfig serviceConfig = routingCtx.virtualHost().fallbackServiceConfig(); final RoutingResult routingResult = RoutingResult.builder() .path(routingCtx.path()) .build(); return new DefaultServiceRequestContext( serviceConfig, - channel, NoopMeterRegistry.get(), protocol(), + channel, eventLoop, NoopMeterRegistry.get(), protocol(), nextRequestId(routingCtx, serviceConfig), routingCtx, routingResult, req.exchangeType(), req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, System.nanoTime(), SystemInfo.currentTimeMicros(), NOOP_CONTEXT_HOOK); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java index e3db202697b..ed91d00fbc6 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java @@ -190,7 +190,7 @@ public final class ServerBuilder implements TlsSetters, ServiceConfigsBuilder { private final VirtualHostBuilder defaultVirtualHostBuilder = new VirtualHostBuilder(this, true); private final List virtualHostBuilders = new ArrayList<>(); - private EventLoopGroup workerGroup = CommonPools.workerGroup(); + EventLoopGroup workerGroup = CommonPools.workerGroup(); private boolean shutdownWorkerGroupOnStop; private Executor startStopExecutor = START_STOP_EXECUTOR; private final Map, Object> channelOptions = new Object2ObjectArrayMap<>(); @@ -537,6 +537,34 @@ public ServerBuilder workerGroup(int numThreads) { return this; } + /** + * Sets the worker {@link EventLoopGroup} which is responsible for running + * {@link Service#serve(ServiceRequestContext, Request)}. + * If not set, the value set via {@linkplain #workerGroup(EventLoopGroup, boolean)} + * or {@linkplain #workerGroup(int)} is used. + * + * @param shutdownOnStop whether to shut down the worker {@link EventLoopGroup} + * when the {@link Server} stops + */ + @UnstableApi + public ServerBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + virtualHostTemplate.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads for + * running {@link Service#serve(ServiceRequestContext, Request)}. + * The worker {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of event loop threads + */ + @UnstableApi + public ServerBuilder serviceWorkerGroup(int numThreads) { + virtualHostTemplate.serviceWorkerGroup(EventLoopGroups.newEventLoopGroup(numThreads), true); + return this; + } + /** * Sets the {@link Executor} which will invoke the callbacks of {@link Server#start()}, * {@link Server#stop()} and {@link ServerListener}. @@ -2200,8 +2228,8 @@ private DefaultServerConfig buildServerConfig(List serverPorts) { final Map, Object> newChildChannelOptions = ChannelUtil.applyDefaultChannelOptions( childChannelOptions, idleTimeoutMillis, pingIntervalMillis); - final BlockingTaskExecutor blockingTaskExecutor = defaultVirtualHost.blockingTaskExecutor(); + return new DefaultServerConfig( ports, setSslContextIfAbsent(defaultVirtualHost, defaultSslContext), virtualHosts, workerGroup, shutdownWorkerGroupOnStop, startStopExecutor, maxNumConnections, diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java index 0e14ce03e23..0bbb20aa235 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java @@ -103,8 +103,7 @@ public interface ServerConfig { List serviceConfigs(); /** - * Returns the worker {@link EventLoopGroup} which is responsible for performing socket I/O and running - * {@link Service#serve(ServiceRequestContext, Request)}. + * Returns the worker {@link EventLoopGroup} which is responsible for performing socket I/O. */ EventLoopGroup workerGroup(); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java index 7ea369b49c5..60e77711e4e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java @@ -35,6 +35,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link ServerBuilder#route()}. You can also configure an {@link HttpService} using @@ -291,6 +293,17 @@ public ServiceBindingBuilder multipartUploadsLocation(Path multipartUploadsLocat return (ServiceBindingBuilder) super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public ServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (ServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public ServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (ServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public ServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java index 3aaf0475889..0a27ab662eb 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java @@ -46,6 +46,8 @@ import com.linecorp.armeria.server.cors.CorsService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * An {@link HttpService} configuration. * @@ -79,6 +81,8 @@ public final class ServiceConfig { private final long requestAutoAbortDelayMillis; private final Path multipartUploadsLocation; + private final EventLoopGroup serviceWorkerGroup; + private final List shutdownSupports; private final HttpHeaders defaultHeaders; private final Function requestIdGenerator; @@ -94,7 +98,8 @@ public final class ServiceConfig { boolean verboseResponses, AccessLogWriter accessLogWriter, BlockingTaskExecutor blockingTaskExecutor, SuccessFunction successFunction, long requestAutoAbortDelayMillis, - Path multipartUploadsLocation, List shutdownSupports, + Path multipartUploadsLocation, EventLoopGroup serviceWorkerGroup, + List shutdownSupports, HttpHeaders defaultHeaders, Function requestIdGenerator, ServiceErrorHandler serviceErrorHandler, Supplier contextHook) { @@ -102,7 +107,7 @@ public final class ServiceConfig { requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, extractTransientServiceOptions(service), blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, shutdownSupports, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, shutdownSupports, defaultHeaders, requestIdGenerator, serviceErrorHandler, contextHook); } @@ -119,6 +124,7 @@ private ServiceConfig(@Nullable VirtualHost virtualHost, Route route, SuccessFunction successFunction, long requestAutoAbortDelayMillis, Path multipartUploadsLocation, + EventLoopGroup serviceWorkerGroup, List shutdownSupports, HttpHeaders defaultHeaders, Function requestIdGenerator, ServiceErrorHandler serviceErrorHandler, @@ -139,6 +145,7 @@ private ServiceConfig(@Nullable VirtualHost virtualHost, Route route, this.successFunction = requireNonNull(successFunction, "successFunction"); this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; this.multipartUploadsLocation = requireNonNull(multipartUploadsLocation, "multipartUploadsLocation"); + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); this.shutdownSupports = ImmutableList.copyOf(requireNonNull(shutdownSupports, "shutdownSupports")); this.defaultHeaders = defaultHeaders; @SuppressWarnings("unchecked") @@ -185,7 +192,7 @@ ServiceConfig withVirtualHost(VirtualHost virtualHost) { defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, transientServiceOptions, blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, shutdownSupports, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, shutdownSupports, defaultHeaders, requestIdGenerator, serviceErrorHandler, contextHook); } @@ -196,7 +203,7 @@ ServiceConfig withDecoratedService(Function shutdownSupports() { return shutdownSupports; } + /** + * Returns the {@link EventLoopGroup} dedicated to the execution of services' methods. + */ + @UnstableApi + public EventLoopGroup serviceWorkerGroup() { + return serviceWorkerGroup; + } + /** * Returns the default headers for an {@link HttpResponse} served by the {@link #service()}. */ diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java index 4a78a286860..e36affcf759 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java @@ -42,10 +42,13 @@ import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + final class ServiceConfigBuilder implements ServiceConfigSetters { private final Route route; @@ -76,6 +79,8 @@ final class ServiceConfigBuilder implements ServiceConfigSetters { @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private ServiceErrorHandler serviceErrorHandler; private Supplier contextHook = NOOP_CONTEXT_HOOK; private final List shutdownSupports = new ArrayList<>(); @@ -289,6 +294,22 @@ public ServiceConfigBuilder defaultServiceNaming(ServiceNaming defaultServiceNam return this; } + @Override + public ServiceConfigBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + @Override + public ServiceConfigBuilder serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + void shutdownSupports(List shutdownSupports) { requireNonNull(shutdownSupports, "shutdownSupports"); this.shutdownSupports.addAll(shutdownSupports); @@ -308,6 +329,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, SuccessFunction defaultSuccessFunction, long defaultRequestAutoAbortDelayMillis, Path defaultMultipartUploadsLocation, + EventLoopGroup defaultServiceWorkerGroup, HttpHeaders virtualHostDefaultHeaders, Function defaultRequestIdGenerator, ServiceErrorHandler defaultServiceErrorHandler, @@ -366,6 +388,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, successFunction != null ? successFunction : defaultSuccessFunction, requestAutoAbortDelayMillis, multipartUploadsLocation != null ? multipartUploadsLocation : defaultMultipartUploadsLocation, + serviceWorkerGroup != null ? serviceWorkerGroup : defaultServiceWorkerGroup, ImmutableList.copyOf(shutdownSupports), mergeDefaultHeaders(virtualHostDefaultHeaders.toBuilder(), defaultHeaders.build()), requestIdGenerator != null ? requestIdGenerator : defaultRequestIdGenerator, errorHandler, @@ -385,6 +408,7 @@ public String toString() { .add("blockingTaskExecutor", blockingTaskExecutor) .add("successFunction", successFunction) .add("multipartUploadsLocation", multipartUploadsLocation) + .add("serviceWorkerGroup", serviceWorkerGroup) .add("shutdownSupports", shutdownSupports) .add("defaultHeaders", defaultHeaders) .add("serviceErrorHandler", serviceErrorHandler) diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java index d7b50b0b586..a1066477998 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java @@ -37,6 +37,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + interface ServiceConfigSetters { /** @@ -209,6 +211,27 @@ ServiceConfigSetters blockingTaskExecutor(BlockingTaskExecutor blockingTaskExecu @UnstableApi ServiceConfigSetters multipartUploadsLocation(Path multipartUploadsLocation); + /** + * Sets a {@linkplain EventLoopGroup worker group} to be used when serving a {@link Service}. + * + * @param serviceWorkerGroup the {@linkplain ScheduledExecutorService executor} to be used. + * @param shutdownOnStop whether to shut down the {@link ScheduledExecutorService} when the {@link Server} + * stops. + */ + @UnstableApi + ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop); + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads dedicated to + * the execution of service codes. + * The {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of threads in the executor + */ + @UnstableApi + ServiceConfigSetters serviceWorkerGroup(int numThreads); + /** * Sets the {@link Function} which generates a {@link RequestId}. * diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java index 98a782d544d..4e0c71bb18f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java @@ -236,10 +236,12 @@ public ServiceRequestContext build() { requestCancellationScheduler.initAndStart(eventLoop(), noopCancellationTask); } + final EventLoop serviceWorkerGroup = eventLoop(); + // Build the context with the properties set by a user and the fake objects. final Channel ch = fakeChannel(); return new DefaultServiceRequestContext( - serviceCfg, ch, meterRegistry(), sessionProtocol(), id(), routingCtx, + serviceCfg, ch, serviceWorkerGroup, meterRegistry(), sessionProtocol(), id(), routingCtx, routingResult, exchangeType, req, sslSession(), proxiedAddresses, clientAddress, remoteAddress(), localAddress(), requestCancellationScheduler, diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java index 49eaa6eaed2..96027f72c0f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java @@ -39,6 +39,7 @@ import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.logging.RequestLog; import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.metric.MeterIdPrefix; @@ -46,6 +47,7 @@ import com.linecorp.armeria.server.logging.AccessLogWriter; import io.micrometer.core.instrument.MeterRegistry; +import io.netty.channel.EventLoopGroup; import io.netty.handler.ssl.SslContext; import io.netty.util.Mapping; @@ -95,6 +97,7 @@ public final class VirtualHost { private final long requestAutoAbortDelayMillis; private final SuccessFunction successFunction; private final Path multipartUploadsLocation; + private final EventLoopGroup serviceWorkerGroup; private final List shutdownSupports; private final Function requestIdGenerator; @@ -113,6 +116,7 @@ public final class VirtualHost { long requestAutoAbortDelayMillis, SuccessFunction successFunction, Path multipartUploadsLocation, + EventLoopGroup serviceWorkerGroup, List shutdownSupports, Function requestIdGenerator) { originalDefaultHostname = defaultHostname; @@ -136,6 +140,7 @@ public final class VirtualHost { this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; this.successFunction = successFunction; this.multipartUploadsLocation = multipartUploadsLocation; + this.serviceWorkerGroup = serviceWorkerGroup; this.shutdownSupports = shutdownSupports; @SuppressWarnings("unchecked") final Function castRequestIdGenerator = @@ -162,7 +167,8 @@ VirtualHost withNewSslContext(SslContext sslContext) { host -> accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, - successFunction, multipartUploadsLocation, shutdownSupports, requestIdGenerator); + successFunction, multipartUploadsLocation, serviceWorkerGroup, + shutdownSupports, requestIdGenerator); } /** @@ -400,6 +406,16 @@ public boolean shutdownBlockingTaskExecutorOnStop() { return false; } + /** + * Returns the service {@link EventLoopGroup}. + * + * @see ServiceConfig#serviceWorkerGroup() + */ + @UnstableApi + public EventLoopGroup serviceWorkerGroup() { + return serviceWorkerGroup; + } + /** * Returns the {@link SuccessFunction} that determines whether a request was * handled successfully or not. @@ -530,7 +546,7 @@ VirtualHost decorate(@Nullable Function accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, - shutdownSupports, requestIdGenerator); + serviceWorkerGroup, shutdownSupports, requestIdGenerator); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java index 9628dcf3826..2e903e9ae93 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} to a virtual host fluently. This class can be instantiated * through {@link VirtualHostBuilder#annotatedService()}. @@ -268,6 +270,18 @@ public VirtualHostAnnotatedServiceBindingBuilder contextHook( return (VirtualHostAnnotatedServiceBindingBuilder) super.contextHook(contextHook); } + @Override + public VirtualHostAnnotatedServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostAnnotatedServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, + shutdownOnStop); + } + + @Override + public VirtualHostAnnotatedServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostAnnotatedServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + /** * Registers the given service to the {@linkplain VirtualHostBuilder}. * diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java index 0d952db5b4e..a64b58e125a 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java @@ -82,6 +82,7 @@ import com.linecorp.armeria.common.logging.RequestLog; import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.internal.common.util.SelfSignedCertificate; import com.linecorp.armeria.internal.server.RouteDecoratingService; @@ -164,6 +165,8 @@ public final class VirtualHostBuilder implements TlsSetters, ServiceConfigsBuild @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private Function requestIdGenerator; @Nullable private ServiceErrorHandler errorHandler; @@ -1194,6 +1197,36 @@ public VirtualHostBuilder requestIdGenerator( return this; } + /** + * Sets the {@link EventLoopGroup} dedicated to the execution of services' methods. + * If not set, the work group of the belonging channel is used. + * + * @param shutdownOnStop whether to shut down the {@link EventLoopGroup} when the + * {@link Server} stops + */ + @UnstableApi + public VirtualHostBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads dedicated to + * the execution of services' methods. + * The worker {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of threads in the executor + */ + @UnstableApi + public VirtualHostBuilder serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + /** * Sets the {@link RequestConverterFunction}s, {@link ResponseConverterFunction} * and {@link ExceptionHandlerFunction}s for creating an {@link AnnotatedServiceExtensions}. @@ -1326,6 +1359,15 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje final Supplier contextHook = mergeHooks(template.contextHook, this.contextHook); + final EventLoopGroup serviceWorkerGroup; + if (this.serviceWorkerGroup != null) { + serviceWorkerGroup = this.serviceWorkerGroup; + } else if (template.serviceWorkerGroup != null) { + serviceWorkerGroup = template.serviceWorkerGroup; + } else { + serviceWorkerGroup = serverBuilder.workerGroup; + } + assert defaultServiceNaming != null; assert rejectedRouteHandler != null; assert accessLoggerMapper != null; @@ -1352,7 +1394,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje return cfgBuilder.build(defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, defaultHeaders, requestIdGenerator, defaultErrorHandler, unhandledExceptionsReporter, baseContextPath, contextHook); }).collect(toImmutableList()); @@ -1361,7 +1403,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje new ServiceConfigBuilder(RouteBuilder.FALLBACK_ROUTE, "/", FallbackService.INSTANCE) .build(defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, successFunction, - requestAutoAbortDelayMillis, multipartUploadsLocation, + requestAutoAbortDelayMillis, multipartUploadsLocation, serviceWorkerGroup, defaultHeaders, requestIdGenerator, defaultErrorHandler, unhandledExceptionsReporter, "/", contextHook); @@ -1375,7 +1417,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje accessLoggerMapper, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, - builder.build(), requestIdGenerator); + serviceWorkerGroup, builder.build(), requestIdGenerator); final Function decorator = getRouteDecoratingService(template, baseContextPath); diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java index 0d0365dd0c0..e58dc62d644 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A {@link VirtualHostContextPathAnnotatedServiceConfigSetters} builder which configures * an {@link AnnotatedService} under a set of context paths. @@ -244,6 +246,18 @@ public VirtualHostContextPathAnnotatedServiceConfigSetters requestAutoAbortDelay super.requestAutoAbortDelayMillis(delayMillis); } + @Override + public VirtualHostContextPathAnnotatedServiceConfigSetters serviceWorkerGroup( + EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + return (VirtualHostContextPathAnnotatedServiceConfigSetters) + super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public VirtualHostContextPathAnnotatedServiceConfigSetters serviceWorkerGroup(int numThreads) { + return (VirtualHostContextPathAnnotatedServiceConfigSetters) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostContextPathAnnotatedServiceConfigSetters multipartUploadsLocation( Path multipartUploadsLocation) { diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java index 42e53ae482d..7d930cf5560 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link VirtualHostBuilder#contextPath(String...)}. @@ -178,6 +180,18 @@ public VirtualHostContextPathServiceBindingBuilder multipartUploadsLocation( super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public VirtualHostContextPathServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostContextPathServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, + shutdownOnStop); + } + + @Override + public VirtualHostContextPathServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostContextPathServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostContextPathServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java index b0180a95ac2..8a8f10464eb 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java @@ -34,6 +34,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link VirtualHostBuilder#route()}. You can also configure an {@link HttpService} using @@ -315,6 +317,17 @@ public VirtualHostServiceBindingBuilder multipartUploadsLocation(Path multipartU return (VirtualHostServiceBindingBuilder) super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public VirtualHostServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public VirtualHostServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java new file mode 100644 index 00000000000..602abdd5ec6 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.Deque; +import java.util.concurrent.ConcurrentLinkedDeque; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +import io.netty.util.concurrent.EventExecutor; + +class SubscribeOnStreamMessageTest { + + @RegisterExtension + static EventLoopExtension eventLoop1 = new EventLoopExtension(); + + @RegisterExtension + static EventLoopExtension eventLoop2 = new EventLoopExtension(); + + @Test + void completeCase() { + final EventLoopCheckingStreamMessage upstream = + new EventLoopCheckingStreamMessage<>(eventLoop1.get()); + final Deque acc = new ConcurrentLinkedDeque<>(); + upstream.subscribeOn(eventLoop1.get()) + .subscribe(new EventLoopCheckingSubscriber<>(acc), eventLoop2.get()); + upstream.write(1); + upstream.close(); + + await().untilAsserted(() -> assertThat(acc).containsExactlyElementsOf( + ImmutableList.of("onSubscribe", "1", "onComplete"))); + } + + @Test + void errorCase() { + final EventLoopCheckingStreamMessage upstream = + new EventLoopCheckingStreamMessage<>(eventLoop1.get()); + final Deque acc = new ConcurrentLinkedDeque<>(); + upstream.subscribeOn(eventLoop1.get()) + .subscribe(new EventLoopCheckingSubscriber<>(acc), eventLoop2.get()); + upstream.write(1); + upstream.close(new Throwable()); + + await().untilAsserted(() -> assertThat(acc).containsExactlyElementsOf( + ImmutableList.of("onSubscribe", "1", "onError"))); + } + + static class EventLoopCheckingStreamMessage extends DefaultStreamMessage { + + private final EventExecutor eventLoop; + + EventLoopCheckingStreamMessage(EventExecutor eventLoop) { + this.eventLoop = eventLoop; + } + + @Override + protected void subscribe0(EventExecutor executor, SubscriptionOption[] options) { + assertThat(eventLoop.inEventLoop()).isTrue(); + } + + @Override + protected void onRequest(long n) { + assert eventLoop.inEventLoop(); + } + } + + static class EventLoopCheckingSubscriber implements Subscriber { + + private final Deque acc; + + EventLoopCheckingSubscriber(Deque acc) { + this.acc = acc; + } + + @Override + public void onSubscribe(Subscription s) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + s.request(Long.MAX_VALUE); + acc.add("onSubscribe"); + } + + @Override + public void onNext(T t) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add(t.toString()); + } + + @Override + public void onError(Throwable t) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add("onError"); + } + + @Override + public void onComplete() { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add("onComplete"); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java new file mode 100644 index 00000000000..ddf856e0548 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java @@ -0,0 +1,283 @@ +/* + * Copyright 2021 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.server.annotation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.common.util.ThreadFactories; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.Get; +import com.linecorp.armeria.server.annotation.Post; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoop; + +class ServiceWorkerGroupTest { + + static final EventLoop aExecutor = + new DefaultEventLoop(ThreadFactories.builder("test-a") + .eventLoop(false) + .build()); + + static final EventLoop defaultExecutor = new DefaultEventLoop( + ThreadFactories.builder("test-default") + .eventLoop(false) + .build()); + + static final EventLoop workerExecutor = new DefaultEventLoop( + ThreadFactories.builder("test-worker") + .eventLoop(false) + .build()); + + static final Queue threadQueue = new ArrayBlockingQueue<>(32); + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.annotatedService().serviceWorkerGroup(aExecutor, true) + .build(new MyAnnotatedServiceA()); + sb.annotatedService(new MyAnnotatedServiceDefault()); + sb.service("/ctxLog", (ctx, req) -> { + for (RequestLogProperty property: RequestLogProperty.values()) { + ctx.log().whenAvailable(property).thenRun(() -> { + threadQueue.add(Thread.currentThread()); + }); + } + return HttpResponse.of(200); + }); + sb.service("/subscribe", (ctx, req) -> { + final SignallingPublisher publisher = new SignallingPublisher(); + req.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onNext(HttpObject httpObject) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onError(Throwable t) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onComplete() { + threadQueue.add(Thread.currentThread()); + publisher.markReady(); + } + }); + return HttpResponse.of(publisher); + }); + + sb.serviceWorkerGroup(defaultExecutor, true); + } + }; + + static class SignallingPublisher implements Publisher { + + private boolean ready; + @Nullable + private Subscriber subscriber; + private long request; + + @Override + public void subscribe(Subscriber s) { + this.subscriber = s; + threadQueue.add(Thread.currentThread()); + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + threadQueue.add(Thread.currentThread()); + request += n; + tryNotify(); + } + + @Override + public void cancel() { + } + }); + } + + void markReady() { + ready = true; + tryNotify(); + } + + private void tryNotify() { + final Subscriber subscriber0 = subscriber; + if (request > 0 && ready && subscriber0 != null) { + subscriber0.onNext(ResponseHeaders.builder(200).endOfStream(true).build()); + subscriber0.onComplete(); + } + } + } + + @BeforeEach + void beforeEach() { + threadQueue.clear(); + } + + @AfterAll + static void afterAll() { + aExecutor.shutdownGracefully(); + defaultExecutor.shutdownGracefully(); + workerExecutor.shutdownGracefully(); + } + + @ParameterizedTest + @ValueSource(strings = {"/a", "/aggregated"}) + void testServiceWorkerGroup(String path) throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient() + .execute(RequestHeaders.of(HttpMethod.GET, path)); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(aExecutor); + assertThat(threadQueue).allSatisfy(t -> assertThat(aExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void aggregatingRequest() throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient().post("/aggregated-string", "hello"); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(aExecutor); + assertThat(threadQueue).allSatisfy(t -> assertThat(aExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void testDefaultServiceWorkerGroup() throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient() + .execute(RequestHeaders.of(HttpMethod.GET, "/default")); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(defaultExecutor); + } + + @Test + void shutdownOnStopBehavior() { + final EventLoop eventLoop = new DefaultEventLoop(); + try (Server server = Server.builder() + .service("/", (ctx, req) -> HttpResponse.of(200)) + .serviceWorkerGroup(eventLoop, true) + .build()) { + server.start().join(); + assertThat(eventLoop.isShutdown()).isFalse(); + } + assertThat(eventLoop.isShutdown()).isTrue(); + } + + @Test + void defaultIsWorkerThread() { + try (Server server = Server.builder() + .service("/", (ctx, req) -> HttpResponse.of(200)) + .workerGroup(workerExecutor, false) + .build()) { + assertThat(server.serviceConfigs()).allSatisfy(cfg -> { + assertThat(cfg.serviceWorkerGroup()).isSameAs(workerExecutor); + }); + } + } + + @ParameterizedTest + @ValueSource(strings = {"/ctxLog", "/aggregatedCtxLog"}) + void contextLogExecutedByServiceWorkerThread(String path) { + final AggregatedHttpResponse aggRes = server.blockingWebClient().get(path); + assertThat(aggRes.status().code()).isEqualTo(200); + + await().untilAsserted(() -> assertThat(threadQueue).hasSize(RequestLogProperty.values().length)); + assertThat(threadQueue).allSatisfy(t -> assertThat(defaultExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void subscribeWorkerThread() { + final AggregatedHttpResponse aggRes = server.blockingWebClient().get("/subscribe"); + assertThat(aggRes.status().code()).isEqualTo(200); + + await().untilAsserted(() -> assertThat(threadQueue).isNotEmpty()); + assertThat(threadQueue).allSatisfy(t -> assertThat(defaultExecutor.inEventLoop(t)).isTrue()); + } + + static class MyAnnotatedServiceA { + @Get("/a") + public HttpResponse httpResponseLoggingServiceTest() { + threadQueue.add(Thread.currentThread()); + return HttpResponse.of(HttpStatus.OK); + } + + @Get("/aggregated") + public String aggregated() { + threadQueue.add(Thread.currentThread()); + return "aggregated"; + } + + @Post("/aggregated-string") + public String aggregated(String request) { + threadQueue.add(Thread.currentThread()); + return request + ", World!"; + } + } + + static class MyAnnotatedServiceDefault { + @Get("/default") + public HttpResponse httpResponse(ServiceRequestContext ctx) { + return HttpResponse.of(HttpStatus.OK); + } + + @Get("/aggregatedCtxLog") + public String aggregated(ServiceRequestContext ctx) { + for (RequestLogProperty property: RequestLogProperty.values()) { + ctx.log().whenAvailable(property).thenRun(() -> { + threadQueue.add(Thread.currentThread()); + }); + } + return "aggregated"; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java index a4dec6a2dad..9fa6e530070 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java @@ -45,7 +45,8 @@ void fullTypeName_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -61,7 +62,8 @@ void fullTypeName_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -78,7 +80,8 @@ void fullTypeName_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -94,7 +97,8 @@ void fullTypeName_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -110,7 +114,8 @@ void fullTypeName_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -126,7 +131,8 @@ void simpleTypeName_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -142,7 +148,8 @@ void simpleTypeName_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -159,7 +166,8 @@ void simpleTypeName_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -175,7 +183,8 @@ void simpleTypeName_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -191,7 +200,8 @@ void simpleTypeName_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -207,7 +217,8 @@ void shorten_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -223,7 +234,8 @@ void shorten_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -240,7 +252,8 @@ void shorten_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -257,7 +270,8 @@ void shorten_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -274,7 +288,8 @@ void shorten_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java index f7f2cd8b08b..659308bae20 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java @@ -65,7 +65,8 @@ private static void assertDecoration(FooService inner, HttpService outer) throws AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); outer.serviceAdded(cfg);