|
| 1 | +import json |
1 | 2 | from collections.abc import Generator
|
2 | 3 | from typing import Optional, Union
|
3 | 4 |
|
| 5 | +import requests |
4 | 6 | from yarl import URL
|
5 | 7 |
|
6 |
| -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult |
| 8 | +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta |
7 | 9 | from core.model_runtime.entities.message_entities import (
|
| 10 | + AssistantPromptMessage, |
8 | 11 | PromptMessage,
|
9 | 12 | PromptMessageTool,
|
10 | 13 | )
|
@@ -36,3 +39,208 @@ def _add_custom_parameters(credentials) -> None:
|
36 | 39 | credentials["mode"] = LLMMode.CHAT.value
|
37 | 40 | credentials["function_calling_type"] = "tool_call"
|
38 | 41 | credentials["stream_function_calling"] = "support"
|
| 42 | + |
| 43 | + def _handle_generate_stream_response( |
| 44 | + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] |
| 45 | + ) -> Generator: |
| 46 | + """ |
| 47 | + Handle llm stream response |
| 48 | +
|
| 49 | + :param model: model name |
| 50 | + :param credentials: model credentials |
| 51 | + :param response: streamed response |
| 52 | + :param prompt_messages: prompt messages |
| 53 | + :return: llm response chunk generator |
| 54 | + """ |
| 55 | + full_assistant_content = "" |
| 56 | + chunk_index = 0 |
| 57 | + is_reasoning_started = False # Add flag to track reasoning state |
| 58 | + |
| 59 | + def create_final_llm_result_chunk( |
| 60 | + id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict |
| 61 | + ) -> LLMResultChunk: |
| 62 | + # calculate num tokens |
| 63 | + prompt_tokens = usage and usage.get("prompt_tokens") |
| 64 | + if prompt_tokens is None: |
| 65 | + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) |
| 66 | + completion_tokens = usage and usage.get("completion_tokens") |
| 67 | + if completion_tokens is None: |
| 68 | + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) |
| 69 | + |
| 70 | + # transform usage |
| 71 | + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
| 72 | + |
| 73 | + return LLMResultChunk( |
| 74 | + id=id, |
| 75 | + model=model, |
| 76 | + prompt_messages=prompt_messages, |
| 77 | + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), |
| 78 | + ) |
| 79 | + |
| 80 | + # delimiter for stream response, need unicode_escape |
| 81 | + import codecs |
| 82 | + |
| 83 | + delimiter = credentials.get("stream_mode_delimiter", "\n\n") |
| 84 | + delimiter = codecs.decode(delimiter, "unicode_escape") |
| 85 | + |
| 86 | + tools_calls: list[AssistantPromptMessage.ToolCall] = [] |
| 87 | + |
| 88 | + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): |
| 89 | + def get_tool_call(tool_call_id: str): |
| 90 | + if not tool_call_id: |
| 91 | + return tools_calls[-1] |
| 92 | + |
| 93 | + tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) |
| 94 | + if tool_call is None: |
| 95 | + tool_call = AssistantPromptMessage.ToolCall( |
| 96 | + id=tool_call_id, |
| 97 | + type="function", |
| 98 | + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), |
| 99 | + ) |
| 100 | + tools_calls.append(tool_call) |
| 101 | + |
| 102 | + return tool_call |
| 103 | + |
| 104 | + for new_tool_call in new_tool_calls: |
| 105 | + # get tool call |
| 106 | + tool_call = get_tool_call(new_tool_call.function.name) |
| 107 | + # update tool call |
| 108 | + if new_tool_call.id: |
| 109 | + tool_call.id = new_tool_call.id |
| 110 | + if new_tool_call.type: |
| 111 | + tool_call.type = new_tool_call.type |
| 112 | + if new_tool_call.function.name: |
| 113 | + tool_call.function.name = new_tool_call.function.name |
| 114 | + if new_tool_call.function.arguments: |
| 115 | + tool_call.function.arguments += new_tool_call.function.arguments |
| 116 | + |
| 117 | + finish_reason = None # The default value of finish_reason is None |
| 118 | + message_id, usage = None, None |
| 119 | + for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): |
| 120 | + chunk = chunk.strip() |
| 121 | + if chunk: |
| 122 | + # ignore sse comments |
| 123 | + if chunk.startswith(":"): |
| 124 | + continue |
| 125 | + decoded_chunk = chunk.strip().removeprefix("data:").lstrip() |
| 126 | + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" |
| 127 | + continue |
| 128 | + |
| 129 | + try: |
| 130 | + chunk_json: dict = json.loads(decoded_chunk) |
| 131 | + # stream ended |
| 132 | + except json.JSONDecodeError as e: |
| 133 | + yield create_final_llm_result_chunk( |
| 134 | + id=message_id, |
| 135 | + index=chunk_index + 1, |
| 136 | + message=AssistantPromptMessage(content=""), |
| 137 | + finish_reason="Non-JSON encountered.", |
| 138 | + usage=usage, |
| 139 | + ) |
| 140 | + break |
| 141 | + # handle the error here. for issue #11629 |
| 142 | + if chunk_json.get("error") and chunk_json.get("choices") is None: |
| 143 | + raise ValueError(chunk_json.get("error")) |
| 144 | + |
| 145 | + if chunk_json: |
| 146 | + if u := chunk_json.get("usage"): |
| 147 | + usage = u |
| 148 | + if not chunk_json or len(chunk_json["choices"]) == 0: |
| 149 | + continue |
| 150 | + |
| 151 | + choice = chunk_json["choices"][0] |
| 152 | + finish_reason = chunk_json["choices"][0].get("finish_reason") |
| 153 | + message_id = chunk_json.get("id") |
| 154 | + chunk_index += 1 |
| 155 | + |
| 156 | + if "delta" in choice: |
| 157 | + delta = choice["delta"] |
| 158 | + is_reasoning = delta.get("reasoning_content") |
| 159 | + delta_content = delta.get("content") or delta.get("reasoning_content") |
| 160 | + |
| 161 | + assistant_message_tool_calls = None |
| 162 | + |
| 163 | + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": |
| 164 | + assistant_message_tool_calls = delta.get("tool_calls", None) |
| 165 | + elif ( |
| 166 | + "function_call" in delta |
| 167 | + and credentials.get("function_calling_type", "no_call") == "function_call" |
| 168 | + ): |
| 169 | + assistant_message_tool_calls = [ |
| 170 | + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} |
| 171 | + ] |
| 172 | + |
| 173 | + # assistant_message_function_call = delta.delta.function_call |
| 174 | + |
| 175 | + # extract tool calls from response |
| 176 | + if assistant_message_tool_calls: |
| 177 | + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
| 178 | + increase_tool_call(tool_calls) |
| 179 | + |
| 180 | + if delta_content is None or delta_content == "": |
| 181 | + continue |
| 182 | + |
| 183 | + # Add markdown quote markers for reasoning content |
| 184 | + if is_reasoning: |
| 185 | + if not is_reasoning_started: |
| 186 | + delta_content = "> 💭 " + delta_content |
| 187 | + is_reasoning_started = True |
| 188 | + elif "\n\n" in delta_content: |
| 189 | + delta_content = delta_content.replace("\n\n", "\n> ") |
| 190 | + elif "\n" in delta_content: |
| 191 | + delta_content = delta_content.replace("\n", "\n> ") |
| 192 | + elif is_reasoning_started: |
| 193 | + # If we were in reasoning mode but now getting regular content, |
| 194 | + # add \n\n to close the reasoning block |
| 195 | + delta_content = "\n\n" + delta_content |
| 196 | + is_reasoning_started = False |
| 197 | + |
| 198 | + # transform assistant message to prompt message |
| 199 | + assistant_prompt_message = AssistantPromptMessage( |
| 200 | + content=delta_content, |
| 201 | + ) |
| 202 | + |
| 203 | + # reset tool calls |
| 204 | + tool_calls = [] |
| 205 | + full_assistant_content += delta_content |
| 206 | + elif "text" in choice: |
| 207 | + choice_text = choice.get("text", "") |
| 208 | + if choice_text == "": |
| 209 | + continue |
| 210 | + |
| 211 | + # transform assistant message to prompt message |
| 212 | + assistant_prompt_message = AssistantPromptMessage(content=choice_text) |
| 213 | + full_assistant_content += choice_text |
| 214 | + else: |
| 215 | + continue |
| 216 | + |
| 217 | + yield LLMResultChunk( |
| 218 | + id=message_id, |
| 219 | + model=model, |
| 220 | + prompt_messages=prompt_messages, |
| 221 | + delta=LLMResultChunkDelta( |
| 222 | + index=chunk_index, |
| 223 | + message=assistant_prompt_message, |
| 224 | + ), |
| 225 | + ) |
| 226 | + |
| 227 | + chunk_index += 1 |
| 228 | + |
| 229 | + if tools_calls: |
| 230 | + yield LLMResultChunk( |
| 231 | + id=message_id, |
| 232 | + model=model, |
| 233 | + prompt_messages=prompt_messages, |
| 234 | + delta=LLMResultChunkDelta( |
| 235 | + index=chunk_index, |
| 236 | + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), |
| 237 | + ), |
| 238 | + ) |
| 239 | + |
| 240 | + yield create_final_llm_result_chunk( |
| 241 | + id=message_id, |
| 242 | + index=chunk_index, |
| 243 | + message=AssistantPromptMessage(content=""), |
| 244 | + finish_reason=finish_reason, |
| 245 | + usage=usage, |
| 246 | + ) |
0 commit comments