Skip to content

Commit 0aa383b

Browse files
wwang500DonalEvans
andauthored
[ML] Correctly handle empty inputs in chunkedInfer() (#138632) (#138782)
* [ML] Correctly handle empty inputs in chunkedInfer() (#138632) - Add method to allow services that implement SenderService to indicate whether they support chunked inference - Return immediately if the input list is empty for services that support chunked inference - Throw exception if the input list is empty for services that do not support chunked inference, to maintain existing behaviour - Add tests for all services that implement doChunkedInfer() - Update DeepSeekServiceTests for new error message (cherry picked from commit f70dbb8) * fix JinaAI tests due to the conflict --------- Co-authored-by: Donal Evans <[email protected]>
1 parent 2c7856e commit 0aa383b

File tree

28 files changed

+558
-8
lines changed

28 files changed

+558
-8
lines changed

docs/changelog/138632.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138632
2+
summary: Correctly handle empty inputs in `chunkedInfer()`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,18 @@ public void chunkedInfer(
149149
if (validationException.validationErrors().isEmpty() == false) {
150150
throw validationException;
151151
}
152-
153-
// a non-null query is not supported and is dropped by all providers
154-
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
152+
if (supportsChunkedInfer()) {
153+
if (input.isEmpty()) {
154+
chunkedInferListener.onResponse(List.of());
155+
} else {
156+
// a non-null query is not supported and is dropped by all providers
157+
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
158+
}
159+
} else {
160+
chunkedInferListener.onFailure(
161+
new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", name()))
162+
);
163+
}
155164
}).addListener(listener);
156165
}
157166

@@ -183,6 +192,10 @@ protected abstract void doChunkedInfer(
183192
ActionListener<List<ChunkedInference>> listener
184193
);
185194

195+
protected boolean supportsChunkedInfer() {
196+
return true;
197+
}
198+
186199
public void start(Model model, ActionListener<Boolean> listener) {
187200
SubscribableListener.newForked(this::init)
188201
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,15 @@ protected void doChunkedInfer(
147147
TimeValue timeout,
148148
ActionListener<List<ChunkedInference>> listener
149149
) {
150+
// Should never be called
150151
throw new UnsupportedOperationException("AI21 service does not support chunked inference");
151152
}
152153

154+
@Override
155+
protected boolean supportsChunkedInfer() {
156+
return false;
157+
}
158+
153159
@Override
154160
public InferenceServiceConfiguration getConfiguration() {
155161
return Configuration.get();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,15 @@ protected void doChunkedInfer(
218218
TimeValue timeout,
219219
ActionListener<List<ChunkedInference>> listener
220220
) {
221+
// Should never be called
221222
throw new UnsupportedOperationException("Anthropic service does not support chunked inference");
222223
}
223224

225+
@Override
226+
protected boolean supportsChunkedInfer() {
227+
return false;
228+
}
229+
224230
@Override
225231
public TransportVersion getMinimalSupportedVersion() {
226232
return TransportVersions.V_8_15_0;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,15 @@ protected void doChunkedInfer(
185185
TimeValue timeout,
186186
ActionListener<List<ChunkedInference>> listener
187187
) {
188+
// Should never be called
188189
listener.onFailure(new ElasticsearchStatusException("Chunked inference is not supported for rerank task", RestStatus.BAD_REQUEST));
189190
}
190191

192+
@Override
193+
protected boolean supportsChunkedInfer() {
194+
return false;
195+
}
196+
191197
@Override
192198
protected void doUnifiedCompletionInfer(
193199
Model model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,15 @@ protected void doChunkedInfer(
122122
TimeValue timeout,
123123
ActionListener<List<ChunkedInference>> listener
124124
) {
125+
// Should never be called
125126
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
126127
}
127128

129+
@Override
130+
protected boolean supportsChunkedInfer() {
131+
return false;
132+
}
133+
128134
@Override
129135
public String name() {
130136
return NAME;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ public void chunkedInfer(
271271
listener.onFailure(createInvalidModelException(model));
272272
return;
273273
}
274+
if (input.isEmpty()) {
275+
listener.onResponse(List.of());
276+
}
274277
try {
275278
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
276279
var batchedRequests = new EmbeddingRequestChunker<>(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
7272
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
7373
import static org.hamcrest.CoreMatchers.is;
74+
import static org.hamcrest.Matchers.empty;
7475
import static org.hamcrest.Matchers.hasSize;
7576
import static org.hamcrest.Matchers.instanceOf;
7677
import static org.mockito.Mockito.mock;
@@ -491,6 +492,27 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx
491492
testChunkedInfer(TaskType.SPARSE_EMBEDDING, null);
492493
}
493494

495+
public void testChunkedInfer_noInputs() throws IOException {
496+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
497+
498+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
499+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
500+
var model = createModelForTaskType(randomFrom(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING), null);
501+
502+
service.chunkedInfer(
503+
model,
504+
null,
505+
List.of(),
506+
new HashMap<>(),
507+
InputTypeTests.randomWithIngestAndSearch(),
508+
InferenceAction.Request.DEFAULT_TIMEOUT,
509+
listener
510+
);
511+
512+
}
513+
assertThat(listener.actionGet(TIMEOUT), empty());
514+
}
515+
494516
private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException {
495517
var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"));
496518

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap;
8686
import static org.hamcrest.CoreMatchers.is;
8787
import static org.hamcrest.Matchers.containsString;
88+
import static org.hamcrest.Matchers.empty;
8889
import static org.hamcrest.Matchers.hasSize;
8990
import static org.hamcrest.Matchers.instanceOf;
9091
import static org.mockito.ArgumentMatchers.any;
@@ -1323,6 +1324,50 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
13231324
testChunkedInfer(model);
13241325
}
13251326

1327+
public void testChunkedInfer_noInputs() throws IOException {
1328+
var model = AmazonBedrockEmbeddingsModelTests.createModel(
1329+
"id",
1330+
"region",
1331+
"model",
1332+
AmazonBedrockProvider.AMAZONTITAN,
1333+
null,
1334+
"access",
1335+
"secret"
1336+
);
1337+
1338+
var sender = createMockSender();
1339+
var factory = mock(HttpRequestSender.Factory.class);
1340+
when(factory.createSender()).thenReturn(sender);
1341+
1342+
var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
1343+
ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
1344+
mockClusterServiceEmpty()
1345+
);
1346+
1347+
try (
1348+
var service = new AmazonBedrockService(
1349+
factory,
1350+
amazonBedrockFactory,
1351+
createWithEmptySettings(threadPool),
1352+
mockClusterServiceEmpty()
1353+
)
1354+
) {
1355+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
1356+
service.chunkedInfer(
1357+
model,
1358+
null,
1359+
List.of(),
1360+
new HashMap<>(),
1361+
InputType.INTERNAL_INGEST,
1362+
InferenceAction.Request.DEFAULT_TIMEOUT,
1363+
listener
1364+
);
1365+
1366+
var results = listener.actionGet(TIMEOUT);
1367+
assertThat(results, empty());
1368+
}
1369+
}
1370+
13261371
private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException {
13271372
var sender = createMockSender();
13281373
var factory = mock(HttpRequestSender.Factory.class);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER;
9090
import static org.hamcrest.CoreMatchers.is;
9191
import static org.hamcrest.Matchers.containsString;
92+
import static org.hamcrest.Matchers.empty;
9293
import static org.hamcrest.Matchers.equalTo;
9394
import static org.hamcrest.Matchers.hasSize;
9495
import static org.hamcrest.Matchers.instanceOf;
@@ -1294,6 +1295,27 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
12941295
testChunkedInfer(model);
12951296
}
12961297

1298+
public void testChunkedInfer_noInputs() throws IOException {
1299+
var model = AzureAiStudioEmbeddingsModelTests.createModel(
1300+
"id",
1301+
getUrl(webServer),
1302+
AzureAiStudioProvider.OPENAI,
1303+
AzureAiStudioEndpointType.TOKEN,
1304+
"apikey"
1305+
);
1306+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
1307+
1308+
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
1309+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
1310+
List<ChunkInferenceInput> input = List.of();
1311+
service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
1312+
1313+
var results = listener.actionGet(TIMEOUT);
1314+
assertThat(results, empty());
1315+
assertThat(webServer.requests(), empty());
1316+
}
1317+
}
1318+
12971319
private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException {
12981320
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
12991321

0 commit comments

Comments
 (0)