diff --git a/internal/venice-common/src/main/java/com/linkedin/venice/acl/handler/StoreAclHandler.java b/internal/venice-common/src/main/java/com/linkedin/venice/acl/handler/StoreAclHandler.java index 354a440a5cd..d8e5063f54d 100644 --- a/internal/venice-common/src/main/java/com/linkedin/venice/acl/handler/StoreAclHandler.java +++ b/internal/venice-common/src/main/java/com/linkedin/venice/acl/handler/StoreAclHandler.java @@ -108,7 +108,7 @@ public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLP X509Certificate clientCert = extractClientCert(ctx); String resourceName = requestParts[2]; - String storeName = extractStoreName(resourceName); + String storeName = extractStoreName(resourceName, queryAction); try { // Check ACL in case of non system store as system store contain public information @@ -132,7 +132,8 @@ public ServerCall.Listener interceptCall( @Override public void onMessage(ReqT message) { VeniceClientRequest request = (VeniceClientRequest) message; - String storeName = extractStoreName(request.getResourceName()); + // For now, GRPC only supports STORAGE query + String storeName = extractStoreName(request.getResourceName(), QueryAction.STORAGE); String method = request.getMethod(); BiConsumer grpcCloseConsumer = call::close; @@ -168,7 +169,7 @@ public void onMessage(ReqT message) { /** * Extract the store name from the incoming resource name. */ - protected String extractStoreName(String resourceName) { + protected String extractStoreName(String resourceName, QueryAction queryAction) { return resourceName; } diff --git a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java index 5ac3ca446d4..f482b1219db 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/listener/ServerStoreAclHandler.java @@ -2,6 +2,7 @@ import com.linkedin.venice.acl.DynamicAccessController; import com.linkedin.venice.acl.handler.StoreAclHandler; +import com.linkedin.venice.meta.QueryAction; import com.linkedin.venice.meta.ReadOnlyStoreRepository; import com.linkedin.venice.meta.Version; import io.grpc.Metadata; @@ -58,11 +59,20 @@ public ServerCall.Listener interceptCall( } /** - * In Venice Server, the resource name is actually a Kafka topic name. + * In Venice Server, the resource name is actually a Kafka topic name for STORAGE/COMPUTE but store name for DICTIONARY. */ @Override - protected String extractStoreName(String resourceName) { - return Version.parseStoreFromKafkaTopicName(resourceName); + protected String extractStoreName(String resourceName, QueryAction queryAction) { + switch (queryAction) { + case STORAGE: + case COMPUTE: + return Version.parseStoreFromKafkaTopicName(resourceName); + case DICTIONARY: + return resourceName; + default: + throw new IllegalArgumentException( + String.format("Unexpected QueryAction: %s with resource name: %s", queryAction, resourceName)); + } } protected static boolean checkWhetherAccessHasAlreadyApproved(ChannelHandlerContext ctx) { diff --git a/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java b/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java index af3c38808d4..f6b5d83785b 100644 --- a/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java +++ b/services/venice-server/src/test/java/com/linkedin/venice/listener/ServerStoreAclHandlerTest.java @@ -1,24 +1,134 @@ package com.linkedin.venice.listener; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; +import com.linkedin.venice.acl.AclException; import com.linkedin.venice.acl.DynamicAccessController; +import com.linkedin.venice.acl.handler.StoreAclHandler; +import com.linkedin.venice.meta.QueryAction; import com.linkedin.venice.meta.ReadOnlyStoreRepository; +import com.linkedin.venice.meta.ServerAdminAction; +import com.linkedin.venice.meta.Version; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.ssl.SslHandler; import io.netty.util.Attribute; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Set; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; import org.testng.annotations.Test; public class ServerStoreAclHandlerTest { + // Store name can be in a version topic format + private static final String TEST_STORE_NAME = "testStore_v1"; + private static final String TEST_STORE_VERSION = Version.composeKafkaTopic(TEST_STORE_NAME, 1); + + /** + * Mock access controller to verify basic request parsing and handling for {@link ServerStoreAclHandler} + */ + private class MockAccessController implements DynamicAccessController { + private QueryAction queryAction; + + public MockAccessController(QueryAction queryAction) { + this.queryAction = queryAction; + } + + @Override + public boolean hasAccessToTopic(X509Certificate clientCert, String resource, String method) throws AclException { + assertNotNull(clientCert, queryAction.toString()); + validateStringArg(resource, "resource"); + validateStringArg(method, "method"); + return true; + } + + @Override + public boolean hasAccessToAdminOperation(X509Certificate clientCert, String operation) throws AclException { + assertNotNull(clientCert, queryAction.toString()); + validateStringArg(operation, "operation"); + return true; + } + + @Override + public boolean isAllowlistUsers(X509Certificate clientCert, String resource, String method) { + assertNotNull(clientCert, queryAction.toString()); + validateStringArg(resource, "resource"); + validateStringArg(method, "method"); + return true; + } + + @Override + public String getPrincipalId(X509Certificate clientCert) { + assertNotNull(clientCert, queryAction.toString()); + return "testPrincipalId"; + } + + @Override + public DynamicAccessController init(List resources) { + return this; + } + + @Override + public boolean hasAccess(X509Certificate clientCert, String resource, String method) throws AclException { + assertNotNull(clientCert); + validateStringArg(resource, "resource"); + validateStringArg(method, "method"); + return true; + } + + @Override + public boolean hasAcl(String resource) throws AclException { + validateStringArg(resource, "resource"); + return true; + } + + @Override + public void addAcl(String resource) throws AclException { + validateStringArg(resource, "resource"); + } + + @Override + public void removeAcl(String resource) throws AclException { + validateStringArg(resource, "resource"); + } + + @Override + public Set getAccessControlledResources() { + return null; + } + + @Override + public boolean isFailOpen() { + return false; + } + + private void validateStringArg(String arg, String argName) { + assertNotNull(arg, argName + " should not be null for query action " + queryAction.toString()); + assertFalse(arg.isEmpty(), argName + " should not be empty string for query action " + queryAction.toString()); + } + } + @Test public void testCheckWhetherAccessHasAlreadyApproved() { Channel channel = mock(Channel.class); @@ -74,4 +184,80 @@ public void testInterceptor() { handler.interceptCall(call, trueHeaders, next); verify(next, times(1)).startCall(call, trueHeaders); } + + @Test + public void testAllRequestTypes() throws SSLPeerUnverifiedException, AclException { + ReadOnlyStoreRepository metadataRepo = mock(ReadOnlyStoreRepository.class); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + HttpRequest request = mock(HttpRequest.class); + Channel channel = mock(Channel.class); + SocketAddress socketAddress = mock(SocketAddress.class); + doReturn("testRemoteHost").when(socketAddress).toString(); + doReturn(socketAddress).when(channel).remoteAddress(); + Attribute accessAttr = mock(Attribute.class); + doReturn(false).when(accessAttr).get(); + doReturn(accessAttr).when(channel).attr(ServerAclHandler.SERVER_ACL_APPROVED_ATTRIBUTE_KEY); + doReturn(channel).when(ctx).channel(); + SslHandler sslHandler = mock(SslHandler.class); + ChannelPipeline channelPipeline = mock(ChannelPipeline.class); + doReturn(sslHandler).when(channelPipeline).get(SslHandler.class); + SSLEngine sslEngine = mock(SSLEngine.class); + SSLSession sslSession = mock(SSLSession.class); + Certificate certificate = mock(X509Certificate.class); + Certificate[] certificates = new Certificate[1]; + certificates[0] = certificate; + doReturn(certificates).when(sslSession).getPeerCertificates(); + doReturn(sslSession).when(sslEngine).getSession(); + doReturn(sslEngine).when(sslHandler).engine(); + doReturn(channelPipeline).when(ctx).pipeline(); + doReturn(HttpMethod.GET).when(request).method(); + for (QueryAction queryAction: QueryAction.values()) { + MockAccessController mockAccessController = new MockAccessController(queryAction); + MockAccessController spyMockAccessController = spy(mockAccessController); + StoreAclHandler storeAclHandler = new ServerStoreAclHandler(spyMockAccessController, metadataRepo); + doReturn(buildTestURI(queryAction)).when(request).uri(); + storeAclHandler.channelRead0(ctx, request); + switch (queryAction) { + case ADMIN: + case CURRENT_VERSION: + case HEALTH: + case METADATA: + case TOPIC_PARTITION_INGESTION_CONTEXT: + verify(spyMockAccessController, never()).hasAccess(any(), any(), any()); + break; + case STORAGE: + case COMPUTE: + case DICTIONARY: + verify(spyMockAccessController).hasAccess(any(), eq(TEST_STORE_NAME), any()); + break; + default: + throw new IllegalArgumentException("Invalid query action: " + queryAction); + } + } + } + + private String buildTestURI(QueryAction queryAction) { + switch (queryAction) { + case STORAGE: + return "/" + QueryAction.STORAGE.toString().toLowerCase() + "/" + TEST_STORE_VERSION + "/1/ABCDEFG"; + case HEALTH: + return "/" + QueryAction.HEALTH.toString().toLowerCase(); + case COMPUTE: + return "/" + QueryAction.COMPUTE.toString().toLowerCase() + "/" + TEST_STORE_VERSION; + case DICTIONARY: + return "/" + QueryAction.DICTIONARY.toString().toLowerCase() + "/" + TEST_STORE_NAME + "/1"; + case ADMIN: + return "/" + QueryAction.ADMIN.toString().toLowerCase() + "/" + TEST_STORE_VERSION + "/" + + ServerAdminAction.DUMP_INGESTION_STATE; + case METADATA: + return "/" + QueryAction.METADATA.toString().toLowerCase() + "/" + TEST_STORE_NAME; + case CURRENT_VERSION: + return "/" + QueryAction.CURRENT_VERSION.toString().toLowerCase() + "/" + TEST_STORE_NAME; + case TOPIC_PARTITION_INGESTION_CONTEXT: + return "/" + QueryAction.TOPIC_PARTITION_INGESTION_CONTEXT.toString().toLowerCase() + "/" + TEST_STORE_VERSION + + "/" + TEST_STORE_VERSION + "/1"; + default: + throw new IllegalArgumentException("Invalid query action: " + queryAction); + } + } }