-
Notifications
You must be signed in to change notification settings - Fork 428
/
Copy path_utils.py
194 lines (162 loc) · 7.53 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import json
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from ddtrace import config
from ddtrace.ext import SpanTypes
from ddtrace.internal.logger import get_logger
from ddtrace.llmobs._constants import GEMINI_APM_SPAN_NAME
from ddtrace.llmobs._constants import INTERNAL_CONTEXT_VARIABLE_KEYS
from ddtrace.llmobs._constants import INTERNAL_QUERY_VARIABLE_KEYS
from ddtrace.llmobs._constants import IS_EVALUATION_SPAN
from ddtrace.llmobs._constants import LANGCHAIN_APM_SPAN_NAME
from ddtrace.llmobs._constants import ML_APP
from ddtrace.llmobs._constants import NAME
from ddtrace.llmobs._constants import OPENAI_APM_SPAN_NAME
from ddtrace.llmobs._constants import SESSION_ID
from ddtrace.llmobs._constants import VERTEXAI_APM_SPAN_NAME
from ddtrace.trace import Span
log = get_logger(__name__)
def validate_prompt(prompt: dict) -> Dict[str, Union[str, dict, List[str]]]:
validated_prompt = {} # type: Dict[str, Union[str, dict, List[str]]]
if not isinstance(prompt, dict):
raise TypeError("Prompt must be a dictionary")
variables = prompt.get("variables")
template = prompt.get("template")
version = prompt.get("version")
prompt_id = prompt.get("id")
ctx_variable_keys = prompt.get("rag_context_variables")
rag_query_variable_keys = prompt.get("rag_query_variables")
if variables is not None:
if not isinstance(variables, dict):
raise TypeError("Prompt variables must be a dictionary.")
if not any(isinstance(k, str) or isinstance(v, str) for k, v in variables.items()):
raise TypeError("Prompt variable keys and values must be strings.")
validated_prompt["variables"] = variables
if template is not None:
if not isinstance(template, str):
raise TypeError("Prompt template must be a string")
validated_prompt["template"] = template
if version is not None:
if not isinstance(version, str):
raise TypeError("Prompt version must be a string.")
validated_prompt["version"] = version
if prompt_id is not None:
if not isinstance(prompt_id, str):
raise TypeError("Prompt id must be a string.")
validated_prompt["id"] = prompt_id
if ctx_variable_keys is not None:
if not isinstance(ctx_variable_keys, list):
raise TypeError("Prompt field `context_variable_keys` must be a list of strings.")
if not all(isinstance(k, str) for k in ctx_variable_keys):
raise TypeError("Prompt field `context_variable_keys` must be a list of strings.")
validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ctx_variable_keys
else:
validated_prompt[INTERNAL_CONTEXT_VARIABLE_KEYS] = ["context"]
if rag_query_variable_keys is not None:
if not isinstance(rag_query_variable_keys, list):
raise TypeError("Prompt field `rag_query_variables` must be a list of strings.")
if not all(isinstance(k, str) for k in rag_query_variable_keys):
raise TypeError("Prompt field `rag_query_variables` must be a list of strings.")
validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = rag_query_variable_keys
else:
validated_prompt[INTERNAL_QUERY_VARIABLE_KEYS] = ["question"]
return validated_prompt
class LinkTracker:
def __init__(self, object_span_links=None):
self._object_span_links = object_span_links or {}
def get_object_id(self, obj):
return f"{type(obj).__name__}_{id(obj)}"
def add_span_links_to_object(self, obj, span_links):
obj_id = self.get_object_id(obj)
if obj_id not in self._object_span_links:
self._object_span_links[obj_id] = []
self._object_span_links[obj_id] += span_links
def get_span_links_from_object(self, obj):
return self._object_span_links.get(self.get_object_id(obj), [])
class AnnotationContext:
def __init__(self, _register_annotator, _deregister_annotator):
self._register_annotator = _register_annotator
self._deregister_annotator = _deregister_annotator
def __enter__(self):
self._register_annotator()
def __exit__(self, exc_type, exc_val, exc_tb):
self._deregister_annotator()
async def __aenter__(self):
self._register_annotator()
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._deregister_annotator()
def _get_attr(o: object, attr: str, default: object):
# Convenience method to get an attribute from an object or dict
if isinstance(o, dict):
return o.get(attr, default)
return getattr(o, attr, default)
def _get_nearest_llmobs_ancestor(span: Span) -> Optional[Span]:
"""Return the nearest LLMObs-type ancestor span of a given span."""
parent = span._parent
while parent:
if parent.span_type == SpanTypes.LLM:
return parent
parent = parent._parent
return None
def _get_span_name(span: Span) -> str:
if span.name in (LANGCHAIN_APM_SPAN_NAME, GEMINI_APM_SPAN_NAME, VERTEXAI_APM_SPAN_NAME) and span.resource != "":
return span.resource
elif span.name == OPENAI_APM_SPAN_NAME and span.resource != "":
client_name = span.get_tag("openai.request.client") or "OpenAI"
return "{}.{}".format(client_name, span.resource)
return span._get_ctx_item(NAME) or span.name
def _is_evaluation_span(span: Span) -> bool:
"""
Return whether or not a span is an evaluation span by checking the span's
nearest LLMObs span ancestor. Default to 'False'
"""
is_evaluation_span = span._get_ctx_item(IS_EVALUATION_SPAN)
if is_evaluation_span:
return is_evaluation_span
llmobs_parent = _get_nearest_llmobs_ancestor(span)
while llmobs_parent:
is_evaluation_span = llmobs_parent._get_ctx_item(IS_EVALUATION_SPAN)
if is_evaluation_span:
return is_evaluation_span
llmobs_parent = _get_nearest_llmobs_ancestor(llmobs_parent)
return False
def _get_ml_app(span: Span) -> str:
"""
Return the ML app name for a given span, by checking the span's nearest LLMObs span ancestor.
Default to the global config LLMObs ML app name otherwise.
"""
ml_app = span._get_ctx_item(ML_APP)
if ml_app:
return ml_app
llmobs_parent = _get_nearest_llmobs_ancestor(span)
while llmobs_parent:
ml_app = llmobs_parent._get_ctx_item(ML_APP)
if ml_app is not None:
return ml_app
llmobs_parent = _get_nearest_llmobs_ancestor(llmobs_parent)
return ml_app or config._llmobs_ml_app or "unknown-ml-app"
def _get_session_id(span: Span) -> Optional[str]:
"""Return the session ID for a given span, by checking the span's nearest LLMObs span ancestor."""
session_id = span._get_ctx_item(SESSION_ID)
if session_id:
return session_id
llmobs_parent = _get_nearest_llmobs_ancestor(span)
while llmobs_parent:
session_id = llmobs_parent._get_ctx_item(SESSION_ID)
if session_id is not None:
return session_id
llmobs_parent = _get_nearest_llmobs_ancestor(llmobs_parent)
return session_id
def _unserializable_default_repr(obj):
default_repr = "[Unserializable object: {}]".format(repr(obj))
log.warning("I/O object is not JSON serializable. Defaulting to placeholder value instead.")
return default_repr
def safe_json(obj, ensure_ascii=True):
if isinstance(obj, str):
return obj
try:
return json.dumps(obj, ensure_ascii=ensure_ascii, skipkeys=True, default=_unserializable_default_repr)
except Exception:
log.error("Failed to serialize object to JSON.", exc_info=True)