From 02736ac8b5d776a58c4552b9549f8dc9ff471ebc Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 30 Jul 2024 16:03:31 -0700 Subject: [PATCH] feat FT cancel and LIST endpoints for Azure --- litellm/fine_tuning/main.py | 145 ++++++++++++++++++------- litellm/llms/fine_tuning_apis/azure.py | 9 +- litellm/tests/test_fine_tuning_api.py | 32 +++--- 3 files changed, 132 insertions(+), 54 deletions(-) diff --git a/litellm/fine_tuning/main.py b/litellm/fine_tuning/main.py index 72119185f222..5206cb789726 100644 --- a/litellm/fine_tuning/main.py +++ b/litellm/fine_tuning/main.py @@ -279,6 +279,25 @@ def cancel_fine_tuning_job( """ try: optional_params = GenericLiteLLMParams(**kwargs) + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _is_async = kwargs.pop("acancel_fine_tuning_job", False) is True + + # OpenAI if custom_llm_provider == "openai": # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there @@ -301,25 +320,6 @@ def cancel_fine_tuning_job( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) - ### TIMEOUT LOGIC ### - timeout = ( - optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - ) - # set timeout for 10 minutes by default - - if ( - timeout is not None - and isinstance(timeout, httpx.Timeout) - and supports_httpx_timeout(custom_llm_provider) == False - ): - read_timeout = timeout.read or 600 - timeout = read_timeout # default 10 min timeout - elif timeout is not None and not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - elif timeout is None: - timeout = 600.0 - - _is_async = kwargs.pop("acancel_fine_tuning_job", False) is True response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job( api_base=api_base, @@ -330,6 +330,40 @@ def cancel_fine_tuning_job( max_retries=optional_params.max_retries, _is_async=_is_async, ) + # Azure OpenAI + elif custom_llm_provider == "azure": + api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job( + api_base=api_base, + api_key=api_key, + api_version=api_version, + fine_tuning_job_id=fine_tuning_job_id, + timeout=timeout, + max_retries=optional_params.max_retries, + _is_async=_is_async, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( @@ -405,6 +439,25 @@ def list_fine_tuning_jobs( """ try: optional_params = GenericLiteLLMParams(**kwargs) + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True + + # OpenAI if custom_llm_provider == "openai": # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there @@ -427,25 +480,6 @@ def list_fine_tuning_jobs( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) - ### TIMEOUT LOGIC ### - timeout = ( - optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - ) - # set timeout for 10 minutes by default - - if ( - timeout is not None - and isinstance(timeout, httpx.Timeout) - and supports_httpx_timeout(custom_llm_provider) == False - ): - read_timeout = timeout.read or 600 - timeout = read_timeout # default 10 min timeout - elif timeout is not None and not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - elif timeout is None: - timeout = 600.0 - - _is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs( api_base=api_base, @@ -457,6 +491,41 @@ def list_fine_tuning_jobs( max_retries=optional_params.max_retries, _is_async=_is_async, ) + # Azure OpenAI + elif custom_llm_provider == "azure": + api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs( + api_base=api_base, + api_key=api_key, + api_version=api_version, + after=after, + limit=limit, + timeout=timeout, + max_retries=optional_params.max_retries, + _is_async=_is_async, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( diff --git a/litellm/llms/fine_tuning_apis/azure.py b/litellm/llms/fine_tuning_apis/azure.py index 6c32e2ac7881..0e6e0e66d317 100644 --- a/litellm/llms/fine_tuning_apis/azure.py +++ b/litellm/llms/fine_tuning_apis/azure.py @@ -91,13 +91,15 @@ def cancel_fine_tuning_job( api_base: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], - organization: Optional[str], + organization: Optional[str] = None, + api_version: Optional[str] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, ): openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( get_azure_openai_client( api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, @@ -141,8 +143,9 @@ def list_fine_tuning_jobs( api_base: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], - organization: Optional[str], + organization: Optional[str] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + api_version: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = None, ): @@ -150,6 +153,7 @@ def list_fine_tuning_jobs( get_azure_openai_client( api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, @@ -175,4 +179,3 @@ def list_fine_tuning_jobs( verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore return response - pass diff --git a/litellm/tests/test_fine_tuning_api.py b/litellm/tests/test_fine_tuning_api.py index d35cd661ab88..1f99d63582a7 100644 --- a/litellm/tests/test_fine_tuning_api.py +++ b/litellm/tests/test_fine_tuning_api.py @@ -146,11 +146,15 @@ async def test_azure_create_fine_tune_jobs_async(): assert create_fine_tuning_response.id is not None assert create_fine_tuning_response.model == "gpt-35-turbo-1106" - # # list fine tuning jobs - # print("listing ft jobs") - # ft_jobs = await litellm.alist_fine_tuning_jobs(limit=2) - # print("response from litellm.list_fine_tuning_jobs=", ft_jobs) - # assert len(list(ft_jobs)) > 0 + # list fine tuning jobs + print("listing ft jobs") + ft_jobs = await litellm.alist_fine_tuning_jobs( + limit=2, + custom_llm_provider="azure", + api_key=os.getenv("AZURE_SWEDEN_API_KEY"), + api_base="https://my-endpoint-sweden-berri992.openai.azure.com/", + ) + print("response from litellm.list_fine_tuning_jobs=", ft_jobs) # # delete file @@ -158,13 +162,15 @@ async def test_azure_create_fine_tune_jobs_async(): # file_id=file_obj.id, # ) - # # cancel ft job - # response = await litellm.acancel_fine_tuning_job( - # fine_tuning_job_id=create_fine_tuning_response.id, - # ) + # cancel ft job + response = await litellm.acancel_fine_tuning_job( + fine_tuning_job_id=create_fine_tuning_response.id, + custom_llm_provider="azure", + api_key=os.getenv("AZURE_SWEDEN_API_KEY"), + api_base="https://my-endpoint-sweden-berri992.openai.azure.com/", + ) - # print("response from litellm.cancel_fine_tuning_job=", response) + print("response from litellm.cancel_fine_tuning_job=", response) - # assert response.status == "cancelled" - # assert response.id == create_fine_tuning_response.id - # pass + assert response.status == "cancelled" + assert response.id == create_fine_tuning_response.id