Skip to content

Commit 611d9e6

Browse files
authored
Merge pull request #131 from mindsdb/completion-streaming
Added Streaming to Agents
2 parents ffac45a + fa5a7e8 commit 611d9e6

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

mindsdb_sdk/agents.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from requests.exceptions import HTTPError
2-
from typing import List, Union
2+
from typing import Iterable, List, Union
33
from urllib.parse import urlparse
44
from uuid import uuid4
55
import datetime
@@ -37,6 +37,12 @@ class Agent:
3737
>>> completion = agent.completion([{'question': 'What is your name?', 'answer': None}])
3838
>>> print(completion.content)
3939
40+
Query an agent with streaming:
41+
42+
>>> completion = agent.completion_stream([{'question': 'What is your name?', 'answer': None}])
43+
>>> for chunk in completion:
44+
print(chunk.choices[0].delta.content)
45+
4046
List all agents:
4147
4248
>>> agents = agents.list()
@@ -81,6 +87,9 @@ def __init__(
8187
def completion(self, messages: List[dict]) -> AgentCompletion:
8288
return self.collection.completion(self.name, messages)
8389

90+
def completion_stream(self, messages: List[dict]) -> Iterable[object]:
91+
return self.collection.completion_stream(self.name, messages)
92+
8493
def add_files(self, file_paths: List[str], description: str, knowledge_base: str = None):
8594
"""
8695
Add a list of files to the agent for retrieval.
@@ -195,6 +204,17 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion:
195204
data = self.api.agent_completion(self.project, name, messages)
196205
return AgentCompletion(data['message']['content'])
197206

207+
def completion_stream(self, name, messages: List[dict]) -> Iterable[object]:
208+
"""
209+
Queries the agent for a completion and streams the response as an iterable object.
210+
211+
:param name: Name of the agent
212+
:param messageS: List of messages to be sent to the agent
213+
214+
:return: iterable of completion chunks from querying the agent.
215+
"""
216+
return self.api.agent_completion_stream(self.project, name, messages)
217+
198218
def _create_default_knowledge_base(self, agent: Agent, name: str) -> KnowledgeBase:
199219
# Make sure default ML engine for embeddings exists.
200220
try:

mindsdb_sdk/connectors/rest_api.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from functools import wraps
22
from typing import List, Union
33
import io
4+
import json
45

56
import requests
67
import pandas as pd
78

89
from mindsdb_sdk import __about__
10+
from sseclient import SSEClient
911

1012

1113
def _try_relogin(fnc):
@@ -260,6 +262,15 @@ def agent_completion(self, project: str, name: str, messages: List[dict]):
260262

261263
return r.json()
262264

265+
@_try_relogin
266+
def agent_completion_stream(self, project: str, name: str, messages: List[dict]):
267+
url = self.url + f'/api/projects/{project}/agents/{name}/completions/stream'
268+
stream = requests.post(url, json={'messages': messages}, stream=True)
269+
client = SSEClient(stream)
270+
for chunk in client.events():
271+
# Stream objects loaded from SSE events 'data' param.
272+
yield json.loads(chunk.data)
273+
263274
@_try_relogin
264275
def create_agent(self, project: str, name: str, model: str, skills: List[str] = None, params: dict = None):
265276
url = self.url + f'/api/projects/{project}/agents'

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ mindsdb-sql >= 0.17.0, < 1.0.0
44
docstring-parser >= 0.7.3
55
tenacity >= 8.0.1
66
openai >= 1.15.0
7+
sseclient-py >= 1.8.0

0 commit comments

Comments
 (0)