@@ -49,16 +49,26 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
49
49
)
50
50
51
51
52
- def test_completion_triton_generate_api ():
52
+ @pytest .mark .parametrize ("stream" , [True , False ])
53
+ def test_completion_triton_generate_api (stream ):
53
54
try :
54
55
mock_response = MagicMock ()
55
-
56
- def return_val ():
57
- return {
58
- "text_output" : "I am an AI assistant" ,
59
- }
60
-
61
- mock_response .json = return_val
56
+ if stream :
57
+ def mock_iter_lines ():
58
+ mock_output = '' .join ([
59
+ 'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n \n '
60
+ for t in ["I" , " am" , " an" , " AI" , " assistant" ]
61
+ ])
62
+ for out in mock_output .split ('\n ' ):
63
+ yield out
64
+ mock_response .iter_lines = mock_iter_lines
65
+ else :
66
+ def return_val ():
67
+ return {
68
+ "text_output" : "I am an AI assistant" ,
69
+ }
70
+
71
+ mock_response .json = return_val
62
72
mock_response .status_code = 200
63
73
64
74
with patch (
@@ -71,6 +81,7 @@ def return_val():
71
81
max_tokens = 10 ,
72
82
timeout = 5 ,
73
83
api_base = "http://localhost:8000/generate" ,
84
+ stream = stream ,
74
85
)
75
86
76
87
# Verify the call was made
@@ -81,7 +92,10 @@ def return_val():
81
92
call_kwargs = mock_post .call_args .kwargs # Access kwargs directly
82
93
83
94
# Verify URL
84
- assert call_kwargs ["url" ] == "http://localhost:8000/generate"
95
+ if stream :
96
+ assert call_kwargs ["url" ] == "http://localhost:8000/generate_stream"
97
+ else :
98
+ assert call_kwargs ["url" ] == "http://localhost:8000/generate"
85
99
86
100
# Parse the request data from the JSON string
87
101
request_data = json .loads (call_kwargs ["data" ])
@@ -91,7 +105,15 @@ def return_val():
91
105
assert request_data ["parameters" ]["max_tokens" ] == 10
92
106
93
107
# Verify response
94
- assert response .choices [0 ].message .content == "I am an AI assistant"
108
+ if stream :
109
+ tokens = ["I" , " am" , " an" , " AI" , " assistant" , None ]
110
+ idx = 0
111
+ for chunk in response :
112
+ assert chunk .choices [0 ].delta .content == tokens [idx ]
113
+ idx += 1
114
+ assert idx == len (tokens )
115
+ else :
116
+ assert response .choices [0 ].message .content == "I am an AI assistant"
95
117
96
118
except Exception as e :
97
119
print ("exception" , e )
0 commit comments