Skip to content

Commit cda73c9

Browse files
feat(anthropic.py): support citations api with new user document message format
Resolves #7970
1 parent 0c87441 commit cda73c9

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed

litellm/litellm_core_utils/prompt_templates/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,8 @@ def anthropic_messages_pt( # noqa: PLR0915
14211421
)
14221422

14231423
user_content.append(_content_element)
1424+
elif m.get("type", "") == "document":
1425+
user_content.append(cast(AnthropicMessagesDocumentParam, m))
14241426
elif isinstance(user_message_types_block["content"], str):
14251427
_anthropic_content_text_element: AnthropicMessagesTextParam = {
14261428
"type": "text",

litellm/types/llms/anthropic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,17 @@ class AnthropicMessagesImageParam(TypedDict, total=False):
9292
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
9393

9494

95+
class CitationsObject(TypedDict):
96+
enabled: bool
97+
98+
9599
class AnthropicMessagesDocumentParam(TypedDict, total=False):
96100
type: Required[Literal["document"]]
97101
source: Required[AnthropicContentParamSource]
98102
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
103+
title: str
104+
context: str
105+
citations: Optional[CitationsObject]
99106

100107

101108
class AnthropicMessagesToolResultContent(TypedDict):

litellm/types/llms/openai.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,29 @@ class ChatCompletionAudioObject(ChatCompletionContentPartInputAudioParam):
382382
pass
383383

384384

385+
class DocumentObject(TypedDict):
386+
type: Literal["text"]
387+
media_type: str
388+
data: str
389+
390+
391+
class CitationsObject(TypedDict):
392+
enabled: bool
393+
394+
395+
class ChatCompletionDocumentObject(TypedDict):
396+
type: Literal["document"]
397+
source: DocumentObject
398+
title: str
399+
context: str
400+
citations: Optional[CitationsObject]
401+
402+
385403
OpenAIMessageContentListBlock = Union[
386404
ChatCompletionTextObject,
387405
ChatCompletionImageObject,
388406
ChatCompletionAudioObject,
407+
ChatCompletionDocumentObject,
389408
]
390409

391410
OpenAIMessageContent = Union[
@@ -460,6 +479,7 @@ class ChatCompletionDeveloperMessage(OpenAIChatCompletionDeveloperMessage, total
460479
"text",
461480
"image_url",
462481
"input_audio",
482+
"document",
463483
] # used for validating user messages. Prevent users from accidentally sending anthropic messages.
464484

465485
AllMessageValues = Union[

tests/llm_translation/test_anthropic_completion.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,3 +1082,55 @@ async def test_anthropic_structured_output():
10821082
assert response is not None
10831083

10841084
print(response)
1085+
1086+
1087+
def test_anthropic_citations_api():
1088+
"""
1089+
Test the citations API
1090+
"""
1091+
from litellm import completion
1092+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
1093+
import json
1094+
1095+
client = HTTPHandler()
1096+
1097+
with patch.object(client, "post") as mock_post:
1098+
try:
1099+
resp = completion(
1100+
model="claude-3-5-sonnet-20241022",
1101+
messages=[
1102+
{
1103+
"role": "user",
1104+
"content": [
1105+
{
1106+
"type": "document",
1107+
"source": {
1108+
"type": "text",
1109+
"media_type": "text/plain",
1110+
"data": "The grass is green. The sky is blue.",
1111+
},
1112+
"title": "My Document",
1113+
"context": "This is a trustworthy document.",
1114+
"citations": {"enabled": True},
1115+
},
1116+
{
1117+
"type": "text",
1118+
"text": "What color is the grass and sky?",
1119+
},
1120+
],
1121+
}
1122+
],
1123+
client=client,
1124+
)
1125+
1126+
print(resp)
1127+
except Exception as e:
1128+
print(e)
1129+
1130+
mock_post.assert_called_once()
1131+
1132+
print(mock_post.call_args.kwargs)
1133+
1134+
json_data = json.loads(mock_post.call_args.kwargs["data"])
1135+
1136+
assert json_data["messages"][0]["content"][0]["citations"]["enabled"]

0 commit comments

Comments
 (0)