Skip to content

Commit b5850b6

Browse files
krrishdholakiavibhavbhathaoshan98
authored
Handle azure deepseek reasoning response (#8288) (#8366)
* Handle azure deepseek reasoning response (#8288) * Handle deepseek reasoning response * Add helper method + unit test * Fix: Follow infinity api url format (#8346) * Follow infinity api url format * Update test_infinity.py * fix(infinity/transformation.py): fix linting error --------- Co-authored-by: vibhavbhat <[email protected]> Co-authored-by: Hao Shan <[email protected]>
1 parent f651d51 commit b5850b6

File tree

6 files changed

+86
-5
lines changed

6 files changed

+86
-5
lines changed

litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import time
44
import traceback
55
import uuid
6-
from typing import Dict, Iterable, List, Literal, Optional, Union
6+
import re
7+
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple
78

89
import litellm
910
from litellm._logging import verbose_logger
@@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
220221
# if there is a JSONDecodeError, return the original tool_calls
221222
return tool_calls
222223

224+
def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
225+
if not message_text:
226+
return None, None
227+
228+
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
229+
230+
if reasoning_match:
231+
return reasoning_match.group(1), reasoning_match.group(2)
232+
233+
return None, message_text
223234

224235
class LiteLLMResponseObjectHandler:
225236

@@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
432443
for field in choice["message"].keys():
433444
if field not in message_keys:
434445
provider_specific_fields[field] = choice["message"][field]
446+
447+
# Handle reasoning models that display `reasoning_content` within `content`
448+
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
449+
if reasoning_content:
450+
provider_specific_fields["reasoning_content"] = reasoning_content
451+
435452
message = Message(
436-
content=choice["message"].get("content", None),
453+
content=content,
437454
role=choice["message"]["role"] or "assistant",
438455
function_call=choice["message"].get("function_call", None),
439456
tool_calls=tool_calls,

litellm/llms/infinity/rerank/transformation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020

2121

2222
class InfinityRerankConfig(CohereRerankConfig):
23+
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
24+
if api_base is None:
25+
raise ValueError("api_base is required for Infinity rerank")
26+
# Remove trailing slashes and ensure clean base URL
27+
api_base = api_base.rstrip("/")
28+
if not api_base.endswith("/rerank"):
29+
api_base = f"{api_base}/rerank"
30+
return api_base
31+
2332
def validate_environment(
2433
self,
2534
headers: dict,

litellm/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
convert_to_model_response_object,
9090
convert_to_streaming_response,
9191
convert_to_streaming_response_async,
92+
_parse_content_for_reasoning,
9293
)
9394
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
9495
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (

tests/litellm_utils_tests/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,18 @@ def test_convert_model_response_object():
864864
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
865865
)
866866

867+
@pytest.mark.parametrize(
868+
"content, expected_reasoning, expected_content",
869+
[
870+
(None, None, None),
871+
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
872+
("I am a regular response", None, "I am a regular response"),
873+
874+
]
875+
)
876+
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
877+
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
878+
867879

868880
@pytest.mark.parametrize(
869881
"model, expected_bool",

tests/llm_translation/test_azure_ai.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from litellm.llms.anthropic.chat import ModelResponseIterator
1414
import httpx
1515
import json
16-
from respx import MockRouter
16+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
1717

1818
load_dotenv()
1919
import io
@@ -184,3 +184,45 @@ def test_completion_azure_ai_command_r():
184184
pass
185185
except Exception as e:
186186
pytest.fail(f"Error occurred: {e}")
187+
188+
def test_azure_deepseek_reasoning_content():
189+
import json
190+
191+
client = HTTPHandler()
192+
193+
with patch.object(client, "post") as mock_post:
194+
mock_response = MagicMock()
195+
196+
mock_response.text = json.dumps(
197+
{
198+
"choices": [
199+
{
200+
"finish_reason": "stop",
201+
"index": 0,
202+
"message": {
203+
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
204+
"role": "assistant",
205+
}
206+
}
207+
],
208+
}
209+
)
210+
211+
mock_response.status_code = 200
212+
# Add required response attributes
213+
mock_response.headers = {"Content-Type": "application/json"}
214+
mock_response.json = lambda: json.loads(mock_response.text)
215+
mock_post.return_value = mock_response
216+
217+
218+
response = litellm.completion(
219+
model='azure_ai/deepseek-r1',
220+
messages=[{"role": "user", "content": "Hello, world!"}],
221+
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
222+
api_key="my-fake-api-key",
223+
client=client
224+
)
225+
226+
print(response)
227+
assert(response.choices[0].message.reasoning_content == "I am thinking here")
228+
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")

tests/llm_translation/test_infinity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def return_val():
6969
_url = mock_post.call_args.kwargs["url"]
7070
print("Arguments passed to API=", args_to_api)
7171
print("url = ", _url)
72-
assert _url == "https://api.infinity.ai/v1/rerank"
72+
assert _url == "https://api.infinity.ai/rerank"
7373

7474
request_data = json.loads(args_to_api)
7575
assert request_data["query"] == expected_payload["query"]
@@ -133,7 +133,7 @@ def return_val():
133133
_url = mock_post.call_args.kwargs["url"]
134134
print("Arguments passed to API=", args_to_api)
135135
print("url = ", _url)
136-
assert _url == "https://env.infinity.ai/v1/rerank"
136+
assert _url == "https://env.infinity.ai/rerank"
137137

138138
request_data = json.loads(args_to_api)
139139
assert request_data["query"] == expected_payload["query"]

0 commit comments

Comments
 (0)