Skip to content

Commit

Permalink
feat(model_providers): Support deepseek for Azure AI Foundry (#13267)
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 authored Feb 6, 2025
1 parent f6c44ca commit 87763fc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 58 deletions.
27 changes: 17 additions & 10 deletions api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Any, Optional, Union

from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import StreamingChatCompletionsUpdate
from azure.ai.inference.models import StreamingChatCompletionsUpdate, SystemMessage, UserMessage
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import (
ClientAuthenticationError,
Expand Down Expand Up @@ -60,10 +60,10 @@ def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[Sequence[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
Expand All @@ -82,8 +82,8 @@ def _invoke(
"""

if not self.client:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
endpoint = str(credentials.get("endpoint"))
api_key = str(credentials.get("api_key"))
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))

messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
Expand All @@ -94,6 +94,7 @@ def _invoke(
"temperature": model_parameters.get("temperature", 0),
"top_p": model_parameters.get("top_p", 1),
"stream": stream,
"model": model,
}

if stop:
Expand Down Expand Up @@ -255,10 +256,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
endpoint = str(credentials.get("endpoint"))
api_key = str(credentials.get("api_key"))
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
client.get_model_info()
client.complete(
messages=[
SystemMessage(content="I say 'ping', you say 'pong'"),
UserMessage(content="ping"),
],
model=model,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

Expand Down
53 changes: 6 additions & 47 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 87763fc

Please sign in to comment.