Skip to content

Commit 02736ac

Browse files
committed
feat FT cancel and LIST endpoints for Azure
1 parent c6bff32 commit 02736ac

File tree

3 files changed

+132
-54
lines changed

3 files changed

+132
-54
lines changed

litellm/fine_tuning/main.py

Lines changed: 107 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,25 @@ def cancel_fine_tuning_job(
279279
"""
280280
try:
281281
optional_params = GenericLiteLLMParams(**kwargs)
282+
### TIMEOUT LOGIC ###
283+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
284+
# set timeout for 10 minutes by default
285+
286+
if (
287+
timeout is not None
288+
and isinstance(timeout, httpx.Timeout)
289+
and supports_httpx_timeout(custom_llm_provider) == False
290+
):
291+
read_timeout = timeout.read or 600
292+
timeout = read_timeout # default 10 min timeout
293+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
294+
timeout = float(timeout) # type: ignore
295+
elif timeout is None:
296+
timeout = 600.0
297+
298+
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
299+
300+
# OpenAI
282301
if custom_llm_provider == "openai":
283302

284303
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@@ -301,25 +320,6 @@ def cancel_fine_tuning_job(
301320
or litellm.openai_key
302321
or os.getenv("OPENAI_API_KEY")
303322
)
304-
### TIMEOUT LOGIC ###
305-
timeout = (
306-
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
307-
)
308-
# set timeout for 10 minutes by default
309-
310-
if (
311-
timeout is not None
312-
and isinstance(timeout, httpx.Timeout)
313-
and supports_httpx_timeout(custom_llm_provider) == False
314-
):
315-
read_timeout = timeout.read or 600
316-
timeout = read_timeout # default 10 min timeout
317-
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
318-
timeout = float(timeout) # type: ignore
319-
elif timeout is None:
320-
timeout = 600.0
321-
322-
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
323323

324324
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
325325
api_base=api_base,
@@ -330,6 +330,40 @@ def cancel_fine_tuning_job(
330330
max_retries=optional_params.max_retries,
331331
_is_async=_is_async,
332332
)
333+
# Azure OpenAI
334+
elif custom_llm_provider == "azure":
335+
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
336+
337+
api_version = (
338+
optional_params.api_version
339+
or litellm.api_version
340+
or get_secret("AZURE_API_VERSION")
341+
) # type: ignore
342+
343+
api_key = (
344+
optional_params.api_key
345+
or litellm.api_key
346+
or litellm.azure_key
347+
or get_secret("AZURE_OPENAI_API_KEY")
348+
or get_secret("AZURE_API_KEY")
349+
) # type: ignore
350+
351+
extra_body = optional_params.get("extra_body", {})
352+
azure_ad_token: Optional[str] = None
353+
if extra_body is not None:
354+
azure_ad_token = extra_body.pop("azure_ad_token", None)
355+
else:
356+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
357+
358+
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
359+
api_base=api_base,
360+
api_key=api_key,
361+
api_version=api_version,
362+
fine_tuning_job_id=fine_tuning_job_id,
363+
timeout=timeout,
364+
max_retries=optional_params.max_retries,
365+
_is_async=_is_async,
366+
)
333367
else:
334368
raise litellm.exceptions.BadRequestError(
335369
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
@@ -405,6 +439,25 @@ def list_fine_tuning_jobs(
405439
"""
406440
try:
407441
optional_params = GenericLiteLLMParams(**kwargs)
442+
### TIMEOUT LOGIC ###
443+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
444+
# set timeout for 10 minutes by default
445+
446+
if (
447+
timeout is not None
448+
and isinstance(timeout, httpx.Timeout)
449+
and supports_httpx_timeout(custom_llm_provider) == False
450+
):
451+
read_timeout = timeout.read or 600
452+
timeout = read_timeout # default 10 min timeout
453+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
454+
timeout = float(timeout) # type: ignore
455+
elif timeout is None:
456+
timeout = 600.0
457+
458+
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
459+
460+
# OpenAI
408461
if custom_llm_provider == "openai":
409462

410463
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@@ -427,25 +480,6 @@ def list_fine_tuning_jobs(
427480
or litellm.openai_key
428481
or os.getenv("OPENAI_API_KEY")
429482
)
430-
### TIMEOUT LOGIC ###
431-
timeout = (
432-
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
433-
)
434-
# set timeout for 10 minutes by default
435-
436-
if (
437-
timeout is not None
438-
and isinstance(timeout, httpx.Timeout)
439-
and supports_httpx_timeout(custom_llm_provider) == False
440-
):
441-
read_timeout = timeout.read or 600
442-
timeout = read_timeout # default 10 min timeout
443-
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
444-
timeout = float(timeout) # type: ignore
445-
elif timeout is None:
446-
timeout = 600.0
447-
448-
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
449483

450484
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
451485
api_base=api_base,
@@ -457,6 +491,41 @@ def list_fine_tuning_jobs(
457491
max_retries=optional_params.max_retries,
458492
_is_async=_is_async,
459493
)
494+
# Azure OpenAI
495+
elif custom_llm_provider == "azure":
496+
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
497+
498+
api_version = (
499+
optional_params.api_version
500+
or litellm.api_version
501+
or get_secret("AZURE_API_VERSION")
502+
) # type: ignore
503+
504+
api_key = (
505+
optional_params.api_key
506+
or litellm.api_key
507+
or litellm.azure_key
508+
or get_secret("AZURE_OPENAI_API_KEY")
509+
or get_secret("AZURE_API_KEY")
510+
) # type: ignore
511+
512+
extra_body = optional_params.get("extra_body", {})
513+
azure_ad_token: Optional[str] = None
514+
if extra_body is not None:
515+
azure_ad_token = extra_body.pop("azure_ad_token", None)
516+
else:
517+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
518+
519+
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
520+
api_base=api_base,
521+
api_key=api_key,
522+
api_version=api_version,
523+
after=after,
524+
limit=limit,
525+
timeout=timeout,
526+
max_retries=optional_params.max_retries,
527+
_is_async=_is_async,
528+
)
460529
else:
461530
raise litellm.exceptions.BadRequestError(
462531
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(

litellm/llms/fine_tuning_apis/azure.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ def cancel_fine_tuning_job(
9191
api_base: Optional[str],
9292
timeout: Union[float, httpx.Timeout],
9393
max_retries: Optional[int],
94-
organization: Optional[str],
94+
organization: Optional[str] = None,
95+
api_version: Optional[str] = None,
9596
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
9697
):
9798
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
9899
get_azure_openai_client(
99100
api_key=api_key,
100101
api_base=api_base,
102+
api_version=api_version,
101103
timeout=timeout,
102104
max_retries=max_retries,
103105
organization=organization,
@@ -141,15 +143,17 @@ def list_fine_tuning_jobs(
141143
api_base: Optional[str],
142144
timeout: Union[float, httpx.Timeout],
143145
max_retries: Optional[int],
144-
organization: Optional[str],
146+
organization: Optional[str] = None,
145147
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
148+
api_version: Optional[str] = None,
146149
after: Optional[str] = None,
147150
limit: Optional[int] = None,
148151
):
149152
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
150153
get_azure_openai_client(
151154
api_key=api_key,
152155
api_base=api_base,
156+
api_version=api_version,
153157
timeout=timeout,
154158
max_retries=max_retries,
155159
organization=organization,
@@ -175,4 +179,3 @@ def list_fine_tuning_jobs(
175179
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
176180
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
177181
return response
178-
pass

litellm/tests/test_fine_tuning_api.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,25 +146,31 @@ async def test_azure_create_fine_tune_jobs_async():
146146
assert create_fine_tuning_response.id is not None
147147
assert create_fine_tuning_response.model == "gpt-35-turbo-1106"
148148

149-
# # list fine tuning jobs
150-
# print("listing ft jobs")
151-
# ft_jobs = await litellm.alist_fine_tuning_jobs(limit=2)
152-
# print("response from litellm.list_fine_tuning_jobs=", ft_jobs)
153-
# assert len(list(ft_jobs)) > 0
149+
# list fine tuning jobs
150+
print("listing ft jobs")
151+
ft_jobs = await litellm.alist_fine_tuning_jobs(
152+
limit=2,
153+
custom_llm_provider="azure",
154+
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
155+
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
156+
)
157+
print("response from litellm.list_fine_tuning_jobs=", ft_jobs)
154158

155159
# # delete file
156160

157161
# await litellm.afile_delete(
158162
# file_id=file_obj.id,
159163
# )
160164

161-
# # cancel ft job
162-
# response = await litellm.acancel_fine_tuning_job(
163-
# fine_tuning_job_id=create_fine_tuning_response.id,
164-
# )
165+
# cancel ft job
166+
response = await litellm.acancel_fine_tuning_job(
167+
fine_tuning_job_id=create_fine_tuning_response.id,
168+
custom_llm_provider="azure",
169+
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
170+
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
171+
)
165172

166-
# print("response from litellm.cancel_fine_tuning_job=", response)
173+
print("response from litellm.cancel_fine_tuning_job=", response)
167174

168-
# assert response.status == "cancelled"
169-
# assert response.id == create_fine_tuning_response.id
170-
# pass
175+
assert response.status == "cancelled"
176+
assert response.id == create_fine_tuning_response.id

0 commit comments

Comments
 (0)