diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/config/VeniceServerConfig.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/config/VeniceServerConfig.java index 9dbfb919482..fe86959c380 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/config/VeniceServerConfig.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/config/VeniceServerConfig.java @@ -3,6 +3,7 @@ import static com.linkedin.davinci.ingestion.utils.IsolatedIngestionUtils.INGESTION_ISOLATION_CONFIG_PREFIX; import static com.linkedin.davinci.store.rocksdb.RocksDBServerConfig.ROCKSDB_TOTAL_MEMTABLE_USAGE_CAP_IN_BYTES; import static com.linkedin.venice.ConfigConstants.DEFAULT_MAX_RECORD_SIZE_BYTES_BACKFILL; +import static com.linkedin.venice.ConfigKeys.ACL_IN_MEMORY_CACHE_TTL_MS; import static com.linkedin.venice.ConfigKeys.AUTOCREATE_DATA_PATH; import static com.linkedin.venice.ConfigKeys.BLOB_TRANSFER_DISABLED_OFFSET_LAG_THRESHOLD; import static com.linkedin.venice.ConfigKeys.BLOB_TRANSFER_MANAGER_ENABLED; @@ -584,6 +585,7 @@ public class VeniceServerConfig extends VeniceClusterConfig { private final int zstdDictCompressionLevel; private final long maxWaitAfterUnsubscribeMs; private final boolean deleteUnassignedPartitionsOnStartup; + private final int aclInMemoryCacheTTLMs; public VeniceServerConfig(VeniceProperties serverProperties) throws ConfigurationException { this(serverProperties, Collections.emptyMap()); @@ -988,6 +990,8 @@ public VeniceServerConfig(VeniceProperties serverProperties, Map extends SimpleChannelInboundHandler { + private static class CachedAcl { + AccessResult accessResult; + long timestamp; + + public CachedAcl(AccessResult accessResult, long timestamp) { + this.accessResult = accessResult; + this.timestamp = timestamp; + } + } + private static final Logger LOGGER = LogManager.getLogger(AbstractStoreAclHandler.class); + public static final String STORE_ACL_CHECK_RESULT = "STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY"; + public static final AttributeKey> STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY = + AttributeKey.valueOf(STORE_ACL_CHECK_RESULT); + private static final byte[] BAD_REQUEST_RESPONSE = "Unexpected! Original channel should not be null".getBytes(); + + private final int cacheTTLMs; + private final Time time; + private final boolean aclCacheEnabled; private final IdentityParser identityParser; private final ReadOnlyStoreRepository metadataRepository; @@ -38,12 +63,25 @@ public abstract class AbstractStoreAclHandler extends SimpleChanne public AbstractStoreAclHandler( IdentityParser identityParser, DynamicAccessController accessController, - ReadOnlyStoreRepository metadataRepository) { + ReadOnlyStoreRepository metadataRepository, + int cacheTTLMs) { + this(identityParser, accessController, metadataRepository, cacheTTLMs, new SystemTime()); + } + + public AbstractStoreAclHandler( + IdentityParser identityParser, + DynamicAccessController accessController, + ReadOnlyStoreRepository metadataRepository, + int cacheTTLMs, + Time time) { this.identityParser = identityParser; this.metadataRepository = metadataRepository; this.accessController = accessController .init(metadataRepository.getAllStores().stream().map(Store::getName).collect(Collectors.toList())); this.metadataRepository.registerStoreDataChangedListener(new AclCreationDeletionListener(accessController)); + this.cacheTTLMs = cacheTTLMs; + this.time = time; + this.aclCacheEnabled = cacheTTLMs > 0; } /** @@ -55,14 +93,18 @@ public AbstractStoreAclHandler( */ @Override public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLPeerUnverifiedException { - if (isAccessAlreadyApproved(ctx)) { + Channel originalChannel = ServerHandlerUtils.getOriginalChannel(ctx); + if (originalChannel == null) { + NettyUtils.setupResponseAndFlush(HttpResponseStatus.BAD_REQUEST, BAD_REQUEST_RESPONSE, false, ctx); + return; + } + if (isAccessAlreadyApproved(originalChannel)) { ReferenceCountUtil.retain(req); ctx.fireChannelRead(req); return; } String uri = req.uri(); - // Parse resource type and store name String[] requestParts = URI.create(uri).getPath().split("/"); REQUEST_TYPE requestType = validateRequest(requestParts); @@ -89,9 +131,41 @@ public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLP return; } - X509Certificate clientCert = extractClientCert(ctx); String method = req.method().name(); - AccessResult accessResult = checkAccess(uri, clientCert, storeName, method); + + if (!method.equals(HttpMethod.GET.name()) && !method.equals(HttpMethod.POST.name())) { + // Neither get nor post method, just let it pass + ReferenceCountUtil.retain(req); + ctx.fireChannelRead(req); + return; + } + AccessResult accessResult; + if (aclCacheEnabled) { + VeniceConcurrentHashMap storeAclCache = + originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).get(); + if (storeAclCache == null) { + originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).setIfAbsent(new VeniceConcurrentHashMap<>()); + storeAclCache = originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).get(); + } + + accessResult = storeAclCache.compute(storeName, (ignored, value) -> { + long currentTimestamp = time.getMilliseconds(); + if (value == null || currentTimestamp - value.timestamp > cacheTTLMs) { + try { + return new CachedAcl( + checkAccess(uri, extractClientCert(ctx), storeName, HttpMethod.GET.name()), + currentTimestamp); + } catch (Exception e) { + LOGGER.error("Error while checking access", e); + return new CachedAcl(AccessResult.ERROR_FORBIDDEN, currentTimestamp); + } + } else { + return value; + } + }).accessResult; + } else { + accessResult = checkAccess(uri, extractClientCert(ctx), storeName, HttpMethod.GET.name()); + } switch (accessResult) { case GRANTED: ReferenceCountUtil.retain(req); @@ -109,7 +183,7 @@ public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLP } } - protected boolean isAccessAlreadyApproved(ChannelHandlerContext ctx) { + protected boolean isAccessAlreadyApproved(Channel originalChannel) { return false; } diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/listener/ServerHandlerUtils.java b/internal/venice-common/src/main/java/com/linkedin/venice/listener/ServerHandlerUtils.java index f1c14b84369..71cb6d03565 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/listener/ServerHandlerUtils.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/listener/ServerHandlerUtils.java @@ -2,6 +2,7 @@ import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.utils.SslUtils; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.ssl.SslHandler; import java.security.cert.X509Certificate; @@ -26,6 +27,27 @@ public static SslHandler extractSslHandler(ChannelHandlerContext ctx) { return sslHandler; } + /** + * Return the channel, which contains the ssl handler and it could be the current channel (http/1.x) or the parent channel (http/2). + */ + public static Channel getOriginalChannel(ChannelHandlerContext ctx) { + /** + * Try to extract ssl handler in current channel, which is mostly for http/1.1 request. + */ + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler != null) { + return ctx.channel(); + } + /** + * Try to extract ssl handler in parent channel, which is for http/2 request. + */ + if (ctx.channel().parent() != null && ctx.channel().parent().pipeline().get(SslHandler.class) != null) { + return ctx.channel().parent(); + } + + return null; + } + public static X509Certificate extractClientCert(ChannelHandlerContext ctx) throws SSLPeerUnverifiedException { SslHandler sslHandler = ServerHandlerUtils.extractSslHandler(ctx); if (sslHandler != null) { diff --git a/internal/venice-common/src/test/java/com/linkedin/venice/acl/handler/AbstractStoreAclHandlerTest.java b/internal/venice-common/src/test/java/com/linkedin/venice/acl/handler/AbstractStoreAclHandlerTest.java index 5fff574240d..4adbd7afed7 100644 --- a/internal/venice-common/src/test/java/com/linkedin/venice/acl/handler/AbstractStoreAclHandlerTest.java +++ b/internal/venice-common/src/test/java/com/linkedin/venice/acl/handler/AbstractStoreAclHandlerTest.java @@ -1,7 +1,10 @@ package com.linkedin.venice.acl.handler; +import static com.linkedin.venice.acl.handler.AbstractStoreAclHandler.STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY; import static org.mockito.Mockito.any; import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -14,6 +17,9 @@ import com.linkedin.venice.common.VeniceSystemStoreUtils; import com.linkedin.venice.helix.HelixReadOnlyStoreRepository; import com.linkedin.venice.meta.Store; +import com.linkedin.venice.utils.TestMockTime; +import com.linkedin.venice.utils.Time; +import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -22,6 +28,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.ssl.SslHandler; +import io.netty.util.Attribute; import java.net.SocketAddress; import java.security.cert.Certificate; import java.security.cert.X509Certificate; @@ -33,12 +40,15 @@ public class AbstractStoreAclHandlerTest { + private static int CACHE_TTL_MS = 1000; private IdentityParser identityParser; private DynamicAccessController accessController; private HelixReadOnlyStoreRepository metadataRepo; private ChannelHandlerContext ctx; + private Channel channel; private HttpRequest req; private Store store; + private Time mockTime; private boolean[] needsAcl = { true }; private boolean[] hasAccess = { false }; private boolean[] hasAcl = { false }; @@ -83,12 +93,18 @@ public void setUp() throws Exception { when(sslSession.getPeerCertificates()).thenReturn(new Certificate[] { cert }); // Host - Channel channel = mock(Channel.class); + channel = mock(Channel.class); when(ctx.channel()).thenReturn(channel); SocketAddress address = mock(SocketAddress.class); when(channel.remoteAddress()).thenReturn(address); + Attribute aclCacheAttr = mock(Attribute.class); + doAnswer((ignored) -> new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + when(req.method()).thenReturn(HttpMethod.GET); + + mockTime = new TestMockTime(0); } @Test @@ -104,6 +120,38 @@ public void noAclNeeded() throws Exception { verify(ctx, times(32)).fireChannelRead(req); } + @Test + public void testAclCache() throws Exception { + hasAccess[0] = true; + hasStore[0] = true; + Attribute aclCacheAttr = mock(Attribute.class); + doReturn(new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + enumerate(hasAcl); + hasAccess[0] = true; + hasStore[0] = true; + enumerate(hasAcl); + verify(accessController).hasAccess(any(), any(), any()); + // Simulate cache expiration + mockTime.sleep(CACHE_TTL_MS + 1); + hasAccess[0] = true; + hasStore[0] = true; + enumerate(hasAcl); + verify(accessController, times(2)).hasAccess(any(), any(), any()); + + // Test cache disabled + // New metadataRepo mock and aclHandler every update since thenThrow cannot be re-mocked. + hasAccess[0] = true; + hasStore[0] = true; + metadataRepo = mock(HelixReadOnlyStoreRepository.class); + AbstractStoreAclHandler aclHandler = + spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo, -1, mockTime)); + update(); + aclHandler.channelRead0(ctx, req); + aclHandler.channelRead0(ctx, req); + verify(accessController, times(4)).hasAccess(any(), any(), any()); + } + @Test public void accessGranted() throws Exception { hasAccess[0] = true; @@ -299,7 +347,8 @@ private void enumerate(boolean[]... conditions) throws Exception { } // New metadataRepo mock and aclHandler every update since thenThrow cannot be re-mocked. metadataRepo = mock(HelixReadOnlyStoreRepository.class); - AbstractStoreAclHandler aclHandler = spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo)); + AbstractStoreAclHandler aclHandler = + spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo, mockTime)); update(); aclHandler.channelRead0(ctx, req); } @@ -334,8 +383,18 @@ private static class MockStoreAclHandler extends AbstractStoreAclHandler aclCacheAttr = mock(Attribute.class); + doReturn(new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + when(req.method()).thenReturn(HttpMethod.GET); HttpHeaders headers = mock(HttpHeaders.class); @@ -189,7 +197,7 @@ private void enumerate(boolean[]... conditions) throws Exception { // New metadataRepo mock and aclHandler every update since thenThrow cannot be re-mocked. metadataRepo = mock(HelixReadOnlyStoreRepository.class); AbstractStoreAclHandler aclHandler = - spy(new RouterStoreAclHandler(identityParser, accessController, metadataRepo)); + spy(new RouterStoreAclHandler(identityParser, accessController, metadataRepo, 1000)); update(); LOGGER.info("hasStore: {}, isBadUri: {}", hasStore[0], isBadUri[0]); aclHandler.channelRead0(ctx, req); @@ -326,12 +334,17 @@ public void testAllRequestTypes() throws SSLPeerUnverifiedException, AclExceptio doReturn(HttpMethod.GET).when(request).method(); IdentityParser identityParser = mock(IdentityParser.class); doReturn("testPrincipalId").when(identityParser).parseIdentityFromCert(certificate); + + Attribute aclCacheAttr = mock(Attribute.class); + doAnswer((ignored) -> new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + for (RouterResourceType resourceType: RouterResourceType.values()) { clearInvocations(ctx); MockAccessController mockAccessController = new MockAccessController(resourceType); MockAccessController spyMockAccessController = spy(mockAccessController); RouterStoreAclHandler storeAclHandler = - new RouterStoreAclHandler(identityParser, spyMockAccessController, metadataRepo); + new RouterStoreAclHandler(identityParser, spyMockAccessController, metadataRepo, 1000); doReturn(buildTestURI(resourceType)).when(request).uri(); storeAclHandler.channelRead0(ctx, request); diff --git a/services/venice-server/src/main/java/com/linkedin/venice/listener/HttpChannelInitializer.java b/services/venice-server/src/main/java/com/linkedin/venice/listener/HttpChannelInitializer.java index caf8eb7e4ea..31b4c4b2346 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/listener/HttpChannelInitializer.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/listener/HttpChannelInitializer.java @@ -118,7 +118,12 @@ public HttpChannelInitializer( this.identityParser = ReflectUtils.callConstructor(identityParserClass, new Class[0], new Object[0]); this.storeAclHandler = storeAccessController.isPresent() - ? Optional.of(new ServerStoreAclHandler(identityParser, storeAccessController.get(), storeMetadataRepository)) + ? Optional.of( + new ServerStoreAclHandler( + identityParser, + storeAccessController.get(), + storeMetadataRepository, + serverConfig.getAclInMemoryCacheTTLMs())) : Optional.empty(); /** * If the store-level access handler is present, we don't want to fail fast if the access gets denied by {@link ServerAclHandler}. diff --git a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerAclHandler.java b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerAclHandler.java index 755822c58a9..242a9f2c6b3 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerAclHandler.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerAclHandler.java @@ -16,6 +16,7 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; @@ -65,11 +66,24 @@ public ServerAclHandler(StaticAccessController accessController, boolean failOnA */ @Override public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLPeerUnverifiedException { - X509Certificate clientCert = extractClientCert(ctx); + Channel originalChannel = ServerHandlerUtils.getOriginalChannel(ctx); + if (originalChannel == null) { + LOGGER.error("Got a non-ssl request on what should be an ssl only port: {}", req.uri()); + NettyUtils.setupResponseAndFlush(HttpResponseStatus.FORBIDDEN, new byte[0], false, ctx); + return; + } String method = req.method().name(); + /** + * Server ACL check is typically against static ACL, so once the access is confirmed, + * we don't need to update it again per connection. + */ + Boolean accessApproved = originalChannel.attr(SERVER_ACL_APPROVED_ATTRIBUTE_KEY).get(); + if (accessApproved == null) { + X509Certificate clientCert = extractClientCert(ctx); - boolean accessApproved = accessController.hasAccess(clientCert, VeniceComponent.SERVER, method); - ctx.channel().attr(SERVER_ACL_APPROVED_ATTRIBUTE_KEY).set(accessApproved); + accessApproved = accessController.hasAccess(clientCert, VeniceComponent.SERVER, method); + originalChannel.attr(SERVER_ACL_APPROVED_ATTRIBUTE_KEY).set(accessApproved); + } if (accessApproved || !failOnAccessRejection) { ReferenceCountUtil.retain(req); ctx.fireChannelRead(req); diff --git a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java index 56cf2850f77..7fea1501976 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java @@ -12,13 +12,14 @@ import com.linkedin.venice.meta.ReadOnlyStoreRepository; import com.linkedin.venice.meta.Version; import com.linkedin.venice.protocols.VeniceClientRequest; +import com.linkedin.venice.utils.Time; import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.Channel; import io.netty.util.Attribute; import java.security.cert.X509Certificate; import java.util.EnumSet; @@ -56,8 +57,18 @@ public class ServerStoreAclHandler extends AbstractStoreAclHandler public ServerStoreAclHandler( IdentityParser identityParser, DynamicAccessController accessController, - ReadOnlyStoreRepository metadataRepository) { - super(identityParser, accessController, metadataRepository); + ReadOnlyStoreRepository metadataRepository, + int cacheTTLMs) { + super(identityParser, accessController, metadataRepository, cacheTTLMs); + } + + public ServerStoreAclHandler( + IdentityParser identityParser, + DynamicAccessController accessController, + ReadOnlyStoreRepository metadataRepository, + int cacheTTLMs, + Time time) { + super(identityParser, accessController, metadataRepository, cacheTTLMs, time); } @Override @@ -179,11 +190,11 @@ protected QueryAction validateRequest(String[] requestParts) { } @Override - protected boolean isAccessAlreadyApproved(ChannelHandlerContext ctx) { + protected boolean isAccessAlreadyApproved(Channel originalChannel) { /** * Access has been approved by {@link ServerAclHandler}. */ - Attribute serverAclApprovedAttr = ctx.channel().attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); + Attribute serverAclApprovedAttr = originalChannel.attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); return Boolean.TRUE.equals(serverAclApprovedAttr.get()); } diff --git a/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java b/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java index 9f2a0b74349..f633fa5d00e 100644 --- a/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java +++ b/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java @@ -1,10 +1,12 @@ package com.linkedin.venice.listener; +import static com.linkedin.venice.acl.handler.AbstractStoreAclHandler.STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY; import static io.grpc.Status.Code.INVALID_ARGUMENT; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -26,6 +28,9 @@ import com.linkedin.venice.meta.ServerAdminAction; import com.linkedin.venice.meta.Version; import com.linkedin.venice.protocols.VeniceClientRequest; +import com.linkedin.venice.utils.TestMockTime; +import com.linkedin.venice.utils.Time; +import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; @@ -148,6 +153,9 @@ public void testCheckWhetherAccessHasAlreadyApproved() { Channel channel = mock(Channel.class); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); doReturn(channel).when(ctx).channel(); + ChannelPipeline mockPipeline = mock(ChannelPipeline.class); + doReturn(mock(SslHandler.class)).when(mockPipeline).get(SslHandler.class); + doReturn(mockPipeline).when(ctx).pipeline(); Attribute accessAttr = mock(Attribute.class); doReturn(true).when(accessAttr).get(); doReturn(accessAttr).when(channel).attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); @@ -155,18 +163,21 @@ public void testCheckWhetherAccessHasAlreadyApproved() { ServerStoreAclHandler handler = new ServerStoreAclHandler( mock(IdentityParser.class), mock(DynamicAccessController.class), - mock(ReadOnlyStoreRepository.class)); + mock(ReadOnlyStoreRepository.class), + 1000); assertTrue( - handler.isAccessAlreadyApproved(ctx), + handler.isAccessAlreadyApproved(channel), "Should return true if it is already approved by previous acl handler"); doReturn(false).when(accessAttr).get(); assertFalse( - handler.isAccessAlreadyApproved(ctx), + handler.isAccessAlreadyApproved(channel), "Should return false if it is already denied by previous acl handler"); doReturn(null).when(accessAttr).get(); - assertFalse(handler.isAccessAlreadyApproved(ctx), "Should return false if it hasn't been processed by acl handler"); + assertFalse( + handler.isAccessAlreadyApproved(channel), + "Should return false if it hasn't been processed by acl handler"); } @Test @@ -190,7 +201,8 @@ public void testInterceptor() { ServerStoreAclHandler handler = new ServerStoreAclHandler( mock(IdentityParser.class), mock(DynamicAccessController.class), - mock(ReadOnlyStoreRepository.class)); + mock(ReadOnlyStoreRepository.class), + 1000); // next.intercept call should have been invoked handler.interceptCall(call, falseHeaders, next); @@ -204,6 +216,58 @@ public void testInterceptor() { verify(next, times(1)).startCall(call, trueHeaders); } + @Test + public void testAclCache() throws SSLPeerUnverifiedException, AclException, InterruptedException { + ReadOnlyStoreRepository metadataRepo = mock(ReadOnlyStoreRepository.class); + when(metadataRepo.hasStore(TEST_STORE_NAME)).thenReturn(true); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + HttpRequest request = mock(HttpRequest.class); + Channel channel = mock(Channel.class); + SocketAddress socketAddress = mock(SocketAddress.class); + doReturn("testRemoteHost").when(socketAddress).toString(); + doReturn(socketAddress).when(channel).remoteAddress(); + + Attribute accessAttr = mock(Attribute.class); + doReturn(false).when(accessAttr).get(); + doReturn(accessAttr).when(channel).attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); + + Attribute aclCacheAttr = mock(Attribute.class); + doReturn(new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + + doReturn(channel).when(ctx).channel(); + SslHandler sslHandler = mock(SslHandler.class); + ChannelPipeline channelPipeline = mock(ChannelPipeline.class); + doReturn(sslHandler).when(channelPipeline).get(SslHandler.class); + SSLEngine sslEngine = mock(SSLEngine.class); + SSLSession sslSession = mock(SSLSession.class); + X509Certificate certificate = mock(X509Certificate.class); + Certificate[] certificates = new Certificate[1]; + certificates[0] = certificate; + doReturn(certificates).when(sslSession).getPeerCertificates(); + doReturn(sslSession).when(sslEngine).getSession(); + doReturn(sslEngine).when(sslHandler).engine(); + doReturn(channelPipeline).when(ctx).pipeline(); + doReturn(HttpMethod.GET).when(request).method(); + IdentityParser identityParser = mock(IdentityParser.class); + doReturn("testPrincipalId").when(identityParser).parseIdentityFromCert(certificate); + MockAccessController mockAccessController = new MockAccessController(QueryAction.STORAGE); + MockAccessController spyMockAccessController = spy(mockAccessController); + doReturn(buildTestURI(QueryAction.STORAGE)).when(request).uri(); + Time mockTime = new TestMockTime(0); + ServerStoreAclHandler storeAclHandler = + new ServerStoreAclHandler(identityParser, spyMockAccessController, metadataRepo, 1000, mockTime); + storeAclHandler.channelRead0(ctx, request); + verify(spyMockAccessController).hasAccess(any(), eq(TEST_STORE_NAME), any()); + // Verify that acl cache is populated by sending another request + storeAclHandler.channelRead0(ctx, request); + verify(spyMockAccessController, times(1)).hasAccess(any(), eq(TEST_STORE_NAME), any()); + // Make sure the cache is expired + mockTime.sleep(2000); + storeAclHandler.channelRead0(ctx, request); + verify(spyMockAccessController, times(2)).hasAccess(any(), eq(TEST_STORE_NAME), any()); + } + @Test public void testAllRequestTypes() throws SSLPeerUnverifiedException, AclException { ReadOnlyStoreRepository metadataRepo = mock(ReadOnlyStoreRepository.class); @@ -214,9 +278,15 @@ public void testAllRequestTypes() throws SSLPeerUnverifiedException, AclExceptio SocketAddress socketAddress = mock(SocketAddress.class); doReturn("testRemoteHost").when(socketAddress).toString(); doReturn(socketAddress).when(channel).remoteAddress(); + Attribute accessAttr = mock(Attribute.class); doReturn(false).when(accessAttr).get(); doReturn(accessAttr).when(channel).attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); + + Attribute aclCacheAttr = mock(Attribute.class); + doAnswer((ignored) -> new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get(); + doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY); + doReturn(channel).when(ctx).channel(); SslHandler sslHandler = mock(SslHandler.class); ChannelPipeline channelPipeline = mock(ChannelPipeline.class); @@ -237,7 +307,7 @@ public void testAllRequestTypes() throws SSLPeerUnverifiedException, AclExceptio MockAccessController mockAccessController = new MockAccessController(queryAction); MockAccessController spyMockAccessController = spy(mockAccessController); ServerStoreAclHandler storeAclHandler = - new ServerStoreAclHandler(identityParser, spyMockAccessController, metadataRepo); + new ServerStoreAclHandler(identityParser, spyMockAccessController, metadataRepo, 1000); doReturn(buildTestURI(queryAction)).when(request).uri(); storeAclHandler.channelRead0(ctx, request); @@ -298,7 +368,8 @@ public void testInvalidRequest() { ServerStoreAclHandler handler = new ServerStoreAclHandler( mock(IdentityParser.class), mock(DynamicAccessController.class), - mock(ReadOnlyStoreRepository.class)); + mock(ReadOnlyStoreRepository.class), + 1000); // Happy path is tested in "testAllRequestTypes". Only test the invalid paths @@ -334,7 +405,8 @@ public void testValidateStoreAclForGRPC() throws SSLPeerUnverifiedException, Acl doReturn(certificates).when(sslSession).getPeerCertificates(); doReturn("identity").when(identityParser).parseIdentityFromCert(certificate); - ServerStoreAclHandler handler = new ServerStoreAclHandler(identityParser, accessController, metadataRepository); + ServerStoreAclHandler handler = + new ServerStoreAclHandler(identityParser, accessController, metadataRepository, 1000); // Empty store name VeniceClientRequest emptyStoreRequest = VeniceClientRequest.newBuilder().build();