Skip to content

Commit

Permalink
Handle azure deepseek reasoning response (#8288) (#8366)
Browse files Browse the repository at this point in the history
* Handle azure deepseek reasoning response (#8288)

* Handle deepseek reasoning response

* Add helper method + unit test

* Fix: Follow infinity api url format (#8346)

* Follow infinity api url format

* Update test_infinity.py

* fix(infinity/transformation.py): fix linting error

---------

Co-authored-by: vibhavbhat <[email protected]>
Co-authored-by: Hao Shan <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2025
1 parent f651d51 commit b5850b6
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time
import traceback
import uuid
from typing import Dict, Iterable, List, Literal, Optional, Union
import re
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple

import litellm
from litellm._logging import verbose_logger
Expand Down Expand Up @@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
# if there is a JSONDecodeError, return the original tool_calls
return tool_calls

def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if not message_text:
return None, None

reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)

if reasoning_match:
return reasoning_match.group(1), reasoning_match.group(2)

return None, message_text

class LiteLLMResponseObjectHandler:

Expand Down Expand Up @@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
for field in choice["message"].keys():
if field not in message_keys:
provider_specific_fields[field] = choice["message"][field]

# Handle reasoning models that display `reasoning_content` within `content`
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
if reasoning_content:
provider_specific_fields["reasoning_content"] = reasoning_content

message = Message(
content=choice["message"].get("content", None),
content=content,
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=tool_calls,
Expand Down
9 changes: 9 additions & 0 deletions litellm/llms/infinity/rerank/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@


class InfinityRerankConfig(CohereRerankConfig):
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity rerank")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/rerank"):
api_base = f"{api_base}/rerank"
return api_base

def validate_environment(
self,
headers: dict,
Expand Down
1 change: 1 addition & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
convert_to_model_response_object,
convert_to_streaming_response,
convert_to_streaming_response_async,
_parse_content_for_reasoning,
)
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
Expand Down
12 changes: 12 additions & 0 deletions tests/litellm_utils_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,18 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)

@pytest.mark.parametrize(
"content, expected_reasoning, expected_content",
[
(None, None, None),
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
("I am a regular response", None, "I am a regular response"),
]
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))


@pytest.mark.parametrize(
"model, expected_bool",
Expand Down
44 changes: 43 additions & 1 deletion tests/llm_translation/test_azure_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from litellm.llms.anthropic.chat import ModelResponseIterator
import httpx
import json
from respx import MockRouter
from litellm.llms.custom_httpx.http_handler import HTTPHandler

load_dotenv()
import io
Expand Down Expand Up @@ -184,3 +184,45 @@ def test_completion_azure_ai_command_r():
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

def test_azure_deepseek_reasoning_content():
import json

client = HTTPHandler()

with patch.object(client, "post") as mock_post:
mock_response = MagicMock()

mock_response.text = json.dumps(
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
"role": "assistant",
}
}
],
}
)

mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response


response = litellm.completion(
model='azure_ai/deepseek-r1',
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
api_key="my-fake-api-key",
client=client
)

print(response)
assert(response.choices[0].message.reasoning_content == "I am thinking here")
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
4 changes: 2 additions & 2 deletions tests/llm_translation/test_infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def return_val():
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.infinity.ai/v1/rerank"
assert _url == "https://api.infinity.ai/rerank"

request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
Expand Down Expand Up @@ -133,7 +133,7 @@ def return_val():
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://env.infinity.ai/v1/rerank"
assert _url == "https://env.infinity.ai/rerank"

request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
Expand Down

0 comments on commit b5850b6

Please sign in to comment.