Skip to content

Commit 360737b

Browse files
prwhelanymao1
andauthored
Fixing bug with bedrock client caching (#118177) (#125840)
* Fixing bug with bedrock client caching * Update docs/changelog/118177.yaml * PR feedback Co-authored-by: Ying Mao <[email protected]>
1 parent 1fc6e46 commit 360737b

File tree

5 files changed

+65
-12
lines changed

5 files changed

+65
-12
lines changed

docs/changelog/118177.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 118177
2+
summary: Fixing bedrock event executor terminated cache issue
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 117916

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java

+7
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
1313

1414
import java.time.Clock;
15+
import java.time.Instant;
1516
import java.util.Objects;
1617

1718
public abstract class AmazonBedrockBaseClient implements AmazonBedrockClient {
1819
protected final Integer modelKeysAndRegionHashcode;
1920
protected Clock clock = Clock.systemUTC();
21+
protected volatile Instant expiryTimestamp;
2022

2123
protected AmazonBedrockBaseClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
2224
Objects.requireNonNull(model);
@@ -33,5 +35,10 @@ public final void setClock(Clock clock) {
3335
this.clock = clock;
3436
}
3537

38+
// used for testing
39+
Instant getExpiryTimestamp() {
40+
return this.expiryTimestamp;
41+
}
42+
3643
abstract void close();
3744
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java

-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {
6161

6262
private final BedrockRuntimeAsyncClient internalClient;
6363
private final ThreadPool threadPool;
64-
private volatile Instant expiryTimestamp;
6564

6665
public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
6766
try {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java

+11-7
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,24 @@ public AmazonBedrockInferenceClientCache(BiFunction<AmazonBedrockModel, TimeValu
3535
}
3636

3737
public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
38-
var returnClient = internalGetOrCreateClient(model, timeout);
3938
flushExpiredClients();
40-
return returnClient;
39+
return internalGetOrCreateClient(model, timeout);
4140
}
4241

4342
private AmazonBedrockBaseClient internalGetOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
4443
final Integer modelHash = AmazonBedrockInferenceClient.getModelKeysAndRegionHashcode(model, timeout);
4544
cacheLock.readLock().lock();
4645
try {
47-
return clientsCache.computeIfAbsent(modelHash, hashKey -> {
48-
final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout);
49-
builtClient.setClock(clock);
50-
builtClient.resetExpiration();
51-
return builtClient;
46+
return clientsCache.compute(modelHash, (hashKey, client) -> {
47+
AmazonBedrockBaseClient clientToUse = client;
48+
if (clientToUse == null) {
49+
clientToUse = creator.apply(model, timeout);
50+
}
51+
52+
// for testing - would be nice to refactor client factory in the future to take clock as parameter
53+
clientToUse.setClock(clock);
54+
clientToUse.resetExpiration();
55+
return clientToUse;
5256
});
5357
} finally {
5458
cacheLock.readLock().unlock();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java

+41-4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,36 @@ public void testCache_ReturnsSameObject() throws IOException {
6060
assertThat(cacheInstance.clientCount(), is(0));
6161
}
6262

63+
public void testCache_ItUpdatesExpirationForExistingClients() throws IOException {
64+
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
65+
AmazonBedrockInferenceClientCache cacheInstance;
66+
try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, clock)) {
67+
cacheInstance = cache;
68+
69+
var model = AmazonBedrockEmbeddingsModelTests.createModel(
70+
"inferenceId",
71+
"testregion",
72+
"model",
73+
AmazonBedrockProvider.AMAZONTITAN,
74+
"access_key",
75+
"secret_key"
76+
);
77+
78+
var client = cache.getOrCreateClient(model, null);
79+
var expiryTimestamp = client.getExpiryTimestamp();
80+
assertThat(cache.clientCount(), is(1));
81+
82+
// set clock to clock + 1 minutes so cache hasn't expired
83+
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(1)), ZoneId.systemDefault()));
84+
85+
var regetClient = cache.getOrCreateClient(model, null);
86+
87+
assertThat(client, sameInstance(regetClient));
88+
assertNotEquals(expiryTimestamp, regetClient.getExpiryTimestamp());
89+
}
90+
assertThat(cacheInstance.clientCount(), is(0));
91+
}
92+
6393
public void testCache_ItEvictsExpiredClients() throws IOException {
6494
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
6595
AmazonBedrockInferenceClientCache cacheInstance;
@@ -76,6 +106,10 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
76106
);
77107

78108
var client = cache.getOrCreateClient(model, null);
109+
assertThat(cache.clientCount(), is(1));
110+
111+
// set clock to clock + 2 minutes
112+
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(2)), ZoneId.systemDefault()));
79113

80114
var secondModel = AmazonBedrockEmbeddingsModelTests.createModel(
81115
"inferenceId_two",
@@ -86,22 +120,25 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
86120
"other_secret_key"
87121
);
88122

89-
assertThat(cache.clientCount(), is(1));
90-
91123
var secondClient = cache.getOrCreateClient(secondModel, null);
92124
assertThat(client, not(sameInstance(secondClient)));
93125

94126
assertThat(cache.clientCount(), is(2));
95127

96-
// set clock to after expiry
128+
// set clock to after expiry of first client but not after expiry of second client
97129
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES + 1)), ZoneId.systemDefault()));
98130

99-
// get another client, this will ensure flushExpiredClients is called
131+
// retrieve the second client, this will ensure flushExpiredClients is called
100132
var regetSecondClient = cache.getOrCreateClient(secondModel, null);
101133
assertThat(secondClient, sameInstance(regetSecondClient));
102134

135+
// expired first client should have been flushed
136+
assertThat(cache.clientCount(), is(1));
137+
103138
var regetFirstClient = cache.getOrCreateClient(model, null);
104139
assertThat(client, not(sameInstance(regetFirstClient)));
140+
141+
assertThat(cache.clientCount(), is(2));
105142
}
106143
assertThat(cacheInstance.clientCount(), is(0));
107144
}

0 commit comments

Comments
 (0)