Skip to content

Commit 8580981

Browse files
authored
[ML] Append all data to Chat Completion buffer (#127658) (#128136)
Moved the Chat Completion buffer into the StreamingUnifiedChatCompletionResults so that all Chat Completion responses can benefit from it. Chat Completions is meant to adhere to OpenAI as much as possible, and OpenAI only sends one response chunk at a time. All implementations of Chat Completions will now buffer. This fixes a bug where more than two chunks in a single item would be dropped, instead they are all added to the buffer. This fixes a bug where onComplete would omit trailing items in the buffer.
1 parent 2254dd4 commit 8580981

File tree

5 files changed

+155
-61
lines changed

5 files changed

+155
-61
lines changed

docs/changelog/127658.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127658
2+
summary: Append all data to Chat Completion buffer
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
1717
import org.elasticsearch.inference.InferenceServiceResults;
1818
import org.elasticsearch.xcontent.ToXContent;
19+
import org.elasticsearch.xpack.core.inference.DequeUtils;
1920

2021
import java.io.IOException;
2122
import java.util.Collections;
2223
import java.util.Deque;
2324
import java.util.Iterator;
2425
import java.util.List;
2526
import java.util.concurrent.Flow;
27+
import java.util.concurrent.LinkedBlockingDeque;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2629

2730
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
2831
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
@@ -31,9 +34,7 @@
3134
/**
3235
* Chat Completion results that only contain a Flow.Publisher.
3336
*/
34-
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
35-
implements
36-
InferenceServiceResults {
37+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {
3738

3839
public static final String NAME = "chat_completion_chunk";
3940
public static final String MODEL_FIELD = "model";
@@ -56,6 +57,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Inf
5657
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5758
public static final String TYPE_FIELD = "type";
5859

60+
/**
61+
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
62+
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
63+
*/
64+
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
65+
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
66+
AtomicBoolean onComplete = new AtomicBoolean();
67+
this.publisher = downstream -> {
68+
publisher.subscribe(new Flow.Subscriber<>() {
69+
@Override
70+
public void onSubscribe(Flow.Subscription subscription) {
71+
downstream.onSubscribe(new Flow.Subscription() {
72+
@Override
73+
public void request(long n) {
74+
var nextItem = buffer.poll();
75+
if (nextItem != null) {
76+
downstream.onNext(new Results(DequeUtils.of(nextItem)));
77+
} else if (onComplete.get()) {
78+
downstream.onComplete();
79+
} else {
80+
subscription.request(n);
81+
}
82+
}
83+
84+
@Override
85+
public void cancel() {
86+
subscription.cancel();
87+
}
88+
});
89+
}
90+
91+
@Override
92+
public void onNext(Results item) {
93+
var chunks = item.chunks();
94+
var firstItem = chunks.poll();
95+
chunks.forEach(buffer::offer);
96+
downstream.onNext(new Results(DequeUtils.of(firstItem)));
97+
}
98+
99+
@Override
100+
public void onError(Throwable throwable) {
101+
downstream.onError(throwable);
102+
}
103+
104+
@Override
105+
public void onComplete() {
106+
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
107+
if (onComplete.compareAndSet(false, true)) {
108+
if (buffer.isEmpty()) {
109+
downstream.onComplete();
110+
}
111+
}
112+
}
113+
});
114+
};
115+
}
116+
59117
@Override
60118
public boolean isStreaming() {
61119
return true;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,17 @@
1919
import java.util.ArrayDeque;
2020
import java.util.Deque;
2121
import java.util.List;
22+
import java.util.concurrent.Flow;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicInteger;
25+
import java.util.concurrent.atomic.AtomicReference;
2226
import java.util.function.Supplier;
2327

28+
import static org.mockito.ArgumentMatchers.any;
29+
import static org.mockito.Mockito.spy;
30+
import static org.mockito.Mockito.times;
31+
import static org.mockito.Mockito.verify;
32+
2433
public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
2534
StreamingUnifiedChatCompletionResults.Results> {
2635

@@ -198,6 +207,66 @@ public void testToolCallToXContentChunked() throws IOException {
198207
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
199208
}
200209

210+
public void testBufferedPublishing() {
211+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
212+
results.offer(randomChatCompletionChunk());
213+
results.offer(randomChatCompletionChunk());
214+
var completed = new AtomicBoolean();
215+
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
216+
downstream.onSubscribe(new Flow.Subscription() {
217+
@Override
218+
public void request(long n) {
219+
if (completed.compareAndSet(false, true)) {
220+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
221+
} else {
222+
downstream.onComplete();
223+
}
224+
}
225+
226+
@Override
227+
public void cancel() {
228+
fail("Cancel should never be called.");
229+
}
230+
});
231+
});
232+
233+
AtomicInteger counter = new AtomicInteger(0);
234+
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
235+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
236+
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
237+
@Override
238+
public void onSubscribe(Flow.Subscription subscription) {
239+
if (upstream.compareAndSet(null, subscription) == false) {
240+
fail("Upstream already set?!");
241+
}
242+
subscription.request(1);
243+
}
244+
245+
@Override
246+
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
247+
assertNotNull(item);
248+
counter.incrementAndGet();
249+
var sub = upstream.get();
250+
if (sub != null) {
251+
sub.request(1);
252+
} else {
253+
fail("Upstream not yet set?!");
254+
}
255+
}
256+
257+
@Override
258+
public void onError(Throwable throwable) {
259+
fail(throwable);
260+
}
261+
262+
@Override
263+
public void onComplete() {}
264+
}
265+
);
266+
streamingResults.publisher().subscribe(subscriber);
267+
verify(subscriber, times(2)).onNext(any());
268+
}
269+
201270
@Override
202271
protected Writeable.Reader<StreamingUnifiedChatCompletionResults.Results> instanceReader() {
203272
return StreamingUnifiedChatCompletionResults.Results::new;

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xcontent.ToXContent;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3738
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3839
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3940
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
@@ -256,37 +257,24 @@ public void cancel() {}
256257
"object": "chat.completion.chunk"
257258
}
258259
*/
259-
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
260-
return new InferenceServiceResults.Result() {
261-
@Override
262-
public String getWriteableName() {
263-
return "test_unifiedCompletionChunk";
264-
}
265-
266-
@Override
267-
public void writeTo(StreamOutput out) throws IOException {
268-
out.writeString(delta);
269-
}
270-
271-
@Override
272-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
273-
return ChunkedToXContentHelper.singleChunk(
274-
(b, p) -> b.startObject()
275-
.field("id", "id")
276-
.startArray("choices")
277-
.startObject()
278-
.startObject("delta")
279-
.field("content", delta)
280-
.endObject()
281-
.field("index", 0)
282-
.endObject()
283-
.endArray()
284-
.field("model", "gpt-4o-2024-08-06")
285-
.field("object", "chat.completion.chunk")
286-
.endObject()
287-
);
288-
}
289-
};
260+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
261+
return new StreamingUnifiedChatCompletionResults.Results(
262+
DequeUtils.of(
263+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
264+
"id",
265+
List.of(
266+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
267+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
268+
null,
269+
0
270+
)
271+
),
272+
"gpt-4o-2024-08-06",
273+
"chat.completion.chunk",
274+
null
275+
)
276+
)
277+
);
290278
}
291279

292280
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.Deque;
2727
import java.util.Iterator;
2828
import java.util.List;
29-
import java.util.concurrent.LinkedBlockingDeque;
3029
import java.util.function.BiFunction;
3130

3231
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
@@ -60,21 +59,11 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
6059
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6160

6261
private final BiFunction<String, Exception, Exception> errorParser;
63-
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
6462

6563
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6664
this.errorParser = errorParser;
6765
}
6866

69-
@Override
70-
protected void upstreamRequest(long n) {
71-
if (buffer.isEmpty()) {
72-
super.upstreamRequest(n);
73-
} else {
74-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
75-
}
76-
}
77-
7867
@Override
7968
protected void next(Deque<ServerSentEvent> item) throws Exception {
8069
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -96,15 +85,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
9685

9786
if (results.isEmpty()) {
9887
upstream().request(1);
99-
} else if (results.size() == 1) {
100-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10188
} else {
102-
// results > 1, but openai spec only wants 1 chunk per SSE event
103-
var firstItem = singleItem(results.poll());
104-
while (results.isEmpty() == false) {
105-
buffer.offer(results.poll());
106-
}
107-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
89+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10890
}
10991
}
11092

@@ -297,12 +279,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
297279
}
298280
}
299281
}
300-
301-
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
302-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
303-
) {
304-
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
305-
deque.offer(result);
306-
return deque;
307-
}
308282
}

0 commit comments

Comments
 (0)