diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java index 24660c1fc83..50c84e161d4 100644 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java @@ -66,27 +66,27 @@ public class RoutersBenchmark { new ServiceConfig(route1, route1, SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, - NOOP_CONTEXT_HOOK), + SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), + ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK), new ServiceConfig(route2, route2, SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(), - HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, - NOOP_CONTEXT_HOOK)); + SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), + ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK)); FALLBACK_SERVICE = new ServiceConfig(Route.ofCatchAll(), Route.ofCatchAll(), SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(), - CommonPools.blockingTaskExecutor(), - SuccessFunction.always(), 0, multipartUploadsLocation, + CommonPools.blockingTaskExecutor(), SuccessFunction.always(), 0, + multipartUploadsLocation, CommonPools.workerGroup(), ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler, NOOP_CONTEXT_HOOK); HOST = new VirtualHost( "localhost", "localhost", 0, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, unused -> NOPLogger.NOP_LOGGER, defaultServiceNaming, defaultLogName, 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0, SuccessFunction.ofDefault(), - multipartUploadsLocation, ImmutableList.of(), + multipartUploadsLocation, CommonPools.workerGroup(), ImmutableList.of(), ctx -> RequestId.random()); ROUTER = Routers.ofVirtualHost(HOST, SERVICES, RejectedRouteHandler.DISABLED); } diff --git a/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java b/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java index ad901c47d39..9b7e5ed5713 100644 --- a/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java +++ b/brave/src/test/java/com/linecorp/armeria/common/brave/RequestContextCurrentTraceContextTest.java @@ -17,8 +17,6 @@ package com.linecorp.armeria.common.brave; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.when; import org.junit.jupiter.api.BeforeEach; @@ -26,7 +24,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import org.mockito.stubbing.Answer; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; @@ -55,10 +52,7 @@ class RequestContextCurrentTraceContextTest { @BeforeEach void setUp() { when(eventLoop.inEventLoop()).thenReturn(true); - doAnswer((Answer) invocation -> { - invocation.getArgument(0).run(); - return null; - }).when(eventLoop).execute(any()); + when(eventLoop.next()).thenReturn(eventLoop); ctx = ServiceRequestContext.builder(HttpRequest.of(HttpMethod.GET, "/")) .eventLoop(eventLoop) diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java index 7ec7a3a360f..8967b6cc0c0 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java @@ -810,4 +810,10 @@ default HttpRequest peekError(Consumer action) { requireNonNull(action, "action"); return of(headers(), HttpMessage.super.peekError(action)); } + + @Override + default HttpRequest subscribeOn(EventExecutor eventExecutor) { + requireNonNull(eventExecutor, "eventExecutor"); + return of(headers(), HttpMessage.super.subscribeOn(eventExecutor)); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java index 7f0c267da42..a92db4b2789 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java @@ -1178,4 +1178,9 @@ default HttpResponse recover(Class causeClass, } }); } + + @Override + default HttpResponse subscribeOn(EventExecutor eventExecutor) { + return of(HttpMessage.super.subscribeOn(eventExecutor)); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java index e0a4c22879f..a203f9c4116 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java @@ -1136,4 +1136,22 @@ default InputStream toInputStream(Function httpDa default StreamMessage endWith(Function<@Nullable Throwable, ? extends @Nullable T> finalizer) { return new SurroundingPublisher<>(null, this, finalizer); } + + /** + * Calls {@link #subscribe(Subscriber, EventExecutor)} to the upstream + * {@link StreamMessage} using the specified {@link EventExecutor} and relays the stream + * transparently downstream. This may be useful if one would like to hide an + * {@link EventExecutor} from an upstream {@link Publisher}. + * + *

For example:

{@code
+     * Subscriber mySubscriber = null;
+     * StreamMessage upstream = ...; // publisher callbacks are invoked by eventLoop1
+     * upstream.subscribeOn(eventLoop1)
+     *         .subscribe(mySubscriber, eventLoop2); // mySubscriber callbacks are invoked with eventLoop2
+     * }
+ */ + default StreamMessage subscribeOn(EventExecutor eventExecutor) { + requireNonNull(eventExecutor, "eventExecutor"); + return new SubscribeOnStreamMessage<>(this, eventExecutor); + } } diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java new file mode 100644 index 00000000000..b5a2f0bf3bb --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessage.java @@ -0,0 +1,117 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.stream; + +import java.util.concurrent.CompletableFuture; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import io.netty.util.concurrent.EventExecutor; + +final class SubscribeOnStreamMessage implements StreamMessage { + + private final StreamMessage upstream; + private final EventExecutor upstreamExecutor; + + SubscribeOnStreamMessage(StreamMessage upstream, EventExecutor upstreamExecutor) { + this.upstream = upstream; + this.upstreamExecutor = upstreamExecutor; + } + + @Override + public boolean isOpen() { + return upstream.isOpen(); + } + + @Override + public boolean isEmpty() { + return upstream.isEmpty(); + } + + @Override + public long demand() { + return upstream.demand(); + } + + @Override + public CompletableFuture whenComplete() { + return upstream.whenComplete(); + } + + @Override + public EventExecutor defaultSubscriberExecutor() { + return upstreamExecutor; + } + + @Override + public void subscribe(Subscriber subscriber, EventExecutor downstreamExecutor, + SubscriptionOption... options) { + final Subscriber subscriber0; + if (upstreamExecutor == downstreamExecutor) { + subscriber0 = subscriber; + } else { + subscriber0 = new SchedulingSubscriber<>(downstreamExecutor, subscriber); + } + if (upstreamExecutor.inEventLoop()) { + upstream.subscribe(subscriber0, downstreamExecutor, options); + } else { + upstreamExecutor.execute(() -> upstream.subscribe(subscriber0, upstreamExecutor, options)); + } + } + + @Override + public void abort() { + upstream.abort(); + } + + @Override + public void abort(Throwable cause) { + upstream.abort(cause); + } + + static class SchedulingSubscriber implements Subscriber { + + private final Subscriber downstream; + private final EventExecutor downstreamExecutor; + + SchedulingSubscriber(EventExecutor downstreamExecutor, Subscriber downstream) { + this.downstream = downstream; + this.downstreamExecutor = downstreamExecutor; + } + + @Override + public void onSubscribe(Subscription s) { + downstreamExecutor.execute(() -> downstream.onSubscribe(s)); + } + + @Override + public void onNext(T t) { + downstreamExecutor.execute(() -> downstream.onNext(t)); + } + + @Override + public void onError(Throwable t) { + downstreamExecutor.execute(() -> downstream.onError(t)); + } + + @Override + public void onComplete() { + downstreamExecutor.execute(downstream::onComplete); + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java index 7753112453c..b1fb086287d 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java @@ -72,6 +72,7 @@ import io.micrometer.core.instrument.MeterRegistry; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; +import io.netty.channel.EventLoop; import io.netty.util.AttributeKey; /** @@ -90,6 +91,7 @@ public final class DefaultServiceRequestContext DefaultServiceRequestContext.class, HttpHeaders.class, "additionalResponseTrailers"); private final Channel ch; + private final EventLoop eventLoop; private final ServiceConfig cfg; private final RoutingContext routingContext; private final RoutingResult routingResult; @@ -141,22 +143,24 @@ public final class DefaultServiceRequestContext * e.g. {@code System.currentTimeMillis() * 1000}. */ public DefaultServiceRequestContext( - ServiceConfig cfg, Channel ch, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, RoutingContext routingContext, RoutingResult routingResult, ExchangeType exchangeType, + ServiceConfig cfg, Channel ch, EventLoop eventLoop, MeterRegistry meterRegistry, + SessionProtocol sessionProtocol, RequestId id, RoutingContext routingContext, + RoutingResult routingResult, ExchangeType exchangeType, HttpRequest req, @Nullable SSLSession sslSession, ProxiedAddresses proxiedAddresses, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, long requestStartTimeNanos, long requestStartTimeMicros, Supplier contextHook) { - this(cfg, ch, meterRegistry, sessionProtocol, id, routingContext, routingResult, exchangeType, - req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, + this(cfg, ch, eventLoop, meterRegistry, sessionProtocol, id, routingContext, routingResult, + exchangeType, req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, null /* requestCancellationScheduler */, requestStartTimeNanos, requestStartTimeMicros, HttpHeaders.of(), HttpHeaders.of(), contextHook); } public DefaultServiceRequestContext( - ServiceConfig cfg, Channel ch, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, RoutingContext routingContext, RoutingResult routingResult, ExchangeType exchangeType, + ServiceConfig cfg, Channel ch, EventLoop eventLoop, MeterRegistry meterRegistry, + SessionProtocol sessionProtocol, RequestId id, RoutingContext routingContext, + RoutingResult routingResult, ExchangeType exchangeType, HttpRequest req, @Nullable SSLSession sslSession, ProxiedAddresses proxiedAddresses, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, @Nullable CancellationScheduler requestCancellationScheduler, @@ -170,6 +174,7 @@ public DefaultServiceRequestContext( requireNonNull(req, "req"), null, null, contextHook); this.ch = requireNonNull(ch, "ch"); + this.eventLoop = requireNonNull(eventLoop, "eventLoop"); this.cfg = requireNonNull(cfg, "cfg"); this.routingContext = routingContext; this.routingResult = routingResult; @@ -178,7 +183,9 @@ public DefaultServiceRequestContext( } else { this.requestCancellationScheduler = CancellationScheduler.ofServer(TimeUnit.MILLISECONDS.toNanos(cfg.requestTimeoutMillis())); - this.requestCancellationScheduler.init(eventLoop()); + // the cancellation scheduler uses channelEventLoop since #start is called + // from the netty pipeline logic + this.requestCancellationScheduler.init(ch.eventLoop()); } this.sslSession = sslSession; this.proxiedAddresses = requireNonNull(proxiedAddresses, "proxiedAddresses"); @@ -301,7 +308,7 @@ public ContextAwareEventLoop eventLoop() { if (contextAwareEventLoop != null) { return contextAwareEventLoop; } - return contextAwareEventLoop = ContextAwareEventLoop.of(this, ch.eventLoop()); + return contextAwareEventLoop = ContextAwareEventLoop.of(this, eventLoop); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java index 489642f3e40..80ca978f5a3 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractAnnotatedServiceConfigSetters.java @@ -47,6 +47,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + @UnstableApi abstract class AbstractAnnotatedServiceConfigSetters implements AnnotatedServiceConfigSetters { @@ -280,6 +282,18 @@ public AbstractAnnotatedServiceConfigSetters multipartUploadsLocation(Path multi return this; } + @Override + public ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + defaultServiceConfigSetters.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + @Override + public ServiceConfigSetters serviceWorkerGroup(int numThreads) { + defaultServiceConfigSetters.serviceWorkerGroup(numThreads); + return this; + } + @Override public AbstractAnnotatedServiceConfigSetters requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java index a570f69a198..9506174428b 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractServiceBindingBuilder.java @@ -35,6 +35,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. * @@ -175,6 +177,19 @@ public AbstractServiceBindingBuilder multipartUploadsLocation(Path multipartUplo return this; } + @Override + public AbstractServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + defaultServiceConfigSetters.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + @Override + public AbstractServiceBindingBuilder serviceWorkerGroup(int numThreads) { + defaultServiceConfigSetters.serviceWorkerGroup(numThreads); + return this; + } + @Override public AbstractServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java b/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java index ea620a13bb7..ec59df678e8 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/AggregatedHttpResponseHandler.java @@ -57,7 +57,7 @@ final class AggregatedHttpResponseHandler extends AbstractHttpResponseHandler @Override public Void apply(@Nullable AggregatedHttpResponse response, @Nullable Throwable cause) { - final EventLoop eventLoop = reqCtx.eventLoop(); + final EventLoop eventLoop = ctx.channel().eventLoop(); if (eventLoop.inEventLoop()) { apply0(response, cause); } else { diff --git a/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java index b0da5f3049f..d232a8855d8 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/AnnotatedServiceBindingBuilder.java @@ -33,6 +33,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link ServerBuilder#annotatedService()}. @@ -260,6 +262,17 @@ public AnnotatedServiceBindingBuilder contextHook(Supplier requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java index 13ae406edc4..b5d1e948a9e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultServiceConfigSetters.java @@ -41,9 +41,12 @@ import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.internal.server.annotation.AnnotatedService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A default implementation of {@link ServiceConfigSetters} that stores service related settings * and provides a method {@link DefaultServiceConfigSetters#toServiceConfigBuilder(Route, String, HttpService)} @@ -76,6 +79,8 @@ final class DefaultServiceConfigSetters implements ServiceConfigSetters { @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private ServiceErrorHandler serviceErrorHandler; private Supplier contextHook = NOOP_CONTEXT_HOOK; private final List shutdownSupports = new ArrayList<>(); @@ -232,6 +237,22 @@ public ServiceConfigSetters multipartUploadsLocation(Path multipartUploadsLocati return this; } + @Override + public ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + @Override + public ServiceConfigSetters serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + @Override public ServiceConfigSetters requestIdGenerator( Function requestIdGenerator) { @@ -356,6 +377,10 @@ ServiceConfigBuilder toServiceConfigBuilder(Route route, String contextPath, Htt if (multipartUploadsLocation != null) { serviceConfigBuilder.multipartUploadsLocation(multipartUploadsLocation); } + if (serviceWorkerGroup != null) { + serviceConfigBuilder.serviceWorkerGroup(serviceWorkerGroup, false); + // Set the serviceWorkerGroup as false because it's shut down in ShutdownSupport. + } if (requestIdGenerator != null) { serviceConfigBuilder.requestIdGenerator(requestIdGenerator); } diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 045842f9ae6..6e4d0b762ab 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -76,6 +76,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.ChannelInputShutdownReadComplete; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2Exception; @@ -343,12 +344,14 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th final InetSocketAddress localAddress = firstNonNull(localAddress(channel), UNKNOWN_ADDR); final ProxiedAddresses proxiedAddresses = determineProxiedAddresses(remoteAddress, headers); final InetAddress clientAddress = config.clientAddressMapper().apply(proxiedAddresses).getAddress(); + final EventLoop channelEventLoop = channel.eventLoop(); final RoutingContext routingCtx = req.routingContext(); final RoutingStatus routingStatus = routingCtx.status(); if (!routingStatus.routeMustExist()) { final ServiceRequestContext reqCtx = newEarlyRespondingRequestContext( - channel, req, proxiedAddresses, clientAddress, remoteAddress, localAddress, routingCtx); + channel, req, proxiedAddresses, clientAddress, remoteAddress, localAddress, routingCtx, + channelEventLoop); // Handle 'OPTIONS * HTTP/1.1'. if (routingStatus == RoutingStatus.OPTIONS) { @@ -367,17 +370,88 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th final RoutingResult routingResult = routed.routingResult(); final ServiceConfig serviceCfg = routed.value(); final HttpService service = serviceCfg.service(); - + final EventLoop serviceEventLoop; + final boolean needsDirectExecution; + final EventLoopGroup serviceWorkerGroup = serviceCfg.serviceWorkerGroup(); + if (serviceWorkerGroup == config.workerGroup()) { + serviceEventLoop = channelEventLoop; + needsDirectExecution = true; + } else { + serviceEventLoop = serviceWorkerGroup.next(); + needsDirectExecution = serviceEventLoop == channelEventLoop; + } final DefaultServiceRequestContext reqCtx = new DefaultServiceRequestContext( - serviceCfg, channel, config.meterRegistry(), protocol, + serviceCfg, channel, serviceEventLoop, config.meterRegistry(), protocol, nextRequestId(routingCtx, serviceCfg), routingCtx, routingResult, req.exchangeType(), req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, req.requestStartTimeNanos(), req.requestStartTimeMicros(), serviceCfg.contextHook()); + final HttpResponse res; + req.init(reqCtx); + if (needsDirectExecution) { + res = serve0(req, serviceCfg, service, reqCtx); + } else { + res = HttpResponse.of(() -> serve0(req.subscribeOn(serviceEventLoop), serviceCfg, service, reqCtx), + serviceEventLoop) + .subscribeOn(serviceEventLoop); + } + + // Keep track of the number of unfinished requests and + // clean up the request stream when response stream ends. + final boolean isTransientService = + serviceCfg.service().as(TransientService.class) != null; + if (!isTransientService) { + gracefulShutdownSupport.inc(); + } + unfinishedRequests.put(req, res); + + if (service.shouldCachePath(routingCtx.path(), routingCtx.query(), routed.route())) { + reqCtx.log().whenComplete().thenAccept(log -> { + final int statusCode = log.responseHeaders().status().code(); + if (statusCode >= 200 && statusCode < 400) { + RequestTargetCache.putForServer(req.path(), routingCtx.requestTarget()); + } + }); + } + + final RequestAndResponseCompleteHandler handler = + new RequestAndResponseCompleteHandler(channelEventLoop, ctx, reqCtx, req, + isTransientService); + req.whenComplete().handle(handler.requestCompleteHandler); + + // A future which is completed when the all response objects are written to channel and + // the returned promises are done. + final CompletableFuture resWriteFuture = new CompletableFuture<>(); + resWriteFuture.handle(handler.responseCompleteHandler); + + // Set the response to the request in order to be able to immediately abort the response + // when the peer cancels the stream. + req.setResponse(res); + + if (req.isHttp1WebSocket()) { + assert responseEncoder instanceof Http1ObjectEncoder; + final WebSocketHttp1ResponseSubscriber resSubscriber = + new WebSocketHttp1ResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.subscribe(resSubscriber, channelEventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); + } else if (reqCtx.exchangeType().isResponseStreaming()) { + final AbstractHttpResponseSubscriber resSubscriber = + new HttpResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.subscribe(resSubscriber, channelEventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); + } else { + final AggregatedHttpResponseHandler resHandler = + new AggregatedHttpResponseHandler(ctx, responseEncoder, reqCtx, req, resWriteFuture); + res.aggregate(AggregationOptions.usePooledObjects(ctx.alloc(), channelEventLoop)) + .handle(resHandler); + } + } + + private HttpResponse serve0(HttpRequest req, + ServiceConfig serviceCfg, + HttpService service, + DefaultServiceRequestContext reqCtx) { try (SafeCloseable ignored = reqCtx.push()) { HttpResponse serviceResponse; try { - req.init(reqCtx); serviceResponse = service.serve(reqCtx, req); } catch (Throwable cause) { // No need to consume further since the response is ready. @@ -394,56 +468,7 @@ private void handleRequest(ChannelHandlerContext ctx, DecodedHttpRequest req) th // Recover the failed response with the error handler. return serviceCfg.errorHandler().onServiceException(reqCtx, cause); }); - final HttpResponse res = serviceResponse; - final EventLoop eventLoop = channel.eventLoop(); - - // Keep track of the number of unfinished requests and - // clean up the request stream when response stream ends. - final boolean isTransientService = - serviceCfg.service().as(TransientService.class) != null; - if (!isTransientService) { - gracefulShutdownSupport.inc(); - } - unfinishedRequests.put(req, res); - - if (service.shouldCachePath(routingCtx.path(), routingCtx.query(), routed.route())) { - reqCtx.log().whenComplete().thenAccept(log -> { - final int statusCode = log.responseHeaders().status().code(); - if (statusCode >= 200 && statusCode < 400) { - RequestTargetCache.putForServer(req.path(), routingCtx.requestTarget()); - } - }); - } - - final RequestAndResponseCompleteHandler handler = - new RequestAndResponseCompleteHandler(eventLoop, ctx, reqCtx, req, - isTransientService); - req.whenComplete().handle(handler.requestCompleteHandler); - - // A future which is completed when the all response objects are written to channel and - // the returned promises are done. - final CompletableFuture resWriteFuture = new CompletableFuture<>(); - resWriteFuture.handle(handler.responseCompleteHandler); - - // Set the response to the request in order to be able to immediately abort the response - // when the peer cancels the stream. - req.setResponse(res); - - if (req.isHttp1WebSocket()) { - assert responseEncoder instanceof Http1ObjectEncoder; - final WebSocketHttp1ResponseSubscriber resSubscriber = - new WebSocketHttp1ResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.subscribe(resSubscriber, eventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); - } else if (reqCtx.exchangeType().isResponseStreaming()) { - final AbstractHttpResponseSubscriber resSubscriber = - new HttpResponseSubscriber(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.subscribe(resSubscriber, eventLoop, SubscriptionOption.WITH_POOLED_OBJECTS); - } else { - final AggregatedHttpResponseHandler resHandler = - new AggregatedHttpResponseHandler(ctx, responseEncoder, reqCtx, req, resWriteFuture); - res.aggregate(AggregationOptions.usePooledObjects(ctx.alloc(), eventLoop)) - .handle(resHandler); - } + return serviceResponse; } } @@ -612,14 +637,15 @@ private ServiceRequestContext newEarlyRespondingRequestContext(Channel channel, InetAddress clientAddress, InetSocketAddress remoteAddress, InetSocketAddress localAddress, - RoutingContext routingCtx) { + RoutingContext routingCtx, + EventLoop eventLoop) { final ServiceConfig serviceConfig = routingCtx.virtualHost().fallbackServiceConfig(); final RoutingResult routingResult = RoutingResult.builder() .path(routingCtx.path()) .build(); return new DefaultServiceRequestContext( serviceConfig, - channel, NoopMeterRegistry.get(), protocol(), + channel, eventLoop, NoopMeterRegistry.get(), protocol(), nextRequestId(routingCtx, serviceConfig), routingCtx, routingResult, req.exchangeType(), req, sslSession, proxiedAddresses, clientAddress, remoteAddress, localAddress, System.nanoTime(), SystemInfo.currentTimeMicros(), NOOP_CONTEXT_HOOK); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java index e3db202697b..ed91d00fbc6 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java @@ -190,7 +190,7 @@ public final class ServerBuilder implements TlsSetters, ServiceConfigsBuilder { private final VirtualHostBuilder defaultVirtualHostBuilder = new VirtualHostBuilder(this, true); private final List virtualHostBuilders = new ArrayList<>(); - private EventLoopGroup workerGroup = CommonPools.workerGroup(); + EventLoopGroup workerGroup = CommonPools.workerGroup(); private boolean shutdownWorkerGroupOnStop; private Executor startStopExecutor = START_STOP_EXECUTOR; private final Map, Object> channelOptions = new Object2ObjectArrayMap<>(); @@ -537,6 +537,34 @@ public ServerBuilder workerGroup(int numThreads) { return this; } + /** + * Sets the worker {@link EventLoopGroup} which is responsible for running + * {@link Service#serve(ServiceRequestContext, Request)}. + * If not set, the value set via {@linkplain #workerGroup(EventLoopGroup, boolean)} + * or {@linkplain #workerGroup(int)} is used. + * + * @param shutdownOnStop whether to shut down the worker {@link EventLoopGroup} + * when the {@link Server} stops + */ + @UnstableApi + public ServerBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + virtualHostTemplate.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + return this; + } + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads for + * running {@link Service#serve(ServiceRequestContext, Request)}. + * The worker {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of event loop threads + */ + @UnstableApi + public ServerBuilder serviceWorkerGroup(int numThreads) { + virtualHostTemplate.serviceWorkerGroup(EventLoopGroups.newEventLoopGroup(numThreads), true); + return this; + } + /** * Sets the {@link Executor} which will invoke the callbacks of {@link Server#start()}, * {@link Server#stop()} and {@link ServerListener}. @@ -2200,8 +2228,8 @@ private DefaultServerConfig buildServerConfig(List serverPorts) { final Map, Object> newChildChannelOptions = ChannelUtil.applyDefaultChannelOptions( childChannelOptions, idleTimeoutMillis, pingIntervalMillis); - final BlockingTaskExecutor blockingTaskExecutor = defaultVirtualHost.blockingTaskExecutor(); + return new DefaultServerConfig( ports, setSslContextIfAbsent(defaultVirtualHost, defaultSslContext), virtualHosts, workerGroup, shutdownWorkerGroupOnStop, startStopExecutor, maxNumConnections, diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java index 0e14ce03e23..0bbb20aa235 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerConfig.java @@ -103,8 +103,7 @@ public interface ServerConfig { List serviceConfigs(); /** - * Returns the worker {@link EventLoopGroup} which is responsible for performing socket I/O and running - * {@link Service#serve(ServiceRequestContext, Request)}. + * Returns the worker {@link EventLoopGroup} which is responsible for performing socket I/O. */ EventLoopGroup workerGroup(); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java index 7ea369b49c5..60e77711e4e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceBindingBuilder.java @@ -35,6 +35,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link ServerBuilder#route()}. You can also configure an {@link HttpService} using @@ -291,6 +293,17 @@ public ServiceBindingBuilder multipartUploadsLocation(Path multipartUploadsLocat return (ServiceBindingBuilder) super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public ServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (ServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public ServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (ServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public ServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java index 3aaf0475889..0a27ab662eb 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfig.java @@ -46,6 +46,8 @@ import com.linecorp.armeria.server.cors.CorsService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * An {@link HttpService} configuration. * @@ -79,6 +81,8 @@ public final class ServiceConfig { private final long requestAutoAbortDelayMillis; private final Path multipartUploadsLocation; + private final EventLoopGroup serviceWorkerGroup; + private final List shutdownSupports; private final HttpHeaders defaultHeaders; private final Function requestIdGenerator; @@ -94,7 +98,8 @@ public final class ServiceConfig { boolean verboseResponses, AccessLogWriter accessLogWriter, BlockingTaskExecutor blockingTaskExecutor, SuccessFunction successFunction, long requestAutoAbortDelayMillis, - Path multipartUploadsLocation, List shutdownSupports, + Path multipartUploadsLocation, EventLoopGroup serviceWorkerGroup, + List shutdownSupports, HttpHeaders defaultHeaders, Function requestIdGenerator, ServiceErrorHandler serviceErrorHandler, Supplier contextHook) { @@ -102,7 +107,7 @@ public final class ServiceConfig { requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, extractTransientServiceOptions(service), blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, shutdownSupports, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, shutdownSupports, defaultHeaders, requestIdGenerator, serviceErrorHandler, contextHook); } @@ -119,6 +124,7 @@ private ServiceConfig(@Nullable VirtualHost virtualHost, Route route, SuccessFunction successFunction, long requestAutoAbortDelayMillis, Path multipartUploadsLocation, + EventLoopGroup serviceWorkerGroup, List shutdownSupports, HttpHeaders defaultHeaders, Function requestIdGenerator, ServiceErrorHandler serviceErrorHandler, @@ -139,6 +145,7 @@ private ServiceConfig(@Nullable VirtualHost virtualHost, Route route, this.successFunction = requireNonNull(successFunction, "successFunction"); this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; this.multipartUploadsLocation = requireNonNull(multipartUploadsLocation, "multipartUploadsLocation"); + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); this.shutdownSupports = ImmutableList.copyOf(requireNonNull(shutdownSupports, "shutdownSupports")); this.defaultHeaders = defaultHeaders; @SuppressWarnings("unchecked") @@ -185,7 +192,7 @@ ServiceConfig withVirtualHost(VirtualHost virtualHost) { defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, transientServiceOptions, blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, shutdownSupports, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, shutdownSupports, defaultHeaders, requestIdGenerator, serviceErrorHandler, contextHook); } @@ -196,7 +203,7 @@ ServiceConfig withDecoratedService(Function shutdownSupports() { return shutdownSupports; } + /** + * Returns the {@link EventLoopGroup} dedicated to the execution of services' methods. + */ + @UnstableApi + public EventLoopGroup serviceWorkerGroup() { + return serviceWorkerGroup; + } + /** * Returns the default headers for an {@link HttpResponse} served by the {@link #service()}. */ diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java index 4a78a286860..e36affcf759 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java @@ -42,10 +42,13 @@ import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + final class ServiceConfigBuilder implements ServiceConfigSetters { private final Route route; @@ -76,6 +79,8 @@ final class ServiceConfigBuilder implements ServiceConfigSetters { @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private ServiceErrorHandler serviceErrorHandler; private Supplier contextHook = NOOP_CONTEXT_HOOK; private final List shutdownSupports = new ArrayList<>(); @@ -289,6 +294,22 @@ public ServiceConfigBuilder defaultServiceNaming(ServiceNaming defaultServiceNam return this; } + @Override + public ServiceConfigBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + @Override + public ServiceConfigBuilder serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + void shutdownSupports(List shutdownSupports) { requireNonNull(shutdownSupports, "shutdownSupports"); this.shutdownSupports.addAll(shutdownSupports); @@ -308,6 +329,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, SuccessFunction defaultSuccessFunction, long defaultRequestAutoAbortDelayMillis, Path defaultMultipartUploadsLocation, + EventLoopGroup defaultServiceWorkerGroup, HttpHeaders virtualHostDefaultHeaders, Function defaultRequestIdGenerator, ServiceErrorHandler defaultServiceErrorHandler, @@ -366,6 +388,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, successFunction != null ? successFunction : defaultSuccessFunction, requestAutoAbortDelayMillis, multipartUploadsLocation != null ? multipartUploadsLocation : defaultMultipartUploadsLocation, + serviceWorkerGroup != null ? serviceWorkerGroup : defaultServiceWorkerGroup, ImmutableList.copyOf(shutdownSupports), mergeDefaultHeaders(virtualHostDefaultHeaders.toBuilder(), defaultHeaders.build()), requestIdGenerator != null ? requestIdGenerator : defaultRequestIdGenerator, errorHandler, @@ -385,6 +408,7 @@ public String toString() { .add("blockingTaskExecutor", blockingTaskExecutor) .add("successFunction", successFunction) .add("multipartUploadsLocation", multipartUploadsLocation) + .add("serviceWorkerGroup", serviceWorkerGroup) .add("shutdownSupports", shutdownSupports) .add("defaultHeaders", defaultHeaders) .add("serviceErrorHandler", serviceErrorHandler) diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java index d7b50b0b586..a1066477998 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigSetters.java @@ -37,6 +37,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + interface ServiceConfigSetters { /** @@ -209,6 +211,27 @@ ServiceConfigSetters blockingTaskExecutor(BlockingTaskExecutor blockingTaskExecu @UnstableApi ServiceConfigSetters multipartUploadsLocation(Path multipartUploadsLocation); + /** + * Sets a {@linkplain EventLoopGroup worker group} to be used when serving a {@link Service}. + * + * @param serviceWorkerGroup the {@linkplain ScheduledExecutorService executor} to be used. + * @param shutdownOnStop whether to shut down the {@link ScheduledExecutorService} when the {@link Server} + * stops. + */ + @UnstableApi + ServiceConfigSetters serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop); + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads dedicated to + * the execution of service codes. + * The {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of threads in the executor + */ + @UnstableApi + ServiceConfigSetters serviceWorkerGroup(int numThreads); + /** * Sets the {@link Function} which generates a {@link RequestId}. * diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java index 98a782d544d..4e0c71bb18f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceRequestContextBuilder.java @@ -236,10 +236,12 @@ public ServiceRequestContext build() { requestCancellationScheduler.initAndStart(eventLoop(), noopCancellationTask); } + final EventLoop serviceWorkerGroup = eventLoop(); + // Build the context with the properties set by a user and the fake objects. final Channel ch = fakeChannel(); return new DefaultServiceRequestContext( - serviceCfg, ch, meterRegistry(), sessionProtocol(), id(), routingCtx, + serviceCfg, ch, serviceWorkerGroup, meterRegistry(), sessionProtocol(), id(), routingCtx, routingResult, exchangeType, req, sslSession(), proxiedAddresses, clientAddress, remoteAddress(), localAddress(), requestCancellationScheduler, diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java index 49eaa6eaed2..96027f72c0f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java @@ -39,6 +39,7 @@ import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.logging.RequestLog; import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.metric.MeterIdPrefix; @@ -46,6 +47,7 @@ import com.linecorp.armeria.server.logging.AccessLogWriter; import io.micrometer.core.instrument.MeterRegistry; +import io.netty.channel.EventLoopGroup; import io.netty.handler.ssl.SslContext; import io.netty.util.Mapping; @@ -95,6 +97,7 @@ public final class VirtualHost { private final long requestAutoAbortDelayMillis; private final SuccessFunction successFunction; private final Path multipartUploadsLocation; + private final EventLoopGroup serviceWorkerGroup; private final List shutdownSupports; private final Function requestIdGenerator; @@ -113,6 +116,7 @@ public final class VirtualHost { long requestAutoAbortDelayMillis, SuccessFunction successFunction, Path multipartUploadsLocation, + EventLoopGroup serviceWorkerGroup, List shutdownSupports, Function requestIdGenerator) { originalDefaultHostname = defaultHostname; @@ -136,6 +140,7 @@ public final class VirtualHost { this.requestAutoAbortDelayMillis = requestAutoAbortDelayMillis; this.successFunction = successFunction; this.multipartUploadsLocation = multipartUploadsLocation; + this.serviceWorkerGroup = serviceWorkerGroup; this.shutdownSupports = shutdownSupports; @SuppressWarnings("unchecked") final Function castRequestIdGenerator = @@ -162,7 +167,8 @@ VirtualHost withNewSslContext(SslContext sslContext) { host -> accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, - successFunction, multipartUploadsLocation, shutdownSupports, requestIdGenerator); + successFunction, multipartUploadsLocation, serviceWorkerGroup, + shutdownSupports, requestIdGenerator); } /** @@ -400,6 +406,16 @@ public boolean shutdownBlockingTaskExecutorOnStop() { return false; } + /** + * Returns the service {@link EventLoopGroup}. + * + * @see ServiceConfig#serviceWorkerGroup() + */ + @UnstableApi + public EventLoopGroup serviceWorkerGroup() { + return serviceWorkerGroup; + } + /** * Returns the {@link SuccessFunction} that determines whether a request was * handled successfully or not. @@ -530,7 +546,7 @@ VirtualHost decorate(@Nullable Function accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, - shutdownSupports, requestIdGenerator); + serviceWorkerGroup, shutdownSupports, requestIdGenerator); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java index 9628dcf3826..2e903e9ae93 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilder.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} to a virtual host fluently. This class can be instantiated * through {@link VirtualHostBuilder#annotatedService()}. @@ -268,6 +270,18 @@ public VirtualHostAnnotatedServiceBindingBuilder contextHook( return (VirtualHostAnnotatedServiceBindingBuilder) super.contextHook(contextHook); } + @Override + public VirtualHostAnnotatedServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostAnnotatedServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, + shutdownOnStop); + } + + @Override + public VirtualHostAnnotatedServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostAnnotatedServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + /** * Registers the given service to the {@linkplain VirtualHostBuilder}. * diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java index 0d952db5b4e..a64b58e125a 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java @@ -82,6 +82,7 @@ import com.linecorp.armeria.common.logging.RequestLog; import com.linecorp.armeria.common.logging.RequestLogBuilder; import com.linecorp.armeria.common.util.BlockingTaskExecutor; +import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.internal.common.util.SelfSignedCertificate; import com.linecorp.armeria.internal.server.RouteDecoratingService; @@ -164,6 +165,8 @@ public final class VirtualHostBuilder implements TlsSetters, ServiceConfigsBuild @Nullable private Path multipartUploadsLocation; @Nullable + private EventLoopGroup serviceWorkerGroup; + @Nullable private Function requestIdGenerator; @Nullable private ServiceErrorHandler errorHandler; @@ -1194,6 +1197,36 @@ public VirtualHostBuilder requestIdGenerator( return this; } + /** + * Sets the {@link EventLoopGroup} dedicated to the execution of services' methods. + * If not set, the work group of the belonging channel is used. + * + * @param shutdownOnStop whether to shut down the {@link EventLoopGroup} when the + * {@link Server} stops + */ + @UnstableApi + public VirtualHostBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + this.serviceWorkerGroup = requireNonNull(serviceWorkerGroup, "serviceWorkerGroup"); + if (shutdownOnStop) { + shutdownSupports.add(ShutdownSupport.of(serviceWorkerGroup)); + } + return this; + } + + /** + * Uses a newly created {@link EventLoopGroup} with the specified number of threads dedicated to + * the execution of services' methods. + * The worker {@link EventLoopGroup} will be shut down when the {@link Server} stops. + * + * @param numThreads the number of threads in the executor + */ + @UnstableApi + public VirtualHostBuilder serviceWorkerGroup(int numThreads) { + final EventLoopGroup workerGroup = EventLoopGroups.newEventLoopGroup(numThreads); + return serviceWorkerGroup(workerGroup, true); + } + /** * Sets the {@link RequestConverterFunction}s, {@link ResponseConverterFunction} * and {@link ExceptionHandlerFunction}s for creating an {@link AnnotatedServiceExtensions}. @@ -1326,6 +1359,15 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje final Supplier contextHook = mergeHooks(template.contextHook, this.contextHook); + final EventLoopGroup serviceWorkerGroup; + if (this.serviceWorkerGroup != null) { + serviceWorkerGroup = this.serviceWorkerGroup; + } else if (template.serviceWorkerGroup != null) { + serviceWorkerGroup = template.serviceWorkerGroup; + } else { + serviceWorkerGroup = serverBuilder.workerGroup; + } + assert defaultServiceNaming != null; assert rejectedRouteHandler != null; assert accessLoggerMapper != null; @@ -1352,7 +1394,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje return cfgBuilder.build(defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, successFunction, requestAutoAbortDelayMillis, - multipartUploadsLocation, defaultHeaders, + multipartUploadsLocation, serviceWorkerGroup, defaultHeaders, requestIdGenerator, defaultErrorHandler, unhandledExceptionsReporter, baseContextPath, contextHook); }).collect(toImmutableList()); @@ -1361,7 +1403,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje new ServiceConfigBuilder(RouteBuilder.FALLBACK_ROUTE, "/", FallbackService.INSTANCE) .build(defaultServiceNaming, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, successFunction, - requestAutoAbortDelayMillis, multipartUploadsLocation, + requestAutoAbortDelayMillis, multipartUploadsLocation, serviceWorkerGroup, defaultHeaders, requestIdGenerator, defaultErrorHandler, unhandledExceptionsReporter, "/", contextHook); @@ -1375,7 +1417,7 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje accessLoggerMapper, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, - builder.build(), requestIdGenerator); + serviceWorkerGroup, builder.build(), requestIdGenerator); final Function decorator = getRouteDecoratingService(template, baseContextPath); diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java index 0d0365dd0c0..e58dc62d644 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathAnnotatedServiceConfigSetters.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.server.annotation.ResponseConverterFunction; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A {@link VirtualHostContextPathAnnotatedServiceConfigSetters} builder which configures * an {@link AnnotatedService} under a set of context paths. @@ -244,6 +246,18 @@ public VirtualHostContextPathAnnotatedServiceConfigSetters requestAutoAbortDelay super.requestAutoAbortDelayMillis(delayMillis); } + @Override + public VirtualHostContextPathAnnotatedServiceConfigSetters serviceWorkerGroup( + EventLoopGroup serviceWorkerGroup, boolean shutdownOnStop) { + return (VirtualHostContextPathAnnotatedServiceConfigSetters) + super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public VirtualHostContextPathAnnotatedServiceConfigSetters serviceWorkerGroup(int numThreads) { + return (VirtualHostContextPathAnnotatedServiceConfigSetters) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostContextPathAnnotatedServiceConfigSetters multipartUploadsLocation( Path multipartUploadsLocation) { diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java index 42e53ae482d..7d930cf5560 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostContextPathServiceBindingBuilder.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link VirtualHostBuilder#contextPath(String...)}. @@ -178,6 +180,18 @@ public VirtualHostContextPathServiceBindingBuilder multipartUploadsLocation( super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public VirtualHostContextPathServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostContextPathServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, + shutdownOnStop); + } + + @Override + public VirtualHostContextPathServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostContextPathServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostContextPathServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java index b0180a95ac2..8a8f10464eb 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostServiceBindingBuilder.java @@ -34,6 +34,8 @@ import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.logging.AccessLogWriter; +import io.netty.channel.EventLoopGroup; + /** * A builder class for binding an {@link HttpService} fluently. This class can be instantiated through * {@link VirtualHostBuilder#route()}. You can also configure an {@link HttpService} using @@ -315,6 +317,17 @@ public VirtualHostServiceBindingBuilder multipartUploadsLocation(Path multipartU return (VirtualHostServiceBindingBuilder) super.multipartUploadsLocation(multipartUploadsLocation); } + @Override + public VirtualHostServiceBindingBuilder serviceWorkerGroup(EventLoopGroup serviceWorkerGroup, + boolean shutdownOnStop) { + return (VirtualHostServiceBindingBuilder) super.serviceWorkerGroup(serviceWorkerGroup, shutdownOnStop); + } + + @Override + public VirtualHostServiceBindingBuilder serviceWorkerGroup(int numThreads) { + return (VirtualHostServiceBindingBuilder) super.serviceWorkerGroup(numThreads); + } + @Override public VirtualHostServiceBindingBuilder requestIdGenerator( Function requestIdGenerator) { diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java new file mode 100644 index 00000000000..602abdd5ec6 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/common/stream/SubscribeOnStreamMessageTest.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.Deque; +import java.util.concurrent.ConcurrentLinkedDeque; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +import io.netty.util.concurrent.EventExecutor; + +class SubscribeOnStreamMessageTest { + + @RegisterExtension + static EventLoopExtension eventLoop1 = new EventLoopExtension(); + + @RegisterExtension + static EventLoopExtension eventLoop2 = new EventLoopExtension(); + + @Test + void completeCase() { + final EventLoopCheckingStreamMessage upstream = + new EventLoopCheckingStreamMessage<>(eventLoop1.get()); + final Deque acc = new ConcurrentLinkedDeque<>(); + upstream.subscribeOn(eventLoop1.get()) + .subscribe(new EventLoopCheckingSubscriber<>(acc), eventLoop2.get()); + upstream.write(1); + upstream.close(); + + await().untilAsserted(() -> assertThat(acc).containsExactlyElementsOf( + ImmutableList.of("onSubscribe", "1", "onComplete"))); + } + + @Test + void errorCase() { + final EventLoopCheckingStreamMessage upstream = + new EventLoopCheckingStreamMessage<>(eventLoop1.get()); + final Deque acc = new ConcurrentLinkedDeque<>(); + upstream.subscribeOn(eventLoop1.get()) + .subscribe(new EventLoopCheckingSubscriber<>(acc), eventLoop2.get()); + upstream.write(1); + upstream.close(new Throwable()); + + await().untilAsserted(() -> assertThat(acc).containsExactlyElementsOf( + ImmutableList.of("onSubscribe", "1", "onError"))); + } + + static class EventLoopCheckingStreamMessage extends DefaultStreamMessage { + + private final EventExecutor eventLoop; + + EventLoopCheckingStreamMessage(EventExecutor eventLoop) { + this.eventLoop = eventLoop; + } + + @Override + protected void subscribe0(EventExecutor executor, SubscriptionOption[] options) { + assertThat(eventLoop.inEventLoop()).isTrue(); + } + + @Override + protected void onRequest(long n) { + assert eventLoop.inEventLoop(); + } + } + + static class EventLoopCheckingSubscriber implements Subscriber { + + private final Deque acc; + + EventLoopCheckingSubscriber(Deque acc) { + this.acc = acc; + } + + @Override + public void onSubscribe(Subscription s) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + s.request(Long.MAX_VALUE); + acc.add("onSubscribe"); + } + + @Override + public void onNext(T t) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add(t.toString()); + } + + @Override + public void onError(Throwable t) { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add("onError"); + } + + @Override + public void onComplete() { + assertThat(eventLoop2.get().inEventLoop()).isTrue(); + acc.add("onComplete"); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java new file mode 100644 index 00000000000..ddf856e0548 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/ServiceWorkerGroupTest.java @@ -0,0 +1,283 @@ +/* + * Copyright 2021 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.server.annotation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.common.util.ThreadFactories; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.Get; +import com.linecorp.armeria.server.annotation.Post; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoop; + +class ServiceWorkerGroupTest { + + static final EventLoop aExecutor = + new DefaultEventLoop(ThreadFactories.builder("test-a") + .eventLoop(false) + .build()); + + static final EventLoop defaultExecutor = new DefaultEventLoop( + ThreadFactories.builder("test-default") + .eventLoop(false) + .build()); + + static final EventLoop workerExecutor = new DefaultEventLoop( + ThreadFactories.builder("test-worker") + .eventLoop(false) + .build()); + + static final Queue threadQueue = new ArrayBlockingQueue<>(32); + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.annotatedService().serviceWorkerGroup(aExecutor, true) + .build(new MyAnnotatedServiceA()); + sb.annotatedService(new MyAnnotatedServiceDefault()); + sb.service("/ctxLog", (ctx, req) -> { + for (RequestLogProperty property: RequestLogProperty.values()) { + ctx.log().whenAvailable(property).thenRun(() -> { + threadQueue.add(Thread.currentThread()); + }); + } + return HttpResponse.of(200); + }); + sb.service("/subscribe", (ctx, req) -> { + final SignallingPublisher publisher = new SignallingPublisher(); + req.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onNext(HttpObject httpObject) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onError(Throwable t) { + threadQueue.add(Thread.currentThread()); + } + + @Override + public void onComplete() { + threadQueue.add(Thread.currentThread()); + publisher.markReady(); + } + }); + return HttpResponse.of(publisher); + }); + + sb.serviceWorkerGroup(defaultExecutor, true); + } + }; + + static class SignallingPublisher implements Publisher { + + private boolean ready; + @Nullable + private Subscriber subscriber; + private long request; + + @Override + public void subscribe(Subscriber s) { + this.subscriber = s; + threadQueue.add(Thread.currentThread()); + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + threadQueue.add(Thread.currentThread()); + request += n; + tryNotify(); + } + + @Override + public void cancel() { + } + }); + } + + void markReady() { + ready = true; + tryNotify(); + } + + private void tryNotify() { + final Subscriber subscriber0 = subscriber; + if (request > 0 && ready && subscriber0 != null) { + subscriber0.onNext(ResponseHeaders.builder(200).endOfStream(true).build()); + subscriber0.onComplete(); + } + } + } + + @BeforeEach + void beforeEach() { + threadQueue.clear(); + } + + @AfterAll + static void afterAll() { + aExecutor.shutdownGracefully(); + defaultExecutor.shutdownGracefully(); + workerExecutor.shutdownGracefully(); + } + + @ParameterizedTest + @ValueSource(strings = {"/a", "/aggregated"}) + void testServiceWorkerGroup(String path) throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient() + .execute(RequestHeaders.of(HttpMethod.GET, path)); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(aExecutor); + assertThat(threadQueue).allSatisfy(t -> assertThat(aExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void aggregatingRequest() throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient().post("/aggregated-string", "hello"); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(aExecutor); + assertThat(threadQueue).allSatisfy(t -> assertThat(aExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void testDefaultServiceWorkerGroup() throws InterruptedException { + final AggregatedHttpResponse res = server.blockingWebClient() + .execute(RequestHeaders.of(HttpMethod.GET, "/default")); + assertThat(res.status()).isSameAs(HttpStatus.OK); + assertThat(server.requestContextCaptor().size()).isEqualTo(1); + final EventLoop ctxEventLoop = server.requestContextCaptor().poll().eventLoop().withoutContext(); + assertThat(ctxEventLoop).isSameAs(defaultExecutor); + } + + @Test + void shutdownOnStopBehavior() { + final EventLoop eventLoop = new DefaultEventLoop(); + try (Server server = Server.builder() + .service("/", (ctx, req) -> HttpResponse.of(200)) + .serviceWorkerGroup(eventLoop, true) + .build()) { + server.start().join(); + assertThat(eventLoop.isShutdown()).isFalse(); + } + assertThat(eventLoop.isShutdown()).isTrue(); + } + + @Test + void defaultIsWorkerThread() { + try (Server server = Server.builder() + .service("/", (ctx, req) -> HttpResponse.of(200)) + .workerGroup(workerExecutor, false) + .build()) { + assertThat(server.serviceConfigs()).allSatisfy(cfg -> { + assertThat(cfg.serviceWorkerGroup()).isSameAs(workerExecutor); + }); + } + } + + @ParameterizedTest + @ValueSource(strings = {"/ctxLog", "/aggregatedCtxLog"}) + void contextLogExecutedByServiceWorkerThread(String path) { + final AggregatedHttpResponse aggRes = server.blockingWebClient().get(path); + assertThat(aggRes.status().code()).isEqualTo(200); + + await().untilAsserted(() -> assertThat(threadQueue).hasSize(RequestLogProperty.values().length)); + assertThat(threadQueue).allSatisfy(t -> assertThat(defaultExecutor.inEventLoop(t)).isTrue()); + } + + @Test + void subscribeWorkerThread() { + final AggregatedHttpResponse aggRes = server.blockingWebClient().get("/subscribe"); + assertThat(aggRes.status().code()).isEqualTo(200); + + await().untilAsserted(() -> assertThat(threadQueue).isNotEmpty()); + assertThat(threadQueue).allSatisfy(t -> assertThat(defaultExecutor.inEventLoop(t)).isTrue()); + } + + static class MyAnnotatedServiceA { + @Get("/a") + public HttpResponse httpResponseLoggingServiceTest() { + threadQueue.add(Thread.currentThread()); + return HttpResponse.of(HttpStatus.OK); + } + + @Get("/aggregated") + public String aggregated() { + threadQueue.add(Thread.currentThread()); + return "aggregated"; + } + + @Post("/aggregated-string") + public String aggregated(String request) { + threadQueue.add(Thread.currentThread()); + return request + ", World!"; + } + } + + static class MyAnnotatedServiceDefault { + @Get("/default") + public HttpResponse httpResponse(ServiceRequestContext ctx) { + return HttpResponse.of(HttpStatus.OK); + } + + @Get("/aggregatedCtxLog") + public String aggregated(ServiceRequestContext ctx) { + for (RequestLogProperty property: RequestLogProperty.values()) { + ctx.log().whenAvailable(property).thenRun(() -> { + threadQueue.add(Thread.currentThread()); + }); + } + return "aggregated"; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java index a4dec6a2dad..9fa6e530070 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceNamingTest.java @@ -45,7 +45,8 @@ void fullTypeName_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -61,7 +62,8 @@ void fullTypeName_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -78,7 +80,8 @@ void fullTypeName_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -94,7 +97,8 @@ void fullTypeName_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -110,7 +114,8 @@ void fullTypeName_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -126,7 +131,8 @@ void simpleTypeName_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -142,7 +148,8 @@ void simpleTypeName_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -159,7 +166,8 @@ void simpleTypeName_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -175,7 +183,8 @@ void simpleTypeName_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -191,7 +200,8 @@ void simpleTypeName_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -207,7 +217,8 @@ void shorten_topClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -223,7 +234,8 @@ void shorten_nestedClass() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -240,7 +252,8 @@ void shorten_trimTrailingDollarSign() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -257,7 +270,8 @@ void shorten_trimTrailingDollarSignMany() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); @@ -274,7 +288,8 @@ void shorten_trimTrailingDollarSignOnly() { null, null, ServiceNaming.fullTypeName(), 0, 0, false, AccessLogWriter.common(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), routingCtx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); when(ctx.config()).thenReturn(config); diff --git a/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java index f7f2cd8b08b..659308bae20 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServiceTest.java @@ -65,7 +65,8 @@ private static void assertDecoration(FooService inner, HttpService outer) throws AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), SuccessFunction.always(), - 0, Files.newTemporaryFolder().toPath(), ImmutableList.of(), HttpHeaders.of(), + 0, Files.newTemporaryFolder().toPath(), CommonPools.workerGroup(), + ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.of(1L), ServerErrorHandler.ofDefault().asServiceErrorHandler(), NOOP_CONTEXT_HOOK); outer.serviceAdded(cfg);