12
12
from typing import Dict , List , Optional , Union , cast
13
13
from urllib .parse import urlparse
14
14
from string import Template
15
+ from azure .ai .evaluation ._common .onedp ._client import AIProjectClient
16
+ from azure .core .exceptions import HttpResponseError
15
17
16
18
import jwt
17
19
18
20
from azure .ai .evaluation ._legacy ._adapters ._errors import MissingRequiredPackage
19
21
from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
20
22
from azure .ai .evaluation ._http_utils import AsyncHttpPipeline , get_async_http_client
21
23
from azure .ai .evaluation ._model_configurations import AzureAIProject
24
+ from azure .ai .evaluation ._common .utils import is_onedp_project
22
25
from azure .core .credentials import TokenCredential
23
26
from azure .core .exceptions import HttpResponseError
24
27
from azure .core .pipeline .policies import AsyncRetryPolicy
41
44
USER_TEXT_TEMPLATE_DICT : Dict [str , Template ] = {
42
45
"DEFAULT" : Template ("<Human>{$query}</><System>{$response}</>" ),
43
46
}
47
+ ML_WORKSPACE = "https://management.azure.com/.default"
48
+ COG_SRV_WORKSPACE = "https://cognitiveservices.azure.com/.default"
44
49
45
50
INFERENCE_OF_SENSITIVE_ATTRIBUTES = "inference_sensitive_attributes"
46
51
@@ -99,11 +104,7 @@ def get_common_headers(token: str, evaluator_name: Optional[str] = None) -> Dict
99
104
user_agent = f"{ USER_AGENT } (type=evaluator; subtype={ evaluator_name } )" if evaluator_name else USER_AGENT
100
105
return {
101
106
"Authorization" : f"Bearer { token } " ,
102
- "Content-Type" : "application/json" ,
103
107
"User-Agent" : user_agent ,
104
- # Handle "RuntimeError: Event loop is closed" from httpx AsyncClient
105
- # https://github.com/encode/httpx/discussions/2959
106
- "Connection" : "close" ,
107
108
}
108
109
109
110
@@ -112,7 +113,31 @@ def get_async_http_client_with_timeout() -> AsyncHttpPipeline:
112
113
retry_policy = AsyncRetryPolicy (timeout = CommonConstants .DEFAULT_HTTP_TIMEOUT )
113
114
)
114
115
116
+ async def ensure_service_availability_onedp (client : AIProjectClient , token : str , capability : Optional [str ] = None ) -> None :
117
+ """Check if the Responsible AI service is available in the region and has the required capability, if relevant.
115
118
119
+ :param client: The AI project client.
120
+ :type client: AIProjectClient
121
+ :param token: The Azure authentication token.
122
+ :type token: str
123
+ :param capability: The capability to check. Default is None.
124
+ :type capability: str
125
+ :raises Exception: If the service is not available in the region or the capability is not available.
126
+ """
127
+ headers = get_common_headers (token )
128
+ capabilities = client .evaluations .check_annotation (headers = headers )
129
+
130
+ if capability and capability not in capabilities :
131
+ msg = f"The needed capability '{ capability } ' is not supported by the RAI service in this region."
132
+ raise EvaluationException (
133
+ message = msg ,
134
+ internal_message = msg ,
135
+ target = ErrorTarget .RAI_CLIENT ,
136
+ category = ErrorCategory .SERVICE_UNAVAILABLE ,
137
+ blame = ErrorBlame .USER_ERROR ,
138
+ tsg_link = "https://aka.ms/azsdk/python/evaluation/safetyevaluator/troubleshoot" ,
139
+ )
140
+
116
141
async def ensure_service_availability (rai_svc_url : str , token : str , capability : Optional [str ] = None ) -> None :
117
142
"""Check if the Responsible AI service is available in the region and has the required capability, if relevant.
118
143
@@ -231,6 +256,40 @@ async def submit_request(
231
256
return operation_id
232
257
233
258
259
+ async def submit_request_onedp (
260
+ client : AIProjectClient ,
261
+ data : dict ,
262
+ metric : str ,
263
+ token : str ,
264
+ annotation_task : str ,
265
+ evaluator_name : str
266
+ ) -> str :
267
+ """Submit request to Responsible AI service for evaluation and return operation ID
268
+
269
+ :param client: The AI project client.
270
+ :type client: AIProjectClient
271
+ :param data: The data to evaluate.
272
+ :type data: dict
273
+ :param metric: The evaluation metric to use.
274
+ :type metric: str
275
+ :param token: The Azure authentication token.
276
+ :type token: str
277
+ :param annotation_task: The annotation task to use.
278
+ :type annotation_task: str
279
+ :param evaluator_name: The evaluator name.
280
+ :type evaluator_name: str
281
+ :return: The operation ID.
282
+ :rtype: str
283
+ """
284
+ normalized_user_text = get_formatted_template (data , annotation_task )
285
+ payload = generate_payload (normalized_user_text , metric , annotation_task = annotation_task )
286
+ headers = get_common_headers (token , evaluator_name )
287
+ response = client .evaluations .submit_annotation (payload , headers = headers )
288
+ result = json .loads (response )
289
+ operation_id = result ["location" ].split ("/" )[- 1 ]
290
+ return operation_id
291
+
292
+
234
293
async def fetch_result (operation_id : str , rai_svc_url : str , credential : TokenCredential , token : str ) -> Dict :
235
294
"""Fetch the annotation result from Responsible AI service
236
295
@@ -267,6 +326,34 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
267
326
sleep_time = RAIService .SLEEP_TIME ** request_count
268
327
await asyncio .sleep (sleep_time )
269
328
329
+ async def fetch_result_onedp (client : AIProjectClient , operation_id : str , token : str ) -> Dict :
330
+ """Fetch the annotation result from Responsible AI service
331
+
332
+ :param client: The AI project client.
333
+ :type client: AIProjectClient
334
+ :param operation_id: The operation ID.
335
+ :type operation_id: str
336
+ :param token: The Azure authentication token.
337
+ :type token: str
338
+ :return: The annotation result.
339
+ :rtype: Dict
340
+ """
341
+ start = time .time ()
342
+ request_count = 0
343
+
344
+ while True :
345
+ headers = get_common_headers (token )
346
+ try :
347
+ return client .evaluations .operation_results (operation_id , headers = headers )
348
+ except HttpResponseError :
349
+ request_count += 1
350
+ time_elapsed = time .time () - start
351
+ if time_elapsed > RAIService .TIMEOUT :
352
+ raise TimeoutError (f"Fetching annotation result { request_count } times out after { time_elapsed :.2f} seconds" )
353
+
354
+ sleep_time = RAIService .SLEEP_TIME ** request_count
355
+ await asyncio .sleep (sleep_time )
356
+
270
357
def parse_response ( # pylint: disable=too-many-branches,too-many-statements
271
358
batch_response : List [Dict ], metric_name : str , metric_display_name : Optional [str ] = None
272
359
) -> Dict [str , Union [str , float ]]:
@@ -500,7 +587,7 @@ async def get_rai_svc_url(project_scope: AzureAIProject, token: str) -> str:
500
587
return rai_url
501
588
502
589
503
- async def fetch_or_reuse_token (credential : TokenCredential , token : Optional [str ] = None ) -> str :
590
+ async def fetch_or_reuse_token (credential : TokenCredential , token : Optional [str ] = None , workspace : Optional [ str ] = ML_WORKSPACE ) -> str :
504
591
"""Get token. Fetch a new token if the current token is near expiry
505
592
506
593
:param credential: The Azure authentication credential.
@@ -524,13 +611,13 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str]
524
611
if (exp_time - current_time ) >= 300 :
525
612
return token
526
613
527
- return credential .get_token ("https://management.azure.com/.default" ).token
614
+ return credential .get_token (workspace ).token
528
615
529
616
530
617
async def evaluate_with_rai_service (
531
618
data : dict ,
532
619
metric_name : str ,
533
- project_scope : AzureAIProject ,
620
+ project_scope : Union [ str , AzureAIProject ] ,
534
621
credential : TokenCredential ,
535
622
annotation_task : str = Tasks .CONTENT_HARM ,
536
623
metric_display_name = None ,
@@ -556,18 +643,26 @@ async def evaluate_with_rai_service(
556
643
:rtype: Dict[str, Union[str, float]]
557
644
"""
558
645
559
- # Get RAI service URL from discovery service and check service availability
560
- token = await fetch_or_reuse_token (credential )
561
- rai_svc_url = await get_rai_svc_url (project_scope , token )
562
- await ensure_service_availability (rai_svc_url , token , annotation_task )
563
-
564
- # Submit annotation request and fetch result
565
- operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task , evaluator_name )
566
- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
567
- result = parse_response (annotation_response , metric_name , metric_display_name )
646
+ if is_onedp_project (project_scope ):
647
+ client = AIProjectClient (endpoint = project_scope , credential = credential )
648
+ token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
649
+ await ensure_service_availability_onedp (client , token , annotation_task )
650
+ operation_id = await submit_request_onedp (client , data , metric_name , token , annotation_task , evaluator_name )
651
+ annotation_response = cast (List [Dict ], await fetch_result_onedp (client , operation_id , token ))
652
+ result = parse_response (annotation_response , metric_name , metric_display_name )
653
+ return result
654
+ else :
655
+ # Get RAI service URL from discovery service and check service availability
656
+ token = await fetch_or_reuse_token (credential )
657
+ rai_svc_url = await get_rai_svc_url (project_scope , token )
658
+ await ensure_service_availability (rai_svc_url , token , annotation_task )
568
659
569
- return result
660
+ # Submit annotation request and fetch result
661
+ operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task , evaluator_name )
662
+ annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
663
+ result = parse_response (annotation_response , metric_name , metric_display_name )
570
664
665
+ return result
571
666
572
667
def generate_payload_multimodal (content_type : str , messages , metric : str ) -> Dict :
573
668
"""Generate the payload for the annotation request
@@ -600,7 +695,6 @@ def generate_payload_multimodal(content_type: str, messages, metric: str) -> Dic
600
695
"AnnotationTask" : task ,
601
696
}
602
697
603
-
604
698
async def submit_multimodal_request (messages , metric : str , rai_svc_url : str , token : str ) -> str :
605
699
"""Submit request to Responsible AI service for evaluation and return operation ID
606
700
:param messages: The normalized list of messages to be entered as the "Contents" in the payload.
@@ -646,9 +740,37 @@ async def submit_multimodal_request(messages, metric: str, rai_svc_url: str, tok
646
740
operation_id = result ["location" ].split ("/" )[- 1 ]
647
741
return operation_id
648
742
743
+ async def submit_multimodal_request_onedp (client : AIProjectClient , messages , metric : str , token : str ) -> str :
744
+
745
+ # handle inference sdk strongly type messages
746
+ if len (messages ) > 0 and not isinstance (messages [0 ], dict ):
747
+ try :
748
+ from azure .ai .inference .models import ChatRequestMessage
749
+ except ImportError as ex :
750
+ error_message = (
751
+ "Please install 'azure-ai-inference' package to use SystemMessage, UserMessage, AssistantMessage"
752
+ )
753
+ raise MissingRequiredPackage (message = error_message ) from ex
754
+ if len (messages ) > 0 and isinstance (messages [0 ], ChatRequestMessage ):
755
+ messages = [message .as_dict () for message in messages ]
756
+
757
+ ## fetch system and assistant messages from the list of messages
758
+ filtered_messages = [message for message in messages if message ["role" ] != "system" ]
759
+ assistant_messages = [message for message in messages if message ["role" ] == "assistant" ]
760
+
761
+ ## prepare for request
762
+ content_type = retrieve_content_type (assistant_messages , metric )
763
+ payload = generate_payload_multimodal (content_type , filtered_messages , metric )
764
+ headers = get_common_headers (token )
765
+
766
+ response = client .evaluations .submit_annotation (payload , headers = headers )
767
+
768
+ result = json .loads (response )
769
+ operation_id = result ["location" ].split ("/" )[- 1 ]
770
+ return operation_id
649
771
650
772
async def evaluate_with_rai_service_multimodal (
651
- messages , metric_name : str , project_scope : AzureAIProject , credential : TokenCredential
773
+ messages , metric_name : str , project_scope : Union [ str , AzureAIProject ] , credential : TokenCredential
652
774
):
653
775
""" "Evaluate the content safety of the response using Responsible AI service
654
776
:param messages: The normalized list of messages.
@@ -664,12 +786,20 @@ async def evaluate_with_rai_service_multimodal(
664
786
:rtype: List[List[Dict]]
665
787
"""
666
788
667
- # Get RAI service URL from discovery service and check service availability
668
- token = await fetch_or_reuse_token (credential )
669
- rai_svc_url = await get_rai_svc_url (project_scope , token )
670
- await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
671
- # Submit annotation request and fetch result
672
- operation_id = await submit_multimodal_request (messages , metric_name , rai_svc_url , token )
673
- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
674
- result = parse_response (annotation_response , metric_name )
675
- return result
789
+ if is_onedp_project (project_scope ):
790
+ client = AIProjectClient (endpoint = project_scope , credential = credential )
791
+ token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
792
+ await ensure_service_availability_onedp (client , token , Tasks .CONTENT_HARM )
793
+ operation_id = await submit_multimodal_request_onedp (client , messages , metric_name , token )
794
+ annotation_response = cast (List [Dict ], await fetch_result_onedp (client , operation_id , token ))
795
+ result = parse_response (annotation_response , metric_name )
796
+ return result
797
+ else :
798
+ token = await fetch_or_reuse_token (credential )
799
+ rai_svc_url = await get_rai_svc_url (project_scope , token )
800
+ await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
801
+ # Submit annotation request and fetch result
802
+ operation_id = await submit_multimodal_request (messages , metric_name , rai_svc_url , token )
803
+ annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
804
+ result = parse_response (annotation_response , metric_name )
805
+ return result
0 commit comments