Skip to content

Commit

Permalink
Introduce WeightedRandomDistributionSelector (xDS-endpoint pt 1) (#…
Browse files Browse the repository at this point in the history
…5501)

Motivation:

`EdfLoadBalancer` is a basic load balancer which distributes entries based on weight using [Earliest Deadline First Scheduling](https://en.wikipedia.org/wiki/Earliest_deadline_first_scheduling).

ref: https://github.com/envoyproxy/envoy/blob/b7818b0df716af47ec22982c5a1cbbace5f2ae15/source/common/upstream/load_balancer_impl.h#L508

This is used not only for request load balancing, but also load balancing based on `Locality`

ref: https://github.com/envoyproxy/envoy/blob/b7818b0df716af47ec22982c5a1cbbace5f2ae15/source/common/upstream/upstream_impl.cc#L625

Within our codebase, it seems like we could use `WeightedRandomDistributionEndpointSelector` easily for this purpose. I propose that we extract a `WeightedRandomDistributionSelector` and make it publicly available.

See the following for a sample use-case in #5450
ref: https://github.com/line/armeria/blob/bd4968fda63089b2c309583b21201dff1a1119bf/xds/src/main/java/com/linecorp/armeria/xds/client/endpoint/HostSet.java#L38-L39

POC: #5450

Modifications:

- Extract `WeightedRandomDistributionSelector` and move it to the `internal` package.

Result:

- `WeightedRandomDistributionSelector` is now ready to be reused.

<!--
Visit this URL to learn more about how to write a pull request description:
https://armeria.dev/community/developer-guide#how-to-write-pull-request-description
-->
  • Loading branch information
jrhee17 authored Mar 22, 2024
1 parent 4dd8c84 commit c6b8906
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,103 +15,40 @@
*/
package com.linecorp.armeria.client.endpoint;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.locks.ReentrantLock;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;

import com.linecorp.armeria.client.Endpoint;
import com.linecorp.armeria.client.endpoint.WeightedRandomDistributionEndpointSelector.Entry;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.util.ReentrantShortLock;
import com.linecorp.armeria.internal.client.endpoint.WeightedRandomDistributionSelector;

/**
* This selector selects an {@link Endpoint} using random and the weight of the {@link Endpoint}. If there are
* A(weight 10), B(weight 4) and C(weight 6) {@link Endpoint}s, the chances that {@link Endpoint}s are selected
* are 10/20, 4/20 and 6/20, respectively. If {@link Endpoint} A is selected 10 times and B and C are not
* selected as much as their weight, then A is removed temporarily and the chances that B and C are selected
* are 4/10 and 6/10.
*/
final class WeightedRandomDistributionEndpointSelector {

private final ReentrantLock lock = new ReentrantShortLock();
private final List<Entry> allEntries;
@GuardedBy("lock")
private final List<Entry> currentEntries;
private final long total;
private long remaining;
final class WeightedRandomDistributionEndpointSelector
extends WeightedRandomDistributionSelector<Entry> {

WeightedRandomDistributionEndpointSelector(List<Endpoint> endpoints) {
final ImmutableList.Builder<Entry> builder = ImmutableList.builderWithExpectedSize(endpoints.size());

long total = 0;
for (Endpoint endpoint : endpoints) {
if (endpoint.weight() <= 0) {
continue;
}
builder.add(new Entry(endpoint));
total += endpoint.weight();
}
this.total = total;
remaining = total;
allEntries = builder.build();
currentEntries = new ArrayList<>(allEntries);
super(mapEndpoints(endpoints));
}

@VisibleForTesting
List<Entry> entries() {
return allEntries;
private static List<Entry> mapEndpoints(List<Endpoint> endpoints) {
return endpoints.stream().map(Entry::new).collect(ImmutableList.toImmutableList());
}

@Nullable
Endpoint selectEndpoint() {
if (allEntries.isEmpty()) {
final Entry entry = select();
if (entry == null) {
return null;
}

final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current();
lock.lock();
try {
long target = threadLocalRandom.nextLong(remaining);
final Iterator<Entry> it = currentEntries.iterator();
while (it.hasNext()) {
final Entry entry = it.next();
final int weight = entry.weight();
target -= weight;
if (target < 0) {
entry.increment();
if (entry.isFull()) {
it.remove();
entry.reset();
remaining -= weight;
if (remaining == 0) {
// As all entries are full, reset `currentEntries` and `remaining`.
currentEntries.addAll(allEntries);
remaining = total;
} else {
assert remaining > 0 : remaining;
}
}
return entry.endpoint();
}
}
} finally {
lock.unlock();
}

// Since `allEntries` is not empty, should select one Endpoint from `allEntries`.
throw new Error("Should never reach here");
return entry.endpoint();
}

@VisibleForTesting
static final class Entry {
static final class Entry extends AbstractEntry {

private final Endpoint endpoint;
private int counter;

Entry(Endpoint endpoint) {
this.endpoint = endpoint;
Expand All @@ -121,26 +58,9 @@ Endpoint endpoint() {
return endpoint;
}

void increment() {
assert counter < endpoint().weight();
counter++;
}

int weight() {
@Override
public int weight() {
return endpoint().weight();
}

void reset() {
counter = 0;
}

@VisibleForTesting
int counter() {
return counter;
}

boolean isFull() {
return counter >= endpoint.weight();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright 2020 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.internal.client.endpoint;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.locks.ReentrantLock;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.client.endpoint.WeightedRandomDistributionSelector.AbstractEntry;
import com.linecorp.armeria.internal.common.util.ReentrantShortLock;

/**
* This selector selects an {@link AbstractEntry} using random and the weight of the {@link AbstractEntry}.
* If there are A(weight 10), B(weight 4) and C(weight 6) {@link AbstractEntry}s, the chances that
* {@link AbstractEntry}s are selected are 10/20, 4/20 and 6/20, respectively. If {@link AbstractEntry}
* A is selected 10 times and B and C are not selected as much as their weight, then A is removed temporarily
* and the chances that B and C are selected are 4/10 and 6/10.
*/
public class WeightedRandomDistributionSelector<T extends AbstractEntry> {

private final ReentrantLock lock = new ReentrantShortLock();
private final List<T> allEntries;
@GuardedBy("lock")
private final List<T> currentEntries;
private final long total;
private long remaining;

public WeightedRandomDistributionSelector(List<T> endpoints) {
final ImmutableList.Builder<T> builder = ImmutableList.builderWithExpectedSize(endpoints.size());

long total = 0;
for (T entry : endpoints) {
if (entry.weight() <= 0) {
continue;
}
builder.add(entry);
total += entry.weight();
}
this.total = total;
remaining = total;
allEntries = builder.build();
currentEntries = new ArrayList<>(allEntries);
}

@VisibleForTesting
public List<T> entries() {
return allEntries;
}

@Nullable
public T select() {
if (allEntries.isEmpty()) {
return null;
}

final ThreadLocalRandom threadLocalRandom = ThreadLocalRandom.current();
lock.lock();
try {
long target = threadLocalRandom.nextLong(remaining);
final Iterator<T> it = currentEntries.iterator();
while (it.hasNext()) {
final T entry = it.next();
final int weight = entry.weight();
target -= weight;
if (target < 0) {
entry.increment();
if (entry.isFull()) {
it.remove();
entry.reset();
remaining -= weight;
if (remaining == 0) {
// As all entries are full, reset `currentEntries` and `remaining`.
currentEntries.addAll(allEntries);
remaining = total;
} else {
assert remaining > 0 : remaining;
}
}
return entry;
}
}
} finally {
lock.unlock();
}

// Since `allEntries` is not empty, should subselect one Endpoint from `allEntries`.
throw new Error("Should never reach here");
}

public abstract static class AbstractEntry {

private int counter;

public final void increment() {
assert counter < weight();
counter++;
}

public abstract int weight();

public final void reset() {
counter = 0;
}

public final int counter() {
return counter;
}

public final boolean isFull() {
return counter >= weight();
}
}
}

0 comments on commit c6b8906

Please sign in to comment.