Skip to content

Commit 0b696cb

Browse files
committed
Updating get token to work with autogenerated client
1 parent e973ccc commit 0b696cb

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,4 @@ def _generate_path(self, *paths: str) -> str:
201201
return url
202202

203203
def _get_headers(self) -> Dict[str, str]:
204-
return {"Authorization": f"Bearer {self.get_token()}", "Content-Type": "application/json"}
204+
return {"Authorization": f"Bearer {self.get_token().token}", "Content-Type": "application/json"}

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_token_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
import inspect
8-
from typing import cast, Optional, Union
8+
from typing import cast, Optional, Union, Any
99

1010
from azure.core.credentials import TokenCredential, AccessToken
1111
from azure.identity import AzureCliCredential, DefaultAzureCredential, ManagedIdentityCredential
@@ -71,7 +71,7 @@ def get_aad_credential(self) -> Union[DefaultAzureCredential, ManagedIdentityCre
7171
# Fall back to using the parent implementation
7272
return super().get_aad_credential()
7373

74-
def get_token(self) -> str:
74+
def get_token(self, *scopes: str, claims: str | None = None, tenant_id: str | None = None, enable_cae: bool = False, **kwargs: Any) -> AccessToken:
7575
"""Get the API token. If the token is not available or has expired, refresh the token.
7676
7777
:return: API token
@@ -82,7 +82,7 @@ def get_token(self) -> str:
8282
access_token = credential.get_token(self.token_scope)
8383
self._update_token(access_token)
8484

85-
return cast(str, self.token) # check for none is hidden in the _token_needs_update method
85+
return self.token # check for none is hidden in the _token_needs_update method
8686

8787
async def get_token_async(self) -> str:
8888
"""Get the API token asynchronously. If the token is not available or has expired, refresh it.
@@ -112,7 +112,7 @@ def _token_needs_update(self) -> bool:
112112
)
113113

114114
def _update_token(self, access_token: AccessToken) -> None:
115-
self.token = cast(str, access_token.token)
115+
self.token = access_token
116116
self.token_expiry_time = access_token.expires_on
117117
self.last_refresh_time = time.time()
118118
self.logger.info("Refreshed Azure management token.")

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_eval_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def get_metrics_url(self):
295295
return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric"
296296

297297
def _get_token(self) -> str:
298-
return self._management_client.get_token()
298+
return self._management_client.get_token().token
299299

300300
def request_with_retry(
301301
self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,11 @@ def _log_metrics_and_instance_results_onedp(
140140
from azure.ai.evaluation._common import EvaluationServiceOneDPClient, EvaluationUpload
141141

142142
credentials = AzureMLTokenManager(
143-
TokenScope.CONGNITIVE_SERVICES.value, LOGGER, credential=kwargs.get("credential")
143+
TokenScope.COGNITIVE_SERVICES.value, LOGGER, credential=kwargs.get("credential")
144144
)
145145
client = EvaluationServiceOneDPClient(
146146
endpoint=project_url,
147-
credentials=credentials
148-
)
149-
150-
upload_run_response = client.start_evaluation_run(
151-
evaluation=EvaluationUpload(
152-
display_name=evaluation_name,
153-
)
147+
credential=credentials
154148
)
155149

156150
# Massaging before artifacts are put on disk
@@ -187,6 +181,12 @@ def _log_metrics_and_instance_results_onedp(
187181
metrics=metrics
188182
)
189183

184+
upload_run_response = client.start_evaluation_run(
185+
evaluation=EvaluationUpload(
186+
display_name=evaluation_name,
187+
)
188+
)
189+
190190
update_run_response = client.update_evaluation_run(
191191
name=upload_run_response.id,
192192
evaluation=EvaluationUpload(

0 commit comments

Comments
 (0)