diff --git a/driver-core/src/main/java/com/datastax/driver/core/Metrics.java b/driver-core/src/main/java/com/datastax/driver/core/Metrics.java index 6eacfdf0e28..7c9437319c4 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/Metrics.java +++ b/driver-core/src/main/java/com/datastax/driver/core/Metrics.java @@ -130,7 +130,35 @@ public Integer getValue() { return value; } }); + private final Gauge>> perShardInflightRequestInfo = + registry.register( + "per-shard-inflight-request-info", + new Gauge>>() { + @Override + public Map> getValue() { + Map> result = new HashMap>(); + for (SessionManager session : manager.sessions) { + for (Map.Entry poolEntry : session.pools.entrySet()) { + HostConnectionPool hostConnectionPool = poolEntry.getValue(); + Map perShardInflightRequests = new HashMap(); + + for (int shardId = 0; + shardId < hostConnectionPool.connections.length; + shardId++) { + int shardInflightRequests = 0; + for (Connection connection : hostConnectionPool.connections[shardId]) { + shardInflightRequests += connection.inFlight.get(); + } + perShardInflightRequests.put(shardId, shardInflightRequests); + } + + result.put(poolEntry.getKey(), perShardInflightRequests); + } + } + return result; + } + }); private final Gauge executorQueueDepth; private final Gauge blockingExecutorQueueDepth; private final Gauge reconnectionSchedulerQueueSize; @@ -374,6 +402,10 @@ public Gauge> getShardAwarenessInfo() { return shardAwarenessInfo; } + public Gauge>> getPerShardInflightRequestInfo() { + return perShardInflightRequestInfo; + } + /** * Returns the number of bytes sent so far. * diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java new file mode 100644 index 00000000000..a9830244be3 --- /dev/null +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java @@ -0,0 +1,157 @@ +package com.datastax.driver.core.policies; + +import com.datastax.driver.core.*; +import com.google.common.collect.AbstractIterator; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A wrapper load balancing policy that adds "Power of 2 Choice" algorithm to a child policy. + * + *

This policy encapsulates another policy. The resulting policy works in the following way: + * + *

    + *
  • the {@code distance} method is inherited from the child policy. + *
  • the {@code newQueryPlan} method will compare first two hosts (by number of inflight + * requests) returned from the {@code newQueryPlan} method of the child policy, and the host + * with fewer number of inflight requests will be returned the first. It will allow to always + * avoid the worst option (comparing by number of inflight requests). + *
  • besides the first two hosts returned by the child policy's {@code newQueryPlan} method, the + * ordering of the rest of the hosts will remain the same. + *
+ * + *

If you wrap the {@code RandomTwoChoicePolicy} policy with {@code TokenAwarePolicy}, it will + * compare the first two replicas by the number of inflight requests, and the worse option will + * always be avoided. In that case, it is recommended to use the TokenAwarePolicy with {@code + * ReplicaOrdering.RANDOM strategy}, which will return the replicas in a shuffled order and thus + * will make the "Power of 2 Choice" algorithm more efficient. + */ +public class RandomTwoChoicePolicy implements ChainableLoadBalancingPolicy { + private final LoadBalancingPolicy childPolicy; + private volatile Metrics metrics; + private volatile Metadata clusterMetadata; + private volatile ProtocolVersion protocolVersion; + private volatile CodecRegistry codecRegistry; + + /** + * Creates a new {@code RandomTwoChoicePolicy}. + * + * @param childPolicy the load balancing policy to wrap with "Power of 2 Choice" algorithm. + */ + public RandomTwoChoicePolicy(LoadBalancingPolicy childPolicy) { + this.childPolicy = childPolicy; + } + + @Override + public LoadBalancingPolicy getChildPolicy() { + return childPolicy; + } + + @Override + public void init(Cluster cluster, Collection hosts) { + this.metrics = cluster.getMetrics(); + this.clusterMetadata = cluster.getMetadata(); + this.protocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion(); + this.codecRegistry = cluster.getConfiguration().getCodecRegistry(); + childPolicy.init(cluster, hosts); + } + + /** + * {@inheritDoc} + * + *

This implementation always returns distances as reported by the wrapped policy. + */ + @Override + public HostDistance distance(Host host) { + return childPolicy.distance(host); + } + + /** + * {@inheritDoc} + * + *

The returned plan will compare (by the number of inflight requests) the first 2 hosts + * returned by the child policy's {@code newQueryPlan} method, and the host with fewer inflight + * requests will be returned the first. The rest of the child policy's query plan will be left + * intact. + */ + @Override + public Iterator newQueryPlan(String loggedKeyspace, Statement statement) { + String keyspace = statement.getKeyspace(); + if (keyspace == null) keyspace = loggedKeyspace; + + ByteBuffer routingKey = statement.getRoutingKey(protocolVersion, codecRegistry); + if (routingKey == null || keyspace == null) { + return childPolicy.newQueryPlan(loggedKeyspace, statement); + } + + final Token t = clusterMetadata.newToken(statement.getPartitioner(), routingKey); + final Iterator childIterator = childPolicy.newQueryPlan(keyspace, statement); + + final Host host1 = childIterator.hasNext() ? childIterator.next() : null; + final Host host2 = childIterator.hasNext() ? childIterator.next() : null; + + final AtomicInteger host1ShardInflightRequests = new AtomicInteger(0); + final AtomicInteger host2ShardInflightRequests = new AtomicInteger(0); + + if (host1 != null) { + final int host1ShardId = host1.getShardingInfo().shardId(t); + host1ShardInflightRequests.set( + metrics.getPerShardInflightRequestInfo().getValue().get(host1).get(host1ShardId)); + } + + if (host2 != null) { + final int host2ShardId = host2.getShardingInfo().shardId(t); + host2ShardInflightRequests.set( + metrics.getPerShardInflightRequestInfo().getValue().get(host2).get(host2ShardId)); + } + + return new AbstractIterator() { + private final Host firstChosenHost = + host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host1 : host2; + private final Host secondChosenHost = + host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host2 : host1; + private int index = 0; + + @Override + protected Host computeNext() { + if (index == 0) { + index++; + return firstChosenHost; + } else if (index == 1) { + index++; + return secondChosenHost; + } else if (childIterator.hasNext()) { + return childIterator.next(); + } + + return endOfData(); + } + }; + } + + @Override + public void onAdd(Host host) { + childPolicy.onAdd(host); + } + + @Override + public void onUp(Host host) { + childPolicy.onUp(host); + } + + @Override + public void onDown(Host host) { + childPolicy.onDown(host); + } + + @Override + public void onRemove(Host host) { + childPolicy.onRemove(host); + } + + @Override + public void close() { + childPolicy.close(); + } +} diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java new file mode 100644 index 00000000000..c895c36ef28 --- /dev/null +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java @@ -0,0 +1,110 @@ +package com.datastax.driver.core.policies; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.codahale.metrics.Gauge; +import com.datastax.driver.core.*; +import java.nio.ByteBuffer; +import java.util.*; +import org.assertj.core.util.Sets; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class RandomTwoChoicePolicyTest { + private final ByteBuffer routingKey = ByteBuffer.wrap(new byte[] {1, 2, 3, 4}); + private final RegularStatement statement = + new SimpleStatement("irrelevant").setRoutingKey(routingKey).setKeyspace("keyspace"); + private final Host host1 = mock(Host.class); + private final Host host2 = mock(Host.class); + private final Host host3 = mock(Host.class); + private Cluster cluster; + + @SuppressWarnings("unchecked") + private final Gauge>> gauge = + mock((Class>>>) (Object) Gauge.class); + + @BeforeMethod(groups = "unit") + public void initMocks() { + CodecRegistry codecRegistry = new CodecRegistry(); + cluster = mock(Cluster.class); + Configuration configuration = mock(Configuration.class); + ProtocolOptions protocolOptions = mock(ProtocolOptions.class); + Metadata metadata = mock(Metadata.class); + Metrics metrics = mock(Metrics.class); + Token t = mock(Token.class); + ShardingInfo shardingInfo = mock(ShardingInfo.class); + + when(metrics.getPerShardInflightRequestInfo()).thenReturn(gauge); + when(cluster.getConfiguration()).thenReturn(configuration); + when(configuration.getCodecRegistry()).thenReturn(codecRegistry); + when(configuration.getProtocolOptions()).thenReturn(protocolOptions); + when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT); + when(cluster.getMetadata()).thenReturn(metadata); + when(cluster.getMetrics()).thenReturn(metrics); + when(metadata.getReplicas(Metadata.quote("keyspace"), null, routingKey)) + .thenReturn(Sets.newLinkedHashSet(host1, host2, host3)); + when(metadata.newToken(null, routingKey)).thenReturn(t); + when(host1.getShardingInfo()).thenReturn(shardingInfo); + when(host2.getShardingInfo()).thenReturn(shardingInfo); + when(host3.getShardingInfo()).thenReturn(shardingInfo); + when(shardingInfo.shardId(t)).thenReturn(0); + when(host1.isUp()).thenReturn(true); + when(host2.isUp()).thenReturn(true); + when(host3.isUp()).thenReturn(true); + } + + @Test(groups = "unit") + public void should_prefer_host_with_less_inflight_requests() { + // given + Map> perHostInflightRequests = + new HashMap>() { + { + put( + host1, + new HashMap() { + { + put(0, 6); + } + }); + put( + host2, + new HashMap() { + { + put(0, 2); + } + }); + put( + host3, + new HashMap() { + { + put(0, 4); + } + }); + } + }; + RandomTwoChoicePolicy policy = + new RandomTwoChoicePolicy( + new TokenAwarePolicy( + new RoundRobinPolicy(), TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL)); + policy.init( + cluster, + new ArrayList() { + + { + add(host1); + add(host2); + add(host3); + } + }); + when(gauge.getValue()).thenReturn(perHostInflightRequests); + + Iterator queryPlan = policy.newQueryPlan("keyspace", statement); + // host2 should appear first in the query plan with fewer inflight requests than host1 + + assertThat(queryPlan.next()).isEqualTo(host2); + assertThat(queryPlan.next()).isEqualTo(host1); + assertThat(queryPlan.next()).isEqualTo(host3); + } +}