Skip to content

Commit 472bc7c

Browse files
committed
some formalization of streaming
1 parent 241e5c4 commit 472bc7c

File tree

5 files changed

+43
-29
lines changed

5 files changed

+43
-29
lines changed

tale/llm_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
URL: "http://localhost:5001"
22
ENDPOINT: "/api/v1/generate"
3-
WORD_LIMIT: 300
3+
STREAM_ENDPOINT: "/api/extra/generate/stream"
4+
DATA_ENDPOINT: "/api/extra/generate/check"
5+
WORD_LIMIT: 500
46
DEFAULT_BODY: '{"stop_sequence": "", "max_length":300, "max_context_length":4096, "temperature":1.0, "top_k":120, "top_a":0.0, "top_p":0.85, "typical_p":1.0, "tfs":1.0, "rep_pen":1.2, "rep_pen_range":256, "mirostat":2, "mirostat_tau":5.0, "mirostat_eta":0.1, "sampler_order":[6,0,1,3,4,2,5], "seed":-1}'
57
ANALYSIS_BODY: '{"stop_sequence": "\n\n", "max_length":300, "max_context_length":4096, "temperature":0.15, "top_k":120, "top_a":0.0, "top_p":0.85, "typical_p":1.0, "tfs":1.0, "rep_pen":1.2, "rep_pen_range":256, "mirostat":2, "mirostat_tau":5.0, "mirostat_eta":0.1, "sampler_order":[6,0,1,3,4,2,5], "seed":-1}'
68
MEMORY_SIZE: 1024

tale/llm_ext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from tale.player import Player
55

66
class LivingNpc(Living):
7+
"""An NPC with extra fields to define personality and help LLM generate dialogue"""
8+
79
def __init__(self, name: str, gender: str, *,
810
title: str="", descr: str="", short_descr: str="", age: int, personality: str, occupation: str=""):
911
super(LivingNpc, self).__init__(name=name, gender=gender, title=title, descr=descr, short_descr=short_descr)

tale/llm_io.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,50 @@
22
import time
33
import aiohttp
44
import asyncio
5-
import threading
65
import json
76
import tale.parse_utils as parse_utils
87
from tale.player_utils import TextBuffer
9-
from .tio.iobase import IoAdapterBase
8+
109
class IoUtil():
10+
""" Handles connection and data retrieval from backend """
1111

1212
def synchronous_request(self, url: str, request_body: dict):
13+
""" Send request to backend and return the result """
1314
response = requests.post(url, data=json.dumps(request_body))
1415
text = parse_utils.trim_response(json.loads(response.text)['results'][0]['text'])
1516
return text
1617

17-
def stream_request(self, player_io: TextBuffer, url: str, request_body: dict, io: IoAdapterBase) -> str:
18-
result = asyncio.run(self._do_stream_request(url, request_body))
18+
def stream_request(self, stream_url: str, data_url: str, request_body: dict, player_io: TextBuffer, io) -> str:
19+
result = asyncio.run(self._do_stream_request(stream_url, request_body))
1920
if result:
20-
return self._do_process_result(url, player_io, io)
21+
return self._do_process_result(data_url, player_io, io)
2122
return ''
2223

2324
async def _do_stream_request(self, url: str, request_body: dict,) -> bool:
24-
sub_endpt = "http://localhost:5001/api/extra/generate/stream"
25-
25+
""" Send request to stream endpoint async to not block the main thread"""
2626
async with aiohttp.ClientSession() as session:
27-
async with session.post(sub_endpt, data=json.dumps(request_body)) as response:
27+
async with session.post(url, data=json.dumps(request_body)) as response:
2828
if response.status == 200:
2929
return True
30-
3130
else:
3231
# Handle errors
3332
print("Error occurred:", response.status)
3433

35-
def _do_process_result(self, url, player_io: TextBuffer, io: IoAdapterBase) -> str:
34+
def _do_process_result(self, url, player_io: TextBuffer, io) -> str:
35+
""" Process the result from the stream endpoint """
3636
tries = 0
37-
old_data = ''
37+
old_text = ''
3838
while tries < 2:
39-
data = requests.post("http://localhost:5001/api/extra/generate/check")
39+
time.sleep(0.5)
40+
data = requests.post(url)
4041
text = json.loads(data.text)['results'][0]['text']
41-
new_text = text[len(old_data):]
42-
player_io.print(new_text, end=False, format=False, line_breaks=False)
43-
io.write_output()
44-
if len(text) == len(old_data):
42+
43+
if len(text) == len(old_text):
4544
tries += 1
46-
old_data = text
47-
time.sleep(1)
48-
return old_data
45+
continue
46+
new_text = text[len(old_text):]
47+
player_io.print(new_text, end=False, format=True, line_breaks=False)
48+
io.write_output()
49+
old_text = text
50+
51+
return old_text

tale/llm_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66
from tale.llm_io import IoUtil
77
import tale.parse_utils as parse_utils
88
from tale.player_utils import TextBuffer
9-
from .tio.iobase import IoAdapterBase
109

1110
class LlmUtil():
11+
""" Prepares prompts for various LLM requests"""
12+
1213
def __init__(self):
1314
with open(os.path.realpath(os.path.join(os.path.dirname(__file__), "llm_config.yaml")), "r") as stream:
1415
try:
1516
config_file = yaml.safe_load(stream)
1617
except yaml.YAMLError as exc:
1718
print(exc)
18-
self.url = config_file['URL'] + config_file['ENDPOINT']
19+
self.url = config_file['URL']
20+
self.endpoint = config_file['ENDPOINT']
21+
self.stream_endpoint = config_file['STREAM_ENDPOINT']
22+
self.data_endpoint = config_file['DATA_ENDPOINT']
1923
self.default_body = json.loads(config_file['DEFAULT_BODY'])
2024
self.analysis_body = json.loads(config_file['ANALYSIS_BODY'])
2125
self.memory_size = config_file['MEMORY_SIZE']
@@ -46,11 +50,12 @@ def evoke(self, player_io: TextBuffer, message: str, max_length : bool=False, ro
4650
request_body['max_length'] = amount
4751

4852
if not self.stream:
49-
text = self.io_util.synchronous_request(self.url, request_body)
53+
text = self.io_util.synchronous_request(self.url + self.endpoint, request_body)
5054
rolling_prompt = self.update_memory(rolling_prompt, text)
5155
return f'Original:[ {message} ]\nGenerated:\n{text}', rolling_prompt
5256
else:
53-
text = self.io_util.stream_request(player_io, self.url, request_body, self.connection)
57+
player_io.print(f'Original:[ {message} ]\nGenerated:\n', end=False, format=True, line_breaks=False)
58+
text = self.io_util.stream_request(self.url + self.stream_endpoint, self.url + self.data_endpoint, request_body, player_io, self.connection)
5459
rolling_prompt = self.update_memory(rolling_prompt, text)
5560
return '\n', rolling_prompt
5661
return str(message), rolling_prompt

tale/player_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ def __init__(self, format: bool=True, line_breaks=True) -> None:
1313
self.line_breaks = line_breaks
1414

1515
def add(self, line: str) -> None:
16-
self.lines.append(line)
17-
18-
def text(self) -> str:
1916
if self.line_breaks:
20-
return "\n".join(self.lines) + "\n"
17+
self.lines.append(line)
18+
elif len(self.lines) > 0:
19+
self.lines[-1] += line
2120
else:
22-
return "".join(self.lines)
21+
self.lines.append(line)
2322

23+
def text(self) -> str:
24+
return "\n".join(self.lines) + "\n"
25+
2426
def __init__(self) -> None:
2527
self.init()
2628

0 commit comments

Comments
 (0)