Skip to content

Commit 94667e1

Browse files
authored
Merge pull request #8386 from minwhoo/triton-completions-streaming-fix
Fix triton streaming completions bug
2 parents 398020a + c62be18 commit 94667e1

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

litellm/llms/triton/completion/transformation.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import json
6-
from typing import Any, Dict, List, Literal, Optional, Union
6+
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
77

88
from httpx import Headers, Response
99

@@ -67,6 +67,18 @@ def map_openai_params(
6767
optional_params[param] = value
6868
return optional_params
6969

70+
def get_complete_url(
71+
self,
72+
api_base: str,
73+
model: str,
74+
optional_params: dict,
75+
stream: Optional[bool] = None,
76+
) -> str:
77+
llm_type = self._get_triton_llm_type(api_base)
78+
if llm_type == "generate" and stream:
79+
return api_base + "_stream"
80+
return api_base
81+
7082
def transform_response(
7183
self,
7284
model: str,
@@ -149,6 +161,18 @@ def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
149161
else:
150162
raise ValueError(f"Invalid Triton API base: {api_base}")
151163

164+
def get_model_response_iterator(
165+
self,
166+
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
167+
sync_stream: bool,
168+
json_mode: Optional[bool] = False,
169+
) -> Any:
170+
return TritonResponseIterator(
171+
streaming_response=streaming_response,
172+
sync_stream=sync_stream,
173+
json_mode=json_mode,
174+
)
175+
152176

153177
class TritonGenerateConfig(TritonConfig):
154178
"""
@@ -204,7 +228,7 @@ def transform_response(
204228
return model_response
205229

206230

207-
class TritonInferConfig(TritonGenerateConfig):
231+
class TritonInferConfig(TritonConfig):
208232
"""
209233
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
210234
"""

tests/llm_translation/test_triton.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,26 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
4949
)
5050

5151

52-
def test_completion_triton_generate_api():
52+
@pytest.mark.parametrize("stream", [True, False])
53+
def test_completion_triton_generate_api(stream):
5354
try:
5455
mock_response = MagicMock()
55-
56-
def return_val():
57-
return {
58-
"text_output": "I am an AI assistant",
59-
}
60-
61-
mock_response.json = return_val
56+
if stream:
57+
def mock_iter_lines():
58+
mock_output = ''.join([
59+
'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n\n'
60+
for t in ["I", " am", " an", " AI", " assistant"]
61+
])
62+
for out in mock_output.split('\n'):
63+
yield out
64+
mock_response.iter_lines = mock_iter_lines
65+
else:
66+
def return_val():
67+
return {
68+
"text_output": "I am an AI assistant",
69+
}
70+
71+
mock_response.json = return_val
6272
mock_response.status_code = 200
6373

6474
with patch(
@@ -71,6 +81,7 @@ def return_val():
7181
max_tokens=10,
7282
timeout=5,
7383
api_base="http://localhost:8000/generate",
84+
stream=stream,
7485
)
7586

7687
# Verify the call was made
@@ -81,7 +92,10 @@ def return_val():
8192
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
8293

8394
# Verify URL
84-
assert call_kwargs["url"] == "http://localhost:8000/generate"
95+
if stream:
96+
assert call_kwargs["url"] == "http://localhost:8000/generate_stream"
97+
else:
98+
assert call_kwargs["url"] == "http://localhost:8000/generate"
8599

86100
# Parse the request data from the JSON string
87101
request_data = json.loads(call_kwargs["data"])
@@ -91,7 +105,15 @@ def return_val():
91105
assert request_data["parameters"]["max_tokens"] == 10
92106

93107
# Verify response
94-
assert response.choices[0].message.content == "I am an AI assistant"
108+
if stream:
109+
tokens = ["I", " am", " an", " AI", " assistant", None]
110+
idx = 0
111+
for chunk in response:
112+
assert chunk.choices[0].delta.content == tokens[idx]
113+
idx += 1
114+
assert idx == len(tokens)
115+
else:
116+
assert response.choices[0].message.content == "I am an AI assistant"
95117

96118
except Exception as e:
97119
print("exception", e)

0 commit comments

Comments
 (0)