Skip to content

Commit 706497c

Browse files
authored
fix: GPT Streams with tool calls (#214)
* fix: GPT Streams with tool calls * Fix streamed tool calls are not aware of already given answers
1 parent 9bdb696 commit 706497c

File tree

4 files changed

+161
-40
lines changed

4 files changed

+161
-40
lines changed

examples/stream-tools-gpt-openai.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
use PhpLlm\LlmChain\Bridge\OpenAI\GPT;
4+
use PhpLlm\LlmChain\Bridge\OpenAI\PlatformFactory;
5+
use PhpLlm\LlmChain\Chain;
6+
use PhpLlm\LlmChain\Chain\ToolBox\ChainProcessor;
7+
use PhpLlm\LlmChain\Chain\ToolBox\Tool\Wikipedia;
8+
use PhpLlm\LlmChain\Chain\ToolBox\ToolAnalyzer;
9+
use PhpLlm\LlmChain\Chain\ToolBox\ToolBox;
10+
use PhpLlm\LlmChain\Model\Message\Message;
11+
use PhpLlm\LlmChain\Model\Message\MessageBag;
12+
use Symfony\Component\Dotenv\Dotenv;
13+
use Symfony\Component\HttpClient\HttpClient;
14+
15+
require_once dirname(__DIR__).'/vendor/autoload.php';
16+
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');
17+
18+
if (empty($_ENV['OPENAI_API_KEY'])) {
19+
echo 'Please set the OPENAI_API_KEY environment variable.'.PHP_EOL;
20+
exit(1);
21+
}
22+
23+
$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']);
24+
$llm = new GPT(GPT::GPT_4O_MINI);
25+
26+
$wikipedia = new Wikipedia(HttpClient::create());
27+
$toolBox = new ToolBox(new ToolAnalyzer(), [$wikipedia]);
28+
$processor = new ChainProcessor($toolBox);
29+
$chain = new Chain($platform, $llm, [$processor], [$processor]);
30+
$messages = new MessageBag(Message::ofUser(<<<TXT
31+
First, define unicorn in 30 words.
32+
Then lookup at Wikipedia what the irish history looks like in 2 sentences.
33+
Please tell me before you call tools.
34+
TXT));
35+
$response = $chain->call($messages, [
36+
'stream' => true, // enable streaming of response text
37+
]);
38+
39+
foreach ($response->getContent() as $word) {
40+
echo $word;
41+
}
42+
43+
echo PHP_EOL;

src/Bridge/OpenAI/GPT/ResponseConverter.php

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public function convert(HttpResponse $response, array $options = []): LlmRespons
5252
}
5353

5454
/** @var Choice[] $choices */
55-
$choices = array_map([$this, 'convertChoice'], $data['choices']);
55+
$choices = \array_map([$this, 'convertChoice'], $data['choices']);
5656

5757
if (1 !== count($choices)) {
5858
return new ChoiceResponse(...$choices);
@@ -65,14 +65,10 @@ public function convert(HttpResponse $response, array $options = []): LlmRespons
6565
return new TextResponse($choices[0]->getContent());
6666
}
6767

68-
private function convertStream(HttpResponse $response): ToolCallResponse|StreamResponse
68+
private function convertStream(HttpResponse $response): StreamResponse
6969
{
7070
$stream = $this->streamResponse($response);
7171

72-
if ($this->streamIsToolCall($stream)) {
73-
return new ToolCallResponse(...$this->convertStreamToToolCalls($stream));
74-
}
75-
7672
return new StreamResponse($this->convertStreamContent($stream));
7773
}
7874

@@ -84,7 +80,9 @@ private function streamResponse(HttpResponse $response): \Generator
8480
}
8581

8682
try {
87-
yield $chunk->getArrayData();
83+
$data = $chunk->getArrayData();
84+
85+
yield $data;
8886
} catch (JsonException) {
8987
// try catch only needed for Symfony 6.4
9088
continue;
@@ -100,37 +98,46 @@ private function streamIsToolCall(\Generator $response): bool
10098
}
10199

102100
/**
103-
* @return ToolCall[]
101+
* @param array<string, mixed> $toolCalls
102+
* @param array<string, mixed> $data
103+
*
104+
* @return array<string, mixed>
104105
*/
105-
private function convertStreamToToolCalls(\Generator $response): array
106+
private function convertStreamToToolCalls(array $toolCalls, array $data): array
106107
{
107-
$toolCalls = [];
108-
foreach ($response as $data) {
109-
if (!isset($data['choices'][0]['delta']['tool_calls'])) {
108+
if (!isset($data['choices'][0]['delta']['tool_calls'])) {
109+
return $toolCalls;
110+
}
111+
112+
foreach ($data['choices'][0]['delta']['tool_calls'] as $i => $toolCall) {
113+
if (isset($toolCall['id'])) {
114+
// initialize tool call
115+
$toolCalls[$i] = [
116+
'id' => $toolCall['id'],
117+
'function' => $toolCall['function'],
118+
];
110119
continue;
111120
}
112121

113-
foreach ($data['choices'][0]['delta']['tool_calls'] as $i => $toolCall) {
114-
if (isset($toolCall['id'])) {
115-
// initialize tool call
116-
$toolCalls[$i] = [
117-
'id' => $toolCall['id'],
118-
'function' => $toolCall['function'],
119-
];
120-
continue;
121-
}
122-
123-
// add arguments delta to tool call
124-
$toolCalls[$i]['function']['arguments'] .= $toolCall['function']['arguments'];
125-
}
122+
// add arguments delta to tool call
123+
$toolCalls[$i]['function']['arguments'] .= $toolCall['function']['arguments'];
126124
}
127125

128-
return array_map([$this, 'convertToolCall'], $toolCalls);
126+
return $toolCalls;
129127
}
130128

131129
private function convertStreamContent(\Generator $generator): \Generator
132130
{
131+
$toolCalls = [];
133132
foreach ($generator as $data) {
133+
if ($this->streamIsToolCall($generator)) {
134+
$toolCalls = $this->convertStreamToToolCalls($toolCalls, $data);
135+
}
136+
137+
if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) {
138+
yield new ToolCallResponse(...\array_map([$this, 'convertToolCall'], $toolCalls));
139+
}
140+
134141
if (!isset($data['choices'][0]['delta']['content'])) {
135142
continue;
136143
}
@@ -139,6 +146,14 @@ private function convertStreamContent(\Generator $generator): \Generator
139146
}
140147
}
141148

149+
/**
150+
* @param array<string, mixed> $data
151+
*/
152+
private function isToolCallsStreamFinished(array $data): bool
153+
{
154+
return isset($data['choices'][0]['finish_reason']) && 'tool_calls' === $data['choices'][0]['finish_reason'];
155+
}
156+
142157
/**
143158
* @param array{
144159
* index: integer,
@@ -162,7 +177,7 @@ private function convertStreamContent(\Generator $generator): \Generator
162177
private function convertChoice(array $choice): Choice
163178
{
164179
if ('tool_calls' === $choice['finish_reason']) {
165-
return new Choice(toolCalls: array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
180+
return new Choice(toolCalls: \array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
166181
}
167182

168183
if (in_array($choice['finish_reason'], ['stop', 'length'], true)) {

src/Chain/ToolBox/ChainProcessor.php

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
use PhpLlm\LlmChain\Chain\Output;
1212
use PhpLlm\LlmChain\Chain\OutputProcessor;
1313
use PhpLlm\LlmChain\Chain\ToolBox\Event\ToolCallsExecuted;
14+
use PhpLlm\LlmChain\Chain\ToolBox\StreamResponse as ToolboxStreamResponse;
1415
use PhpLlm\LlmChain\Exception\MissingModelSupport;
16+
use PhpLlm\LlmChain\Model\Message\AssistantMessage;
1517
use PhpLlm\LlmChain\Model\Message\Message;
18+
use PhpLlm\LlmChain\Model\Response\ResponseInterface;
19+
use PhpLlm\LlmChain\Model\Response\StreamResponse as GenericStreamResponse;
1620
use PhpLlm\LlmChain\Model\Response\ToolCallResponse;
1721
use Symfony\Contracts\EventDispatcher\EventDispatcherInterface;
1822

@@ -45,23 +49,49 @@ public function processInput(Input $input): void
4549

4650
public function processOutput(Output $output): void
4751
{
48-
$messages = clone $output->messages;
52+
if ($output->response instanceof GenericStreamResponse) {
53+
$output->response = new ToolboxStreamResponse(
54+
$output->response->getContent(),
55+
$this->handleToolCallsCallback($output),
56+
);
4957

50-
while ($output->response instanceof ToolCallResponse) {
51-
$toolCalls = $output->response->getContent();
52-
$messages->add(Message::ofAssistant(toolCalls: $toolCalls));
58+
return;
59+
}
60+
61+
if (!$output->response instanceof ToolCallResponse) {
62+
return;
63+
}
64+
65+
$output->response = $this->handleToolCallsCallback($output)($output->response);
66+
}
5367

54-
$results = [];
55-
foreach ($toolCalls as $toolCall) {
56-
$result = $this->toolBox->execute($toolCall);
57-
$results[] = new ToolCallResult($toolCall, $result);
58-
$messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($result)));
68+
private function handleToolCallsCallback(Output $output): \Closure
69+
{
70+
return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface {
71+
$messages = clone $output->messages;
72+
73+
if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) {
74+
$messages->add($streamedAssistantResponse);
5975
}
6076

61-
$event = new ToolCallsExecuted(...$results);
62-
$this->eventDispatcher?->dispatch($event);
77+
do {
78+
$toolCalls = $response->getContent();
79+
$messages->add(Message::ofAssistant(toolCalls: $toolCalls));
6380

64-
$output->response = $event->hasResponse() ? $event->response : $this->chain->call($messages, $output->options);
65-
}
81+
$results = [];
82+
foreach ($toolCalls as $toolCall) {
83+
$result = $this->toolBox->execute($toolCall);
84+
$results[] = new ToolCallResult($toolCall, $result);
85+
$messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($result)));
86+
}
87+
88+
$event = new ToolCallsExecuted(...$results);
89+
$this->eventDispatcher?->dispatch($event);
90+
91+
$response = $event->hasResponse() ? $event->response : $this->chain->call($messages, $output->options);
92+
} while ($response instanceof ToolCallResponse);
93+
94+
return $response;
95+
};
6696
}
6797
}

src/Chain/ToolBox/StreamResponse.php

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Chain\ToolBox;
6+
7+
use PhpLlm\LlmChain\Model\Message\Message;
8+
use PhpLlm\LlmChain\Model\Response\ResponseInterface;
9+
use PhpLlm\LlmChain\Model\Response\ToolCallResponse;
10+
11+
final readonly class StreamResponse implements ResponseInterface
12+
{
13+
public function __construct(
14+
private \Generator $generator,
15+
private \Closure $handleToolCallsCallback,
16+
) {
17+
}
18+
19+
public function getContent(): \Generator
20+
{
21+
$streamedResponse = '';
22+
foreach ($this->generator as $value) {
23+
if ($value instanceof ToolCallResponse) {
24+
yield from ($this->handleToolCallsCallback)($value, Message::ofAssistant($streamedResponse))->getContent();
25+
26+
break;
27+
}
28+
29+
$streamedResponse .= $value;
30+
yield $value;
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)