Skip to content

Commit

Permalink
[router][server] ACL optimization (#1521)
Browse files Browse the repository at this point in the history
* [router][server] ACL optimization

High-level idea:
1. For a keep-alive connection, the client cert will never change,
   so for the same store, it is useless to validate each request
   on the same connection.
2. In Server, there is an Acl Handler called `ServerAclHandler`, which
   is used to validate whether the connection is from Venice Router or
   not via static ACL. For each connection, Server will buffer the ACL
   check result in this attribute key: `SERVER_ACL_APPROVED_ATTRIBUTE_KEY`.
   And the ACL check result won't change during the lifetime of the Server
   instance.
3. In both Router and Server, there is a store-level ACL check, which
   can change during the lifetime of the Router/Server (ACL added/removed).
   The caching idea is a little different from #2, and it will maintain
   a cache map in the original connection, and for all the requests coming
   from this particular connection, it will check the acl check cache map first,
   and it will follow the previous result if the cache entry is not expired.
   If there is no such entry or the cache entry is expired, it will resort
   to the underlying access control to update the cache map.
In theory, the acl check against the access controller will be minimized
a lot.

New config:
acl.in.memory.cache.ttl.ms: 60000 (by default)
  • Loading branch information
gaojieliu authored Feb 12, 2025
1 parent aa7dd9a commit b68eb16
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static com.linkedin.davinci.ingestion.utils.IsolatedIngestionUtils.INGESTION_ISOLATION_CONFIG_PREFIX;
import static com.linkedin.davinci.store.rocksdb.RocksDBServerConfig.ROCKSDB_TOTAL_MEMTABLE_USAGE_CAP_IN_BYTES;
import static com.linkedin.venice.ConfigConstants.DEFAULT_MAX_RECORD_SIZE_BYTES_BACKFILL;
import static com.linkedin.venice.ConfigKeys.ACL_IN_MEMORY_CACHE_TTL_MS;
import static com.linkedin.venice.ConfigKeys.AUTOCREATE_DATA_PATH;
import static com.linkedin.venice.ConfigKeys.BLOB_TRANSFER_DISABLED_OFFSET_LAG_THRESHOLD;
import static com.linkedin.venice.ConfigKeys.BLOB_TRANSFER_MANAGER_ENABLED;
Expand Down Expand Up @@ -584,6 +585,7 @@ public class VeniceServerConfig extends VeniceClusterConfig {
private final int zstdDictCompressionLevel;
private final long maxWaitAfterUnsubscribeMs;
private final boolean deleteUnassignedPartitionsOnStartup;
private final int aclInMemoryCacheTTLMs;

public VeniceServerConfig(VeniceProperties serverProperties) throws ConfigurationException {
this(serverProperties, Collections.emptyMap());
Expand Down Expand Up @@ -988,6 +990,8 @@ public VeniceServerConfig(VeniceProperties serverProperties, Map<String, Map<Str

deleteUnassignedPartitionsOnStartup =
serverProperties.getBoolean(SERVER_DELETE_UNASSIGNED_PARTITIONS_ON_STARTUP, false);
aclInMemoryCacheTTLMs = serverProperties.getInt(ACL_IN_MEMORY_CACHE_TTL_MS, -1); // acl caching is disabled by
// default
}

long extractIngestionMemoryLimit(
Expand Down Expand Up @@ -1801,4 +1805,8 @@ public long getMaxWaitAfterUnsubscribeMs() {
public boolean isDeleteUnassignedPartitionsOnStartupEnabled() {
return deleteUnassignedPartitionsOnStartup;
}

public int getAclInMemoryCacheTTLMs() {
return aclInMemoryCacheTTLMs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2442,4 +2442,10 @@ private ConfigKeys() {
public static final String CONTROLLER_DEFERRED_VERSION_SWAP_SLEEP_MS = "controller.deferred.version.swap.sleep.ms";
public static final String CONTROLLER_DEFERRED_VERSION_SWAP_SERVICE_ENABLED =
"controller.deferred.version.swap.service.enabled";

/*
* Both Router and Server will maintain an in-memory cache for connection-level ACLs and the following config
* controls the TTL of the cache per entry.
*/
public static final String ACL_IN_MEMORY_CACHE_TTL_MS = "acl.in.memory.cache.ttl.ms";
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
import com.linkedin.venice.acl.DynamicAccessController;
import com.linkedin.venice.authorization.IdentityParser;
import com.linkedin.venice.common.VeniceSystemStoreUtils;
import com.linkedin.venice.listener.ServerHandlerUtils;
import com.linkedin.venice.meta.ReadOnlyStoreRepository;
import com.linkedin.venice.meta.Store;
import com.linkedin.venice.utils.NettyUtils;
import com.linkedin.venice.utils.SystemTime;
import com.linkedin.venice.utils.Time;
import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import java.net.URI;
import java.security.cert.X509Certificate;
Expand All @@ -29,7 +36,25 @@
*/
@ChannelHandler.Sharable
public abstract class AbstractStoreAclHandler<REQUEST_TYPE> extends SimpleChannelInboundHandler<HttpRequest> {
private static class CachedAcl {
AccessResult accessResult;
long timestamp;

public CachedAcl(AccessResult accessResult, long timestamp) {
this.accessResult = accessResult;
this.timestamp = timestamp;
}
}

private static final Logger LOGGER = LogManager.getLogger(AbstractStoreAclHandler.class);
public static final String STORE_ACL_CHECK_RESULT = "STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY";
public static final AttributeKey<VeniceConcurrentHashMap<String, CachedAcl>> STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY =
AttributeKey.valueOf(STORE_ACL_CHECK_RESULT);
private static final byte[] BAD_REQUEST_RESPONSE = "Unexpected! Original channel should not be null".getBytes();

private final int cacheTTLMs;
private final Time time;
private final boolean aclCacheEnabled;

private final IdentityParser identityParser;
private final ReadOnlyStoreRepository metadataRepository;
Expand All @@ -38,12 +63,25 @@ public abstract class AbstractStoreAclHandler<REQUEST_TYPE> extends SimpleChanne
public AbstractStoreAclHandler(
IdentityParser identityParser,
DynamicAccessController accessController,
ReadOnlyStoreRepository metadataRepository) {
ReadOnlyStoreRepository metadataRepository,
int cacheTTLMs) {
this(identityParser, accessController, metadataRepository, cacheTTLMs, new SystemTime());
}

public AbstractStoreAclHandler(
IdentityParser identityParser,
DynamicAccessController accessController,
ReadOnlyStoreRepository metadataRepository,
int cacheTTLMs,
Time time) {
this.identityParser = identityParser;
this.metadataRepository = metadataRepository;
this.accessController = accessController
.init(metadataRepository.getAllStores().stream().map(Store::getName).collect(Collectors.toList()));
this.metadataRepository.registerStoreDataChangedListener(new AclCreationDeletionListener(accessController));
this.cacheTTLMs = cacheTTLMs;
this.time = time;
this.aclCacheEnabled = cacheTTLMs > 0;
}

/**
Expand All @@ -55,14 +93,18 @@ public AbstractStoreAclHandler(
*/
@Override
public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLPeerUnverifiedException {
if (isAccessAlreadyApproved(ctx)) {
Channel originalChannel = ServerHandlerUtils.getOriginalChannel(ctx);
if (originalChannel == null) {
NettyUtils.setupResponseAndFlush(HttpResponseStatus.BAD_REQUEST, BAD_REQUEST_RESPONSE, false, ctx);
return;
}
if (isAccessAlreadyApproved(originalChannel)) {
ReferenceCountUtil.retain(req);
ctx.fireChannelRead(req);
return;
}

String uri = req.uri();

// Parse resource type and store name
String[] requestParts = URI.create(uri).getPath().split("/");
REQUEST_TYPE requestType = validateRequest(requestParts);
Expand All @@ -89,9 +131,41 @@ public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLP
return;
}

X509Certificate clientCert = extractClientCert(ctx);
String method = req.method().name();
AccessResult accessResult = checkAccess(uri, clientCert, storeName, method);

if (!method.equals(HttpMethod.GET.name()) && !method.equals(HttpMethod.POST.name())) {
// Neither get nor post method, just let it pass
ReferenceCountUtil.retain(req);
ctx.fireChannelRead(req);
return;
}
AccessResult accessResult;
if (aclCacheEnabled) {
VeniceConcurrentHashMap<String, CachedAcl> storeAclCache =
originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).get();
if (storeAclCache == null) {
originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).setIfAbsent(new VeniceConcurrentHashMap<>());
storeAclCache = originalChannel.attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY).get();
}

accessResult = storeAclCache.compute(storeName, (ignored, value) -> {
long currentTimestamp = time.getMilliseconds();
if (value == null || currentTimestamp - value.timestamp > cacheTTLMs) {
try {
return new CachedAcl(
checkAccess(uri, extractClientCert(ctx), storeName, HttpMethod.GET.name()),
currentTimestamp);
} catch (Exception e) {
LOGGER.error("Error while checking access", e);
return new CachedAcl(AccessResult.ERROR_FORBIDDEN, currentTimestamp);
}
} else {
return value;
}
}).accessResult;
} else {
accessResult = checkAccess(uri, extractClientCert(ctx), storeName, HttpMethod.GET.name());
}
switch (accessResult) {
case GRANTED:
ReferenceCountUtil.retain(req);
Expand All @@ -109,7 +183,7 @@ public void channelRead0(ChannelHandlerContext ctx, HttpRequest req) throws SSLP
}
}

protected boolean isAccessAlreadyApproved(ChannelHandlerContext ctx) {
protected boolean isAccessAlreadyApproved(Channel originalChannel) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.utils.SslUtils;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.ssl.SslHandler;
import java.security.cert.X509Certificate;
Expand All @@ -26,6 +27,27 @@ public static SslHandler extractSslHandler(ChannelHandlerContext ctx) {
return sslHandler;
}

/**
* Return the channel, which contains the ssl handler and it could be the current channel (http/1.x) or the parent channel (http/2).
*/
public static Channel getOriginalChannel(ChannelHandlerContext ctx) {
/**
* Try to extract ssl handler in current channel, which is mostly for http/1.1 request.
*/
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler != null) {
return ctx.channel();
}
/**
* Try to extract ssl handler in parent channel, which is for http/2 request.
*/
if (ctx.channel().parent() != null && ctx.channel().parent().pipeline().get(SslHandler.class) != null) {
return ctx.channel().parent();
}

return null;
}

public static X509Certificate extractClientCert(ChannelHandlerContext ctx) throws SSLPeerUnverifiedException {
SslHandler sslHandler = ServerHandlerUtils.extractSslHandler(ctx);
if (sslHandler != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.linkedin.venice.acl.handler;

import static com.linkedin.venice.acl.handler.AbstractStoreAclHandler.STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
Expand All @@ -14,6 +17,9 @@
import com.linkedin.venice.common.VeniceSystemStoreUtils;
import com.linkedin.venice.helix.HelixReadOnlyStoreRepository;
import com.linkedin.venice.meta.Store;
import com.linkedin.venice.utils.TestMockTime;
import com.linkedin.venice.utils.Time;
import com.linkedin.venice.utils.concurrent.VeniceConcurrentHashMap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
Expand All @@ -22,6 +28,7 @@
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.Attribute;
import java.net.SocketAddress;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
Expand All @@ -33,12 +40,15 @@


public class AbstractStoreAclHandlerTest {
private static int CACHE_TTL_MS = 1000;
private IdentityParser identityParser;
private DynamicAccessController accessController;
private HelixReadOnlyStoreRepository metadataRepo;
private ChannelHandlerContext ctx;
private Channel channel;
private HttpRequest req;
private Store store;
private Time mockTime;
private boolean[] needsAcl = { true };
private boolean[] hasAccess = { false };
private boolean[] hasAcl = { false };
Expand Down Expand Up @@ -83,12 +93,18 @@ public void setUp() throws Exception {
when(sslSession.getPeerCertificates()).thenReturn(new Certificate[] { cert });

// Host
Channel channel = mock(Channel.class);
channel = mock(Channel.class);
when(ctx.channel()).thenReturn(channel);
SocketAddress address = mock(SocketAddress.class);
when(channel.remoteAddress()).thenReturn(address);

Attribute<VeniceConcurrentHashMap> aclCacheAttr = mock(Attribute.class);
doAnswer((ignored) -> new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get();
doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY);

when(req.method()).thenReturn(HttpMethod.GET);

mockTime = new TestMockTime(0);
}

@Test
Expand All @@ -104,6 +120,38 @@ public void noAclNeeded() throws Exception {
verify(ctx, times(32)).fireChannelRead(req);
}

@Test
public void testAclCache() throws Exception {
hasAccess[0] = true;
hasStore[0] = true;
Attribute<VeniceConcurrentHashMap> aclCacheAttr = mock(Attribute.class);
doReturn(new VeniceConcurrentHashMap<>()).when(aclCacheAttr).get();
doReturn(aclCacheAttr).when(channel).attr(STORE_ACL_CHECK_RESULT_ATTRIBUTE_KEY);
enumerate(hasAcl);
hasAccess[0] = true;
hasStore[0] = true;
enumerate(hasAcl);
verify(accessController).hasAccess(any(), any(), any());
// Simulate cache expiration
mockTime.sleep(CACHE_TTL_MS + 1);
hasAccess[0] = true;
hasStore[0] = true;
enumerate(hasAcl);
verify(accessController, times(2)).hasAccess(any(), any(), any());

// Test cache disabled
// New metadataRepo mock and aclHandler every update since thenThrow cannot be re-mocked.
hasAccess[0] = true;
hasStore[0] = true;
metadataRepo = mock(HelixReadOnlyStoreRepository.class);
AbstractStoreAclHandler aclHandler =
spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo, -1, mockTime));
update();
aclHandler.channelRead0(ctx, req);
aclHandler.channelRead0(ctx, req);
verify(accessController, times(4)).hasAccess(any(), any(), any());
}

@Test
public void accessGranted() throws Exception {
hasAccess[0] = true;
Expand Down Expand Up @@ -299,7 +347,8 @@ private void enumerate(boolean[]... conditions) throws Exception {
}
// New metadataRepo mock and aclHandler every update since thenThrow cannot be re-mocked.
metadataRepo = mock(HelixReadOnlyStoreRepository.class);
AbstractStoreAclHandler aclHandler = spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo));
AbstractStoreAclHandler aclHandler =
spy(new MockStoreAclHandler(identityParser, accessController, metadataRepo, mockTime));
update();
aclHandler.channelRead0(ctx, req);
}
Expand Down Expand Up @@ -334,8 +383,18 @@ private static class MockStoreAclHandler extends AbstractStoreAclHandler<TestReq
public MockStoreAclHandler(
IdentityParser identityParser,
DynamicAccessController accessController,
HelixReadOnlyStoreRepository metadataRepository) {
super(identityParser, accessController, metadataRepository);
HelixReadOnlyStoreRepository metadataRepository,
Time mockTime) {
super(identityParser, accessController, metadataRepository, CACHE_TTL_MS, mockTime);
}

public MockStoreAclHandler(
IdentityParser identityParser,
DynamicAccessController accessController,
HelixReadOnlyStoreRepository metadataRepository,
int cacheTTLMs,
Time mockTime) {
super(identityParser, accessController, metadataRepository, cacheTTLMs, mockTime);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,11 @@ public boolean startInner() throws Exception {

RouterSslVerificationHandler routerSslVerificationHandler = new RouterSslVerificationHandler(securityStats);
RouterStoreAclHandler aclHandler = accessController.isPresent()
? new RouterStoreAclHandler(identityParser, accessController.get(), metadataRepository)
? new RouterStoreAclHandler(
identityParser,
accessController.get(),
metadataRepository,
config.getAclInMemoryCacheTTLMs())
: null;
final SslInitializer sslInitializer;
if (sslFactory.isPresent()) {
Expand Down
Loading

0 comments on commit b68eb16

Please sign in to comment.